[
  {
    "path": ".github/CODEOWNERS",
    "content": "# https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customizing-your-repository/about-code-owners\n\n* @rusty1s @akihironitta\n\n*.py @rusty1s @wsad1 @akihironitta\n\n/.github/ @rusty1s @akihironitta\n\n/.github/CODEOWNERS @rusty1s\n\n/torch_geometric/data/ @rusty1s @mananshah99 @akihironitta\n\n/torch_geometric/loader/ @rusty1s @mananshah99 @akihironitta\n\n/torch_geometric/sampler/ @rusty1s @mananshah99 @akihironitta\n\n/docs/ @rusty1s @akihironitta\n\n/torch_geometric/nn/conv/cugraph @tingyu66\n\n/examples/llm @puririshi98\n\n/torch_geometric/llm @puririshi98\n"
  },
  {
    "path": ".github/CONTRIBUTING.md",
    "content": "# Contributing to PyG\n\nIf you are interested in contributing to PyG, your contributions will likely fall into one of the following two categories:\n\n1. You want to implement a new feature:\n   - In general, we accept any features as long as they fit the scope of this package. If you are unsure about this or need help on the design/implementation of your feature, post about it in an issue.\n1. You want to fix a bug:\n   - Feel free to send a Pull Request (PR) any time you encounter a bug. Please provide a clear and concise description of what the bug was. If you are unsure about if this is a bug at all or how to fix, post about it in an issue.\n\nOnce you finish implementing a feature or bug-fix, please send a PR to https://github.com/pyg-team/pytorch_geometric.\nYour PR will be merged after one or more rounds of reviews by the [pyg-team](https://github.com/pyg-team).\nIf your PR isn't merged anytime soon (*e.g.,* due to its large size, complexity or unavailability of reviewers), try moving your contribution to the [`torch_geometric.contrib`](https://pytorch-geometric.readthedocs.io/en/latest/modules/contrib.html) package.\n[`torch_geometric.contrib`](https://pytorch-geometric.readthedocs.io/en/latest/modules/contrib.html) has less rigourous review requirements and might lead to your PR getting merged faster.\n\n## Developing PyG\n\nTo develop PyG on your machine, here are some tips:\n\n1. Ensure that you are running on one of the two latest PyTorch releases (*e.g.*, `2.8.0`):\n\n   ```python\n   import torch\n   print(torch.__version__)\n   ```\n\n1. *(Optional)* Follow the [installation instructions](https://github.com/pyg-team/pytorch_geometric#installation) to install `pyg-lib`, `torch-scatter`, `torch-sparse`, and `torch-cluster` (if you haven't already).\n   Note that this step is optional and only necessary if you develop a feature that uses one of these libraries.\n\n   ```bash\n   pip install pyg-lib torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html\n   ```\n\n   where `${TORCH}` should be replaced by your PyTorch version (*e.g.*, `2.8.0`), and `${CUDA}` should be replaced by your CUDA version (*e.g.*, `cpu`, `cu126`, `cu128`, or `cu129`).\n\n1. Uninstall all existing PyG installations.\n   It is advised to run this command repeatedly to confirm that installations across all locations are properly removed.\n\n   ```bash\n   pip uninstall torch-geometric\n   pip uninstall torch-geometric  # run this command twice\n   ```\n\n1. Fork and clone the PyG repository:\n\n   ```bash\n   git clone https://github.com/<your_username>/pytorch_geometric\n   cd pytorch_geometric\n   ```\n\n1. If you already cloned PyG from source, update it:\n\n   ```bash\n   git pull\n   ```\n\n1. Install PyG in editable mode:\n\n   ```bash\n   pip install -e \".[dev,full]\"\n   ```\n\n   This mode will symlink the Python files from the current local source tree into the Python install.\n   Hence, if you modify a Python file, you do not need to re-install PyG again.\n\n1. Ensure that you have a working PyG installation by running the entire test suite with\n\n   ```bash\n   pytest\n   ```\n\n   In case an error occurs, please first check if all sub-packages ([`pyg-lib`](https://github.com/pyg-team/pyg-lib), [`torch-scatter`](https://github.com/rusty1s/pytorch_scatter), [`torch-sparse`](https://github.com/rusty1s/pytorch_sparse) and [`torch-cluster`](https://github.com/rusty1s/pytorch_cluster)) are on its latest reported version.\n\n1. Install pre-commit hooks:\n\n   ```bash\n    pre-commit install\n   ```\n\n## Unit Testing\n\nThe PyG testing suite is located under `test/`.\nRun the entire test suite with\n\n```bash\npytest\n```\n\nor test individual files via, *e.g.*, `pytest test/utils/test_convert.py`.\n\n## Continuous Integration\n\nPyG uses [GitHub Actions](https://github.com/pyg-team/pytorch_geometric/actions) in combination with [CodeCov](https://codecov.io/github/pyg-team/pytorch_geometric?branch=master) for continuous integration.\n\nEverytime you send a Pull Request, your commit will be built and checked against the PyG guidelines:\n\n1. Ensure that your code is formatted correctly by testing against the styleguide of [`flake8`](https://github.com/PyCQA/flake8).\n   We use the [`Flake8-pyproject`](https://pypi.org/project/Flake8-pyproject/) plugin for configuration:\n\n   ```bash\n   flake8 .\n   ```\n\n   If you do not want to format your code manually, we recommend to use [`yapf`](https://github.com/google/yapf).\n\n1. Ensure that the entire test suite passes and that code coverage roughly stays the same.\n   Please feel encouraged to provide a test with your submitted code.\n   To test, either run\n\n   ```bash\n   pytest --cov\n   ```\n\n   or\n\n   ```bash\n   FULL_TEST=1 pytest --cov\n   ```\n\n   (which runs a set of additional but time-consuming tests) dependening on your needs.\n\n1. Add your feature/bugfix to the [`CHANGELOG.md`](https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md?plain=1).\n   If multiple PRs move towards integrating a single feature, it is advised to group them together into one bullet point.\n\n## Building Documentation\n\nTo build the documentation:\n\n1. [Build and install](#developing-pyg) PyG from source.\n1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via\n   ```bash\n   pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git\n   ```\n1. Generate the documentation via:\n   ```bash\n   cd docs\n   make html\n   ```\n\nThe documentation is now available to view by opening `docs/build/html/index.html`.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug-report.yml",
    "content": "name: \"🐛 Bug Report\"\ndescription: \"Submit a report to help us reproduce and fix the bug\"\nlabels: bug\n\nbody:\n  - type: markdown\n    attributes:\n      value: >\n        #### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the existing and past issues](https://github.com/pyg-team/pytorch_geometric/issues).\n        #\n  - type: textarea\n    attributes:\n      label: 🐛 Describe the bug\n      description: |\n        Please provide a clear and concise description of the bug.\n\n        If relevant, add a minimal example so that we can reproduce the error by running the code. It is very important for the snippet to be as minimal as possible, so please take time to trim down any irrelevant code to help us debug efficiently. We are going to copy-paste your code and we expect to get the same result as you did: avoid any external data, and include the relevant imports:\n\n        ```python\n        # All necessary imports at the beginning\n        import torch\n        from torch_geometric.utils import to_undirected\n\n        # A minimal reproducing example trimmed down to the essential parts:\n        edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])\n        edge_index = to_undirected(edge_index, num_nodes=1)\n        assert edge_index.size(1) == 6  # We expect that the number of edges is doubled.\n\n        # NOTE: the bug is that num_nodes < edge_index.max() + 1\n        ```\n\n        Please also paste or describe the results you observe instead of the expected results. If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.\n      placeholder: |\n        A clear and concise description of the bug.\n\n        ```python\n        # Sample code to reproduce the problem\n        ```\n\n        ```\n        The error message you got, with the full traceback.\n        ```\n    validations:\n      required: true\n  - type: textarea\n    attributes:\n      label: Versions\n      description: |\n        Please run the following and paste the output below.\n        ```sh\n        curl -OL https://raw.githubusercontent.com/pytorch/pytorch/main/torch/utils/collect_env.py\n        # For security purposes, please check the contents of collect_env.py before running it.\n        python3 collect_env.py\n        ```\n    validations:\n      required: true\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/config.yml",
    "content": "blank_issues_enabled: false\ncontact_links:\n  - name: 🙏 Ask a Question\n    url: https://github.com/pyg-team/pytorch_geometric/discussions/new\n    about: Ask and answer PyG related questions\n  - name: 💬 Slack\n    url: https://join.slack.com/t/torchgeometricco/shared_invite/zt-p6br3yuo-BxRoe36OHHLF6jYU8xHtBA\n    about: Chat with our community\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/documentation.yml",
    "content": "name: \"📚 Typos and Doc Fixes\"\ndescription: \"Tell us about how we can improve our documentation\"\nlabels: documentation\n\nbody:\n  - type: textarea\n    attributes:\n      label: 📚 Describe the documentation issue\n      description: |\n        A clear and concise description of the issue.\n    validations:\n      required: true\n  - type: textarea\n    attributes:\n      label: Suggest a potential alternative/fix\n      description: |\n        Tell us how we could improve the documentation in this regard.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature-request.yml",
    "content": "name: \"🚀 Feature Request\"\ndescription: \"Propose a new PyG feature\"\nlabels: feature\n\nbody:\n  - type: textarea\n    attributes:\n      label: 🚀 The feature, motivation and pitch\n      description: >\n        A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *\"I'm working on X and would like Y to be possible\"*. If this is related to another GitHub issue, please link here too.\n    validations:\n      required: true\n  - type: textarea\n    attributes:\n      label: Alternatives\n      description: >\n        A description of any alternative solutions or features you've considered, if any.\n  - type: textarea\n    attributes:\n      label: Additional context\n      description: >\n        Add any other context or screenshots about the feature request.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/installation.yml",
    "content": "name: \"😵 Installation\"\ndescription: \"Report an installation problem\"\nlabels: installation\n\nbody:\n  - type: markdown\n    attributes:\n      value: >\n        #### Before submitting an installation problem, please make sure the issue hasn't been already reported by searching through [the existing and past issues](https://github.com/pyg-team/pytorch_geometric/issues).\n        #\n  - type: textarea\n    attributes:\n      label: 😵 Describe the installation problem\n      description: |\n        Please provide a clear and concise description of the installation problem. If you have installation log files, please prvoide them here as well. It may be relevant to wrap the log files in ```` ```triple quotes blocks``` ````.\n      placeholder: |\n        A clear and concise description of the installation problem.\n    validations:\n      required: true\n  - type: textarea\n    attributes:\n      label: Environment\n      description: |\n        Please provide as much information as possible about your environment, such as your PyG (`print(torch_geometric.__version__)`) and PyTorch version (`print(torch.__version__)`), your OS (*e.g.*, Linux), and your Python version (*e.g.*, `3.13`):\n      value: |\n        * PyG version:\n        * PyTorch version:\n        * OS:\n        * Python version:\n        * CUDA/cuDNN version:\n        * How you installed PyTorch and PyG (`conda`, `pip`, source):\n        * Any other relevant information (*e.g.*, version of `torch-scatter`):\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/refactor.yml",
    "content": "name: \"🛠 Refactor\"\ndescription: \"Suggest a code refactor or deprecation\"\nlabels: refactor\n\nbody:\n  - type: textarea\n    attributes:\n      label: 🛠 Proposed Refactor\n      description: |\n        A clear and concise description of the refactor proposal. Please outline the motivation for the proposal. If this is related to another GitHub issue, please link here too.\n    validations:\n      required: true\n  - type: textarea\n    attributes:\n      label: Suggest a potential alternative/fix\n      description: |\n        Tell us how we could improve the code in this regard.\n    validations:\n      required: true\n"
  },
  {
    "path": ".github/actions/setup/action.yml",
    "content": "name: Setup\n\ninputs:\n  python-version:\n    required: false\n    default: '3.10'\n  torch-version:\n    required: false\n    default: '2.10.0'\n  cuda-version:\n    required: false\n    default: cpu\n  full_install:\n    required: false\n    default: true\n\nruns:\n  using: composite\n\n  steps:\n    - name: Install uv\n      uses: astral-sh/setup-uv@v7\n      with:\n        python-version: ${{ inputs.python-version }}\n        activate-environment: true\n\n    - name: Set up Python ${{ inputs.python-version }}\n      run: |\n        uv pip install --upgrade pip setuptools\n      shell: bash\n\n    - name: Install numpy\n      run: |\n        uv pip install \"numpy<2\"\n      shell: bash\n\n    - name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }}\n      if: ${{ inputs.torch-version != 'nightly' }}\n      run: |\n        uv pip install torch==${{ inputs.torch-version }} --extra-index-url https://download.pytorch.org/whl/${{ inputs.cuda-version }}\n      shell: bash\n\n    - name: Install PyTorch ${{ inputs.torch-version }}+${{ inputs.cuda-version }}\n      if: ${{ inputs.torch-version == 'nightly' }}\n      run: |\n        uv pip install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/${{ inputs.cuda-version }}\n      shell: bash\n\n    - name: Check installation\n      run: |\n        uv run --no-project python -c \"import torch; print('PyTorch:', torch.__version__)\"\n        uv run --no-project python -c \"import torch; print('CUDA available:', torch.cuda.is_available())\"\n        uv run --no-project python -c \"import torch; print('CUDA:', torch.version.cuda)\"\n      shell: bash\n\n    - name: Install pyg-lib\n      if: ${{ inputs.torch-version != 'nightly' }}\n      run: |\n        uv pip install --no-index --upgrade pyg-lib -f https://data.pyg.org/whl/nightly/torch-${{ inputs.torch-version }}+${{ inputs.cuda-version }}.html\n      shell: bash\n\n    - name: Install faiss-cpu\n      if: ${{ inputs.full_install == 'true' && inputs.cuda-version == 'cpu' && runner.os == 'Linux' }}\n      run: |\n        uv pip install faiss-cpu==1.7.2\n      shell: bash\n\n    - name: Install faiss-gpu\n      if: ${{ inputs.full_install == 'true' && inputs.cuda-version != 'cpu' && runner.os == 'Linux' }}\n      run: |\n        uv pip install faiss-gpu==1.7.2\n      shell: bash\n\n    - name: Install torchvision\n      if: ${{ inputs.full_install == 'true' && inputs.torch-version != 'nightly' }}\n      run: |\n        if [ ${{ inputs.torch-version }} == '2.10.0' ]; then\n          TORCHVISION_VERSION=0.25.0\n        elif [ ${{ inputs.torch-version }} == '2.9.0' ]; then\n          TORCHVISION_VERSION=0.24.0\n        elif [ ${{ inputs.torch-version }} == '2.8.0' ]; then\n          TORCHVISION_VERSION=0.23.0\n        elif [ ${{ inputs.torch-version }} == '2.7.0' ]; then\n          TORCHVISION_VERSION=0.22.0\n        elif [ ${{ inputs.torch-version }} == '2.6.0' ]; then\n          TORCHVISION_VERSION=0.21.0\n        fi\n        uv pip install torchvision==${TORCHVISION_VERSION} --extra-index-url https://download.pytorch.org/whl/${{ inputs.cuda-version }}\n      shell: bash\n\n    - name: Install extension packages\n      if: ${{ inputs.full_install == 'true' && inputs.torch-version != 'nightly' }}\n      run: |\n        uv pip install scipy\n        uv pip install --no-index --upgrade torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-${{ inputs.torch-version }}+${{ inputs.cuda-version }}.html\n      shell: bash\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "# https://docs.github.com/en/code-security/dependabot/working-with-dependabot/dependabot-options-reference\nversion: 2\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directories:\n      - \"/\"\n      - \"/.github/actions/setup\"\n    schedule:\n      interval: \"daily\"\n      time: \"00:00\"\n    labels:\n      - \"ci\"\n      - \"skip-changelog\"\n    pull-request-branch-name:\n      separator: \"-\"\n    open-pull-requests-limit: 10\n"
  },
  {
    "path": ".github/labeler.yml",
    "content": "installation:\n  - changed-files:\n      - any-glob-to-any-file: [\"pyproject.toml\"]\n\nci:\n  - changed-files:\n      - any-glob-to-any-file: [\".github/**/*\", \"codecov.yaml\", \".pre-commit-config.yaml\"]\n\ndocumentation:\n  - changed-files:\n      - any-glob-to-any-file: [\"docs/**/*\", \"readthedocs.yml\", \"README.MD\"]\n\nexample:\n  - changed-files:\n      - any-glob-to-any-file: \"examples/**/*\"\n\ndata:\n  - changed-files:\n      - any-glob-to-any-file: \"torch_geometric/data/**/*\"\n\ndataset:\n  - changed-files:\n      - any-glob-to-any-file: [\"torch_geometric/io/**/*\", \"torch_geometric/datasets/**/*\"]\n\nsampler:\n  - changed-files:\n      - any-glob-to-any-file: \"torch_geometric/sampler/**/*\"\n\nloader:\n  - changed-files:\n      - any-glob-to-any-file: \"torch_geometric/loader/**/*\"\n\nnn:\n  - changed-files:\n      - any-glob-to-any-file: \"torch_geometric/nn/**/*\"\n\nexplain:\n  - changed-files:\n      - any-glob-to-any-file: \"torch_geometric/explain/**/*\"\n\ntransform:\n  - changed-files:\n      - any-glob-to-any-file: \"torch_geometric/transforms/**/*\"\n\nmetrics:\n  - changed-files:\n      - any-glob-to-any-file: \"torch_geometric/metrics/**/*\"\n\nutils:\n  - changed-files:\n      - any-glob-to-any-file: \"torch_geometric/utils/**/*\"\n\ndistributed:\n  - changed-files:\n      - any-glob-to-any-file: \"torch_geometric/distributed/**/*\"\n\ncontrib:\n  - changed-files:\n      - any-glob-to-any-file: \"torch_geometric/contrib/**/*\"\n\ngraphgym:\n  - changed-files:\n      - any-glob-to-any-file: [\"graphgym/**/*\", \"torch_geometric/graphgym/**/*\"]\n\nbenchmark:\n  - changed-files:\n      - any-glob-to-any-file: [\"benchmark/**/*\", \"torch_geometric/profile/**/*\"]\n"
  },
  {
    "path": ".github/workflows/_testing.yml",
    "content": "name: Reusable Testing\n\non:  # yamllint disable-line rule:truthy\n  workflow_call:\n    inputs:\n      test-matrix:\n        type: string\n        required: true\n\n\ndefaults:\n  run:\n    shell: bash\n\njobs:\n  test:\n    runs-on: ${{ matrix.os }}\n\n    strategy:\n      fail-fast: false\n      matrix: ${{ fromJSON(inputs.test-matrix) }}\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v6\n\n      - name: Set up packages\n        uses: ./.github/actions/setup\n        with:\n          python-version: ${{ matrix.python-version }}\n          torch-version: ${{ matrix.torch-version }}\n          cuda-version: ${{ matrix.cuda-version }}\n\n      - name: Install graphviz\n        if: ${{ runner.os == 'Linux' && matrix.full_test == '1' }}\n        run: |\n          sudo apt-get install graphviz\n\n      - name: Install main package (Windows)\n        if: ${{ runner.os == 'Windows' }}\n        run: |\n          uv pip install -e \".[test]\"\n\n      - name: Install main package\n        if: ${{ runner.os != 'Windows' }}\n        run: |\n          uv pip install -e \".[full,test]\"\n\n      - name: Check installation\n        run: |\n          python -c \"import torch; print('PyTorch:', torch.__version__)\"\n          python -c \"import torch; print('CUDA available:', torch.cuda.is_available())\"\n          python -c \"import torch; print('CUDA:', torch.version.cuda)\"\n        shell: bash\n\n      - name: Run tests\n        timeout-minutes: 20\n        run: |\n          pytest -m \"not rag\" --cov --cov-report=xml --durations 10\n        env:\n          FULL_TEST: ${{ matrix.full_test }}\n\n      - name: Upload coverage\n        if: ${{ runner.os == 'Linux' }}\n        uses: codecov/codecov-action@v5\n        with:\n          fail_ci_if_error: false\n"
  },
  {
    "path": ".github/workflows/auto-merge.yml",
    "content": "name: Auto-merge Bot PRs\n\non:  # yamllint disable-line rule:truthy\n  pull_request_target:\n    types: [opened, reopened]\n\npermissions:\n  contents: write\n  pull-requests: write\n\njobs:\n  auto-merge:\n    runs-on: ubuntu-latest\n    if: ${{ github.event.pull_request.user.login == 'dependabot[bot]' || github.event.pull_request.user.login == 'pre-commit-ci[bot]' }}\n    steps:\n      - uses: actions/checkout@v6\n\n      - name: Label bot PRs\n        run: gh pr edit --add-label \"ci,skip-changelog\" ${{ github.event.pull_request.html_url }}\n        env:\n          GITHUB_TOKEN: ${{ secrets.PAT }}\n\n      - name: Auto-approve\n        uses: hmarr/auto-approve-action@v4\n        with:\n          github-token: ${{ secrets.PAT }}\n\n      - name: Enable auto-merge\n        run: gh pr merge --auto --squash ${{ github.event.pull_request.html_url }}\n        env:\n          GITHUB_TOKEN: ${{ secrets.PAT }}\n"
  },
  {
    "path": ".github/workflows/building_nightly.yml",
    "content": "name: Nightly Build\n\non:  # yamllint disable-line rule:truthy\n  workflow_dispatch:\n  schedule:\n    - cron: \"0 6 * * *\"  # Everyday at 6:00am UTC/10:00pm PST\n\njobs:\n\n  build:\n    if: github.repository == 'pyg-team/pytorch_geometric'\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v6\n\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: '3.10'\n\n      - name: Set version\n        run: echo \"VERSION=$(sed -n \"s/^__version__ = '\\(.*\\)'/\\1/p\" torch_geometric/__init__.py)\" >> ${GITHUB_ENV}\n\n      - name: Set time\n        run: echo \"TODAY=$(date +'%Y%m%d')\" >> ${GITHUB_ENV}\n\n      - name: Customize build version\n        run: |\n          sed -i \"s/$VERSION/$VERSION.dev$TODAY/\" torch_geometric/__init__.py\n          sed -i '0,/name=\"torch-geometric\"/s//name=\"pyg-nightly\"/' pyproject.toml # Only change first occurence\n          sed -i \"s/version=\\\"$VERSION\\\"/version=\\\"$VERSION.dev$TODAY\\\"/\" pyproject.toml\n\n      - name: Build package\n        run: |\n          pip install build\n          python -m build\n\n      - name: Publish package\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          user: __token__\n          password: ${{ secrets.PYPI_API_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/changelog.yml",
    "content": "name: Changelog Enforcer\n\non:  # yamllint disable-line rule:truthy\n  pull_request:\n    types: [opened, synchronize, reopened, ready_for_review, labeled, unlabeled]\n\njobs:\n\n  changelog:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Enforce changelog entry\n        uses: dangoslen/changelog-enforcer@v3\n        with:\n          skipLabels: skip-changelog\n"
  },
  {
    "path": ".github/workflows/documentation.yml",
    "content": "name: Documentation\n\non:  # yamllint disable-line rule:truthy\n  push:\n    branches:\n      - master\n    paths:\n      - 'torch_geometric/**'\n      - 'docs/**'\n      - 'pyproject.toml'\n      - '.github/actions/setup/action.yml'\n      - '.github/workflows/documentation.yml'\n  pull_request:\n    paths:\n      - 'torch_geometric/**'\n      - 'docs/**'\n      - 'pyproject.toml'\n      - '.github/actions/setup/action.yml'\n      - '.github/workflows/documentation.yml'\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }}  # yamllint disable-line\n  # Only cancel intermediate builds if on a PR:\n  cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}\n\njobs:\n\n  make_html:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v6\n\n      - name: Setup packages\n        uses: ./.github/actions/setup\n        with:\n          full_install: false\n\n      - name: Install main package\n        run: |\n          uv pip install -e . nbsphinx git+https://github.com/pyg-team/pyg_sphinx_theme.git\n\n      - name: Build documentation\n        working-directory: ./docs\n        run: |\n          sphinx-build -M html \"source\" \"build\" -W  # Fail on warning.\n"
  },
  {
    "path": ".github/workflows/examples.yml",
    "content": "name: Examples\n\non:  # yamllint disable-line rule:truthy\n  workflow_dispatch:\n  schedule:\n    - cron: \"0 7 * * *\"  # Everyday at 7:00am UTC/11:00pm PST\n\njobs:\n\n  pytest:\n    if: github.repository == 'pyg-team/pytorch_geometric'\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v6\n\n      - name: Setup packages\n        uses: ./.github/actions/setup\n\n      - name: Install main package\n        run: |\n          uv pip install \".[benchmark]\"\n\n      - name: Run GCN on Cora\n        run: |\n          python examples/gcn.py --wandb\n        env:\n          WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}\n\n      - name: Run GAT on Cora\n        run: |\n          python examples/gat.py --wandb\n        env:\n          WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}\n\n      - name: Run GIN on MUTAG\n        run: |\n          python examples/mutag_gin.py --wandb\n        env:\n          WANDB_API_KEY: ${{ secrets.WANDB_API_KEY }}\n\n      - name: Run GNNExplainer\n        run: |\n          python examples/explain/gnn_explainer.py\n"
  },
  {
    "path": ".github/workflows/labeler.yml",
    "content": "name: PR Labeler\n\non:  # yamllint disable-line rule:truthy\n  pull_request:\n\njobs:\n\n  assign-labels:\n    if: github.repository == 'pyg-team/pytorch_geometric'\n    runs-on: ubuntu-latest\n\n    permissions:\n      contents: read\n      pull-requests: write\n\n    steps:\n      - name: Add PR labels\n        uses: actions/labeler@v6\n        continue-on-error: true\n        with:\n          repo-token: \"${{ secrets.GITHUB_TOKEN }}\"\n          sync-labels: true\n\n  assign-author:\n    if: github.repository == 'pyg-team/pytorch_geometric'\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Add PR author\n        uses: samspills/assign-pr-to-author@v1.0\n        continue-on-error: true\n        if: github.event_name == 'pull_request'\n        with:\n          repo-token: \"${{ secrets.GITHUB_TOKEN }}\"\n"
  },
  {
    "path": ".github/workflows/linting.yml",
    "content": "name: Linting\n\non:  # yamllint disable-line rule:truthy\n  push:\n    branches:\n      - master\n    paths:\n      - 'torch_geometric/**'\n      - 'test/**'\n      - 'examples/**'\n      - 'benchmark/**'\n      - 'pyproject.toml'\n      - '.github/actions/setup/action.yml'\n      - '.github/workflows/linting.yml'\n  pull_request:\n    paths:\n      - 'torch_geometric/**'\n      - 'test/**'\n      - 'examples/**'\n      - 'benchmark/**'\n      - 'pyproject.toml'\n      - '.github/actions/setup/action.yml'\n      - '.github/workflows/linting.yml'\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }}  # yamllint disable-line\n  # Only cancel intermediate builds if on a PR:\n  cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}\n\njobs:\n\n  mypy:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v6\n\n      - name: Setup packages\n        uses: ./.github/actions/setup\n\n      - name: Install main package\n        run: |\n          uv pip install -e \".[full,test]\" mypy types-requests\n\n      - name: Check type hints\n        run: |\n          mypy -v --cache-dir=/dev/null\n"
  },
  {
    "path": ".github/workflows/testing.yml",
    "content": "name: Testing\n\non:  # yamllint disable-line rule:truthy\n  push:\n    branches:\n      - master\n  pull_request:\n  schedule:\n    - cron: \"0 6 * * *\"  # Everyday at 6:00am UTC/10:00pm PST\n\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }}  # yamllint disable-line\n  # Only cancel intermediate builds if on a PR:\n  cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}\n\njobs:\n  trigger:\n    runs-on: ubuntu-latest\n    outputs:\n      triggered: ${{ steps.check.outputs.triggered }}\n    steps:\n      - uses: actions/checkout@v6\n        with:\n          fetch-depth: 2\n      - id: check\n        run: |\n          if [ \"${{ github.event_name }}\" = \"schedule\" ]; then\n            echo \"triggered=true\" >> \"$GITHUB_OUTPUT\"\n            exit 0\n          fi\n\n          if [ \"${{ github.event_name }}\" = \"push\" ]; then\n            BASE=${{ github.event.before }}\n          else\n            BASE=${{ github.event.pull_request.base.sha }}\n          fi\n\n          PATHS=(\n            'torch_geometric/'\n            'test/'\n            'pyproject.toml'\n            'codecov.yml'\n            '.github/actions/setup/action.yml'\n            '.github/workflows/testing.yml'\n            '.github/workflows/_testing.yml'\n          )\n\n          CHANGED=$(git diff --name-only \"$BASE\" HEAD)\n          for path in \"${PATHS[@]}\"; do\n            if echo \"$CHANGED\" | grep -q \"^${path}\"; then\n              echo \"triggered=true\" >> \"$GITHUB_OUTPUT\"\n              exit 0\n            fi\n          done\n\n          echo \"triggered=false\" >> \"$GITHUB_OUTPUT\"\n\n  pytest:\n    needs: trigger\n    if: ${{ needs.trigger.outputs.triggered == 'true' && github.repository == 'pyg-team/pytorch_geometric' && github.event_name != 'schedule' }}\n    uses: ./.github/workflows/_testing.yml\n    with:\n      test-matrix: |\n        { include: [\n          { \"os\": \"ubuntu-22.04\", \"python-version\": \"3.10\", \"torch-version\": \"nightly\", \"cuda-version\": \"cpu\" },\n          { \"os\": \"ubuntu-22.04\", \"python-version\": \"3.10\", \"torch-version\": \"2.10.0\", \"cuda-version\": \"cpu\" },\n          { \"os\": \"ubuntu-22.04\", \"python-version\": \"3.10\", \"torch-version\": \"2.9.0\", \"cuda-version\": \"cpu\" },\n          { \"os\": \"ubuntu-22.04\", \"python-version\": \"3.10\", \"torch-version\": \"2.8.0\", \"cuda-version\": \"cpu\" },\n          { \"os\": \"macos-14\", \"python-version\": \"3.10\", \"torch-version\": \"2.10.0\", \"cuda-version\": \"cpu\" },\n          { \"os\": \"windows-2022\", \"python-version\": \"3.10\", \"torch-version\": \"2.10.0\", \"cuda-version\": \"cpu\" }\n        ]}\n\n  pytest-full:\n    # Only run this on nightly schedule for now.\n    if: ${{ github.repository == 'pyg-team/pytorch_geometric' && github.event_name == 'schedule' }}\n    uses: ./.github/workflows/_testing.yml\n    with:\n      test-matrix: |\n        {\n          \"os\": [\"ubuntu-22.04\", \"macos-14\", \"windows-2022\"],\n          \"python-version\": [\"3.10\"],\n          \"torch-version\": [\"2.8.0\", \"2.9.0\", \"2.10.0\", \"nightly\"],\n          \"cuda-version\": [\"cpu\"],\n          \"full-test\": \"1\"\n        }\n\n  # gpu:\n  #   if: github.repository == 'pyg-team/pytorch_geometric'\n  #   uses: ./.github/workflows/_testing.yml\n  #   with:\n  #     test-matrix: |\n  #       {\n  #         \"os\": [[\"self-hosted, nvidia\"]],\n  #         \"python-version\": [\"3.10\"],\n  #         \"torch-version\": [\"2.8.0\"],\n  #         \"cuda-version\": [\"cu128\"]\n  #       }\n\n  status:\n    if: always() && github.event_name == 'pull_request'\n    needs:\n      - pytest\n      - pytest-full\n    runs-on: ubuntu-latest\n    steps:\n      - run: |\n          if [ \"${{ needs.pytest.result }}\" = \"failure\" ] || \\\n             [ \"${{ needs.pytest-full.result }}\" = \"failure\" ] || \\\n             [ \"${{ needs.pytest.result }}\" = \"cancelled\" ] || \\\n             [ \"${{ needs.pytest-full.result }}\" = \"cancelled\" ]; then\n            exit 1\n          fi\n"
  },
  {
    "path": ".github/workflows/testing_rag.yml",
    "content": "name: Testing RAG\n\non:  # yamllint disable-line rule:truthy\n  push:\n    branches:\n      - master\n    paths:\n      - 'torch_geometric/datasets/web_qsp_dataset.py'\n      - 'torch_geometric/llm/**'\n      - '.github/workflows/testing_rag.yml'\n  pull_request:\n    paths:\n      - 'torch_geometric/datasets/web_qsp_dataset.py'\n      - 'torch_geometric/llm/**'\n      - '.github/workflows/testing_rag.yml'\n\nconcurrency:\n  group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}-${{ startsWith(github.ref, 'refs/pull/') || github.run_number }}  # yamllint disable-line\n  # Only cancel intermediate builds if on a PR:\n  cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}\n\njobs:\n  rag_pytest:\n    runs-on: ubuntu-latest\n\n    steps:\n      - name: Checkout repository\n        uses: actions/checkout@v6\n\n      - name: Setup packages\n        uses: ./.github/actions/setup\n        with:\n          full_install: false\n\n      - name: Install main package\n        run: |\n          uv pip install -e \".[test,rag]\"\n\n      - name: Run tests\n        timeout-minutes: 15\n        run: |\n          # ignore mysterious segfault (139) if tests pass since this does not repro locally\n          bash -c 'pytest -m rag --cov --cov-report=xml -v; E=$?; [[ $E == 0 || $E == 139 ]] && exit 0 || exit $E'\n        shell: bash\n        env:\n          TOKENIZERS_PARALLELISM: \"false\"\n          OMP_NUM_THREADS: \"1\"\n          MKL_NUM_THREADS: \"1\"\n"
  },
  {
    "path": ".gitignore",
    "content": "__pycache__/\n.pytest_cache/\n.DS_Store\ndata/\nbuild/\ndist/\nalpha/\nruns/\nwandb/\n.cache/\n.eggs/\nlightning_logs/\noutputs/\ngraphgym/datasets/\ngraphgym/results/\n*.egg-info/\n.ipynb_checkpoints\n.coverage\n.coverage.*\ncoverage.xml\n.vscode\n.idea\n.venv\n*.out\n*.pt\n*.onnx\nexamples/**/*.png\nexamples/**/*.pdf\nbenchmark/results/\n.mypy_cache/\nuv.lock\n\n!torch_geometric/data/\n!test/data/\n"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "ci:\n  # https://pre-commit.ci/#configuration\n  autofix_prs: true\n  autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'\n  autoupdate_schedule: monthly\n\nrepos:\n  - repo: https://github.com/pre-commit/pre-commit-hooks\n    rev: v6.0.0\n    hooks:\n      - id: no-commit-to-branch\n        name: No commits to master\n      - id: end-of-file-fixer\n        name: End-of-file fixer\n      - name: mixed-line-ending\n        id: mixed-line-ending\n        args: [--fix, lf]\n      - id: trailing-whitespace\n        name: Remove trailing whitespaces\n      - id: check-toml\n        name: Check toml\n      - id: check-yaml\n        name: Check yaml\n\n  - repo: https://github.com/adrienverge/yamllint.git\n    rev: v1.38.0\n    hooks:\n      - id: yamllint\n        name: Lint yaml\n        args: [-d, '{extends: default, rules: {line-length: disable, document-start: disable, truthy: {level: error}, braces: {max-spaces-inside: 1}}}']\n\n  - repo: https://github.com/asottile/pyupgrade\n    rev: v3.21.2\n    hooks:\n      - id: pyupgrade\n        name: Upgrade Python syntax\n        args: [--py38-plus]\n\n  - repo: https://github.com/PyCQA/autoflake\n    rev: v2.3.3\n    hooks:\n      - id: autoflake\n        name: Remove unused imports and variables\n        args: [\n          --remove-all-unused-imports,\n          --remove-unused-variables,\n          --remove-duplicate-keys,\n          --ignore-init-module-imports,\n          --in-place,\n        ]\n\n  - repo: https://github.com/google/yapf\n    rev: v0.43.0\n    hooks:\n      - id: yapf\n        name: Format code\n        additional_dependencies: [toml]\n\n  - repo: https://github.com/pycqa/isort\n    rev: 8.0.1\n    hooks:\n      - id: isort\n        name: Sort imports\n\n  - repo: https://github.com/PyCQA/flake8\n    rev: 7.3.0\n    hooks:\n      - id: flake8\n        name: Check PEP8\n        additional_dependencies: [Flake8-pyproject]\n\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.15.4\n    hooks:\n      - id: ruff\n        name: Ruff formatting\n        args: [--fix, --exit-non-zero-on-fix]\n\n  - repo: https://github.com/executablebooks/mdformat\n    rev: 1.0.0\n    hooks:\n      - id: mdformat\n        name: Format Markdown\n        additional_dependencies:\n          - mdformat-gfm\n          - mdformat-front-matters\n          - mdformat-footnote\n\n  - repo: https://github.com/sphinx-contrib/sphinx-lint\n    rev: v1.0.2\n    hooks:\n      - id: sphinx-lint\n        name: Check Sphinx\n"
  },
  {
    "path": "CHANGELOG.md",
    "content": "# Changelog\n\nAll notable changes to this project will be documented in this file.\nThe format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).\n\n## [Unreleased] - YYYY-MM-DD\n\n### Added\n\n### Changed\n\n- Dropped support for TorchScript in `GATConv` and `GATv2Conv` for correctness ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596))\n\n### Deprecated\n\n- Deprecated support for `torch-spline-conv` in favor of `pyg-lib>=0.6.0` ([#10622](https://github.com/pyg-team/pytorch_geometric/pull/10622))\n\n### Removed\n\n### Fixed\n\n- Fixed `return_attention_weights: bool` being not respected in `GATConv` and `GATv2Conv` ([#10596](https://github.com/pyg-team/pytorch_geometric/pull/10596))\n- Fixed download links for politifact and gossipcop datasets of `UPFD` ([#10558](https://github.com/pyg-team/pytorch_geometric/pull/10558))\n\n### Security\n\n## [2.7.0] - 2025-10-14\n\n### Added\n\n- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918))\n- Added `torch_geometric.llm` and its examples ([#10436](https://github.com/pyg-team/pytorch_geometric/pull/10436))\n- Added support for negative weights in `sparse_cross_entropy` ([#10432](https://github.com/pyg-team/pytorch_geometric/pull/10432))\n- Added `connected_components()` method to `Data` and `HeterData` ([#10388](https://github.com/pyg-team/pytorch_geometric/pull/10388))\n- Added LPFormer Graph Transformer for Link Prediction ([#9956](https://github.com/pyg-team/pytorch_geometric/pull/9956))\n- Added `BidirectionalSampler`, which samples both forwards and backwards on graph edges ([#10126](https://github.com/pyg-team/pytorch_geometric/pull/10126))\n- Enable Sampling both forwards and reverse edges on `NeighborSampler` ([#10126](https://github.com/pyg-team/pytorch_geometric/pull/10126))\n- Added ability to merge together `SamplerOutput` objects ([#10126](https://github.com/pyg-team/pytorch_geometric/pull/10126))\n- Added ability to get global row and col ids from `SamplerOutput` ([#10200](https://github.com/pyg-team/pytorch_geometric/pull/10200))\n- Added PyTorch 2.8 support ([#10403](https://github.com/pyg-team/pytorch_geometric/pull/10403))\n- Added `Polynormer` model and example ([#9908](https://github.com/pyg-team/pytorch_geometric/pull/9908))\n- Added `ProteinMPNN` model and example ([#10289](https://github.com/pyg-team/pytorch_geometric/pull/10289))\n- Added the `Teeth3DS` dataset, an extended benchmark for intraoral 3D scan analysis ([#9833](https://github.com/pyg-team/pytorch_geometric/pull/9833))\n- Added `torch.device` to `PatchTransformerAggregation` [#10342](https://github.com/pyg-team/pytorch_geometric/pull/10342)\n- Added `torch.device` to normalization layers [#10341](https://github.com/pyg-team/pytorch_geometric/pull/10341)\n- Added `total_influence` for quantifying long-range dependency ([#10263](https://github.com/pyg-team/pytorch_geometric/pull/10263))\n- Added `MedShapeNet` Dataset ([#9823](https://github.com/pyg-team/pytorch_geometric/pull/9823))\n- Added RelBench example ([#10230](https://github.com/pyg-team/pytorch_geometric/pull/10230))\n- Added `CityNetwork` dataset ([#10115](https://github.com/pyg-team/pytorch_geometric/pull/10115))\n- Added `visualize_graph` to HeteroExplanation ([#10207](https://github.com/pyg-team/pytorch_geometric/pull/10207))\n- Added PyTorch 2.6 support ([#10170](https://github.com/pyg-team/pytorch_geometric/pull/10170))\n- Added support for heterogenous graphs in `AttentionExplainer` ([#10169](https://github.com/pyg-team/pytorch_geometric/pull/10169))\n- Added support for heterogenous graphs in `PGExplainer` ([#10168](https://github.com/pyg-team/pytorch_geometric/pull/10168))\n- Added support for heterogenous graphs in `GNNExplainer` ([#10158](https://github.com/pyg-team/pytorch_geometric/pull/10158))\n- Added Graph Positional and Structural Encoder (GPSE) and example ([#9018](https://github.com/pyg-team/pytorch_geometric/pull/9018)) ([#10118](https://github.com/pyg-team/pytorch_geometric/pull/10118))\n- Added attract-repel link prediction example ([#10107](https://github.com/pyg-team/pytorch_geometric/pull/10107))\n- Added `ARLinkPredictor` for implementing Attract-Repel embeddings for link prediction ([#10105](https://github.com/pyg-team/pytorch_geometric/pull/10105))\n- Improving documentation for [cuGraph](https://github.com/rapidsai/cugraph) ([#10083](https://github.com/pyg-team/pytorch_geometric/pull/10083))\n- Added `HashTensor` ([#10072](https://github.com/pyg-team/pytorch_geometric/pull/10072))\n- Added `SGFormer` model and example ([#9904](https://github.com/pyg-team/pytorch_geometric/pull/9904))\n- Added `AveragePopularity` metric for link prediction ([#10022](https://github.com/pyg-team/pytorch_geometric/pull/10022))\n- Added `Personalization` metric for link prediction ([#10015](https://github.com/pyg-team/pytorch_geometric/pull/10015))\n- Added `HitRatio` metric for link prediction ([#10013](https://github.com/pyg-team/pytorch_geometric/pull/10013))\n- Added Data Splitting Tutorial ([#8366](https://github.com/pyg-team/pytorch_geometric/pull/8366))\n- Added `Diversity` metric for link prediction ([#10009](https://github.com/pyg-team/pytorch_geometric/pull/10009))\n- Added `Coverage` metric for link prediction ([#10006](https://github.com/pyg-team/pytorch_geometric/pull/10006))\n- Added Graph Transformer Tutorial ([#8144](https://github.com/pyg-team/pytorch_geometric/pull/8144))\n- Consolidate Cugraph examples into `ogbn_train_cugraph.py` and `ogbn_train_cugraph_multigpu.py` for `ogbn-arxiv`, `ogbn-products` and `ogbn-papers100M` ([#9953](https://github.com/pyg-team/pytorch_geometric/pull/9953))\n- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975))\n- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))\n- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))\n- Added `LinkPredMetricCollection` ([#9941](https://github.com/pyg-team/pytorch_geometric/pull/9941))\n- Added various `GRetriever` architecture benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))\n- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))\n- Added `loader.RagQueryLoader` with Remote Backend Example ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))\n- Added `data.LargeGraphIndexer` ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))\n- Added `GIT-Mol` ([#9730](https://github.com/pyg-team/pytorch_geometric/pull/9730))\n- Added comment in `g_retriever.py` pointing to `Neo4j` Graph DB integration demo ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9797))\n- Added `MoleculeGPT` example ([#9710](https://github.com/pyg-team/pytorch_geometric/pull/9710))\n- Added `nn.models.GLEM` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662))\n- Added `TAGDataset` ([#9662](https://github.com/pyg-team/pytorch_geometric/pull/9662))\n- Added support for fast `Delaunay()` triangulation via the `torch_delaunay` package ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9748))\n- Added PyTorch 2.5 support ([#9779](https://github.com/pyg-team/pytorch_geometric/pull/9779), [#9779](https://github.com/pyg-team/pytorch_geometric/pull/9780))\n- Support 3D tetrahedral mesh elements of shape `[4, num_faces]` in the `FaceToEdge` transformation ([#9776](https://github.com/pyg-team/pytorch_geometric/pull/9776))\n- Added the `use_pcst` option to `WebQSPDataset` ([#9722](https://github.com/pyg-team/pytorch_geometric/pull/9722))\n- Allowed users to pass `edge_weight` to `GraphUNet` models ([#9737](https://github.com/pyg-team/pytorch_geometric/pull/9737))\n- Consolidated `examples/ogbn_{papers_100m,products_gat,products_sage}.py` into `examples/ogbn_train.py` ([#9467](https://github.com/pyg-team/pytorch_geometric/pull/9467))\n- Add ComplexWebQuestions (CWQ) dataset ([#9950](https://github.com/pyg-team/pytorch_geometric/pull/9950))\n\n### Changed\n\n- Added `edge_attr` in `CuGraphGATConv` ([#10383](https://github.com/pyg-team/pytorch_geometric/pull/10383))\n- Adapt `dgcnn_classification` example to work with `ModelNet` and `MedShapeNet` Datasets ([#9823](https://github.com/pyg-team/pytorch_geometric/pull/9823))\n- Chained exceptions explicitly instead of implicitly ([#10242](https://github.com/pyg-team/pytorch_geometric/pull/10242))\n- Updated cuGraph examples to use buffered sampling which keeps data in memory and is significantly faster than the deprecated buffered sampling ([#10079](https://github.com/pyg-team/pytorch_geometric/pull/10079))\n- Updated Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))\n- Dropped Python 3.8 support ([#9696](https://github.com/pyg-team/pytorch_geometric/pull/9606))\n- Added a check that confirms that custom edge types of `NumNeighbors` actually exist in the graph ([#9807](https://github.com/pyg-team/pytorch_geometric/pull/9807))\n- Automatic num_params in LLM + update `GRetriever` default llm ([#9938](https://github.com/pyg-team/pytorch_geometric/pull/9938))\n- Updated calls to NumPy's deprecated `np.in1d` to `np.isin` ([#10283](https://github.com/pyg-team/pytorch_geometric/pull/10283))\n\n### Deprecated\n\n- Deprecated `torch_geometric.distributed` ([#10411](https://github.com/pyg-team/pytorch_geometric/pull/10411))\n\n### Fixed\n\n- Fixed `ogbn_train_cugraph` example for distributed cuGraph ([#10439](https://github.com/pyg-team/pytorch_geometric/pull/10439))\n- Added `safe_onnx_export` function with workarounds for `onnx_ir.serde.SerdeError` issues in ONNX export ([#10422](https://github.com/pyg-team/pytorch_geometric/pull/10422))\n- Fixed importing PyTorch Lightning in `torch_geometric.graphgym` and `torch_geometric.data.lightning` when using `lightning` instead of `pytorch-lightning` ([#10404](https://github.com/pyg-team/pytorch_geometric/pull/10404), [#10417](https://github.com/pyg-team/pytorch_geometric/pull/10417)))\n- Fixed `detach()` warnings in example scripts involving tensor conversions ([#10357](https://github.com/pyg-team/pytorch_geometric/pull/10357))\n- Fixed non-tuple indexing to resolve PyTorch deprecation warning ([#10389](https://github.com/pyg-team/pytorch_geometric/pull/10389))\n- Fixed conversion to/from `cuGraph` graph objects by ensuring `cudf` column names are correctly specified ([#10343](https://github.com/pyg-team/pytorch_geometric/pull/10343))\n- Fixed `_recursive_config()` for `torch.nn.ModuleList` and `torch.nn.ModuleDict` ([#10124](https://github.com/pyg-team/pytorch_geometric/pull/10124), [#10129](https://github.com/pyg-team/pytorch_geometric/pull/10129))\n- Fixed the `k_hop_subgraph()` method for directed graphs ([#9756](https://github.com/pyg-team/pytorch_geometric/pull/9756))\n- Fixed `utils.group_cat` concatenating dimension ([#9766](https://github.com/pyg-team/pytorch_geometric/pull/9766))\n- Fixed `WebQSDataset.process` raising exceptions ([#9665](https://github.com/pyg-team/pytorch_geometric/pull/9665))\n- Fixed `is_node_attr()` and `is_edge_attr()` errors when `cat_dim` is a tuple ([#9895](https://github.com/pyg-team/pytorch_geometric/issues/9895))\n- Avoid GRetriever instantiation when num_gnn_layers == 0 ([#10156](https://github.com/pyg-team/pytorch_geometric/pull/10156))\n\n### Removed\n\n- Removed `proxies` and `resume_download` arguments from `PyGModelHubMixin` ([#10521](https://github.com/pyg-team/pytorch_geometric/pull/10521)\n- Dropped support for Python 3.9 ([#10461](https://github.com/pyg-team/pytorch_geometric/pull/10461))\n- Dropped support for PyTorch 1.13 - 2.5 ([#00000](https://github.com/pyg-team/pytorch_geometric/pull/00000))\n- Dropped support for PyTorch 1.12 ([#10248](https://github.com/pyg-team/pytorch_geometric/pull/10248))\n- Dropped support for PyTorch 1.11 ([#10247](https://github.com/pyg-team/pytorch_geometric/pull/10247))\n\n## [2.6.0] - 2024-09-13\n\n### Added\n\n- Added the `WebQSPDataset` dataset ([#9481](https://github.com/pyg-team/pytorch_geometric/pull/9481))\n- Added the `GRetriever` model and an example ([#9480](https://github.com/pyg-team/pytorch_geometric/pull/9480), [#9167](https://github.com/pyg-team/pytorch_geometric/pull/9167))\n- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))\n- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))\n- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))\n- Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554))\n- Added the `RemoveSelfLoops` transformation ([#9562](https://github.com/pyg-team/pytorch_geometric/pull/9562))\n- Added ONNX export for `scatter` with min/max reductions ([#9587](https://github.com/pyg-team/pytorch_geometric/pull/9587))\n- Added a `residual` option in `GATConv` and `GATv2Conv` ([#9515](https://github.com/pyg-team/pytorch_geometric/pull/9515))\n- Added the `PatchTransformerAggregation` layer ([#9487](https://github.com/pyg-team/pytorch_geometric/pull/9487))\n- Added the `nn.nlp.LLM` model ([#9462](https://github.com/pyg-team/pytorch_geometric/pull/9462))\n- Added an example of training GNNs for a graph-level regression task ([#9070](https://github.com/pyg-team/pytorch_geometric/pull/9070))\n- Added `utils.from_rdmol`/`utils.to_rdmol` functionality ([#9452](https://github.com/pyg-team/pytorch_geometric/pull/9452))\n- Added the `OPFDataset` ([#9379](https://github.com/pyg-team/pytorch_geometric/pull/9379))\n- Added the heterogeneous `HeteroJumpingKnowledge` module ([#9380](https://github.com/pyg-team/pytorch_geometric/pull/9380))\n- Started work on GNN+LLM package ([#9350](https://github.com/pyg-team/pytorch_geometric/pull/9350))\n- Added support for negative sampling in `LinkLoader` acccording to source and destination node weights ([#9316](https://github.com/pyg-team/pytorch_geometric/pull/9316))\n- Added support for `EdgeIndex.unbind` ([#9298](https://github.com/pyg-team/pytorch_geometric/pull/9298))\n- Integrate `torch_geometric.Index` into `torch_geometric.EdgeIndex` ([#9296](https://github.com/pyg-team/pytorch_geometric/pull/9296))\n- Support `EdgeIndex.sparse_narrow` for non-sorted edge indices ([#9291](https://github.com/pyg-team/pytorch_geometric/pull/9291))\n- Added `torch_geometric.Index` ([#9276](https://github.com/pyg-team/pytorch_geometric/pull/9276), [#9277](https://github.com/pyg-team/pytorch_geometric/pull/9277), [#9278](https://github.com/pyg-team/pytorch_geometric/pull/9278), [#9279](https://github.com/pyg-team/pytorch_geometric/pull/9279), [#9280](https://github.com/pyg-team/pytorch_geometric/pull/9280), [#9281](https://github.com/pyg-team/pytorch_geometric/pull/9281), [#9284](https://github.com/pyg-team/pytorch_geometric/pull/9284), [#9285](https://github.com/pyg-team/pytorch_geometric/pull/9285), [#9286](https://github.com/pyg-team/pytorch_geometric/pull/9286), [#9287](https://github.com/pyg-team/pytorch_geometric/pull/9287), [#9288](https://github.com/pyg-team/pytorch_geometric/pull/9288), [#9289](https://github.com/pyg-team/pytorch_geometric/pull/9289), [#9297](https://github.com/pyg-team/pytorch_geometric/pull/9297))\n- Added support for PyTorch 2.3 ([#9240](https://github.com/pyg-team/pytorch_geometric/pull/9240))\n- Added support for `EdgeIndex` in `message_and_aggregate` ([#9131](https://github.com/pyg-team/pytorch_geometric/pull/9131))\n- Added `CornellTemporalHyperGraphDataset` ([#9090](https://github.com/pyg-team/pytorch_geometric/pull/9090))\n- Added support for cuGraph data loading and `GAT` in single node Papers100m examples ([#8173](https://github.com/pyg-team/pytorch_geometric/pull/8173))\n- Added the `VariancePreservingAggregation` (VPA) ([#9075](https://github.com/pyg-team/pytorch_geometric/pull/9075))\n- Added option to pass custom` from_smiles` functionality to `PCQM4Mv2` and `MoleculeNet` ([#9073](https://github.com/pyg-team/pytorch_geometric/pull/9073))\n- Added `group_cat` functionality ([#9029](https://github.com/pyg-team/pytorch_geometric/pull/9029))\n- Added support for `EdgeIndex` in `spmm` ([#9026](https://github.com/pyg-team/pytorch_geometric/pull/9026))\n- Added option to pre-allocate memory in GPU-based `ApproxKNN` ([#9046](https://github.com/pyg-team/pytorch_geometric/pull/9046))\n- Added support for `EdgeIndex` in `MessagePassing` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))\n- Added support for `torch.compile` in combination with `EdgeIndex` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))\n- Added a `ogbn-mag240m` example ([#8249](https://github.com/pyg-team/pytorch_geometric/pull/8249))\n- Added `EdgeIndex.sparse_resize_` functionality ([#8983](https://github.com/pyg-team/pytorch_geometric/pull/8983))\n- Added approximate `faiss`-based KNN-search ([#8952](https://github.com/pyg-team/pytorch_geometric/pull/8952))\n- Added documentation on environment setup on XPU device ([#9407](https://github.com/pyg-team/pytorch_geometric/pull/9407))\n\n### Changed\n\n- Add args to Taobao multi-GPU example and move item-item compute to dataset ([#9550](https://github.com/pyg-team/pytorch_geometric/pull/9550))\n- Use `torch.load(weights_only=True)` by default ([#9618](https://github.com/pyg-team/pytorch_geometric/pull/9618))\n- Adapt `cugraph` examples to its new API ([#9541](https://github.com/pyg-team/pytorch_geometric/pull/9541))\n- Allow optional but untyped tensors in `MessagePassing` ([#9494](https://github.com/pyg-team/pytorch_geometric/pull/9494))\n- Added support for modifying `filename` of the stored partitioned file in `ClusterLoader` ([#9448](https://github.com/pyg-team/pytorch_geometric/pull/9448))\n- Support other than two-dimensional inputs in `AttentionalAggregation` ([#9433](https://github.com/pyg-team/pytorch_geometric/pull/9433))\n- Improved model performance of the `examples/ogbn_papers_100m.py` script ([#9386](https://github.com/pyg-team/pytorch_geometric/pull/9386), [#9445](https://github.com/pyg-team/pytorch_geometric/pull/9445))\n- Added the `fmt` arg to `Dataset.get_summary` ([#9408](https://github.com/pyg-team/pytorch_geometric/pull/9408))\n- Skipped zero atom molecules in `MoleculeNet` ([#9318](https://github.com/pyg-team/pytorch_geometric/pull/9318))\n- Ensure proper parallelism in `OnDiskDataset` for multi-threaded `get` calls ([#9140](https://github.com/pyg-team/pytorch_geometric/pull/9140))\n- Allow `None` outputs in `FeatureStore.get_tensor()` - `KeyError` should now be raised based on the implementation in `FeatureStore._get_tensor()` ([#9102](https://github.com/pyg-team/pytorch_geometric/pull/9102))\n- Allow mini-batching of uncoalesced sparse matrices ([#9099](https://github.com/pyg-team/pytorch_geometric/pull/9099))\n- Improvements to multi-node `ogbn-papers100m` default hyperparameters and adding evaluation on all ranks ([#8823](https://github.com/pyg-team/pytorch_geometric/pull/8823))\n- Changed distributed sampler and loader tests to correctly report failures in subprocesses to `pytest` ([#8978](https://github.com/pyg-team/pytorch_geometric/pull/8978))\n- Remove filtering of node/edge types in `trim_to_layer` functionality ([#9021](https://github.com/pyg-team/pytorch_geometric/pull/9021))\n- Default to `scatter` operations in `MessagePassing` in case `torch.use_deterministic_algorithms` is not set ([#9009](https://github.com/pyg-team/pytorch_geometric/pull/9009))\n- Made `MessagePassing` interface thread-safe ([#9001](https://github.com/pyg-team/pytorch_geometric/pull/9001))\n- Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937))\n- Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918))\n- Added XPU support to basic GNN examples ([#9421](https://github.com/pyg-team/pytorch_geometric/pull/9421), [#9439](https://github.com/pyg-team/pytorch_geometric/pull/9439))\n\n### Deprecated\n\n### Fixed\n\n- Fixed `VirtualNode` transform for empty edge indices ([#9605](https://github.com/pyg-team/pytorch_geometric/pull/9605))\n- Fixed an issue where import order in the multi-GPU `cugraph` example could cause an `rmm` error ([#9577](https://github.com/pyg-team/pytorch_geometric/pull/9577))\n- Made the output of the single-GPU `cugraph` example more readable ([#9577](https://github.com/pyg-team/pytorch_geometric/pull/9577))\n- Fixed `load_state_dict` behavior with lazy parameters in `HeteroDictLinear` ([#9493](https://github.com/pyg-team/pytorch_geometric/pull/9493))\n- `Sequential` can now be properly pickled ([#9369](https://github.com/pyg-team/pytorch_geometric/pull/9369))\n- Fixed `pickle.load` for jittable `MessagePassing` modules ([#9368](https://github.com/pyg-team/pytorch_geometric/pull/9368))\n- Fixed batching of sparse tensors saved via `data.edge_index` ([#9317](https://github.com/pyg-team/pytorch_geometric/pull/9317))\n- Fixed arbitrary keyword ordering in `MessagePassing.propgate` ([#9245](https://github.com/pyg-team/pytorch_geometric/pull/9245))\n- Fixed node mapping bug in `RCDD` dataset ([#9234](https://github.com/pyg-team/pytorch_geometric/pull/9234))\n- Fixed incorrect treatment of `edge_label` and `edge_label_index` in `ToSparseTensor` transform ([#9199](https://github.com/pyg-team/pytorch_geometric/pull/9199))\n- Fixed `EgoData` processing in `SnapDataset` in case filenames are unsorted ([#9195](https://github.com/pyg-team/pytorch_geometric/pull/9195))\n- Fixed empty graph and isolated node handling in `to_dgl` ([#9188](https://github.com/pyg-team/pytorch_geometric/pull/9188))\n- Fixed bug in `to_scipy_sparse_matrix` when cuda is set as default torch device ([#9146](https://github.com/pyg-team/pytorch_geometric/pull/9146))\n- Fixed `MetaPath2Vec` in case the last node is isolated ([#9145](https://github.com/pyg-team/pytorch_geometric/pull/9145))\n- Ensure backward compatibility in `MessagePassing` via `torch.load` ([#9105](https://github.com/pyg-team/pytorch_geometric/pull/9105))\n- Prevent model compilation on custom `propagate` functions ([#9079](https://github.com/pyg-team/pytorch_geometric/pull/9079))\n- Ignore `self.propagate` appearances in comments when parsing `MessagePassing` implementation ([#9044](https://github.com/pyg-team/pytorch_geometric/pull/9044))\n- Fixed `OSError` on read-only file systems within `MessagePassing` ([#9032](https://github.com/pyg-team/pytorch_geometric/pull/9032))\n- Fixed metaclass conflict in `Dataset` ([#8999](https://github.com/pyg-team/pytorch_geometric/pull/8999))\n- Fixed import errors on `MessagePassing` modules with nested inheritance ([#8973](https://github.com/pyg-team/pytorch_geometric/pull/8973))\n- Fixed bug in multi XPU training ([#9456](https://github.com/pyg-team/pytorch_geometric/pull/9456))\n- Fixed TorchScript compilation error for `MessagePassing._check_input` on older torch versions ([#9564](https://github.com/pyg-team/pytorch_geometric/pull/9564))\n\n### Removed\n\n## [2.5.0] - 2024-02-16\n\n### Added\n\n- Added an example for recommender systems, including k-NN search and retrieval metrics ([#8546](https://github.com/pyg-team/pytorch_geometric/pull/8546))\n- Added multi-GPU evaluation in distributed sampling example ([#8880](https://github.com/pyg-team/pytorch_geometric/pull/8880))\n- Added end-to-end example for distributed CPU training ([#8713](https://github.com/pyg-team/pytorch_geometric/pull/8713))\n- Added PyTorch 2.2 support ([#8857](https://github.com/pyg-team/pyg-lib/pull/8857))\n- Added fallback code path for `segment` in case `torch-scatter` is not installed ([#8852](https://github.com/pyg-team/pytorch_geometric/pull/8852))\n- Added support for custom node labels in `visualize_graph()` ([#8816](https://github.com/pyg-team/pytorch_geometric/pull/8816))\n- Added support for graph partitioning for temporal data in `torch_geometric.distributed` ([#8718](https://github.com/pyg-team/pytorch_geometric/pull/8718), [#8815](https://github.com/pyg-team/pytorch_geometric/pull/8815), [#8874](https://github.com/pyg-team/pytorch_geometric/pull/8874))\n- Added `TreeGraph` and `GridMotif` generators ([#8736](https://github.com/pyg-team/pytorch_geometric/pull/8736))\n- Added two examples for edge-level temporal sampling on a heterogenous graph, with and without distributed training ([#8383](https://github.com/pyg-team/pytorch_geometric/pull/8383), [#8820](https://github.com/pyg-team/pytorch_geometric/pull/8820))\n- Added the `num_graphs` option to the `StochasticBlockModelDataset` ([#8648](https://github.com/pyg-team/pytorch_geometric/pull/8648))\n- Added noise scheduler utility for diffusion based graph generative models ([#8347](https://github.com/pyg-team/pytorch_geometric/pull/8347))\n- Added the equivariant `ViSNet` model ([#8287](https://github.com/pyg-team/pytorch_geometric/pull/8287))\n- Added temporal-related capabilities to `Data` ([#8454](https://github.com/pyg-team/pytorch_geometric/pull/8454))\n- Added support for returning multi graphs in `to_networkx` ([#8575](https://github.com/pyg-team/pytorch_geometric/pull/8575))\n- Added support for XPU device in `profileit` decorator ([#8532](https://github.com/pyg-team/pytorch_geometric/pull/8532))\n- Added `KNNIndex` exclusion logic ([#8573](https://github.com/pyg-team/pytorch_geometric/pull/8573))\n- Added warning when calling `dataset.num_classes` on regression problems ([#8550](https://github.com/pyg-team/pytorch_geometric/pull/8550))\n- Added relabel node functionality to `dropout_node` ([#8524](https://github.com/pyg-team/pytorch_geometric/pull/8524))\n- Added support for type checking via `mypy` ([#8254](https://github.com/pyg-team/pytorch_geometric/pull/8254))\n- Added support for link-prediction retrieval metrics ([#8499](https://github.com/pyg-team/pytorch_geometric/pull/8499), [#8326](https://github.com/pyg-team/pytorch_geometric/pull/8326), [#8566](https://github.com/pyg-team/pytorch_geometric/pull/8566), [#8647](https://github.com/pyg-team/pytorch_geometric/pull/8647))\n- Added METIS partitioning with CSC/CSR format selection in `ClusterData` ([#8438](https://github.com/pyg-team/pytorch_geometric/pull/8438))\n- Added `is_torch_instance` to check against the original class of compiled models ([#8461](https://github.com/pyg-team/pytorch_geometric/pull/8461))\n- Added dense computation for `AddRandomWalkPE` ([#8431](https://github.com/pyg-team/pytorch_geometric/pull/8431))\n- Added a tutorial for point cloud processing ([#8015](https://github.com/pyg-team/pytorch_geometric/pull/8015))\n- Added `fsspec` as file system backend ([#8379](https://github.com/pyg-team/pytorch_geometric/pull/8379), [#8426](https://github.com/pyg-team/pytorch_geometric/pull/8426), [#8434](https://github.com/pyg-team/pytorch_geometric/pull/8434), [#8474](https://github.com/pyg-team/pytorch_geometric/pull/8474))\n- Added support for floating-point average degree numbers in `FakeDataset` and `FakeHeteroDataset` ([#8404](https://github.com/pyg-team/pytorch_geometric/pull/8404))\n- Added support for device conversions of `InMemoryDataset` ([#8402](https://github.com/pyg-team/pytorch_geometric/pull/8402))\n- Added support for edge-level temporal sampling in `NeighborLoader` and `LinkNeighborLoader` ([#8372](https://github.com/pyg-team/pytorch_geometric/pull/8372), [#8428](https://github.com/pyg-team/pytorch_geometric/pull/8428))\n- Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363))\n- Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357), [#8436](https://github.com/pyg-team/pytorch_geometric/pull/8436))\n- Added support for `torch.compile` in `MultiAggregation` ([#8345](https://github.com/pyg-team/pytorch_geometric/pull/8345))\n- Added support for `torch.compile` in `HeteroConv` ([#8344](https://github.com/pyg-team/pytorch_geometric/pull/8344))\n- Added support for weighted `sparse_cross_entropy` ([#8340](https://github.com/pyg-team/pytorch_geometric/pull/8340))\n- Added a multi GPU training benchmarks for CUDA and XPU devices ([#8288](https://github.com/pyg-team/pytorch_geometric/pull/8288), [#8382](https://github.com/pyg-team/pytorch_geometric/pull/8382), [#8386](https://github.com/pyg-team/pytorch_geometric/pull/8386))\n- Support MRR computation in `KGEModel.test()` ([#8298](https://github.com/pyg-team/pytorch_geometric/pull/8298))\n- Added an example for model parallelism (`examples/multi_gpu/model_parallel.py`) ([#8309](https://github.com/pyg-team/pytorch_geometric/pull/8309))\n- Added a tutorial for multi-node multi-GPU training with pure PyTorch ([#8071](https://github.com/pyg-team/pytorch_geometric/pull/8071))\n- Added a multinode-multigpu example on `ogbn-papers100M` ([#8070](https://github.com/pyg-team/pytorch_geometric/pull/8070))\n- Added support for `to_hetero_with_bases` on static graphs ([#8247](https://github.com/pyg-team/pytorch_geometric/pull/8247))\n- Added the `RCDD` dataset ([#8196](https://github.com/pyg-team/pytorch_geometric/pull/8196))\n- Added distributed `GAT + ogbn-products` example targeting XPU device ([#8032](https://github.com/pyg-team/pytorch_geometric/pull/8032))\n- Added the option to skip explanations of certain message passing layers via `conv.explain = False` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216))\n\n### Changed\n\n- Changed the default inference mode for `use_segment_matmul` based on benchmarking (from a heuristic-based version) ([#8615](https://github.com/pyg-team/pytorch_geometric/pull/8615))\n- Return an empty tensor from `utils.group_argsort` if its input tensor is empty ([#8752](https://github.com/pyg-team/pytorch_geometric/pull/8752))\n- GNN layers are now jittable by default ([#8745](https://github.com/pyg-team/pytorch_geometric/pull/8745))\n- Sparse node features in `NELL` and `AttributedGraphDataset` are now represented as `torch.sparse_csr_tensor` instead of `torch_sparse.SparseTensor` ([#8679](https://github.com/pyg-team/pytorch_geometric/pull/8679))\n- Accelerated mini-batching of `torch.sparse` tensors ([#8670](https://github.com/pyg-team/pytorch_geometric/pull/8670))\n- Fixed RPC timeout due to worker closing in `DistLoader` with `atexit` not executed correctly in `worker_init_fn` ([#8605](https://github.com/pyg-team/pytorch_geometric/pull/8605))\n- `ExplainerDataset` will now contain node labels for any motif generator ([#8519](https://github.com/pyg-team/pytorch_geometric/pull/8519))\n- Made `utils.softmax` faster via `softmax_csr` ([#8399](https://github.com/pyg-team/pytorch_geometric/pull/8399))\n- Made `utils.mask.mask_select` faster ([#8369](https://github.com/pyg-team/pytorch_geometric/pull/8369))\n- Update `DistNeighborSampler` ([#8209](https://github.com/pyg-team/pytorch_geometric/pull/8209), [#8367](https://github.com/pyg-team/pytorch_geometric/pull/8367), [#8375](https://github.com/pyg-team/pytorch_geometric/pull/8375), ([#8624](https://github.com/pyg-team/pytorch_geometric/pull/8624), [#8722](https://github.com/pyg-team/pytorch_geometric/pull/8722))\n- Update `GraphStore` and `FeatureStore` to support distributed training ([#8083](https://github.com/pyg-team/pytorch_geometric/pull/8083))\n- Disallow the usage of `add_self_loops=True` in `GCNConv(normalize=False)` ([#8210](https://github.com/pyg-team/pytorch_geometric/pull/8210))\n- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220))\n\n### Deprecated\n\n- Deprecated `MessagePassing.jittable` ([#8781](https://github.com/pyg-team/pytorch_geometric/pull/8781))\n- Deprecated the usage of `torch_geometric.compile`; Use `torch.compile` instead ([#8780](https://github.com/pyg-team/pytorch_geometric/pull/8780))\n- Deprecated the `typing` argument in `MessagePassing.jittable()` ([#8731](https://github.com/pyg-team/pytorch_geometric/pull/8731))\n- Deprecated `torch_geometric.data.makedirs` in favor of `os.makedirs` ([#8421](https://github.com/pyg-team/pytorch_geometric/pull/8421))\n- Deprecated `DataParallel` in favor of `DistributedDataParallel` ([#8250](https://github.com/pyg-team/pytorch_geometric/pull/8250))\n\n### Fixed\n\n- Fixed dummy value creation of boolean tensors in `to_homogeneous()` ([#8858](https://github.com/pyg-team/pytorch_geometric/pull/8858))\n- Fixed Google Drive download issues ([#8804](https://github.com/pyg-team/pytorch_geometric/pull/8804))\n- Fixed a bug in which `InMemoryDataset` did not reconstruct the correct data class when a `pre_transform` has modified it ([#8692](https://github.com/pyg-team/pytorch_geometric/pull/8692))\n- Fixed a bug in which transforms were not applied for `OnDiskDataset` ([#8663](https://github.com/pyg-team/pytorch_geometric/pull/8663))\n- Fixed mini-batch computation in `DMoNPooing` loss function ([#8285](https://github.com/pyg-team/pytorch_geometric/pull/8285))\n- Fixed `NaN` handling in `SQLDatabase` ([#8479](https://github.com/pyg-team/pytorch_geometric/pull/8479))\n- Fixed `CaptumExplainer` in case no `index` is passed ([#8440](https://github.com/pyg-team/pytorch_geometric/pull/8440))\n- Fixed `edge_index` construction in the `UPFD` dataset ([#8413](https://github.com/pyg-team/pytorch_geometric/pull/8413))\n- Fixed TorchScript support in `AttentionalAggregation` and `DeepSetsAggregation` ([#8406](https://github.com/pyg-team/pytorch_geometric/pull/8406))\n- Fixed `GraphMaskExplainer` for GNNs with more than two layers ([#8401](https://github.com/pyg-team/pytorch_geometric/pull/8401))\n- Breaking Change: Properly initialize modules in `GATConv` depending on whether the input is bipartite or non-bipartite ([#8397](https://github.com/pyg-team/pytorch_geometric/pull/8397))\n- Fixed `input_id` computation in `NeighborLoader` in case a `mask` is given ([#8312](https://github.com/pyg-team/pytorch_geometric/pull/8312))\n- Respect current device when deep-copying `Linear` layers ([#8311](https://github.com/pyg-team/pytorch_geometric/pull/8311))\n- Fixed `Data.subgraph()`/`HeteroData.subgraph()` in case `edge_index` is not defined ([#8277](https://github.com/pyg-team/pytorch_geometric/pull/8277))\n- Fixed empty edge handling in `MetaPath2Vec` ([#8248](https://github.com/pyg-team/pytorch_geometric/pull/8248))\n- Fixed `AttentionExplainer` usage within `AttentiveFP` ([#8244](https://github.com/pyg-team/pytorch_geometric/pull/8244))\n- Fixed `load_from_state_dict` in lazy `Linear` modules ([#8242](https://github.com/pyg-team/pytorch_geometric/pull/8242))\n- Fixed pre-trained `DimeNet++` performance on `QM9` ([#8239](https://github.com/pyg-team/pytorch_geometric/pull/8239))\n- Fixed `GNNExplainer` usage within `AttentiveFP` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216))\n- Fixed `to_networkx(to_undirected=True)` in case the input graph is not undirected ([#8204](https://github.com/pyg-team/pytorch_geometric/pull/8204))\n- Fixed sparse-sparse matrix multiplication support on Windows in `TwoHop` and `AddRandomWalkPE` transformations ([#8197](https://github.com/pyg-team/pytorch_geometric/pull/8197), [#8225](https://github.com/pyg-team/pytorch_geometric/pull/8225))\n- Fixed batching of `HeteroData` converted using `ToSparseTensor()` when `torch_sparse` is not installed ([#8356](https://github.com/pyg-team/pytorch_geometric/pull/8356))\n\n### Removed\n\n- Removed disabling of extension packages during `torch_geometric.compile` ([#8698](https://github.com/pyg-team/pytorch_geometric/pull/8698))\n\n## [2.4.0] - 2023-10-12\n\n### Added\n\n- Add the `ogc` method as example ([#8168](https://github.com/pyg-team/pytorch_geometric/pull/8168))\n- Added a tutorial on `NeighborLoader` ([#7931](https://github.com/pyg-team/pytorch_geometric/pull/7931))\n- Added the option to override usage of `segment_matmul`/`grouped_matmul` via the `torch_geometric.backend.use_segment_matmul` flag ([#8148](https://github.com/pyg-team/pytorch_geometric/pull/8148))\n- Added support for PyTorch 2.1.0 ([#8134](https://github.com/pyg-team/pytorch_geometric/pull/8134))\n- Added the `NeuroGraphDataset` benchmark collection ([#8122](https://github.com/pyg-team/pytorch_geometric/pull/8122))\n- Added support for a node-level `mask` tensor in `dense_to_sparse` ([#8117](https://github.com/pyg-team/pytorch_geometric/pull/8117))\n- Added the `to_on_disk_dataset()` method to convert `InMemoryDataset` instances to `OnDiskDataset` instances ([#8116](https://github.com/pyg-team/pytorch_geometric/pull/8116))\n- Added `torch-frame` support ([#8110](https://github.com/pyg-team/pytorch_geometric/pull/8110), [#8118](https://github.com/pyg-team/pytorch_geometric/pull/8118), [#8151](https://github.com/pyg-team/pytorch_geometric/pull/8151), [#8152](https://github.com/pyg-team/pytorch_geometric/pull/8152))\n- Added the `DistLoader` base class ([#8079](https://github.com/pyg-team/pytorch_geometric/pull/8079))\n- Added `HyperGraphData` to support hypergraphs ([#7611](https://github.com/pyg-team/pytorch_geometric/pull/7611))\n- Added the `PCQM4Mv2` dataset as a reference implementation for `OnDiskDataset` ([#8102](https://github.com/pyg-team/pytorch_geometric/pull/8102))\n- Added `module_headers` property to `nn.Sequential` models ([#8093](https://github.com/pyg-team/pytorch_geometric/pull/8093))\n- Added `OnDiskDataset` interface with data loader support ([#8066](https://github.com/pyg-team/pytorch_geometric/pull/8066), [#8088](https://github.com/pyg-team/pytorch_geometric/pull/8088), [#8092](https://github.com/pyg-team/pytorch_geometric/pull/8092), [#8106](https://github.com/pyg-team/pytorch_geometric/pull/8106))\n- Added a tutorial for `Node2Vec` and `MetaPath2Vec` usage ([#7938](https://github.com/pyg-team/pytorch_geometric/pull/7938))\n- Added a tutorial for multi-GPU training with pure PyTorch ([#7894](https://github.com/pyg-team/pytorch_geometric/pull/7894))\n- Added `edge_attr` support to `ResGatedGraphConv` ([#8048](https://github.com/pyg-team/pytorch_geometric/pull/8048))\n- Added a `Database` interface and `SQLiteDatabase`/`RocksDatabase` implementations ([#8028](https://github.com/pyg-team/pytorch_geometric/pull/8028), [#8044](https://github.com/pyg-team/pytorch_geometric/pull/8044), [#8046](https://github.com/pyg-team/pytorch_geometric/pull/8046), [#8051](https://github.com/pyg-team/pytorch_geometric/pull/8051), [#8052](https://github.com/pyg-team/pytorch_geometric/pull/8052), [#8054](https://github.com/pyg-team/pytorch_geometric/pull/8054), [#8057](https://github.com/pyg-team/pytorch_geometric/pull/8057), [#8058](https://github.com/pyg-team/pytorch_geometric/pull/8058))\n- Added support for weighted/biased sampling in `NeighborLoader`/`LinkNeighborLoader` ([#8038](https://github.com/pyg-team/pytorch_geometric/pull/8038))\n- Added the `MixHopConv` layer and an corresponding example ([#8025](https://github.com/pyg-team/pytorch_geometric/pull/8025))\n- Added the option to pass keyword arguments to the underlying normalization layers within `BasicGNN` and `MLP` ([#8024](https://github.com/pyg-team/pytorch_geometric/pull/8024), [#8033](https://github.com/pyg-team/pytorch_geometric/pull/8033))\n- Added `IBMBNodeLoader` and `IBMBBatchLoader` data loaders ([#6230](https://github.com/pyg-team/pytorch_geometric/pull/6230))\n- Added the `NeuralFingerprint` model for learning fingerprints of molecules ([#7919](https://github.com/pyg-team/pytorch_geometric/pull/7919))\n- Added `SparseTensor` support to `WLConvContinuous`, `GeneralConv`, `PDNConv` and `ARMAConv` ([#8013](https://github.com/pyg-team/pytorch_geometric/pull/8013))\n- Added `LCMAggregation`, an implementation of Learnable Communitive Monoids, along with an example ([#7976](https://github.com/pyg-team/pytorch_geometric/pull/7976), [#8020](https://github.com/pyg-team/pytorch_geometric/pull/8020), [#8023](https://github.com/pyg-team/pytorch_geometric/pull/8023), [#8026](https://github.com/pyg-team/pytorch_geometric/pull/8026), [#8075](https://github.com/pyg-team/pytorch_geometric/pull/8075))\n- Added a warning for isolated/non-existing node types in `HeteroData.validate()` ([#7995](https://github.com/pyg-team/pytorch_geometric/pull/7995))\n- Added `utils.cumsum` implementation ([#7994](https://github.com/pyg-team/pytorch_geometric/pull/7994))\n- Added the `BrcaTcga` dataset ([#7905](https://github.com/pyg-team/pytorch_geometric/pull/7905))\n- Added the `MyketDataset` ([#7959](https://github.com/pyg-team/pytorch_geometric/pull/7959))\n- Added a multi-GPU `ogbn-papers100M` example ([#7921](https://github.com/pyg-team/pytorch_geometric/pull/7921))\n- Added `group_argsort` implementation ([#7948](https://github.com/pyg-team/pytorch_geometric/pull/7948))\n- Added `CachedLoader` implementation ([#7896](https://github.com/pyg-team/pytorch_geometric/pull/7896), [#7897](https://github.com/pyg-team/pytorch_geometric/pull/7897))\n- Added possibility to run training benchmarks on XPU device ([#7925](https://github.com/pyg-team/pytorch_geometric/pull/7925))\n- Added `utils.ppr` for personalized PageRank computation ([#7917](https://github.com/pyg-team/pytorch_geometric/pull/7917))\n- Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918))\n- Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915))\n- Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895))\n- Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827))\n- Added the `Wikidata5M` dataset ([#7864](https://github.com/pyg-team/pytorch_geometric/pull/7864))\n- Added TorchScript support inside `BasicGNN` models ([#7865](https://github.com/pyg-team/pytorch_geometric/pull/7865))\n- Added a `batch_size` argument to `unbatch` functionalities ([#7851](https://github.com/pyg-team/pytorch_geometric/pull/7851))\n- Added a distributed example using `graphlearn-for-pytorch` ([#7402](https://github.com/pyg-team/pytorch_geometric/pull/7402))\n- Integrate `neg_sampling_ratio` into `TemporalDataLoader` ([#7644](https://github.com/pyg-team/pytorch_geometric/pull/7644))\n- Added `faiss`-based `KNNINdex` classes for L2 or maximum inner product search ([#7842](https://github.com/pyg-team/pytorch_geometric/pull/7842))\n- Added the `OSE_GVCS` dataset ([#7811](https://github.com/pyg-team/pytorch_geometric/pull/7811))\n- Added `output_initializer` argument to `DimeNet` models ([#7774](https://github.com/pyg-team/pytorch_geometric/pull/7774), [#7780](https://github.com/pyg-team/pytorch_geometric/pull/7780))\n- Added `lexsort` implementation ([#7775](https://github.com/pyg-team/pytorch_geometric/pull/7775))\n- Added possibility to run inference benchmarks on XPU device ([#7705](https://github.com/pyg-team/pytorch_geometric/pull/7705))\n- Added `HeteroData` support in `to_networkx` ([#7713](https://github.com/pyg-team/pytorch_geometric/pull/7713))\n- Added `FlopsCount` support via `fvcore` ([#7693](https://github.com/pyg-team/pytorch_geometric/pull/7693))\n- Added back support for PyTorch >= 1.11.0 ([#7656](https://github.com/pyg-team/pytorch_geometric/pull/7656))\n- Added `Data.sort()` and `HeteroData.sort()` functionalities ([#7649](https://github.com/pyg-team/pytorch_geometric/pull/7649))\n- Added `torch.nested_tensor` support in `Data` and `Batch` ([#7643](https://github.com/pyg-team/pytorch_geometric/pull/7643), [#7647](https://github.com/pyg-team/pytorch_geometric/pull/7647))\n- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700))\n- Added a `LightGCN` example on the `AmazonBook` dataset ([7603](https://github.com/pyg-team/pytorch_geometric/pull/7603))\n- Added a tutorial on hierarchical neighborhood sampling ([#7594](https://github.com/pyg-team/pytorch_geometric/pull/7594))\n- Enabled different attention modes in `HypergraphConv` via the `attention_mode` argument ([#7601](https://github.com/pyg-team/pytorch_geometric/pull/7601))\n- Added the `FilterEdges` graph coarsening operator ([#7361](https://github.com/pyg-team/pytorch_geometric/pull/7361))\n- Added the `DirGNN` model for learning on directed graphs ([#7458](https://github.com/pyg-team/pytorch_geometric/pull/7458))\n- Allow GPU tensors as input to `NodeLoader` and `LinkLoader` ([#7572](https://github.com/pyg-team/pytorch_geometric/pull/7572))\n- Added an `embedding_device` option to allow for GPU inference in `BasicGNN` ([#7548](https://github.com/pyg-team/pytorch_geometric/pull/7548), [#7829](https://github.com/pyg-team/pytorch_geometric/pull/7829))\n- Added `Performer` to `GPSConv` and remove `attn_dropout` argument from `GPSConv` ([#7465](https://github.com/pyg-team/pytorch_geometric/pull/7465))\n- Enabled `LinkNeighborLoader` to return number of sampled nodes and edges per hop ([#7516](https://github.com/pyg-team/pytorch_geometric/pull/7516))\n- Added the `HM` personalized fashion recommendation dataset ([#7515](https://github.com/pyg-team/pytorch_geometric/pull/7515))\n- Added the `GraphMixer` model ([#7501](https://github.com/pyg-team/pytorch_geometric/pull/7501), [#7459](https://github.com/pyg-team/pytorch_geometric/pull/7459))\n- Added the `disable_dynamic_shape` experimental flag ([#7246](https://github.com/pyg-team/pytorch_geometric/pull/7246), [#7534](https://github.com/pyg-team/pytorch_geometric/pull/7534))\n- Added the `MovieLens-1M` heterogeneous dataset ([#7479](https://github.com/pyg-team/pytorch_geometric/pull/7479))\n- Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493), [#7764](https://github.com/pyg-team/pytorch_geometric/pull/7764) [#7765](https://github.com/pyg-team/pytorch_geometric/pull/7765))\n- Added the `AmazonBook` heterogeneous dataset ([#7483](https://github.com/pyg-team/pytorch_geometric/pull/7483))\n- Added hierarchical heterogeneous GraphSAGE example on OGB-MAG ([#7425](https://github.com/pyg-team/pytorch_geometric/pull/7425))\n- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671), [#7846](https://github.com/pyg-team/pytorch_geometric/pull/7846), [#7715](https://github.com/pyg-team/pytorch_geometric/pull/7715), [#7974](https://github.com/pyg-team/pytorch_geometric/pull/7974))\n- Added the `GDELTLite` dataset ([#7442](https://github.com/pyg-team/pytorch_geometric/pull/7442))\n- Added the `approx_knn` function for approximated nearest neighbor search ([#7421](https://github.com/pyg-team/pytorch_geometric/pull/7421))\n- Added the `IGMCDataset` ([#7441](https://github.com/pyg-team/pytorch_geometric/pull/7441))\n- Added a sparse `cross_entropy` implementation ([#7447](https://github.com/pyg-team/pytorch_geometric/pull/7447), [#7466](https://github.com/pyg-team/pytorch_geometric/pull/7466))\n- Added the `MovieLens-100K` heterogeneous dataset ([#7398](https://github.com/pyg-team/pytorch_geometric/pull/7398))\n- Added the `PMLP` model and an example ([#7370](https://github.com/pyg-team/pytorch_geometric/pull/7370), [#7543](https://github.com/pyg-team/pytorch_geometric/pull/7543))\n- Added padding capabilities to `HeteroData.to_homogeneous()` in case feature dimensionalities do not match ([#7374](https://github.com/pyg-team/pytorch_geometric/pull/7374))\n- Added an optional `batch_size` argument to `fps`, `knn`, `knn_graph`, `radius` and `radius_graph` ([#7368](https://github.com/pyg-team/pytorch_geometric/pull/7368))\n- Added `PrefetchLoader` capabilities ([#7376](https://github.com/pyg-team/pytorch_geometric/pull/7376), [#7378](https://github.com/pyg-team/pytorch_geometric/pull/7378), [#7383](https://github.com/pyg-team/pytorch_geometric/pull/7383))\n- Added an example for hierarchical sampling ([#7244](https://github.com/pyg-team/pytorch_geometric/pull/7244))\n- Added Kùzu remote backend examples ([#7298](https://github.com/pyg-team/pytorch_geometric/pull/7298))\n- Added an optional `add_pad_mask` argument to the `Pad` transform ([#7339](https://github.com/pyg-team/pytorch_geometric/pull/7339))\n- Added `keep_inter_cluster_edges` option to `ClusterData` to support inter-subgraph edge connections when doing graph partitioning ([#7326](https://github.com/pyg-team/pytorch_geometric/pull/7326))\n- Unify graph pooling framework ([#7308](https://github.com/pyg-team/pytorch_geometric/pull/7308), [#7625](https://github.com/pyg-team/pytorch_geometric/pull/7625))\n- Added support for tuples as keys in `ModuleDict`/`ParameterDict` ([#7294](https://github.com/pyg-team/pytorch_geometric/pull/7294))\n- Added `NodePropertySplit` transform for creating node-level splits using structural node properties ([#6894](https://github.com/pyg-team/pytorch_geometric/pull/6894))\n- Added an option to preserve directed graphs in `CitationFull` datasets ([#7275](https://github.com/pyg-team/pytorch_geometric/pull/7275))\n- Added support for `torch.sparse.Tensor` in `DataLoader` ([#7252](https://github.com/pyg-team/pytorch_geometric/pull/7252))\n- Added `save` and `load` methods to `InMemoryDataset` ([#7250](https://github.com/pyg-team/pytorch_geometric/pull/7250), [#7413](https://github.com/pyg-team/pytorch_geometric/pull/7413))\n- Added an example for heterogeneous GNN explanation via `CaptumExplainer` ([#7096](https://github.com/pyg-team/pytorch_geometric/pull/7096))\n- Added `visualize_feature_importance` functionality to `HeteroExplanation` ([#7096](https://github.com/pyg-team/pytorch_geometric/pull/7096))\n- Added a `AddRemainingSelfLoops` transform ([#7192](https://github.com/pyg-team/pytorch_geometric/pull/7192))\n- Added `optimizer_resolver` ([#7209](https://github.com/pyg-team/pytorch_geometric/pull/7209))\n- Added `type_ptr` argument to `HeteroLayerNorm` ([#7208](https://github.com/pyg-team/pytorch_geometric/pull/7208))\n- Added an option to benchmark scripts to write PyTorch profiler results to CSV ([#7114](https://github.com/pyg-team/pytorch_geometric/pull/7114))\n- Added subgraph type sampling option with bidirectional edge support ([#7199](https://github.com/pyg-team/pytorch_geometric/pull/7199), [#7200](https://github.com/pyg-team/pytorch_geometric/pull/7200))\n- Added support for `\"any\"`-reductions in `scatter` ([#7198](https://github.com/pyg-team/pytorch_geometric/pull/7198))\n- Added manual sampling interface to `NodeLoader` and `LinkLoader` ([#7197](https://github.com/pyg-team/pytorch_geometric/pull/7197))\n- Extending `torch.sparse` support ([#7155](https://github.com/pyg-team/pytorch_geometric/pull/7155))\n- Added edge weight support to `LightGCN` ([#7157](https://github.com/pyg-team/pytorch_geometric/pull/7157))\n- Added `SparseTensor` support to `trim_to_layer` function ([#7089](https://github.com/pyg-team/pytorch_geometric/pull/7089))\n- Added instructions for ROCm build wheels ([#7143](https://github.com/pyg-team/pytorch_geometric/pull/7143))\n- Added a `ComposeFilters` class to compose `pre_filter` functions in `Dataset` ([#7097](https://github.com/pyg-team/pytorch_geometric/pull/7097))\n- Added a time-step aware variant of the `EllipticBitcoinDataset` called `EllipticBitcoinTemporalDataset` ([#7011](https://github.com/pyg-team/pytorch_geometric/pull/7011))\n- Added `to_dgl` and `from_dgl` conversion functions ([#7053](https://github.com/pyg-team/pytorch_geometric/pull/7053))\n- Added support for `torch.jit.script` within `MessagePassing` layers without `torch_sparse` being installed ([#7061](https://github.com/pyg-team/pytorch_geometric/pull/7061), [#7062](https://github.com/pyg-team/pytorch_geometric/pull/7062))\n- Added unbatching logic for `torch.sparse` tensors ([#7037](https://github.com/pyg-team/pytorch_geometric/pull/7037))\n- Added the `RotatE` KGE model ([#7026](https://github.com/pyg-team/pytorch_geometric/pull/7026))\n- Added support for Apple silicon GPU acceleration in some main examples ([#7770](https://github.com/pyg-team/pytorch_geometric/pull/7770), [#7711](https://github.com/pyg-team/pytorch_geometric/pull/7711), [#7784](https://github.com/pyg-team/pytorch_geometric/pull/7784), [#7785](https://github.com/pyg-team/pytorch_geometric/pull/7785))\n\n### Changed\n\n- Fixed `HeteroConv` for layers that have a non-default argument order, *e.g.*, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166))\n- Handle reserved keywords as keys in `ModuleDict` and `ParameterDict` ([#8163](https://github.com/pyg-team/pytorch_geometric/pull/8163))\n- Updated the examples and tutorials to account for `torch.compile(dynamic=True)` in PyTorch 2.1.0 ([#8145](https://github.com/pyg-team/pytorch_geometric/pull/8145))\n- Enabled dense eigenvalue computation in `AddLaplacianEigenvectorPE` for small-scale graphs ([#8143](https://github.com/pyg-team/pytorch_geometric/pull/8143))\n- Fix `DynamicBatchSampler.__len__` to raise an error in case `num_steps` is undefined ([#8137](https://github.com/pyg-team/pytorch_geometric/pull/8137))\n- Enabled pickling of `DimeNet` models ([#8019](https://github.com/pyg-team/pytorch_geometric/pull/8019))\n- Changed the `trim_to_layer` function to filter out non-reachable node and edge types when operating on heterogeneous graphs ([#7942](https://github.com/pyg-team/pytorch_geometric/pull/7942))\n- Accelerated and simplified `top_k` computation in `TopKPooling` ([#7737](https://github.com/pyg-team/pytorch_geometric/pull/7737))\n- Updated `GIN` implementation in kernel benchmarks to have sequential batchnorms ([#7955](https://github.com/pyg-team/pytorch_geometric/pull/7955))\n- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956)\n- Fixed a bug in which `batch.e_id` was not correctly computed on unsorted graph inputs ([#7953](https://github.com/pyg-team/pytorch_geometric/pull/7953))\n- Fixed `from_networkx` conversion from `nx.stochastic_block_model` graphs ([#7941](https://github.com/pyg-team/pytorch_geometric/pull/7941))\n- Fixed the usage of `bias_initializer` in `HeteroLinear` ([#7923](https://github.com/pyg-team/pytorch_geometric/pull/7923))\n- Fixed broken links in `HGBDataset` ([#7907](https://github.com/pyg-team/pytorch_geometric/pull/7907))\n- Fixed an issue where `SetTransformerAggregation` produced `NaN` values for isolates nodes ([#7902](https://github.com/pyg-team/pytorch_geometric/pull/7902))\n- Fixed `model_summary` on modules with uninitialized parameters ([#7884](https://github.com/pyg-team/pytorch_geometric/pull/7884))\n- Updated `QM9` data pre-processing to include the SMILES string ([#7867](https://github.com/pyg-team/pytorch_geometric/pull/7867))\n- Fixed tracing of `add_self_loops` for a dynamic number of nodes ([#7330](https://github.com/pyg-team/pytorch_geometric/pull/7330))\n- Fixed device issue in `PNAConv.get_degree_histogram` ([#7830](https://github.com/pyg-team/pytorch_geometric/pull/7830))\n- Fixed the shape of `edge_label_time` when using temporal sampling on homogeneous graphs ([#7807](https://github.com/pyg-team/pytorch_geometric/pull/7807))\n- Moved `torch_geometric.contrib.explain.GraphMaskExplainer` to `torch_geometric.explain.algorithm.GraphMaskExplainer` ([#7779](https://github.com/pyg-team/pytorch_geometric/pull/7779))\n- Made `FieldStatus` enum picklable to avoid `PicklingError` in a multi-process setting ([#7808](https://github.com/pyg-team/pytorch_geometric/pull/7808))\n- Fixed `edge_label_index` computation in `LinkNeighborLoader` for the homogeneous+`disjoint` mode ([#7791](https://github.com/pyg-team/pytorch_geometric/pull/7791))\n- Fixed `CaptumExplainer` for `binary_classification` tasks ([#7787](https://github.com/pyg-team/pytorch_geometric/pull/7787))\n- Warn user when using the `training` flag in `to_hetero` modules ([#7772](https://github.com/pyg-team/pytorch_geometric/pull/7772))\n- Unchained exceptions raised when accessing non-existent data attributes for better readability ([#7734](https://github.com/pyg-team/pytorch_geometric/pull/7734))\n- Raise error when collecting non-existing attributes in `HeteroData` ([#7714](https://github.com/pyg-team/pytorch_geometric/pull/7714))\n- Renamed `dest` argument to `dst` in `utils.geodesic_distance` ([#7708](https://github.com/pyg-team/pytorch_geometric/pull/7708))\n- Changed `add_random_edge` to only add true negative edges ([#7654](https://github.com/pyg-team/pytorch_geometric/pull/7654))\n- Allowed the usage of `BasicGNN` models in `DeepGraphInfomax` ([#7648](https://github.com/pyg-team/pytorch_geometric/pull/7648))\n- Breaking Change: Made `Data.keys` a method rather than a property ([#7629](https://github.com/pyg-team/pytorch_geometric/pull/7629))\n- Added a `num_edges` parameter to the forward method of `HypergraphConv` ([#7560](https://github.com/pyg-team/pytorch_geometric/pull/7560))\n- Fixed `get_mesh_laplacian` for `normalization=\"sym\"` ([#7544](https://github.com/pyg-team/pytorch_geometric/pull/7544))\n- Use `dim_size` to initialize output size of the `EquilibriumAggregation` layer ([#7530](https://github.com/pyg-team/pytorch_geometric/pull/7530))\n- Added a `max_num_elements` parameter to the forward method of `GraphMultisetTransformer`, `GRUAggregation`, `LSTMAggregation` and `SetTransformerAggregation` ([#7529](https://github.com/pyg-team/pytorch_geometric/pull/7529))\n- Fixed empty edge indices handling in `SparseTensor` ([#7519](https://github.com/pyg-team/pytorch_geometric/pull/7519))\n- Move the `scaler` tensor in `GeneralConv` to the correct device ([#7484](https://github.com/pyg-team/pytorch_geometric/pull/7484))\n- Fixed `HeteroLinear` bug when used via mixed precision ([#7473](https://github.com/pyg-team/pytorch_geometric/pull/7473))\n- All transforms are now immutable, i.e., they perform a shallow-copy of the data and therefore do not longer modify data in-place ([#7429](https://github.com/pyg-team/pytorch_geometric/pull/7429))\n- Set `output_size` in the `repeat_interleave` operation in `QuantileAggregation` ([#7426](https://github.com/pyg-team/pytorch_geometric/pull/7426))\n- Fixed gradient computation of edge weights in `utils.spmm` ([#7428](https://github.com/pyg-team/pytorch_geometric/pull/7428))\n- Re-factored `ClusterLoader` to integrate `pyg-lib` METIS routine ([#7416](https://github.com/pyg-team/pytorch_geometric/pull/7416))\n- Fixed an index-out-of-range bug in `QuantileAggregation` when `dim_size` is passed ([#7407](https://github.com/pyg-team/pytorch_geometric/pull/7407))\n- The `filter_per_worker` option will not get automatically inferred by default based on the device of the underlying data ([#7399](https://github.com/pyg-team/pytorch_geometric/pull/7399))\n- Fixed a bug in `LightGCN.recommendation_loss()` to only use the embeddings of the nodes involved in the current mini-batch ([#7384](https://github.com/pyg-team/pytorch_geometric/pull/7384))\n- Added an optional `max_num_elements` argument to `SortAggregation` ([#7367](https://github.com/pyg-team/pytorch_geometric/pull/7367))\n- Added the option to pass `fill_value` as a `torch.tensor` to `utils.to_dense_batch` ([#7367](https://github.com/pyg-team/pytorch_geometric/pull/7367))\n- Fixed a bug in which inputs where modified in-place in `to_hetero_with_bases` ([#7363](https://github.com/pyg-team/pytorch_geometric/pull/7363))\n- Do not load `node_default` and `edge_default` attributes in `from_networkx` ([#7348](https://github.com/pyg-team/pytorch_geometric/pull/7348))\n- Updated examples to use `NeighborLoader` instead of `NeighborSampler` ([#7152](https://github.com/pyg-team/pytorch_geometric/pull/7152))\n- Fixed `HGTConv` utility function `_construct_src_node_feat` ([#7194](https://github.com/pyg-team/pytorch_geometric/pull/7194))\n- Extend dataset summary to create stats for each node/edge type ([#7203](https://github.com/pyg-team/pytorch_geometric/pull/7203))\n- Added an optional `batch_size` argument to `avg_pool_x` and `max_pool_x` ([#7216](https://github.com/pyg-team/pytorch_geometric/pull/7216))\n- Fixed `subgraph` on unordered inputs ([#7187](https://github.com/pyg-team/pytorch_geometric/pull/7187))\n- Allow missing node types in `HeteroDictLinear` ([#7185](https://github.com/pyg-team/pytorch_geometric/pull/7185))\n- Optimized `from_networkx` memory footprint by reducing unnecessary copies ([#7119](https://github.com/pyg-team/pytorch_geometric/pull/7119))\n- Added an optional `batch_size` argument to `LayerNorm`, `GraphNorm`, `InstanceNorm`, `GraphSizeNorm` and `PairNorm` ([#7135](https://github.com/pyg-team/pytorch_geometric/pull/7135))\n- Improved code coverage ([#7093](https://github.com/pyg-team/pytorch_geometric/pull/7093), [#7195](https://github.com/pyg-team/pytorch_geometric/pull/7195))\n- Fix `numpy` incompatiblity when reading files for `Planetoid` datasets ([#7141](https://github.com/pyg-team/pytorch_geometric/pull/7141))\n- Added support for `Data.num_edges` for native `torch.sparse.Tensor` adjacency matrices ([#7104](https://github.com/pyg-team/pytorch_geometric/pull/7104))\n- Fixed crash of heterogeneous data loaders if node or edge types are missing ([#7060](https://github.com/pyg-team/pytorch_geometric/pull/7060), [#7087](https://github.com/pyg-team/pytorch_geometric/pull/7087))\n- Accelerated attention-based `MultiAggregation` ([#7077](https://github.com/pyg-team/pytorch_geometric/pull/7077))\n- Edges in `HeterophilousGraphDataset` are now undirected by default ([#7065](https://github.com/pyg-team/pytorch_geometric/pull/7065))\n- Fixed a bug in `FastHGTConv` that computed values via parameters used to compute the keys ([#7050](https://github.com/pyg-team/pytorch_geometric/pull/7050))\n- Accelerated sparse tensor conversion routines ([#7042](https://github.com/pyg-team/pytorch_geometric/pull/7042), [#7043](https://github.com/pyg-team/pytorch_geometric/pull/7043))\n- Change `torch_sparse.SparseTensor` logic to utilize `torch.sparse_csr` instead ([#7041](https://github.com/pyg-team/pytorch_geometric/pull/7041))\n- Added an optional `batch_size` and `max_num_nodes` arguments to `MemPooling` layer ([#7239](https://github.com/pyg-team/pytorch_geometric/pull/7239))\n- Fixed training issues of the GraphGPS example ([#7377](https://github.com/pyg-team/pytorch_geometric/pull/7377))\n- Allowed `CaptumExplainer` to be called multiple times in a row ([#7391](https://github.com/pyg-team/pytorch_geometric/pull/7391))\n\n### Removed\n\n- Dropped Python 3.7 support ([#7939](https://github.com/pyg-team/pytorch_geometric/pull/7939))\n- Removed `layer_type` argument in `contrib.explain.GraphMaskExplainer` ([#7445](https://github.com/pyg-team/pytorch_geometric/pull/7445))\n- Replaced `FastHGTConv` with `HGTConv` ([#7117](https://github.com/pyg-team/pytorch_geometric/pull/7117))\n\n## [2.3.0] - 2023-03-23\n\n### Added\n\n- Added a memory-efficient `utils.one_hot` implementation ([#7005](https://github.com/pyg-team/pytorch_geometric/pull/7005))\n- Added `HeteroDictLinear` and an optimized `FastHGTConv` module ([#6178](https://github.com/pyg-team/pytorch_geometric/pull/6178), [#6998](https://github.com/pyg-team/pytorch_geometric/pull/6998))\n- Added the `DenseGATConv` module ([#6928](https://github.com/pyg-team/pytorch_geometric/pull/6928))\n- Added `trim_to_layer` utility function for more efficient `NeighborLoader` use-cases ([#6661](https://github.com/pyg-team/pytorch_geometric/pull/6661))\n- Added the `DistMult` KGE model ([#6958](https://github.com/pyg-team/pytorch_geometric/pull/6958))\n- Added `HeteroData.set_value_dict` functionality ([#6961](https://github.com/pyg-team/pytorch_geometric/pull/6961), [#6974](https://github.com/pyg-team/pytorch_geometric/pull/6974))\n- Added PyTorch >= 2.0 support ([#6934](https://github.com/pyg-team/pytorch_geometric/pull/6934), [#7000](https://github.com/pyg-team/pytorch_geometric/pull/7000))\n- Added PyTorch Lightning >= 2.0 support ([#6929](https://github.com/pyg-team/pytorch_geometric/pull/6929))\n- Added the `ComplEx` KGE model ([#6898](https://github.com/pyg-team/pytorch_geometric/pull/6898))\n- Added option to write benchmark results to csv ([#6888](https://github.com/pyg-team/pytorch_geometric/pull/6888))\n- Added `HeteroLayerNorm` and `HeteroBatchNorm` layers ([#6838](https://github.com/pyg-team/pytorch_geometric/pull/6838))\n- Added the `HeterophilousGraphDataset` suite ([#6846](https://github.com/pyg-team/pytorch_geometric/pull/6846))\n- Added support for sparse tensor in full batch mode inference benchmark ([#6843](https://github.com/pyg-team/pytorch_geometric/pull/6843))\n- Enabled `NeighborLoader` to return number of sampled nodes and edges per hop ([#6834](https://github.com/pyg-team/pytorch_geometric/pull/6834))\n- Added `ZipLoader` to execute multiple `NodeLoader` or `LinkLoader` instances ([#6829](https://github.com/pyg-team/pytorch_geometric/issues/6829))\n- Added common `utils.select` and `utils.narrow` functionality to support filtering of both tensors and lists ([#6162](https://github.com/pyg-team/pytorch_geometric/issues/6162))\n- Support `normalization` customization in `get_mesh_laplacian` ([#6790](https://github.com/pyg-team/pytorch_geometric/issues/6790))\n- Added the `TemporalEncoding` module ([#6785](https://github.com/pyg-team/pytorch_geometric/pull/6785))\n- Added CPU-optimized `spmm_reduce` functionality via CSR format ([#6699](https://github.com/pyg-team/pytorch_geometric/pull/6699), [#6759](https://github.com/pyg-team/pytorch_geometric/pull/6759))\n- Added support for the revised version of the `MD17` dataset ([#6734](https://github.com/pyg-team/pytorch_geometric/pull/6734))\n- Added TorchScript support to the `RECT_L` model ([#6727](https://github.com/pyg-team/pytorch_geometric/pull/6727))\n- Added TorchScript support to the `Node2Vec` model ([#6726](https://github.com/pyg-team/pytorch_geometric/pull/6726))\n- Added `utils.to_edge_index` to convert sparse tensors to edge indices and edge attributes ([#6728](https://github.com/pyg-team/pytorch_geometric/issues/6728))\n- Fixed expected data format in `PolBlogs` dataset ([#6714](https://github.com/pyg-team/pytorch_geometric/issues/6714))\n- Added `SimpleConv` to perform non-trainable propagation ([#6718](https://github.com/pyg-team/pytorch_geometric/pull/6718))\n- Added a `RemoveDuplicatedEdges` transform ([#6709](https://github.com/pyg-team/pytorch_geometric/pull/6709))\n- Added TorchScript support to the `LINKX` model ([#6712](https://github.com/pyg-team/pytorch_geometric/pull/6712))\n- Added `torch.jit` examples for `example/film.py` and `example/gcn.py`([#6602](https://github.com/pyg-team/pytorch_geometric/pull/6692))\n- Added `Pad` transform ([#5940](https://github.com/pyg-team/pytorch_geometric/pull/5940), [#6697](https://github.com/pyg-team/pytorch_geometric/pull/6697), [#6731](https://github.com/pyg-team/pytorch_geometric/pull/6731), [#6758](https://github.com/pyg-team/pytorch_geometric/pull/6758))\n- Added full batch mode to the inference benchmark ([#6631](https://github.com/pyg-team/pytorch_geometric/pull/6631))\n- Added `cat` aggregation type to the `HeteroConv` class so that features can be concatenated during grouping ([#6634](https://github.com/pyg-team/pytorch_geometric/pull/6634))\n- Added `torch.compile` support and benchmark study ([#6610](https://github.com/pyg-team/pytorch_geometric/pull/6610), [#6952](https://github.com/pyg-team/pytorch_geometric/pull/6952), [#6953](https://github.com/pyg-team/pytorch_geometric/pull/6953), [#6980](https://github.com/pyg-team/pytorch_geometric/pull/6980), [#6983](https://github.com/pyg-team/pytorch_geometric/pull/6983), [#6984](https://github.com/pyg-team/pytorch_geometric/pull/6984), [#6985](https://github.com/pyg-team/pytorch_geometric/pull/6985), [#6986](https://github.com/pyg-team/pytorch_geometric/pull/6986), [#6989](https://github.com/pyg-team/pytorch_geometric/pull/6989), [#7002](https://github.com/pyg-team/pytorch_geometric/pull/7002))\n- Added the `AntiSymmetricConv` layer ([#6577](https://github.com/pyg-team/pytorch_geometric/pull/6577))\n- Added a mixin for Huggingface model hub integration ([#5930](https://github.com/pyg-team/pytorch_geometric/pull/5930), [#6591](https://github.com/pyg-team/pytorch_geometric/pull/6591))\n- Added support for accelerated GNN layers in `nn.conv.cugraph` via `cugraph-ops` ([#6278](https://github.com/pyg-team/pytorch_geometric/pull/6278), [#6388](https://github.com/pyg-team/pytorch_geometric/pull/6388), [#6412](https://github.com/pyg-team/pytorch_geometric/pull/6412))\n- Added accelerated `index_sort` function from `pyg-lib` for faster sorting ([#6554](https://github.com/pyg-team/pytorch_geometric/pull/6554))\n- Fix incorrect device in `EquilibriumAggregration` ([#6560](https://github.com/pyg-team/pytorch_geometric/pull/6560))\n- Added bipartite graph support in `dense_to_sparse()` ([#6546](https://github.com/pyg-team/pytorch_geometric/pull/6546))\n- Add CPU affinity support for more data loaders ([#6534](https://github.com/pyg-team/pytorch_geometric/pull/6534), [#6922](https://github.com/pyg-team/pytorch_geometric/pull/6922))\n- Added the `BAMultiShapesDataset` ([#6541](https://github.com/pyg-team/pytorch_geometric/pull/6541))\n- Added the interfaces of a graph pooling framework ([#6540](https://github.com/pyg-team/pytorch_geometric/pull/6540))\n- Added automatic `n_id` and `e_id` attributes to mini-batches produced by `NodeLoader` and `LinkLoader` ([#6524](https://github.com/pyg-team/pytorch_geometric/pull/6524))\n- Added `PGMExplainer` to `torch_geometric.contrib` ([#6149](https://github.com/pyg-team/pytorch_geometric/pull/6149), [#6588](https://github.com/pyg-team/pytorch_geometric/pull/6588), [#6589](https://github.com/pyg-team/pytorch_geometric/pull/6589))\n- Added a `NumNeighbors` helper class for specifying the number of neighbors when sampling ([#6501](https://github.com/pyg-team/pytorch_geometric/pull/6501), [#6505](https://github.com/pyg-team/pytorch_geometric/pull/6505), [#6690](https://github.com/pyg-team/pytorch_geometric/pull/6690))\n- Added caching to `is_node_attr()` and `is_edge_attr()` calls ([#6492](https://github.com/pyg-team/pytorch_geometric/pull/6492))\n- Added `ToHeteroLinear` and `ToHeteroMessagePassing` modules to accelerate `to_hetero` functionality ([#5992](https://github.com/pyg-team/pytorch_geometric/pull/5992), [#6456](https://github.com/pyg-team/pytorch_geometric/pull/6456))\n- Added `GraphMaskExplainer` ([#6284](https://github.com/pyg-team/pytorch_geometric/pull/6284))\n- Added the `GRBCD` and `PRBCD` adversarial attack models ([#5972](https://github.com/pyg-team/pytorch_geometric/pull/5972))\n- Added `dropout` option to `SetTransformer` and `GraphMultisetTransformer` ([#6484](https://github.com/pyg-team/pytorch_geometric/pull/6484))\n- Added option to customize loader arguments for evaluation in `LightningNodeData` and `LightningLinkData` ([#6450](https://github.com/pyg-team/pytorch_geometric/pull/6450), [#6456](https://github.com/pyg-team/pytorch_geometric/pull/6456))\n- Added option to customize `num_neighbors` in `NeighborSampler` after instantiation ([#6446](https://github.com/pyg-team/pytorch_geometric/pull/6446))\n- Added the `Taobao` dataset and a corresponding example for it ([#6144](https://github.com/pyg-team/pytorch_geometric/pull/6144))\n- Added `pyproject.toml` ([#6431](https://github.com/pyg-team/pytorch_geometric/pull/6431))\n- Added the `torch_geometric.contrib` sub-package ([#6422](https://github.com/pyg-team/pytorch_geometric/pull/6422))\n- Warn on using latest documentation ([#6418](https://github.com/pyg-team/pytorch_geometric/pull/6418))\n- Added basic `pyright` type checker support ([#6415](https://github.com/pyg-team/pytorch_geometric/pull/6415))\n- Added a new external resource for link prediction ([#6396](https://github.com/pyg-team/pytorch_geometric/pull/6396))\n- Added `CaptumExplainer` ([#6383](https://github.com/pyg-team/pytorch_geometric/pull/6383), [#6387](https://github.com/pyg-team/pytorch_geometric/pull/6387), [#6433](https://github.com/pyg-team/pytorch_geometric/pull/6433), [#6487](https://github.com/pyg-team/pytorch_geometric/pull/6487), [#6966](https://github.com/pyg-team/pytorch_geometric/pull/6966))\n- Added support for custom `HeteroData` mini-batch class in remote backends ([#6377](https://github.com/pyg-team/pytorch_geometric/pull/6377))\n- Added the `GNNFF` model ([#5866](https://github.com/pyg-team/pytorch_geometric/pull/5866))\n- Added `MLPAggregation`, `SetTransformerAggregation`, `GRUAggregation`, and `DeepSetsAggregation` as adaptive readout functions ([#6301](https://github.com/pyg-team/pytorch_geometric/pull/6301), [#6336](https://github.com/pyg-team/pytorch_geometric/pull/6336), [#6338](https://github.com/pyg-team/pytorch_geometric/pull/6338))\n- Added `Dataset.to_datapipe` for converting PyG datasets into a torchdata `DataPipe`([#6141](https://github.com/pyg-team/pytorch_geometric/pull/6141))\n- Added `to_nested_tensor` and `from_nested_tensor` functionality ([#6329](https://github.com/pyg-team/pytorch_geometric/pull/6329), [#6330](https://github.com/pyg-team/pytorch_geometric/pull/6330), [#6331](https://github.com/pyg-team/pytorch_geometric/pull/6331), [#6332](https://github.com/pyg-team/pytorch_geometric/pull/6332))\n- Added the `GPSConv` Graph Transformer layer and example ([#6326](https://github.com/pyg-team/pytorch_geometric/pull/6326), [#6327](https://github.com/pyg-team/pytorch_geometric/pull/6327))\n- Added `networkit` conversion utilities ([#6321](https://github.com/pyg-team/pytorch_geometric/pull/6321))\n- Added global dataset attribute access via `dataset.{attr_name}` ([#6319](https://github.com/pyg-team/pytorch_geometric/pull/6319))\n- Added the `TransE` KGE model and example ([#6314](https://github.com/pyg-team/pytorch_geometric/pull/6314))\n- Added the Freebase `FB15k_237` dataset ([#3204](https://github.com/pyg-team/pytorch_geometric/pull/3204))\n- Added `Data.update()` and `HeteroData.update()` functionality ([#6313](https://github.com/pyg-team/pytorch_geometric/pull/6313))\n- Added `PGExplainer` ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6204))\n- Added the `AirfRANS` dataset ([#6287](https://github.com/pyg-team/pytorch_geometric/pull/6287))\n- Added `AttentionExplainer` ([#6279](https://github.com/pyg-team/pytorch_geometric/pull/6279))\n- Added (un)faithfulness explainability metric ([#6090](https://github.com/pyg-team/pytorch_geometric/pull/6090))\n- Added fidelity explainability metric ([#6116](https://github.com/pyg-team/pytorch_geometric/pull/6116), [#6510](https://github.com/pyg-team/pytorch_geometric/pull/6510))\n- Added subgraph visualization of GNN explanations ([#6235](https://github.com/pyg-team/pytorch_geometric/pull/6235), [#6271](https://github.com/pyg-team/pytorch_geometric/pull/6271))\n- Added weighted negative sampling option in `LinkNeighborLoader` ([#6264](https://github.com/pyg-team/pytorch_geometric/pull/6264))\n- Added the `BA2MotifDataset` explainer dataset ([#6257](https://github.com/pyg-team/pytorch_geometric/pull/6257))\n- Added `CycleMotif` motif generator to generate `n`-node cycle shaped motifs ([#6256](https://github.com/pyg-team/pytorch_geometric/pull/6256))\n- Added the `InfectionDataset` to evaluate explanations ([#6222](https://github.com/pyg-team/pytorch_geometric/pull/6222))\n- Added `characterization_score` and `fidelity_curve_auc` explainer metrics ([#6188](https://github.com/pyg-team/pytorch_geometric/pull/6188))\n- Added `get_message_passing_embeddings` ([#6201](https://github.com/pyg-team/pytorch_geometric/pull/6201))\n- Added the `PointGNNConv` layer ([#6194](https://github.com/pyg-team/pytorch_geometric/pull/6194))\n- Added `GridGraph` graph generator to generate grid graphs ([#6220](https://github.com/pyg-team/pytorch_geometric/pull/6220)\n- Added explainability metrics for when ground truth is available ([#6137](https://github.com/pyg-team/pytorch_geometric/pull/6137))\n- Added `visualize_feature_importance` to support node feature visualizations ([#6094](https://github.com/pyg-team/pytorch_geometric/pull/6094))\n- Added heterogeneous graph support to `Explanation` framework ([#6091](https://github.com/pyg-team/pytorch_geometric/pull/6091), [#6218](https://github.com/pyg-team/pytorch_geometric/pull/6218))\n- Added a `CustomMotif` motif generator ([#6179](https://github.com/pyg-team/pytorch_geometric/pull/6179))\n- Added `ERGraph` graph generator to generate Ergos-Renyi (ER) graphs ([#6073](https://github.com/pyg-team/pytorch_geometric/pull/6073))\n- Added `BAGraph` graph generator to generate Barabasi-Albert graphs - the usage of `datasets.BAShapes` is now deprecated ([#6072](https://github.com/pyg-team/pytorch_geometric/pull/6072)\n- Added explainability benchmark dataset framework ([#6104](https://github.com/pyg-team/pytorch_geometric/pull/6104))\n- Added `seed_time` attribute to temporal `NodeLoader` outputs in case `input_time` is given ([#6196](https://github.com/pyg-team/pytorch_geometric/pull/6196))\n- Added `Data.edge_subgraph` and `HeteroData.edge_subgraph` functionalities ([#6193](https://github.com/pyg-team/pytorch_geometric/pull/6193))\n- Added `input_time` option to `LightningNodeData` and `transform_sampler_output` to `NodeLoader` and `LinkLoader` ([#6187](https://github.com/pyg-team/pytorch_geometric/pull/6187))\n- Added `summary` for PyG/PyTorch models ([#5859](https://github.com/pyg-team/pytorch_geometric/pull/5859), [#6161](https://github.com/pyg-team/pytorch_geometric/pull/6161))\n- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6033](https://github.com/pyg-team/pytorch_geometric/pull/6033), [#6514](https://github.com/pyg-team/pytorch_geometric/pull/6514), [#6532](https://github.com/pyg-team/pytorch_geometric/pull/6532), [#6748](https://github.com/pyg-team/pytorch_geometric/pull/6748), [#6847](https://github.com/pyg-team/pytorch_geometric/pull/6847), [#6868](https://github.com/pyg-team/pytorch_geometric/pull/6868), [#6874](https://github.com/pyg-team/pytorch_geometric/pull/6874), [#6897](https://github.com/pyg-team/pytorch_geometric/pull/6897), [#6930](https://github.com/pyg-team/pytorch_geometric/pull/6930), [#6932](https://github.com/pyg-team/pytorch_geometric/pull/6932), [#6936](https://github.com/pyg-team/pytorch_geometric/pull/6936), [#6937](https://github.com/pyg-team/pytorch_geometric/pull/6937), [#6939](https://github.com/pyg-team/pytorch_geometric/pull/6939), [#6947](https://github.com/pyg-team/pytorch_geometric/pull/6947), [#6950](https://github.com/pyg-team/pytorch_geometric/pull/6950), [#6951](https://github.com/pyg-team/pytorch_geometric/pull/6951), [#6957](https://github.com/pyg-team/pytorch_geometric/pull/6957))\n- Add `inputs_channels` back in training benchmark ([#6154](https://github.com/pyg-team/pytorch_geometric/pull/6154))\n- Added support for dropping nodes in `utils.to_dense_batch` in case `max_num_nodes` is smaller than the number of nodes ([#6124](https://github.com/pyg-team/pytorch_geometric/pull/6124))\n- Added the RandLA-Net architecture as an example ([#5117](https://github.com/pyg-team/pytorch_geometric/pull/5117))\n\n### Changed\n\n- Migrate to `pyproject.toml` for packaging ([#6880](https://github.com/pyg-team/pytorch_geometric/pull/6880))\n- Drop internal usage of `__dunder__` names ([#6999](https://github.com/pyg-team/pytorch_geometric/issues/6999))\n- Changed the interface of `sort_edge_index`, `coalesce` and `to_undirected` to only return single `edge_index` information in case the `edge_attr` argument is not specified ([#6875](https://github.com/pyg-team/pytorch_geometric/issues/6875), [#6879](https://github.com/pyg-team/pytorch_geometric/issues/6879), [#6893](https://github.com/pyg-team/pytorch_geometric/issues/6893))\n- Fixed a bug in `to_hetero` when using an uninitialized submodule without implementing `reset_parameters` ([#6863](https://github.com/pyg-team/pytorch_geometric/issues/6790))\n- Fixed a bug in `get_mesh_laplacian` ([#6790](https://github.com/pyg-team/pytorch_geometric/issues/6790))\n- Fixed a bug in which masks were not properly masked in `GNNExplainer` on link prediction tasks ([#6787](https://github.com/pyg-team/pytorch_geometric/pull/6787))\n- Allow the usage of `ChebConv` within `GNNExplainer` ([#6778](https://github.com/pyg-team/pytorch_geometric/pull/6778))\n- Allow setting the `EdgeStorage.num_edges` property ([#6710](https://github.com/pyg-team/pytorch_geometric/pull/6710))\n- Fixed a bug in `utils.bipartite_subgraph()` and updated docs of `HeteroData.subgraph()` ([#6654](https://github.com/pyg-team/pytorch_geometric/pull/6654))\n- Properly reset the `data_list` cache of an `InMemoryDataset` when accessing `dataset.data` ([#6685](https://github.com/pyg-team/pytorch_geometric/pull/6685))\n- Fixed a bug in `Data.subgraph()` and `HeteroData.subgraph()` ([#6613](https://github.com/pyg-team/pytorch_geometric/pull/6613))\n- Fixed a bug in `PNAConv` and `DegreeScalerAggregation` to correctly incorporate degree statistics of isolated nodes ([#6609](https://github.com/pyg-team/pytorch_geometric/pull/6609))\n- Improved code coverage ([#6523](https://github.com/pyg-team/pytorch_geometric/pull/6523), [#6538](https://github.com/pyg-team/pytorch_geometric/pull/6538), [#6555](https://github.com/pyg-team/pytorch_geometric/pull/6555), [#6558](https://github.com/pyg-team/pytorch_geometric/pull/6558), [#6568](https://github.com/pyg-team/pytorch_geometric/pull/6568), [#6573](https://github.com/pyg-team/pytorch_geometric/pull/6573), [#6578](https://github.com/pyg-team/pytorch_geometric/pull/6578), [#6597](https://github.com/pyg-team/pytorch_geometric/pull/6597), [#6600](https://github.com/pyg-team/pytorch_geometric/pull/6600), [#6618](https://github.com/pyg-team/pytorch_geometric/pull/6618), [#6619](https://github.com/pyg-team/pytorch_geometric/pull/6619), [#6621](https://github.com/pyg-team/pytorch_geometric/pull/6621), [#6623](https://github.com/pyg-team/pytorch_geometric/pull/6623), [#6637](https://github.com/pyg-team/pytorch_geometric/pull/6637), [#6638](https://github.com/pyg-team/pytorch_geometric/pull/6638), [#6640](https://github.com/pyg-team/pytorch_geometric/pull/6640), [#6645](https://github.com/pyg-team/pytorch_geometric/pull/6645), [#6648](https://github.com/pyg-team/pytorch_geometric/pull/6648), [#6647](https://github.com/pyg-team/pytorch_geometric/pull/6647), [#6653](https://github.com/pyg-team/pytorch_geometric/pull/6653), [#6657](https://github.com/pyg-team/pytorch_geometric/pull/6657), [#6662](https://github.com/pyg-team/pytorch_geometric/pull/6662), [#6664](https://github.com/pyg-team/pytorch_geometric/pull/6664), [#6667](https://github.com/pyg-team/pytorch_geometric/pull/6667), [#6668](https://github.com/pyg-team/pytorch_geometric/pull/6668), [#6669](https://github.com/pyg-team/pytorch_geometric/pull/6669), [#6670](https://github.com/pyg-team/pytorch_geometric/pull/6670), [#6671](https://github.com/pyg-team/pytorch_geometric/pull/6671), [#6673](https://github.com/pyg-team/pytorch_geometric/pull/6673), [#6675](https://github.com/pyg-team/pytorch_geometric/pull/6675), [#6676](https://github.com/pyg-team/pytorch_geometric/pull/6676), [#6677](https://github.com/pyg-team/pytorch_geometric/pull/6677), [#6678](https://github.com/pyg-team/pytorch_geometric/pull/6678), [#6681](https://github.com/pyg-team/pytorch_geometric/pull/6681), [#6683](https://github.com/pyg-team/pytorch_geometric/pull/6683), [#6703](https://github.com/pyg-team/pytorch_geometric/pull/6703), [#6720](https://github.com/pyg-team/pytorch_geometric/pull/6720), [#6735](https://github.com/pyg-team/pytorch_geometric/pull/6735), [#6736](https://github.com/pyg-team/pytorch_geometric/pull/6736), [#6763](https://github.com/pyg-team/pytorch_geometric/pull/6763), [#6781](https://github.com/pyg-team/pytorch_geometric/pull/6781), [#6797](https://github.com/pyg-team/pytorch_geometric/pull/6797), [#6799](https://github.com/pyg-team/pytorch_geometric/pull/6799), [#6824](https://github.com/pyg-team/pytorch_geometric/pull/6824), [#6858](https://github.com/pyg-team/pytorch_geometric/pull/6858))\n- Fixed a bug in which `data.to_heterogeneous()` filtered attributs in the wrong dimension ([#6522](https://github.com/pyg-team/pytorch_geometric/pull/6522))\n- Breaking Change: Temporal sampling will now also sample nodes with an equal timestamp to the seed time (requires `pyg-lib>0.1.0`) ([#6517](https://github.com/pyg-team/pytorch_geometric/pull/6517))\n- Changed `DataLoader` workers with affinity to start at `cpu0` ([#6512](https://github.com/pyg-team/pytorch_geometric/pull/6512))\n- Allow 1D input to `global_*_pool` functions ([#6504](https://github.com/pyg-team/pytorch_geometric/pull/6504))\n- Add information about dynamic shapes in `RGCNConv` ([#6482](https://github.com/pyg-team/pytorch_geometric/pull/6482))\n- Fixed the use of types removed in `numpy 1.24.0` ([#6495](https://github.com/pyg-team/pytorch_geometric/pull/6495))\n- Fixed keyword parameters in `examples/mnist_voxel_grid.py` ([#6478](https://github.com/pyg-team/pytorch_geometric/pull/6478))\n- Unified `LightningNodeData` and `LightningLinkData` code paths ([#6473](https://github.com/pyg-team/pytorch_geometric/pull/6473))\n- Allow indices with any integer type in `RGCNConv` ([#6463](https://github.com/pyg-team/pytorch_geometric/pull/6463))\n- Re-structured the documentation ([#6420](https://github.com/pyg-team/pytorch_geometric/pull/6420), [#6423](https://github.com/pyg-team/pytorch_geometric/pull/6423), [#6429](https://github.com/pyg-team/pytorch_geometric/pull/6429), [#6440](https://github.com/pyg-team/pytorch_geometric/pull/6440), [#6443](https://github.com/pyg-team/pytorch_geometric/pull/6443), [#6445](https://github.com/pyg-team/pytorch_geometric/pull/6445), [#6452](https://github.com/pyg-team/pytorch_geometric/pull/6452), [#6453](https://github.com/pyg-team/pytorch_geometric/pull/6453), [#6458](https://github.com/pyg-team/pytorch_geometric/pull/6458), [#6459](https://github.com/pyg-team/pytorch_geometric/pull/6459), [#6460](https://github.com/pyg-team/pytorch_geometric/pull/6460), [#6490](https://github.com/pyg-team/pytorch_geometric/pull/6490), [#6491](https://github.com/pyg-team/pytorch_geometric/pull/6491), [#6693](https://github.com/pyg-team/pytorch_geometric/pull/6693), [#6744](https://github.com/pyg-team/pytorch_geometric/pull/6744))\n- Fix the default arguments of `DataParallel` class ([#6376](https://github.com/pyg-team/pytorch_geometric/pull/6376))\n- Fix `ImbalancedSampler` on sliced `InMemoryDataset` ([#6374](https://github.com/pyg-team/pytorch_geometric/pull/6374))\n- Breaking Change: Changed the interface and implementation of `GraphMultisetTransformer` ([#6343](https://github.com/pyg-team/pytorch_geometric/pull/6343))\n- Fixed the approximate PPR variant in `transforms.GDC` to not crash on graphs with isolated nodes ([#6242](https://github.com/pyg-team/pytorch_geometric/pull/6242))\n- Added a warning when accesing `InMemoryDataset.data` ([#6318](https://github.com/pyg-team/pytorch_geometric/pull/6318))\n- Drop `SparseTensor` dependency in `GraphStore` ([#5517](https://github.com/pyg-team/pytorch_geometric/pull/5517))\n- Replace `NeighborSampler` with `NeighborLoader` in the distributed sampling example ([#6204](https://github.com/pyg-team/pytorch_geometric/pull/6307))\n- Fixed the filtering of node features in `transforms.RemoveIsolatedNodes` ([#6308](https://github.com/pyg-team/pytorch_geometric/pull/6308))\n- Fixed a bug in `DimeNet` that causes a output dimension mismatch ([#6305](https://github.com/pyg-team/pytorch_geometric/pull/6305))\n- Fixed `Data.to_heterogeneous()` with empty `edge_index` ([#6304](https://github.com/pyg-team/pytorch_geometric/pull/6304))\n- Unify `Explanation.node_mask` and `Explanation.node_feat_mask` ([#6267](https://github.com/pyg-team/pytorch_geometric/pull/6267))\n- Moved thresholding config of the `Explainer` to `Explanation` ([#6215](https://github.com/pyg-team/pytorch_geometric/pull/6215))\n- Fixed a bug in the output order in `HeteroLinear` for un-sorted type vectors ([#6198](https://github.com/pyg-team/pytorch_geometric/pull/6198))\n- Breaking Change: Move `ExplainerConfig` arguments to the `Explainer` class ([#6176](https://github.com/pyg-team/pytorch_geometric/pull/6176))\n- Refactored `NeighborSampler` to be input-type agnostic ([#6173](https://github.com/pyg-team/pytorch_geometric/pull/6173))\n- Infer correct CUDA device ID in `profileit` decorator ([#6164](https://github.com/pyg-team/pytorch_geometric/pull/6164))\n- Correctly use edge weights in `GDC` example ([#6159](https://github.com/pyg-team/pytorch_geometric/pull/6159))\n- Breaking Change: Moved PyTorch Lightning data modules to `torch_geometric.data.lightning` ([#6140](https://github.com/pyg-team/pytorch_geometric/pull/6140))\n- Make `torch_sparse` an optional dependency ([#6132](https://github.com/pyg-team/pytorch_geometric/pull/6132), [#6134](https://github.com/pyg-team/pytorch_geometric/pull/6134), [#6138](https://github.com/pyg-team/pytorch_geometric/pull/6138), [#6139](https://github.com/pyg-team/pytorch_geometric/pull/6139), [#7387](https://github.com/pyg-team/pytorch_geometric/pull/7387))\n- Optimized `utils.softmax` implementation ([#6113](https://github.com/pyg-team/pytorch_geometric/pull/6113), [#6155](https://github.com/pyg-team/pytorch_geometric/pull/6155), [#6805](https://github.com/pyg-team/pytorch_geometric/pull/6805))\n- Optimized `topk` implementation for large enough graphs ([#6123](https://github.com/pyg-team/pytorch_geometric/pull/6123))\n\n### Removed\n\n- `torch-sparse` is now an optional dependency ([#6625](https://github.com/pyg-team/pytorch_geometric/pull/6625), [#6626](https://github.com/pyg-team/pytorch_geometric/pull/6626), [#6627](https://github.com/pyg-team/pytorch_geometric/pull/6627), [#6628](https://github.com/pyg-team/pytorch_geometric/pull/6628), [#6629](https://github.com/pyg-team/pytorch_geometric/pull/6629), [#6630](https://github.com/pyg-team/pytorch_geometric/pull/6630))\n- Removed most of the `torch-scatter` dependencies ([#6394](https://github.com/pyg-team/pytorch_geometric/pull/6394), [#6395](https://github.com/pyg-team/pytorch_geometric/pull/6395), [#6399](https://github.com/pyg-team/pytorch_geometric/pull/6399), [#6400](https://github.com/pyg-team/pytorch_geometric/pull/6400), [#6615](https://github.com/pyg-team/pytorch_geometric/pull/6615), [#6617](https://github.com/pyg-team/pytorch_geometric/pull/6617))\n- Removed the deprecated classes `GNNExplainer` and `Explainer` from `nn.models` ([#6382](https://github.com/pyg-team/pytorch_geometric/pull/6382))\n- Removed `target_index` argument in the `Explainer` interface ([#6270](https://github.com/pyg-team/pytorch_geometric/pull/6270))\n- Removed `Aggregation.set_validate_args` option ([#6175](https://github.com/pyg-team/pytorch_geometric/pull/6175))\n\n## [2.2.0] - 2022-12-01\n\n### Added\n\n- Extended `GNNExplainer` to support edge level explanations ([#6056](https://github.com/pyg-team/pytorch_geometric/pull/6056), [#6083](https://github.com/pyg-team/pytorch_geometric/pull/6083))\n- Added CPU affinitization for `NodeLoader` ([#6005](https://github.com/pyg-team/pytorch_geometric/pull/6005))\n- Added triplet sampling in `LinkNeighborLoader` ([#6004](https://github.com/pyg-team/pytorch_geometric/pull/6004))\n- Added `FusedAggregation` of simple scatter reductions ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036))\n- Added a `to_smiles` function ([#6038](https://github.com/pyg-team/pytorch_geometric/pull/6038))\n- Added option to make normalization coefficients trainable in `PNAConv` ([#6039](https://github.com/pyg-team/pytorch_geometric/pull/6039))\n- Added `semi_grad` option in `VarAggregation` and `StdAggregation` ([#6042](https://github.com/pyg-team/pytorch_geometric/pull/6042))\n- Allow for fused aggregations in `MultiAggregation` ([#6036](https://github.com/pyg-team/pytorch_geometric/pull/6036), [#6040](https://github.com/pyg-team/pytorch_geometric/pull/6040))\n- Added `HeteroData` support for `to_captum_model` and added `to_captum_input` ([#5934](https://github.com/pyg-team/pytorch_geometric/pull/5934))\n- Added `HeteroData` support in `RandomNodeLoader` ([#6007](https://github.com/pyg-team/pytorch_geometric/pull/6007))\n- Added bipartite `GraphSAGE` example ([#5834](https://github.com/pyg-team/pytorch_geometric/pull/5834))\n- Added `LRGBDataset` to include 5 datasets from the [Long Range Graph Benchmark](https://openreview.net/pdf?id=in7XC5RcjEn) ([#5935](https://github.com/pyg-team/pytorch_geometric/pull/5935))\n- Added a warning for invalid node and edge type names in `HeteroData` ([#5990](https://github.com/pyg-team/pytorch_geometric/pull/5990))\n- Added PyTorch 1.13 support ([#5975](https://github.com/pyg-team/pytorch_geometric/pull/5975))\n- Added `int32` support in `NeighborLoader` ([#5948](https://github.com/pyg-team/pytorch_geometric/pull/5948))\n- Add `dgNN` support and `FusedGATConv` implementation ([#5140](https://github.com/pyg-team/pytorch_geometric/pull/5140))\n- Added `lr_scheduler_solver` and customized `lr_scheduler` classes ([#5942](https://github.com/pyg-team/pytorch_geometric/pull/5942))\n- Add `to_fixed_size` graph transformer ([#5939](https://github.com/pyg-team/pytorch_geometric/pull/5939))\n- Add support for symbolic tracing of `SchNet` model ([#5938](https://github.com/pyg-team/pytorch_geometric/pull/5938))\n- Add support for customizable interaction graph in `SchNet` model ([#5919](https://github.com/pyg-team/pytorch_geometric/pull/5919))\n- Started adding `torch.sparse` support to PyG ([#5906](https://github.com/pyg-team/pytorch_geometric/pull/5906), [#5944](https://github.com/pyg-team/pytorch_geometric/pull/5944), [#6003](https://github.com/pyg-team/pytorch_geometric/pull/6003), [#6633](https://github.com/pyg-team/pytorch_geometric/pull/6633))\n- Added `HydroNet` water cluster dataset ([#5537](https://github.com/pyg-team/pytorch_geometric/pull/5537), [#5902](https://github.com/pyg-team/pytorch_geometric/pull/5902), [#5903](https://github.com/pyg-team/pytorch_geometric/pull/5903))\n- Added explainability support for heterogeneous GNNs ([#5886](https://github.com/pyg-team/pytorch_geometric/pull/5886))\n- Added `SparseTensor` support to `SuperGATConv` ([#5888](https://github.com/pyg-team/pytorch_geometric/pull/5888))\n- Added TorchScript support for `AttentiveFP `([#5868](https://github.com/pyg-team/pytorch_geometric/pull/5868))\n- Added `num_steps` argument to training and inference benchmarks ([#5898](https://github.com/pyg-team/pytorch_geometric/pull/5898))\n- Added `torch.onnx.export` support ([#5877](https://github.com/pyg-team/pytorch_geometric/pull/5877), [#5997](https://github.com/pyg-team/pytorch_geometric/pull/5997))\n- Enable VTune ITT in inference and training benchmarks ([#5830](https://github.com/pyg-team/pytorch_geometric/pull/5830), [#5878](https://github.com/pyg-team/pytorch_geometric/pull/5878))\n- Add training benchmark ([#5774](https://github.com/pyg-team/pytorch_geometric/pull/5774))\n- Added a \"Link Prediction on MovieLens\" Colab notebook ([#5823](https://github.com/pyg-team/pytorch_geometric/pull/5823))\n- Added custom `sampler` support in `LightningDataModule` ([#5820](https://github.com/pyg-team/pytorch_geometric/pull/5820))\n- Added a `return_semantic_attention_weights` argument `HANConv` ([#5787](https://github.com/pyg-team/pytorch_geometric/pull/5787))\n- Added `disjoint` argument to `NeighborLoader` and `LinkNeighborLoader` ([#5775](https://github.com/pyg-team/pytorch_geometric/pull/5775))\n- Added support for `input_time` in `NeighborLoader` ([#5763](https://github.com/pyg-team/pytorch_geometric/pull/5763))\n- Added `disjoint` mode for temporal `LinkNeighborLoader` ([#5717](https://github.com/pyg-team/pytorch_geometric/pull/5717))\n- Added `HeteroData` support for `transforms.Constant` ([#5700](https://github.com/pyg-team/pytorch_geometric/pull/5700))\n- Added `np.memmap` support in `NeighborLoader` ([#5696](https://github.com/pyg-team/pytorch_geometric/pull/5696))\n- Added `assortativity` that computes degree assortativity coefficient ([#5587](https://github.com/pyg-team/pytorch_geometric/pull/5587))\n- Added `SSGConv` layer ([#5599](https://github.com/pyg-team/pytorch_geometric/pull/5599))\n- Added `shuffle_node`, `mask_feature` and `add_random_edge` augmentation methdos ([#5548](https://github.com/pyg-team/pytorch_geometric/pull/5548))\n- Added `dropout_path` augmentation that drops edges from a graph based on random walks ([#5531](https://github.com/pyg-team/pytorch_geometric/pull/5531))\n- Add support for filling labels with dummy values in `HeteroData.to_homogeneous()` ([#5540](https://github.com/pyg-team/pytorch_geometric/pull/5540))\n- Added `temporal_strategy` option to `neighbor_sample` ([#5576](https://github.com/pyg-team/pyg-lib/pull/5576))\n- Added `torch_geometric.sampler` package to docs ([#5563](https://github.com/pyg-team/pytorch_geometric/pull/5563))\n- Added the `DGraphFin` dynamic graph dataset ([#5504](https://github.com/pyg-team/pytorch_geometric/pull/5504))\n- Added `dropout_edge` augmentation that randomly drops edges from a graph - the usage of `dropout_adj` is now deprecated ([#5495](https://github.com/pyg-team/pytorch_geometric/pull/5495))\n- Added `dropout_node` augmentation that randomly drops nodes from a graph ([#5481](https://github.com/pyg-team/pytorch_geometric/pull/5481))\n- Added `AddRandomMetaPaths` that adds edges based on random walks along a metapath ([#5397](https://github.com/pyg-team/pytorch_geometric/pull/5397))\n- Added `WLConvContinuous` for performing WL refinement with continuous attributes ([#5316](https://github.com/pyg-team/pytorch_geometric/pull/5316))\n- Added `print_summary` method for the `torch_geometric.data.Dataset` interface ([#5438](https://github.com/pyg-team/pytorch_geometric/pull/5438))\n- Added `sampler` support to `LightningDataModule` ([#5456](https://github.com/pyg-team/pytorch_geometric/pull/5456), [#5457](https://github.com/pyg-team/pytorch_geometric/pull/5457))\n- Added official splits to `MalNetTiny` dataset ([#5078](https://github.com/pyg-team/pytorch_geometric/pull/5078))\n- Added `IndexToMask` and `MaskToIndex` transforms ([#5375](https://github.com/pyg-team/pytorch_geometric/pull/5375), [#5455](https://github.com/pyg-team/pytorch_geometric/pull/5455))\n- Added `FeaturePropagation` transform ([#5387](https://github.com/pyg-team/pytorch_geometric/pull/5387))\n- Added `PositionalEncoding` ([#5381](https://github.com/pyg-team/pytorch_geometric/pull/5381))\n- Consolidated sampler routines behind `torch_geometric.sampler`, enabling ease of extensibility in the future ([#5312](https://github.com/pyg-team/pytorch_geometric/pull/5312), [#5365](https://github.com/pyg-team/pytorch_geometric/pull/5365), [#5402](https://github.com/pyg-team/pytorch_geometric/pull/5402), [#5404](https://github.com/pyg-team/pytorch_geometric/pull/5404)), [#5418](https://github.com/pyg-team/pytorch_geometric/pull/5418))\n- Added `pyg-lib` neighbor sampling ([#5384](https://github.com/pyg-team/pytorch_geometric/pull/5384), [#5388](https://github.com/pyg-team/pytorch_geometric/pull/5388))\n- Added `pyg_lib.segment_matmul` integration within `HeteroLinear` ([#5330](https://github.com/pyg-team/pytorch_geometric/pull/5330), [#5347](https://github.com/pyg-team/pytorch_geometric/pull/5347)))\n- Enabled `bf16` support in benchmark scripts ([#5293](https://github.com/pyg-team/pytorch_geometric/pull/5293), [#5341](https://github.com/pyg-team/pytorch_geometric/pull/5341))\n- Added `Aggregation.set_validate_args` option to skip validation of `dim_size` ([#5290](https://github.com/pyg-team/pytorch_geometric/pull/5290))\n- Added `SparseTensor` support to inference and training benchmark suite ([#5242](https://github.com/pyg-team/pytorch_geometric/pull/5242), [#5258](https://github.com/pyg-team/pytorch_geometric/pull/5258), [#5881](https://github.com/pyg-team/pytorch_geometric/pull/5881))\n- Added experimental mode in inference benchmarks ([#5254](https://github.com/pyg-team/pytorch_geometric/pull/5254))\n- Added node classification example instrumented with [Weights and Biases (W&B) logging](https://wandb.com) and [W&B Sweeps](https://wandb.com/sweeps) ([#5192](https://github.com/pyg-team/pytorch_geometric/pull/5192))\n- Added experimental mode for `utils.scatter` ([#5232](https://github.com/pyg-team/pytorch_geometric/pull/5232), [#5241](https://github.com/pyg-team/pytorch_geometric/pull/5241), [#5386](https://github.com/pyg-team/pytorch_geometric/pull/5386))\n- Added missing test labels in `HGBDataset` ([#5233](https://github.com/pyg-team/pytorch_geometric/pull/5233))\n- Added `BaseStorage.get()` functionality ([#5240](https://github.com/pyg-team/pytorch_geometric/pull/5240))\n- Added a test to confirm that `to_hetero` works with `SparseTensor` ([#5222](https://github.com/pyg-team/pytorch_geometric/pull/5222))\n- Added `torch_geometric.explain` module with base functionality for explainability methods ([#5804](https://github.com/pyg-team/pytorch_geometric/pull/5804), [#6054](https://github.com/pyg-team/pytorch_geometric/pull/6054), [#6089](https://github.com/pyg-team/pytorch_geometric/pull/6089))\n\n### Changed\n\n- Moved and adapted `GNNExplainer` from `torch_geometric.nn` to `torch_geometric.explain.algorithm` ([#5967](https://github.com/pyg-team/pytorch_geometric/pull/5967), [#6065](https://github.com/pyg-team/pytorch_geometric/pull/6065))\n- Optimized scatter implementations for CPU/GPU, both with and without backward computation ([#6051](https://github.com/pyg-team/pytorch_geometric/pull/6051), [#6052](https://github.com/pyg-team/pytorch_geometric/pull/6052))\n- Support temperature value in `dense_mincut_pool` ([#5908](https://github.com/pyg-team/pytorch_geometric/pull/5908))\n- Fixed a bug in which `VirtualNode` mistakenly treated node features as edge features ([#5819](https://github.com/pyg-team/pytorch_geometric/pull/5819))\n- Fixed `setter` and `getter` handling in `BaseStorage` ([#5815](https://github.com/pyg-team/pytorch_geometric/pull/5815))\n- Fixed `path` in `hetero_conv_dblp.py` example ([#5686](https://github.com/pyg-team/pytorch_geometric/pull/5686))\n- Fix `auto_select_device` routine in GraphGym for PyTorch Lightning>=1.7 ([#5677](https://github.com/pyg-team/pytorch_geometric/pull/5677))\n- Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641))\n- Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642))\n- Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610))\n- Improved type hint support ([#5842](https://github.com/pyg-team/pytorch_geometric/pull/5842), [#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701](https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714), [#5715](https://github.com/pyg-team/pytorch_geometric/pull/5715), [#5716](https://github.com/pyg-team/pytorch_geometric/pull/5716), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5724](https://github.com/pyg-team/pytorch_geometric/pull/5724), [#5725](https://github.com/pyg-team/pytorch_geometric/pull/5725), [#5726](https://github.com/pyg-team/pytorch_geometric/pull/5726), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729), [#5730](https://github.com/pyg-team/pytorch_geometric/pull/5730), [#5731](https://github.com/pyg-team/pytorch_geometric/pull/5731), [#5732](https://github.com/pyg-team/pytorch_geometric/pull/5732), [#5733](https://github.com/pyg-team/pytorch_geometric/pull/5733), [#5743](https://github.com/pyg-team/pytorch_geometric/pull/5743), [#5734](https://github.com/pyg-team/pytorch_geometric/pull/5734), [#5735](https://github.com/pyg-team/pytorch_geometric/pull/5735), [#5736](https://github.com/pyg-team/pytorch_geometric/pull/5736), [#5737](https://github.com/pyg-team/pytorch_geometric/pull/5737), [#5738](https://github.com/pyg-team/pytorch_geometric/pull/5738), [#5747](https://github.com/pyg-team/pytorch_geometric/pull/5747), [#5752](https://github.com/pyg-team/pytorch_geometric/pull/5752), [#5753](https://github.com/pyg-team/pytorch_geometric/pull/5753), [#5754](https://github.com/pyg-team/pytorch_geometric/pull/5754), [#5756](https://github.com/pyg-team/pytorch_geometric/pull/5756), [#5757](https://github.com/pyg-team/pytorch_geometric/pull/5757), [#5758](https://github.com/pyg-team/pytorch_geometric/pull/5758), [#5760](https://github.com/pyg-team/pytorch_geometric/pull/5760), [#5766](https://github.com/pyg-team/pytorch_geometric/pull/5766), [#5767](https://github.com/pyg-team/pytorch_geometric/pull/5767), [#5768](https://github.com/pyg-team/pytorch_geometric/pull/5768)), [#5781](https://github.com/pyg-team/pytorch_geometric/pull/5781), [#5778](https://github.com/pyg-team/pytorch_geometric/pull/5778), [#5797](https://github.com/pyg-team/pytorch_geometric/pull/5797), [#5798](https://github.com/pyg-team/pytorch_geometric/pull/5798), [#5799](https://github.com/pyg-team/pytorch_geometric/pull/5799), [#5800](https://github.com/pyg-team/pytorch_geometric/pull/5800), [#5806](https://github.com/pyg-team/pytorch_geometric/pull/5806), [#5810](https://github.com/pyg-team/pytorch_geometric/pull/5810), [#5811](https://github.com/pyg-team/pytorch_geometric/pull/5811), [#5828](https://github.com/pyg-team/pytorch_geometric/pull/5828), [#5847](https://github.com/pyg-team/pytorch_geometric/pull/5847), [#5851](https://github.com/pyg-team/pytorch_geometric/pull/5851), [#5852](https://github.com/pyg-team/pytorch_geometric/pull/5852))\n- Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601))\n- Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614))\n- Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602))\n- Fixed a bug when applying several scalers with `PNAConv` ([#5514](https://github.com/pyg-team/pytorch_geometric/issues/5514))\n- Allow `.` in `ParameterDict` key names ([#5494](https://github.com/pyg-team/pytorch_geometric/pull/5494))\n- Renamed `drop_unconnected_nodes` to `drop_unconnected_node_types` and `drop_orig_edges` to `drop_orig_edge_types` in `AddMetapaths` ([#5490](https://github.com/pyg-team/pytorch_geometric/pull/5490))\n- Improved `utils.scatter` performance by explicitly choosing better implementation for `add` and `mean` reduction ([#5399](https://github.com/pyg-team/pytorch_geometric/pull/5399))\n- Fix `to_dense_adj` with empty `edge_index` ([#5476](https://github.com/pyg-team/pytorch_geometric/pull/5476))\n- The `AttentionalAggregation` module can now be applied to compute attentin on a per-feature level ([#5449](https://github.com/pyg-team/pytorch_geometric/pull/5449))\n- Ensure equal lenghts of `num_neighbors` across edge types in `NeighborLoader` ([#5444](https://github.com/pyg-team/pytorch_geometric/pull/5444))\n- Fixed a bug in `TUDataset` in which node features were wrongly constructed whenever `node_attributes` only hold a single feature (_e.g._, in `PROTEINS`) ([#5441](https://github.com/pyg-team/pytorch_geometric/pull/5441))\n- Breaking change: removed `num_neighbors` as an attribute of loader ([#5404](https://github.com/pyg-team/pytorch_geometric/pull/5404))\n- `ASAPooling` is now jittable ([#5395](https://github.com/pyg-team/pytorch_geometric/pull/5395))\n- Updated unsupervised `GraphSAGE` example to leverage `LinkNeighborLoader` ([#5317](https://github.com/pyg-team/pytorch_geometric/pull/5317))\n- Replace in-place operations with out-of-place ones to align with `torch.scatter_reduce` API ([#5353](https://github.com/pyg-team/pytorch_geometric/pull/5353))\n- Breaking bugfix: `PointTransformerConv` now correctly uses `sum` aggregation ([#5332](https://github.com/pyg-team/pytorch_geometric/pull/5332))\n- Improve out-of-bounds error message in `MessagePassing` ([#5339](https://github.com/pyg-team/pytorch_geometric/pull/5339))\n- Allow file names of a `Dataset` to be specified as either property and method ([#5338](https://github.com/pyg-team/pytorch_geometric/pull/5338))\n- Fixed separating a list of `SparseTensor` within `InMemoryDataset` ([#5299](https://github.com/pyg-team/pytorch_geometric/pull/5299))\n- Improved name resolving of normalization layers ([#5277](https://github.com/pyg-team/pytorch_geometric/pull/5277))\n- Fail gracefully on `GLIBC` errors within `torch-spline-conv` ([#5276](https://github.com/pyg-team/pytorch_geometric/pull/5276))\n- Fixed `Dataset.num_classes` in case a `transform` modifies `data.y` ([#5274](https://github.com/pyg-team/pytorch_geometric/pull/5274))\n- Allow customization of the activation function within `PNAConv` ([#5262](https://github.com/pyg-team/pytorch_geometric/pull/5262))\n- Do not fill `InMemoryDataset` cache on `dataset.num_features` ([#5264](https://github.com/pyg-team/pytorch_geometric/pull/5264))\n- Changed tests relying on `dblp` datasets to instead use synthetic data ([#5250](https://github.com/pyg-team/pytorch_geometric/pull/5250))\n- Fixed a bug for the initialization of activation function examples in `custom_graphgym` ([#5243](https://github.com/pyg-team/pytorch_geometric/pull/5243))\n- Allow any integer tensors when checking edge_index input to message passing ([5281](https://github.com/pyg-team/pytorch_geometric/pull/5281))\n\n### Removed\n\n- Removed `scatter_reduce` option from experimental mode ([#5399](https://github.com/pyg-team/pytorch_geometric/pull/5399))\n\n## [2.1.0] - 2022-08-17\n\n### Added\n\n- Added the test for `DeepGCNLayer` ([#5704](https://github.com/pyg-team/pytorch_geometric/pull/5704))\n- Allow `.` in `ModuleDict` key names ([#5227](https://github.com/pyg-team/pytorch_geometric/pull/5227))\n- Added `edge_label_time` argument to `LinkNeighborLoader` ([#5137](https://github.com/pyg-team/pytorch_geometric/pull/5137), [#5173](https://github.com/pyg-team/pytorch_geometric/pull/5173))\n- Let `ImbalancedSampler` accept `torch.Tensor` as input ([#5138](https://github.com/pyg-team/pytorch_geometric/pull/5138))\n- Added `flow` argument to `gcn_norm` to correctly normalize the adjacency matrix in `GCNConv` ([#5149](https://github.com/pyg-team/pytorch_geometric/pull/5149))\n- `NeighborSampler` supports graphs without edges ([#5072](https://github.com/pyg-team/pytorch_geometric/pull/5072))\n- Added the `MeanSubtractionNorm` layer ([#5068](https://github.com/pyg-team/pytorch_geometric/pull/5068))\n- Added `pyg_lib.segment_matmul` integration within `RGCNConv` ([#5052](https://github.com/pyg-team/pytorch_geometric/pull/5052), [#5096](https://github.com/pyg-team/pytorch_geometric/pull/5096))\n- Support `SparseTensor` as edge label in `LightGCN` (#[5046](https://github.com/pyg-team/pytorch_geometric/issues/5046))\n- Added support for `BasicGNN` models within `to_hetero` ([#5091](https://github.com/pyg-team/pytorch_geometric/pull/5091))\n- Added support for computing weighted metapaths in `AddMetapaths` ([#5049](https://github.com/pyg-team/pytorch_geometric/pull/5049))\n- Added inference benchmark suite ([#4915](https://github.com/pyg-team/pytorch_geometric/pull/4915))\n- Added a dynamically sized batch sampler for filling a mini-batch with a variable number of samples up to a maximum size ([#4972](https://github.com/pyg-team/pytorch_geometric/pull/4972))\n- Added fine grained options for setting `bias` and `dropout` per layer in the `MLP` model ([#4981](https://github.com/pyg-team/pytorch_geometric/pull/4981))\n- Added `EdgeCNN` model ([#4991](https://github.com/pyg-team/pytorch_geometric/pull/4991))\n- Added scalable `inference` mode in `BasicGNN` with layer-wise neighbor loading ([#4977](https://github.com/pyg-team/pytorch_geometric/pull/4977))\n- Added inference benchmarks ([#4892](https://github.com/pyg-team/pytorch_geometric/pull/4892), [#5107](https://github.com/pyg-team/pytorch_geometric/pull/5107))\n- Added PyTorch 1.12 support ([#4975](https://github.com/pyg-team/pytorch_geometric/pull/4975))\n- Added `unbatch_edge_index` functionality for splitting an `edge_index` tensor according to a `batch` vector ([#4903](https://github.com/pyg-team/pytorch_geometric/pull/4903))\n- Added node-wise normalization mode in `LayerNorm` ([#4944](https://github.com/pyg-team/pytorch_geometric/pull/4944))\n- Added support for `normalization_resolver` ([#4926](https://github.com/pyg-team/pytorch_geometric/pull/4926), [#4951](https://github.com/pyg-team/pytorch_geometric/pull/4951), [#4958](https://github.com/pyg-team/pytorch_geometric/pull/4958), [#4959](https://github.com/pyg-team/pytorch_geometric/pull/4959))\n- Added notebook tutorial for `torch_geometric.nn.aggr` package to documentation ([#4927](https://github.com/pyg-team/pytorch_geometric/pull/4927))\n- Added support for `follow_batch` for lists or dictionaries of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))\n- Added `Data.validate()` and `HeteroData.validate()` functionality ([#4885](https://github.com/pyg-team/pytorch_geometric/pull/4885))\n- Added `LinkNeighborLoader` support to `LightningDataModule` ([#4868](https://github.com/pyg-team/pytorch_geometric/pull/4868))\n- Added `predict()` support to the `LightningNodeData` module ([#4884](https://github.com/pyg-team/pytorch_geometric/pull/4884))\n- Added `time_attr` argument to `LinkNeighborLoader` ([#4877](https://github.com/pyg-team/pytorch_geometric/pull/4877), [#4908](https://github.com/pyg-team/pytorch_geometric/pull/4908))\n- Added a `filter_per_worker` argument to data loaders to allow filtering of data within sub-processes ([#4873](https://github.com/pyg-team/pytorch_geometric/pull/4873))\n- Added a `NeighborLoader` benchmark script ([#4815](https://github.com/pyg-team/pytorch_geometric/pull/4815), [#4862](https://github.com/pyg-team/pytorch_geometric/pull/4862/files))\n- Added support for `FeatureStore` and `GraphStore` in `NeighborLoader` ([#4817](https://github.com/pyg-team/pytorch_geometric/pull/4817), [#4851](https://github.com/pyg-team/pytorch_geometric/pull/4851), [#4854](https://github.com/pyg-team/pytorch_geometric/pull/4854), [#4856](https://github.com/pyg-team/pytorch_geometric/pull/4856), [#4857](https://github.com/pyg-team/pytorch_geometric/pull/4857), [#4882](https://github.com/pyg-team/pytorch_geometric/pull/4882), [#4883](https://github.com/pyg-team/pytorch_geometric/pull/4883), [#4929](https://github.com/pyg-team/pytorch_geometric/pull/4929), [#4992](https://github.com/pyg-team/pytorch_geometric/pull/4922), [#4962](https://github.com/pyg-team/pytorch_geometric/pull/4962), [#4968](https://github.com/pyg-team/pytorch_geometric/pull/4968), [#5037](https://github.com/pyg-team/pytorch_geometric/pull/5037), [#5088](https://github.com/pyg-team/pytorch_geometric/pull/5088), [#5270](https://github.com/pyg-team/pytorch_geometric/pull/5270), [#5307](https://github.com/pyg-team/pytorch_geometric/pull/5307), [#5318](https://github.com/pyg-team/pytorch_geometric/pull/5318))\n- Added a `normalize` parameter to `dense_diff_pool` ([#4847](https://github.com/pyg-team/pytorch_geometric/pull/4847))\n- Added `size=None` explanation to jittable `MessagePassing` modules in the documentation ([#4850](https://github.com/pyg-team/pytorch_geometric/pull/4850))\n- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))\n- Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816))\n- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807), [#4853](https://github.com/pyg-team/pytorch_geometric/pull/4853))\n- Added `FeatureStore` and `GraphStore` abstractions ([#4534](https://github.com/pyg-team/pytorch_geometric/pull/4534), [#4568](https://github.com/pyg-team/pytorch_geometric/pull/4568), [#5120](https://github.com/pyg-team/pytorch_geometric/pull/5120))\n- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))\n- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))\n- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))\n- Added a `max_sample` argument to `AddMetaPaths` in order to tackle very dense metapath edges ([#4750](https://github.com/pyg-team/pytorch_geometric/pull/4750))\n- Test `HANConv` with empty tensors ([#4756](https://github.com/pyg-team/pytorch_geometric/pull/4756), [#4841](https://github.com/pyg-team/pytorch_geometric/pull/4841))\n- Added the `bias` vector to the `GCN` model definition in the \"Create Message Passing Networks\" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))\n- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))\n- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))\n- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033), [#5085](https://github.com/pyg-team/pytorch_geometric/pull/5085), [#5097](https://github.com/pyg-team/pytorch_geometric/pull/5097), [#5099](https://github.com/pyg-team/pytorch_geometric/pull/5099), [#5104](https://github.com/pyg-team/pytorch_geometric/pull/5104), [#5113](https://github.com/pyg-team/pytorch_geometric/pull/5113), [#5130](https://github.com/pyg-team/pytorch_geometric/pull/5130), [#5098](https://github.com/pyg-team/pytorch_geometric/pull/5098), [#5191](https://github.com/pyg-team/pytorch_geometric/pull/5191))\n- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))\n- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))\n- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))\n- Added benchmarks via [`wandb`](https://wandb.ai/site) ([#4656](https://github.com/pyg-team/pytorch_geometric/pull/4656), [#4672](https://github.com/pyg-team/pytorch_geometric/pull/4672), [#4676](https://github.com/pyg-team/pytorch_geometric/pull/4676))\n- Added `unbatch` functionality ([#4628](https://github.com/pyg-team/pytorch_geometric/pull/4628))\n- Confirm that `to_hetero()` works with custom functions, _e.g._, `dropout_adj` ([4653](https://github.com/pyg-team/pytorch_geometric/pull/4653))\n- Added the `MLP.plain_last=False` option ([4652](https://github.com/pyg-team/pytorch_geometric/pull/4652))\n- Added a check in `HeteroConv` and `to_hetero()` to ensure that `MessagePassing.add_self_loops` is disabled ([4647](https://github.com/pyg-team/pytorch_geometric/pull/4647))\n- Added `HeteroData.subgraph()`, `HeteroData.node_type_subgraph()` and `HeteroData.edge_type_subgraph()` support ([#4635](https://github.com/pyg-team/pytorch_geometric/pull/4635))\n- Added the `AQSOL` dataset ([#4626](https://github.com/pyg-team/pytorch_geometric/pull/4626))\n- Added `HeteroData.node_items()` and `HeteroData.edge_items()` functionality ([#4644](https://github.com/pyg-team/pytorch_geometric/pull/4644))\n- Added PyTorch Lightning support in GraphGym ([#4511](https://github.com/pyg-team/pytorch_geometric/pull/4511), [#4516](https://github.com/pyg-team/pytorch_geometric/pull/4516) [#4531](https://github.com/pyg-team/pytorch_geometric/pull/4531), [#4689](https://github.com/pyg-team/pytorch_geometric/pull/4689), [#4843](https://github.com/pyg-team/pytorch_geometric/pull/4843))\n- Added support for returning embeddings in `MLP` models ([#4625](https://github.com/pyg-team/pytorch_geometric/pull/4625))\n- Added faster initialization of `NeighborLoader` in case edge indices are already sorted (via `is_sorted=True`) ([#4620](https://github.com/pyg-team/pytorch_geometric/pull/4620), [#4702](https://github.com/pyg-team/pytorch_geometric/pull/4702))\n- Added `AddPositionalEncoding` transform ([#4521](https://github.com/pyg-team/pytorch_geometric/pull/4521))\n- Added `HeteroData.is_undirected()` support ([#4604](https://github.com/pyg-team/pytorch_geometric/pull/4604))\n- Added the `Genius` and `Wiki` datasets to `nn.datasets.LINKXDataset` ([#4570](https://github.com/pyg-team/pytorch_geometric/pull/4570), [#4600](https://github.com/pyg-team/pytorch_geometric/pull/4600))\n- Added `nn.aggr.EquilibrumAggregation` implicit global layer ([#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522))\n- Added support for graph-level outputs in `to_hetero` ([#4582](https://github.com/pyg-team/pytorch_geometric/pull/4582))\n- Added `CHANGELOG.md` ([#4581](https://github.com/pyg-team/pytorch_geometric/pull/4581))\n- Added `HeteroData` support to the `RemoveIsolatedNodes` transform ([#4479](https://github.com/pyg-team/pytorch_geometric/pull/4479))\n- Added `HeteroData.num_features` functionality ([#4504](https://github.com/pyg-team/pytorch_geometric/pull/4504))\n- Added support for projecting features before propagation in `SAGEConv` ([#4437](https://github.com/pyg-team/pytorch_geometric/pull/4437))\n- Added `Geom-GCN` splits to the `Planetoid` datasets ([#4442](https://github.com/pyg-team/pytorch_geometric/pull/4442))\n- Added a `LinkNeighborLoader` for training scalable link predictions models [#4396](https://github.com/pyg-team/pytorch_geometric/pull/4396), [#4439](https://github.com/pyg-team/pytorch_geometric/pull/4439), [#4441](https://github.com/pyg-team/pytorch_geometric/pull/4441), [#4446](https://github.com/pyg-team/pytorch_geometric/pull/4446), [#4508](https://github.com/pyg-team/pytorch_geometric/pull/4508), [#4509](https://github.com/pyg-team/pytorch_geometric/pull/4509))\n- Added an unsupervised `GraphSAGE` example on `PPI` ([#4416](https://github.com/pyg-team/pytorch_geometric/pull/4416))\n- Added support for `LSTM` aggregation in `SAGEConv` ([#4379](https://github.com/pyg-team/pytorch_geometric/pull/4379))\n- Added support for floating-point labels in `RandomLinkSplit` ([#4311](https://github.com/pyg-team/pytorch_geometric/pull/4311), [#4383](https://github.com/pyg-team/pytorch_geometric/pull/4383))\n- Added support for `torch.data` `DataPipes` ([#4302](https://github.com/pyg-team/pytorch_geometric/pull/4302), [#4345](https://github.com/pyg-team/pytorch_geometric/pull/4345), [#4349](https://github.com/pyg-team/pytorch_geometric/pull/4349))\n- Added support for the `cosine` argument in the `KNNGraph`/`RadiusGraph` transforms ([#4344](https://github.com/pyg-team/pytorch_geometric/pull/4344))\n- Added support graph-level attributes in `networkx` conversion ([#4343](https://github.com/pyg-team/pytorch_geometric/pull/4343))\n- Added support for renaming node types via `HeteroData.rename` ([#4329](https://github.com/pyg-team/pytorch_geometric/pull/4329))\n- Added an example to load a trained PyG model in C++ ([#4307](https://github.com/pyg-team/pytorch_geometric/pull/4307))\n- Added a `MessagePassing.explain_message` method to customize making explanations on messages ([#4278](https://github.com/pyg-team/pytorch_geometric/pull/4278), [#4448](https://github.com/pyg-team/pytorch_geometric/pull/4448)))\n- Added support for `GATv2Conv` in the `nn.models.GAT` model ([#4357](https://github.com/pyg-team/pytorch_geometric/pull/4357))\n- Added `HeteroData.subgraph` functionality ([#4243](https://github.com/pyg-team/pytorch_geometric/pull/4243))\n- Added the `MaskLabel` module and a corresponding masked label propagation example ([#4197](https://github.com/pyg-team/pytorch_geometric/pull/4197))\n- Added temporal sampling support to `NeighborLoader` ([#4025](https://github.com/pyg-team/pytorch_geometric/pull/4025))\n- Added an example for unsupervised heterogeneous graph learning based on \"Deep Multiplex Graph Infomax\" ([#3189](https://github.com/pyg-team/pytorch_geometric/pull/3189))\n\n### Changed\n\n- Changed docstring for `RandomLinkSplit` ([#5190](https://github.com/pyg-team/pytorch_geometric/issues/5190))\n- Switched to PyTorch `scatter_reduce` implementation - experimental feature ([#5120](https://github.com/pyg-team/pytorch_geometric/pull/5120))\n- Fixed `RGATConv` device mismatches for `f-scaled` mode ([#5187](https://github.com/pyg-team/pytorch_geometric/pull/5187))\n- Allow for multi-dimensional `edge_labels` in `LinkNeighborLoader` ([#5186](https://github.com/pyg-team/pytorch_geometric/pull/5186))\n- Fixed `GINEConv` bug with non-sequential input ([#5154](https://github.com/pyg-team/pytorch_geometric/pull/5154))\n- Improved error message ([#5095](https://github.com/pyg-team/pytorch_geometric/pull/5095))\n- Fixed `HGTLoader` bug which produced outputs with missing edge types ([#5067](https://github.com/pyg-team/pytorch_geometric/pull/5067))\n- Fixed dynamic inheritance issue in data batching ([#5051](https://github.com/pyg-team/pytorch_geometric/pull/5051))\n- Fixed `load_state_dict` in `Linear` with `strict=False` mode ([5094](https://github.com/pyg-team/pytorch_geometric/pull/5094))\n- Fixed typo in `MaskLabel.ratio_mask` ([5093](https://github.com/pyg-team/pytorch_geometric/pull/5093))\n- Fixed `data.num_node_features` computation for sparse matrices ([5089](https://github.com/pyg-team/pytorch_geometric/pull/5089))\n- Fixed `torch.fx` bug with `torch.nn.aggr` package ([#5021](https://github.com/pyg-team/pytorch_geometric/pull/5021)))\n- Fixed `GenConv` test ([4993](https://github.com/pyg-team/pytorch_geometric/pull/4993))\n- Fixed packaging tests for Python 3.10 ([4982](https://github.com/pyg-team/pytorch_geometric/pull/4982))\n- Changed `act_dict` (part of `graphgym`) to create individual instances instead of reusing the same ones everywhere ([4978](https://github.com/pyg-team/pytorch_geometric/pull/4978))\n- Fixed issue where one-hot tensors were passed to `F.one_hot` ([4970](https://github.com/pyg-team/pytorch_geometric/pull/4970))\n- Fixed `bool` arugments in `argparse` in `benchmark/` ([#4967](https://github.com/pyg-team/pytorch_geometric/pull/4967))\n- Fixed `BasicGNN` for `num_layers=1`, which now respects a desired number of `out_channels` ([#4943](https://github.com/pyg-team/pytorch_geometric/pull/4943))\n- `len(batch)` will now return the number of graphs inside the batch, not the number of attributes ([#4931](https://github.com/pyg-team/pytorch_geometric/pull/4931))\n- Fixed `data.subgraph` generation for 0-dim tensors ([#4932](https://github.com/pyg-team/pytorch_geometric/pull/4932))\n- Removed unnecssary inclusion of self-loops when sampling negative edges ([#4880](https://github.com/pyg-team/pytorch_geometric/pull/4880))\n- Fixed `InMemoryDataset` inferring wrong `len` for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))\n- Fixed `Batch.separate` when using it for lists of tensors ([#4837](https://github.com/pyg-team/pytorch_geometric/pull/4837))\n- Correct docstring for SAGEConv ([#4852](https://github.com/pyg-team/pytorch_geometric/pull/4852))\n- Fixed a bug in `TUDataset` where `pre_filter` was not applied whenever `pre_transform` was present\n- Renamed `RandomTranslate` to `RandomJitter` - the usage of `RandomTranslate` is now deprecated ([#4828](https://github.com/pyg-team/pytorch_geometric/pull/4828))\n- Do not allow accessing edge types in `HeteroData` with two node types when there exists multiple relations between these types ([#4782](https://github.com/pyg-team/pytorch_geometric/pull/4782))\n- Allow `edge_type == rev_edge_type` argument in `RandomLinkSplit` ([#4757](https://github.com/pyg-team/pytorch_geometric/pull/4757), [#5221](https://github.com/pyg-team/pytorch_geometric/pull/5221))\n- Fixed a numerical instability in the `GeneralConv` and `neighbor_sample` tests ([#4754](https://github.com/pyg-team/pytorch_geometric/pull/4754))\n- Fixed a bug in `HANConv` in which destination node features rather than source node features were propagated ([#4753](https://github.com/pyg-team/pytorch_geometric/pull/4753))\n- Fixed versions of `checkout` and `setup-python` in CI ([#4751](https://github.com/pyg-team/pytorch_geometric/pull/4751))\n- Fixed `protobuf` version ([#4719](https://github.com/pyg-team/pytorch_geometric/pull/4719))\n- Fixed the ranking protocol bug in the RGCN link prediction example ([#4688](https://github.com/pyg-team/pytorch_geometric/pull/4688))\n- Math support in Markdown ([#4683](https://github.com/pyg-team/pytorch_geometric/pull/4683))\n- Allow for `setter` properties in `Data` ([#4682](https://github.com/pyg-team/pytorch_geometric/pull/4682), [#4686](https://github.com/pyg-team/pytorch_geometric/pull/4686))\n- Allow for optional `edge_weight` in `GCN2Conv` ([#4670](https://github.com/pyg-team/pytorch_geometric/pull/4670))\n- Fixed the interplay between `TUDataset` and `pre_transform` that modify node features ([#4669](https://github.com/pyg-team/pytorch_geometric/pull/4669))\n- Make use of the `pyg_sphinx_theme` documentation template ([#4664](https://github.com/pyg-team/pyg-lib/pull/4664), [#4667](https://github.com/pyg-team/pyg-lib/pull/4667))\n- Refactored reading molecular positions from sdf file for qm9 datasets ([4654](https://github.com/pyg-team/pytorch_geometric/pull/4654))\n- Fixed `MLP.jittable()` bug in case `return_emb=True` ([#4645](https://github.com/pyg-team/pytorch_geometric/pull/4645), [#4648](https://github.com/pyg-team/pytorch_geometric/pull/4648))\n- The generated node features of `StochasticBlockModelDataset` are now ordered with respect to their labels ([#4617](https://github.com/pyg-team/pytorch_geometric/pull/4617))\n- Fixed typos in the documentation ([#4616](https://github.com/pyg-team/pytorch_geometric/pull/4616), [#4824](https://github.com/pyg-team/pytorch_geometric/pull/4824), [#4895](https://github.com/pyg-team/pytorch_geometric/pull/4895), [#5161](https://github.com/pyg-team/pytorch_geometric/pull/5161))\n- The `bias` argument in `TAGConv` is now actually applied ([#4597](https://github.com/pyg-team/pytorch_geometric/pull/4597))\n- Fixed subclass behavior of `process` and `download` in `Datsaet` ([#4586](https://github.com/pyg-team/pytorch_geometric/pull/4586))\n- Fixed filtering of attributes for loaders in case `__cat_dim__ != 0` ([#4629](https://github.com/pyg-team/pytorch_geometric/pull/4629))\n- Fixed `SparseTensor` support in `NeighborLoader` ([#4320](https://github.com/pyg-team/pytorch_geometric/pull/4320))\n- Fixed average degree handling in `PNAConv` ([#4312](https://github.com/pyg-team/pytorch_geometric/pull/4312))\n- Fixed a bug in `from_networkx` in case some attributes are PyTorch tensors ([#4486](https://github.com/pyg-team/pytorch_geometric/pull/4486))\n- Added a missing clamp in `DimeNet` ([#4506](https://github.com/pyg-team/pytorch_geometric/pull/4506), [#4562](https://github.com/pyg-team/pytorch_geometric/pull/4562))\n- Fixed the download link in `DBP15K` ([#4428](https://github.com/pyg-team/pytorch_geometric/pull/4428))\n- Fixed an autograd bug in `DimeNet` when resetting parameters ([#4424](https://github.com/pyg-team/pytorch_geometric/pull/4424))\n- Fixed bipartite message passing in case `flow=\"target_to_source\"` ([#4418](https://github.com/pyg-team/pytorch_geometric/pull/4418))\n- Fixed a bug in which `num_nodes` was not properly updated in the `FixedPoints` transform ([#4394](https://github.com/pyg-team/pytorch_geometric/pull/4394))\n- PyTorch Lightning >= 1.6 support ([#4377](https://github.com/pyg-team/pytorch_geometric/pull/4377))\n- Fixed a bug in which `GATConv` was not jittable ([#4347](https://github.com/pyg-team/pytorch_geometric/pull/4347))\n- Fixed a bug in which the GraphGym config was not stored in each specific experiment directory ([#4338](https://github.com/pyg-team/pytorch_geometric/pull/4338))\n- Fixed a bug in which `nn.models.GAT` did not produce `out_channels`-many output channels ([#4299](https://github.com/pyg-team/pytorch_geometric/pull/4299))\n- Fixed mini-batching with empty lists as attributes ([#4293](https://github.com/pyg-team/pytorch_geometric/pull/4293))\n- Fixed a bug in which `GCNConv` could not be combined with `to_hetero` on heterogeneous graphs with one node type ([#4279](https://github.com/pyg-team/pytorch_geometric/pull/4279))\n- Added a scheduler to the Graph Sage OGBN Example [#9877](https://github.com/pyg-team/pytorch_geometric/pull/9877)\n\n### Removed\n\n- Remove internal metrics in favor of `torchmetrics` ([#4287](https://github.com/pyg-team/pytorch_geometric/pull/4287))\n"
  },
  {
    "path": "CITATION.cff",
    "content": "---\ncff-version: 1.2.0\nmessage: \"Please cite our papers if you use this code in your own work.\"\ntitle: \"Fast Graph Representation Learning with PyTorch Geometric\"\nauthors:\n- family-names: \"Fey\"\n  given-names: \"Matthias\"\ndate-released: 2019-05-06\nlicense: MIT\nurl: \"https://github.com/pyg-team/pytorch_geometric\"\npreferred-citation:\n  type: article\n  title: \"PyG 2.0: Scalable Learning on Real World Graphs\"\n  authors:\n  - family-names: \"Fey\"\n    given-names: \"Matthias\"\n  - family-names: \"Sunil\"\n    given-names: \"Jinu\"\n  - family-names: \"Nitta\"\n    given-names: \"Akihiro\"\n  - family-names: \"Puri\"\n    given-names: \"Rishi\"\n  - family-names: \"Shah\"\n    given-names: \"Manan\"\n  - family-names: \"Stojanovi{\\v{c}}\"\n    given-names: \"Bla{\\v{z}}\"\n  - family-names: \"Bendias\"\n    given-names: \"Ramona\"\n  - family-names: \"Barghi\"\n    given-names: \"Alexandria\"\n  - family-names: \"Kocijan\"\n    given-names: \"Vid\"\n  - family-names: \"Zhang\"\n    given-names: \"Zecheng\"\n  - family-names: \"He\"\n    given-names: \"Xinwei\"\n  - family-names: \"Lenssen\"\n    given-names: \"Jan Eric\"\n  - family-names: \"Leskovec\"\n    given-names: \"Jure\"\n  journal: \"Temporal Graph Learning Workshop @ KDD\"\n  year: 2025\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2023 PyG Team <team@pyg.org>\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in\nall copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\nTHE SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n  <img height=\"150\" src=\"https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/master/pyg_sphinx_theme/static/img/pyg_logo_text.svg?sanitize=true\" />\n</p>\n\n______________________________________________________________________\n\n<div align=\"center\">\n\n[![PyPI Version][pypi-image]][pypi-url]\n[![PyPI Download][pypi-download-image]][pypi-download-url]\n[![Slack][slack-image]][slack-url]\n[![Contributing][contributing-image]][contributing-url]\n\n**[Documentation](https://pytorch-geometric.readthedocs.io)** |\n**[PyG 1.0 Paper](https://arxiv.org/abs/1903.02428)** |\n**[PyG 2.0 Paper](https://arxiv.org/abs/2507.16991)** |\n**[Colab Notebooks](https://pytorch-geometric.readthedocs.io/en/latest/get_started/colabs.html)** |\n**[External Resources](https://pytorch-geometric.readthedocs.io/en/latest/external/resources.html)** |\n**[OGB Examples](https://github.com/snap-stanford/ogb/tree/master/examples)**\n\n</div>\n\n**PyG** *(PyTorch Geometric)* is a library built upon [PyTorch](https://pytorch.org/) to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.\n\nIt consists of various methods for deep learning on graphs and other irregular structures, also known as *[geometric deep learning](http://geometricdeeplearning.com/)*, from a variety of published papers.\nIn addition, it consists of easy-to-use mini-batch loaders for operating on many small and single giant graphs, [multi GPU-support](https://github.com/pyg-team/pytorch_geometric/tree/master/examples/multi_gpu), [`torch.compile`](https://pytorch-geometric.readthedocs.io/en/latest/advanced/compile.html) support, [`DataPipe`](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/datapipe.py) support, a large number of common benchmark datasets (based on simple interfaces to create your own), and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds.\n\n**[Click here to join our Slack community!][slack-url]**\n\n<p align=\"center\">\n  <a href=\"https://medium.com/stanford-cs224w\"><img style=\"max-width: 941px\" src=\"https://data.pyg.org/img/cs224w_tutorials.png\" /></a>\n</p>\n\n______________________________________________________________________\n\n- [Library Highlights](#library-highlights)\n- [Quick Tour for New Users](#quick-tour-for-new-users)\n- [Architecture Overview](#architecture-overview)\n- [Implemented GNN Models](#implemented-gnn-models)\n- [Installation](#installation)\n\n## Library Highlights\n\nWhether you are a machine learning researcher or first-time user of machine learning toolkits, here are some reasons to try out PyG for machine learning on graph-structured data.\n\n- **Easy-to-use and unified API**:\n  All it takes is 10-20 lines of code to get started with training a GNN model (see the next section for a [quick tour](#quick-tour-for-new-users)).\n  PyG is *PyTorch-on-the-rocks*: It utilizes a tensor-centric API and keeps design principles close to vanilla PyTorch.\n  If you are already familiar with PyTorch, utilizing PyG is straightforward.\n- **Comprehensive and well-maintained GNN models**:\n  Most of the state-of-the-art Graph Neural Network architectures have been implemented by library developers or authors of research papers and are ready to be applied.\n- **Great flexibility**:\n  Existing PyG models can easily be extended for conducting your own research with GNNs.\n  Making modifications to existing models or creating new architectures is simple, thanks to its easy-to-use message passing API, and a variety of operators and utility functions.\n- **Large-scale real-world GNN models**:\n  We focus on the need of GNN applications in challenging real-world scenarios, and support learning on diverse types of graphs, including but not limited to: scalable GNNs for graphs with millions of nodes; dynamic GNNs for node predictions over time; heterogeneous GNNs with multiple node types and edge types.\n\n## Quick Tour for New Users\n\nIn this quick tour, we highlight the ease of creating and training a GNN model with only a few lines of code.\n\n### Train your own GNN model\n\nIn the first glimpse of PyG, we implement the training of a GNN for classifying papers in a citation graph.\nFor this, we load the [Cora](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Planetoid.html) dataset, and create a simple 2-layer GCN model using the pre-defined [`GCNConv`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html):\n\n```python\nimport torch\nfrom torch import Tensor\nfrom torch_geometric.nn import GCNConv\nfrom torch_geometric.datasets import Planetoid\n\ndataset = Planetoid(root='.', name='Cora')\n\nclass GCN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, hidden_channels)\n        self.conv2 = GCNConv(hidden_channels, out_channels)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        # x: Node feature matrix of shape [num_nodes, in_channels]\n        # edge_index: Graph connectivity matrix of shape [2, num_edges]\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n        return x\n\nmodel = GCN(dataset.num_features, 16, dataset.num_classes)\n```\n\n<details>\n<summary>We can now optimize the model in a training loop, similar to the <a href=\"https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html#full-implementation\">standard PyTorch training procedure</a>.</summary>\n\n```python\nimport torch.nn.functional as F\n\ndata = dataset[0]\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\nfor epoch in range(200):\n    pred = model(data.x, data.edge_index)\n    loss = F.cross_entropy(pred[data.train_mask], data.y[data.train_mask])\n\n    # Backpropagation\n    optimizer.zero_grad()\n    loss.backward()\n    optimizer.step()\n```\n\n</details>\n\nMore information about evaluating final model performance can be found in the corresponding [example](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn.py).\n\n### Create your own GNN layer\n\nIn addition to the easy application of existing GNNs, PyG makes it simple to implement custom Graph Neural Networks (see [here](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html) for the accompanying tutorial).\nFor example, this is all it takes to implement the [edge convolutional layer](https://arxiv.org/abs/1801.07829) from Wang *et al.*:\n\n$$x_i^{\\\\prime} ~ = ~ \\\\max\\_{j \\\\in \\\\mathcal{N}(i)} ~ \\\\textrm{MLP}\\_{\\\\theta} \\\\left( [ ~ x_i, ~ x_j - x_i ~ ] \\\\right)$$\n\n```python\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Sequential, Linear, ReLU\nfrom torch_geometric.nn import MessagePassing\n\nclass EdgeConv(MessagePassing):\n    def __init__(self, in_channels, out_channels):\n        super().__init__(aggr=\"max\")  # \"Max\" aggregation.\n        self.mlp = Sequential(\n            Linear(2 * in_channels, out_channels),\n            ReLU(),\n            Linear(out_channels, out_channels),\n        )\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        # x: Node feature matrix of shape [num_nodes, in_channels]\n        # edge_index: Graph connectivity matrix of shape [2, num_edges]\n        return self.propagate(edge_index, x=x)  # shape [num_nodes, out_channels]\n\n    def message(self, x_j: Tensor, x_i: Tensor) -> Tensor:\n        # x_j: Source node features of shape [num_edges, in_channels]\n        # x_i: Target node features of shape [num_edges, in_channels]\n        edge_features = torch.cat([x_i, x_j - x_i], dim=-1)\n        return self.mlp(edge_features)  # shape [num_edges, out_channels]\n```\n\n## Architecture Overview\n\nPyG provides a multi-layer framework that enables users to build Graph Neural Network solutions on both low and high levels.\nIt comprises of the following components:\n\n- The PyG **engine** utilizes the powerful PyTorch deep learning framework with full [`torch.compile`](https://pytorch-geometric.readthedocs.io/en/latest/advanced/compile.html) and [TorchScript](https://pytorch-geometric.readthedocs.io/en/latest/advanced/jit.html) support, as well as additions of efficient CPU/CUDA libraries for operating on sparse data, *e.g.*, [`pyg-lib`](https://github.com/pyg-team/pyg-lib).\n- The PyG **storage** handles data processing, transformation and loading pipelines. It is capable of handling and processing large-scale graph datasets, and provides effective solutions for heterogeneous graphs. It further provides a variety of sampling solutions, which enable training of GNNs on large-scale graphs.\n- The PyG **operators** bundle essential functionalities for implementing Graph Neural Networks. PyG supports important GNN building blocks that can be combined and applied to various parts of a GNN model, ensuring rich flexibility of GNN design.\n- Finally, PyG provides an abundant set of GNN **models**, and examples that showcase GNN models on standard graph benchmarks. Thanks to its flexibility, users can easily build and modify custom GNN models to fit their specific needs.\n\n<p align=\"center\">\n  <img width=\"100%\" src=\"https://raw.githubusercontent.com/pyg-team/pytorch_geometric/master/docs/source/_figures/architecture.svg?sanitize=true\" />\n</p>\n\n## Implemented GNN Models\n\nWe list currently supported PyG models, layers and operators according to category:\n\n**GNN layers:**\nAll Graph Neural Network layers are implemented via the **[`nn.MessagePassing`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MessagePassing.html)** interface.\nA GNN layer specifies how to perform message passing, *i.e.* by designing different message, aggregation and update functions as defined [here](https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html).\nThese GNN layers can be stacked together to create Graph Neural Network models.\n\n- **[GCNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCNConv.html)** from Kipf and Welling: [Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907) (ICLR 2017) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn.py)\\]\n- **[ChebConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.ChebConv.html)** from Defferrard *et al.*: [Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering](https://arxiv.org/abs/1606.09375) (NIPS 2016) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn.py#L36-L37)\\]\n- **[GATConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GATConv.html)** from Veličković *et al.*: [Graph Attention Networks](https://arxiv.org/abs/1710.10903) (ICLR 2018) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gat.py)\\]\n\n<details>\n<summary><b>Expand to see all implemented GNN layers...</b></summary>\n\n- **[GCN2Conv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GCN2Conv.html)** from Chen *et al.*: [Simple and Deep Graph Convolutional Networks](https://arxiv.org/abs/2007.02133) (ICML 2020) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn2_cora.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn2_ppi.py)\\]\n- **[SplineConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SplineConv.html)** from Fey *et al.*: [SplineCNN: Fast Geometric Deep Learning with Continuous B-Spline Kernels](https://arxiv.org/abs/1711.08920) (CVPR 2018) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/cora.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/faust.py)\\]\n- **[NNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.NNConv.html)** from Gilmer *et al.*: [Neural Message Passing for Quantum Chemistry](https://arxiv.org/abs/1704.01212) (ICML 2017) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_nn_conv.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mnist_nn_conv.py)\\]\n- **[CGConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.CGConv.html)** from Xie and Grossman: [Crystal Graph Convolutional Neural Networks for an Accurate and Interpretable Prediction of Material Properties](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301) (Physical Review Letters 120, 2018)\n- **[ECConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.ECConv.html)** from Simonovsky and Komodakis: [Edge-Conditioned Convolution on Graphs](https://arxiv.org/abs/1704.02901) (CVPR 2017)\n- **[EGConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.EGConv.html)** from Tailor *et al.*: [Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions](https://arxiv.org/abs/2104.01481) (GNNSys 2021) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/egc.py)\\]\n- **[GATv2Conv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GATv2Conv.html)** from Brody *et al.*: [How Attentive are Graph Attention Networks?](https://arxiv.org/abs/2105.14491) (ICLR 2022)\n- **[TransformerConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.TransformerConv.html)** from Shi *et al.*: [Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification](https://arxiv.org/abs/2009.03509) (CoRR 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/unimp_arxiv.py)\\]\n- **[SAGEConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SAGEConv.html)** from Hamilton *et al.*: [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216) (NIPS 2017) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_train.py), [**Example3**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_sage_unsup.py), [**Example4**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_sage_unsup_ppi.py)\\]\n- **[GraphConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GraphConv.html)** from, *e.g.*, Morris *et al.*: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244) (AAAI 2019)\n- **[GatedGraphConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GatedGraphConv.html)** from Li *et al.*: [Gated Graph Sequence Neural Networks](https://arxiv.org/abs/1511.05493) (ICLR 2016)\n- **[ResGatedGraphConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.ResGatedGraphConv.html)** from Bresson and Laurent: [Residual Gated Graph ConvNets](https://arxiv.org/abs/1711.07553) (CoRR 2017)\n- **[GINConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GINConv.html)** from Xu *et al.*: [How Powerful are Graph Neural Networks?](https://arxiv.org/abs/1810.00826) (ICLR 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mutag_gin.py)\\]\n- **[GINEConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GINEConv.html)** from Hu *et al.*: [Strategies for Pre-training Graph Neural Networks](https://arxiv.org/abs/1905.12265) (ICLR 2020)\n- **[ARMAConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.ARMAConv.html)** from Bianchi *et al.*: [Graph Neural Networks with Convolutional ARMA Filters](https://arxiv.org/abs/1901.01343) (CoRR 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/arma.py)\\]\n- **[SGConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SGConv.html)** from Wu *et al.*: [Simplifying Graph Convolutional Networks](https://arxiv.org/abs/1902.07153) (CoRR 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/sgc.py)\\]\n- **[APPNP](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.APPNP.html)** from Klicpera *et al.*: [Predict then Propagate: Graph Neural Networks meet Personalized PageRank](https://arxiv.org/abs/1810.05997) (ICLR 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/appnp.py)\\]\n- **[MFConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.MFConv.html)** from Duvenaud *et al.*: [Convolutional Networks on Graphs for Learning Molecular Fingerprints](https://arxiv.org/abs/1509.09292) (NIPS 2015)\n- **[AGNNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.AGNNConv.html)** from Thekumparampil *et al.*: [Attention-based Graph Neural Network for Semi-Supervised Learning](https://arxiv.org/abs/1803.03735) (CoRR 2017) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/agnn.py)\\]\n- **[TAGConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.TAGConv.html)** from Du *et al.*: [Topology Adaptive Graph Convolutional Networks](https://arxiv.org/abs/1710.10370) (CoRR 2017) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tagcn.py)\\]\n- **[PNAConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PNAConv.html)** from Corso *et al.*: [Principal Neighbourhood Aggregation for Graph Nets](https://arxiv.org/abs/2004.05718) (CoRR 2020) \\[**[Example](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pna.py)**\\]\n- **[FAConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.FAConv.html)** from Bo *et al.*: [Beyond Low-Frequency Information in Graph Convolutional Networks](https://arxiv.org/abs/2101.00797) (AAAI 2021)\n- **[PDNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.nn.conv.PDNConv.html)** from Rozemberczki *et al.*: [Pathfinder Discovery Networks for Neural Message Passing](https://arxiv.org/abs/2010.12878) (WWW 2021)\n- **[RGCNConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.RGCNConv.html)** from Schlichtkrull *et al.*: [Modeling Relational Data with Graph Convolutional Networks](https://arxiv.org/abs/1703.06103) (ESWC 2018) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rgcn.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rgcn_link_pred.py)\\]\n- **[RGATConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.RGATConv.html)** from Busbridge *et al.*: [Relational Graph Attention Networks](https://arxiv.org/abs/1904.05811) (CoRR 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rgat.py)\\]\n- **[FiLMConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.FiLMConv.html)** from Brockschmidt: [GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation](https://arxiv.org/abs/1906.12192) (ICML 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/film.py)\\]\n- **[SignedConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SignedConv.html)** from Derr *et al.*: [Signed Graph Convolutional Network](https://arxiv.org/abs/1808.06354) (ICDM 2018) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/signed_gcn.py)\\]\n- **[DNAConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.DNAConv.html)** from Fey: [Just Jump: Dynamic Neighborhood Aggregation in Graph Neural Networks](https://arxiv.org/abs/1904.04849) (ICLR-W 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/dna.py)\\]\n- **[PANConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PANConv.html)** from Ma *et al.*: [Path Integral Based Convolution and Pooling for Graph Neural Networks](https://arxiv.org/abs/2006.16811) (NeurIPS 2020)\n- **[PointNetConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PointNetConv.html)** (including **[Iterative Farthest Point Sampling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.fps.html)**, dynamic graph generation based on **[nearest neighbor](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.knn_graph.html)** or **[maximum distance](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.radius_graph.html)**, and **[k-NN interpolation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.unpool.knn_interpolate.html)** for upsampling) from Qi *et al.*: [PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation](https://arxiv.org/abs/1612.00593) (CVPR 2017) and [PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space](https://arxiv.org/abs/1706.02413) (NIPS 2017) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pointnet2_classification.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pointnet2_segmentation.py)\\]\n- **[EdgeConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.EdgeConv.html)** from Wang *et al.*: [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829) (CoRR, 2018) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/dgcnn_classification.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/dgcnn_segmentation.py)\\]\n- **[XConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.XConv.html)** from Li *et al.*: [PointCNN: Convolution On X-Transformed Points](https://arxiv.org/abs/1801.07791) (NeurIPS 2018) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/point_cnn.py)\\]\n- **[PPFConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PPFConv.html)** from Deng *et al.*: [PPFNet: Global Context Aware Local Features for Robust 3D Point Matching](https://arxiv.org/abs/1802.02669) (CVPR 2018)\n- **[GMMConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GMMConv.html)** from Monti *et al.*: [Geometric Deep Learning on Graphs and Manifolds using Mixture Model CNNs](https://arxiv.org/abs/1611.08402) (CVPR 2017)\n- **[FeaStConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.FeaStConv.html)** from Verma *et al.*: [FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis](https://arxiv.org/abs/1706.05206) (CVPR 2018)\n- **[PointTransformerConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.PointTransformerConv.html)** from Zhao *et al.*: [Point Transformer](https://arxiv.org/abs/2012.09164) (2020)\n- **[HypergraphConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HypergraphConv.html)** from Bai *et al.*: [Hypergraph Convolution and Hypergraph Attention](https://arxiv.org/abs/1901.08150) (CoRR 2019)\n- **[GravNetConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GravNetConv.html)** from Qasim *et al.*: [Learning Representations of Irregular Particle-detector Geometry with Distance-weighted Graph Networks](https://arxiv.org/abs/1902.07987) (European Physics Journal C, 2019)\n- **[SuperGAT](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SuperGATConv.html)** from Kim and Oh: [How To Find Your Friendly Neighborhood: Graph Attention Design With Self-Supervision](https://openreview.net/forum?id=Wi5KUNlqWty) (ICLR 2021) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/super_gat.py)\\]\n- **[HGTConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HGTConv.html)** from Hu *et al.*: [Heterogeneous Graph Transformer](https://arxiv.org/abs/2003.01332) (WWW 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hgt_dblp.py)\\]\n- **[HEATConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.HEATonv.html)** from Mo *et al.*: [Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent Trajectory Prediction](https://arxiv.org/abs/2106.07161) (CoRR 2021)\n- **[SSGConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SSGConv.html)** from Zhu *et al.*: [Simple Spectral Graph Convolution](https://openreview.net/forum?id=CYO5T-YjWZV) (ICLR 2021)\n- **[FusedGATConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.FusedGATConv.html)** from Zhang *et al.*: [Understanding GNN Computational Graph: A Coordinated Computation, IO, and Memory Perspective](https://proceedings.mlsys.org/paper/2022/file/9a1158154dfa42caddbd0694a4e9bdc8-Paper.pdf) (MLSys 2022)\n- **[GPSConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GPSConv.html)** from Rampášek *et al.*: [Recipe for a General, Powerful, Scalable Graph Transformer](https://arxiv.org/abs/2205.12454) (NeurIPS 2022) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_gps.py)\\]\n\n</details>\n\n**Pooling layers:**\nGraph pooling layers combine the vectorial representations of a set of nodes in a graph (or a subgraph) into a single vector representation that summarizes its properties of nodes.\nIt is commonly applied to graph-level tasks, which require combining node features into a single graph representation.\n\n- **[Top-K Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.TopKPooling.html)** from Gao and Ji: [Graph U-Nets](https://arxiv.org/abs/1905.05178) (ICML 2019), Cangea *et al.*: [Towards Sparse Hierarchical Graph Classifiers](https://arxiv.org/abs/1811.01287) (NeurIPS-W 2018) and Knyazev *et al.*: [Understanding Attention and Generalization in Graph Neural Networks](https://arxiv.org/abs/1905.02850) (ICLR-W 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_topk_pool.py)\\]\n- **[DiffPool](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.dense.dense_diff_pool.html)** from Ying *et al.*: [Hierarchical Graph Representation Learning with Differentiable Pooling](https://arxiv.org/abs/1806.08804) (NeurIPS 2018) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_diff_pool.py)\\]\n\n<details>\n<summary><b>Expand to see all implemented pooling layers...</b></summary>\n\n- **[Attentional Aggregation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.aggr.AttentionalAggregation.html)** from Li *et al.*: [Graph Matching Networks for Learning the Similarity of Graph Structured Objects](https://arxiv.org/abs/1904.12787) (ICML 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/global_attention.py)\\]\n- **[Set2Set](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.aggr.Set2Set.html)** from Vinyals *et al.*: [Order Matters: Sequence to Sequence for Sets](https://arxiv.org/abs/1511.06391) (ICLR 2016) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)\\]\n- **[Sort Aggregation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.aggr.SortAggregation.html)** from Zhang *et al.*: [An End-to-End Deep Learning Architecture for Graph Classification](https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf) (AAAI 2018) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sort_pool.py)\\]\n- **[MinCut Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.dense.dense_mincut_pool.html)** from Bianchi *et al.*: [Spectral Clustering with Graph Neural Networks for Graph Pooling](https://arxiv.org/abs/1907.00481) (ICML 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_mincut_pool.py)\\]\n- **[DMoN Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.dense.DMoNPooling.html)** from Tsitsulin *et al.*: [Graph Clustering with Graph Neural Networks](https://arxiv.org/abs/2006.16904) (CoRR 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_dmon_pool.py)\\]\n- **[Graclus Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.graclus.html)** from Dhillon *et al.*: [Weighted Graph Cuts without Eigenvectors: A Multilevel Approach](http://www.cs.utexas.edu/users/inderjit/public_papers/multilevel_pami.pdf) (PAMI 2007) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mnist_graclus.py)\\]\n- **[Voxel Grid Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.voxel_grid.html)** from, *e.g.*, Simonovsky and Komodakis: [Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on Graphs](https://arxiv.org/abs/1704.02901) (CVPR 2017) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mnist_voxel_grid.py)\\]\n- **[SAG Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.SAGPooling.html)** from Lee *et al.*: [Self-Attention Graph Pooling](https://arxiv.org/abs/1904.08082) (ICML 2019) and Knyazev *et al.*: [Understanding Attention and Generalization in Graph Neural Networks](https://arxiv.org/abs/1905.02850) (ICLR-W 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sag_pool.py)\\]\n- **[Edge Pooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.EdgePooling.html)** from Diehl *et al.*: [Towards Graph Pooling by Edge Contraction](https://graphreason.github.io/papers/17.pdf) (ICML-W 2019) and Diehl: [Edge Contraction Pooling for Graph Neural Networks](https://arxiv.org/abs/1905.10990) (CoRR 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/edge_pool.py)\\]\n- **[ASAPooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.ASAPooling.html)** from Ranjan *et al.*: [ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph Representations](https://arxiv.org/abs/1911.07979) (AAAI 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/asap.py)\\]\n- **[PANPooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.PANPooling.html)** from Ma *et al.*: [Path Integral Based Convolution and Pooling for Graph Neural Networks](https://arxiv.org/abs/2006.16811) (NeurIPS 2020)\n- **[MemPooling](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.pool.MemPooling.html)** from Khasahmadi *et al.*: [Memory-Based Graph Networks](https://arxiv.org/abs/2002.09518) (ICLR 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/mem_pool.py)\\]\n- **[Graph Multiset Transformer](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.aggr.GraphMultisetTransformer.html)** from Baek *et al.*: [Accurate Learning of Graph Representations with Graph Multiset Pooling](https://arxiv.org/abs/2102.11533) (ICLR 2021) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/proteins_gmt.py)\\]\n- **[Equilibrium Aggregation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.aggr.EquilibriumAggregation.html)** from Bartunov *et al.*: [](https://arxiv.org/abs/2202.12795) (UAI 2022) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/equilibrium_median.py)\\]\n\n</details>\n\n**GNN models:**\nOur supported GNN models incorporate multiple message passing layers, and users can directly use these pre-defined models to make predictions on graphs.\nUnlike simple stacking of GNN layers, these models could involve pre-processing, additional learnable parameters, skip connections, graph coarsening, etc.\n\n- **[SchNet](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.SchNet.html)** from Schütt *et al.*: [SchNet: A Continuous-filter Convolutional Neural Network for Modeling Quantum Interactions](https://arxiv.org/abs/1706.08566) (NIPS 2017) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_pretrained_schnet.py)\\]\n- **[DimeNet](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.DimeNet.html)** and **[DimeNetPlusPlus](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.DimeNetPlusPlus.html)** from Klicpera *et al.*: [Directional Message Passing for Molecular Graphs](https://arxiv.org/abs/2003.03123) (ICLR 2020) and [Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules](https://arxiv.org/abs/2011.14115) (NeurIPS-W 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/qm9_pretrained_dimenet.py)\\]\n- **[Node2Vec](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.Node2Vec.html)** from Grover and Leskovec: [node2vec: Scalable Feature Learning for Networks](https://arxiv.org/abs/1607.00653) (KDD 2016) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/node2vec.py)\\]\n- **[Deep Graph Infomax](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.DeepGraphInfomax.html)** from Veličković *et al.*: [Deep Graph Infomax](https://arxiv.org/abs/1809.10341) (ICLR 2019) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/infomax_transductive.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/infomax_inductive.py)\\]\n- **Deep Multiplex Graph Infomax** from Park *et al.*: [Unsupervised Attributed Multiplex Network Embedding](https://arxiv.org/abs/1911.06750) (AAAI 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/dmgi_unsup.py)\\]\n- **[Masked Label Prediction](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MaskLabel.html)** from Shi *et al.*: [Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification](https://arxiv.org/abs/2009.03509) (CoRR 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/unimp_arxiv.py)\\]\n- **[PMLP](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.PMLP.html)** from Yang *et al.*: [Graph Neural Networks are Inherently Good Generalizers: Insights by Bridging GNNs and MLPs](https://arxiv.org/abs/2212.09034) (ICLR 2023)\n\n<details>\n<summary><b>Expand to see all implemented GNN models...</b></summary>\n\n- **[Jumping Knowledge](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.JumpingKnowledge.html)** from Xu *et al.*: [Representation Learning on Graphs with Jumping Knowledge Networks](https://arxiv.org/abs/1806.03536) (ICML 2018) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/gin.py#L54-L106)\\]\n- A **[MetaLayer](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MetaLayer.html)** for building any kind of graph network similar to the [TensorFlow Graph Nets library](https://github.com/deepmind/graph_nets) from Battaglia *et al.*: [Relational Inductive Biases, Deep Learning, and Graph Networks](https://arxiv.org/abs/1806.01261) (CoRR 2018)\n- **[MetaPath2Vec](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.MetaPath2Vec.html)** from Dong *et al.*: [metapath2vec: Scalable Representation Learning for Heterogeneous Networks](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) (KDD 2017) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/metapath2vec.py)\\]\n- All variants of **[Graph Autoencoders](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GAE.html)** and **[Variational Autoencoders](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.VGAE.html)** from:\n  - [Variational Graph Auto-Encoders](https://arxiv.org/abs/1611.07308) from Kipf and Welling (NIPS-W 2016) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/autoencoder.py)\\]\n  - [Adversarially Regularized Graph Autoencoder for Graph Embedding](https://arxiv.org/abs/1802.04407) from Pan *et al.* (IJCAI 2018) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/argva_node_clustering.py)\\]\n  - [Simple and Effective Graph Autoencoders with One-Hop Linear Models](https://arxiv.org/abs/2001.07614) from Salha *et al.* (ECML 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/autoencoder.py)\\]\n- **[SEAL](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/seal_link_pred.py)** from Zhang and Chen: [Link Prediction Based on Graph Neural Networks](https://arxiv.org/pdf/1802.09691.pdf) (NeurIPS 2018) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/seal_link_pred.py)\\]\n- **[RENet](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.RENet.html)** from Jin *et al.*: [Recurrent Event Network for Reasoning over Temporal Knowledge Graphs](https://arxiv.org/abs/1904.05530) (ICLR-W 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/renet.py)\\]\n- **[GraphUNet](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphUNet.html)** from Gao and Ji: [Graph U-Nets](https://arxiv.org/abs/1905.05178) (ICML 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_unet.py)\\]\n- **[AttentiveFP](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.AttentiveFP.html)** from Xiong *et al.*: [Pushing the Boundaries of Molecular Representation for Drug Discovery with the Graph Attention Mechanism](https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959) (J. Med. Chem. 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/attentive_fp.py)\\]\n- **[DeepGCN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.DeepGCNLayer.html)** and the **[GENConv](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.GENConv.html)** from Li *et al.*: [DeepGCNs: Can GCNs Go as Deep as CNNs?](https://arxiv.org/abs/1904.03751) (ICCV 2019) and [DeeperGCN: All You Need to Train Deeper GCNs](https://arxiv.org/abs/2006.07739) (CoRR 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_proteins_deepgcn.py)\\]\n- **[RECT](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.RECT_L.html)** from Wang *et al.*: [Network Embedding with Completely-imbalanced Labels](https://ieeexplore.ieee.org/document/8979355) (TKDE 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rect.py)\\]\n- **[GNNExplainer](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.explain.algorithm.GNNExplainer.html)** from Ying *et al.*: [GNNExplainer: Generating Explanations for Graph Neural Networks](https://arxiv.org/abs/1903.03894) (NeurIPS 2019) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/explain/gnn_explainer.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/explain/gnn_explainer_ba_shapes.py), [**Example3**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/explain/gnn_explainer_link_pred.py)\\]\n- **Graph-less Neural Networks** from Zhang *et al.*: [Graph-less Neural Networks: Teaching Old MLPs New Tricks via Distillation](https://arxiv.org/abs/2110.08727) (CoRR 2021) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/glnn.py)\\]\n- **[LINKX](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.LINKX.html)** from Lim *et al.*: [Large Scale Learning on Non-Homophilous Graphs:\n  New Benchmarks and Strong Simple Methods](https://arxiv.org/abs/2110.14446) (NeurIPS 2021) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/linkx.py)\\]\n- **[RevGNN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GroupAddRev.html)** from Li *et al.*: [Training Graph Neural with 1000 Layers](https://arxiv.org/abs/2106.07476) (ICML 2021) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/rev_gnn.py)\\]\n- **[TransE](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.TransE.html)** from Bordes *et al.*: [Translating Embeddings for Modeling Multi-Relational Data](https://proceedings.neurips.cc/paper/2013/file/1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf) (NIPS 2013) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\\]\n- **[ComplEx](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.ComplEx.html)** from Trouillon *et al.*: [Complex Embeddings for Simple Link Prediction](https://arxiv.org/abs/1606.06357) (ICML 2016) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\\]\n- **[DistMult](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.DistMult.html)** from Yang *et al.*: [Embedding Entities and Relations for Learning and Inference in Knowledge Bases](https://arxiv.org/abs/1412.6575) (ICLR 2015) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\\]\n- **[RotatE](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.kge.RotatE.html)** from Sun *et al.*: [RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space](https://arxiv.org/abs/1902.10197) (ICLR 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/kge_fb15k_237.py)\\]\n\n</details>\n\n**GNN operators and utilities:**\nPyG comes with a rich set of neural network operators that are commonly used in many GNN models.\nThey follow an extensible design: It is easy to apply these operators and graph utilities to existing GNN layers and models to further enhance model performance.\n\n- **[DropEdge](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.dropout_edge)** from Rong *et al.*: [DropEdge: Towards Deep Graph Convolutional Networks on Node Classification](https://openreview.net/forum?id=Hkx1qkrKPr) (ICLR 2020)\n- **[DropNode](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.dropout_node)**, **[MaskFeature](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.mask_feature)** and **[AddRandomEdge](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.add_random_edge)** from You *et al.*: [Graph Contrastive Learning with Augmentations](https://arxiv.org/abs/2010.13902) (NeurIPS 2020)\n- **[DropPath](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.dropout_path)** from Li *et al.*: [MaskGAE: Masked Graph Modeling Meets Graph Autoencoders](https://arxiv.org/abs/2205.10053) (arXiv 2022)\n- **[ShuffleNode](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.shuffle_node)** from Veličković *et al.*: [Deep Graph Infomax](https://arxiv.org/abs/1809.10341) (ICLR 2019)\n- **[GraphNorm](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.norm.GraphNorm.html)** from Cai *et al.*: [GraphNorm: A Principled Approach to Accelerating Graph Neural Network Training](https://proceedings.mlr.press/v139/cai21e.html) (ICML 2021)\n- **[GDC](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.transforms.GDC.html)** from Klicpera *et al.*: [Diffusion Improves Graph Learning](https://arxiv.org/abs/1911.05485) (NeurIPS 2019) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn.py)\\]\n\n<details>\n<summary><b>Expand to see all implemented GNN operators and utilities...</b></summary>\n\n- **[GraphSizeNorm](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.norm.GraphSizeNorm.html)** from Dwivedi *et al.*: [Benchmarking Graph Neural Networks](https://arxiv.org/abs/2003.00982) (CoRR 2020)\n- **[PairNorm](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.norm.PairNorm.html)** from Zhao and Akoglu: [PairNorm: Tackling Oversmoothing in GNNs](https://arxiv.org/abs/1909.12223) (ICLR 2020)\n- **[MeanSubtractionNorm](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.norm.MeanSubtractionNorm.html)** from Yang *et al.*: [Revisiting \"Over-smoothing\" in Deep GCNs](https://arxiv.org/abs/2003.13663) (CoRR 2020)\n- **[DiffGroupNorm](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.norm.DiffGroupNorm.html)** from Zhou *et al.*: [Towards Deeper Graph Neural Networks with Differentiable Group Normalization](https://arxiv.org/abs/2006.06972) (NeurIPS 2020)\n- **[Tree Decomposition](https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.tree_decomposition)** from Jin *et al.*: [Junction Tree Variational Autoencoder for Molecular Graph Generation](https://arxiv.org/abs/1802.04364) (ICML 2018)\n- **[TGN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.TGNMemory.html)** from Rossi *et al.*: [Temporal Graph Networks for Deep Learning on Dynamic Graphs](https://arxiv.org/abs/2006.10637) (GRL+ 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/tgn.py)\\]\n- **[Weisfeiler Lehman Operator](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.WLConv.html)** from Weisfeiler and Lehman: [A Reduction of a Graph to a Canonical Form and an Algebra Arising During this Reduction](https://www.iti.zcu.cz/wl2018/pdf/wl_paper_translation.pdf) (Nauchno-Technicheskaya Informatsia 1968) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/wl_kernel.py)\\]\n- **[Continuous Weisfeiler Lehman Operator](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.WLConvContinuous.html)** from Togninalli *et al.*: [Wasserstein Weisfeiler-Lehman Graph Kernels](https://arxiv.org/abs/1906.01277) (NeurIPS 2019)\n- **[Label Propagation](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.LabelPropagation.html)** from Zhu and Ghahramani: [Learning from Labeled and Unlabeled Data with Label Propagation](http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf) (CMU-CALD 2002) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/label_prop.py)\\]\n- **[Local Degree Profile](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.transforms.LocalDegreeProfile)** from Cai and Wang: [A Simple yet Effective Baseline for Non-attribute Graph Classification](https://arxiv.org/abs/1811.03508) (CoRR 2018)\n- **[CorrectAndSmooth](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.CorrectAndSmooth.html)** from Huang *et al.*: [Combining Label Propagation And Simple Models Out-performs Graph Neural Networks](https://arxiv.org/abs/2010.13993) (CoRR 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/correct_and_smooth.py)\\]\n- **[Gini](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.functional.gini.html)** and **[BRO](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.functional.bro.html)** regularization from Henderson *et al.*: [Improving Molecular Graph Neural Network Explainability with Orthonormalization and Induced Sparsity](https://arxiv.org/abs/2105.04854) (ICML 2021)\n- **[RootedEgoNets](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.transforms.RootedEgoNets)** and **[RootedRWSubgraph](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.transforms.RootedRWSubgraph)** from Zhao *et al.*: [From Stars to Subgraphs: Uplifting Any GNN with Local Structure Awareness](https://arxiv.org/abs/2110.03753) (ICLR 2022)\n- **[FeaturePropagation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.transforms.FeaturePropagation)** from Rossi *et al.*: [On the Unreasonable Effectiveness of Feature Propagation in Learning on Graphs with Missing Node Features](https://arxiv.org/abs/2111.12128) (CoRR 2021)\n\n</details>\n\n**Scalable GNNs:**\nPyG supports the implementation of Graph Neural Networks that can scale to large-scale graphs.\nSuch application is challenging since the entire graph, its associated features and the GNN parameters cannot fit into GPU memory.\nMany state-of-the-art scalability approaches tackle this challenge by sampling neighborhoods for mini-batch training, graph clustering and partitioning, or by using simplified GNN models.\nThese approaches have been implemented in PyG, and can benefit from the above GNN layers, operators and models.\n\n- **[NeighborLoader](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.NeighborLoader)** from Hamilton *et al.*: [Inductive Representation Learning on Large Graphs](https://arxiv.org/abs/1706.02216) (NIPS 2017) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ogbn_train.py), [**Example3**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/to_hetero_mag.py)\\]\n- **[ClusterGCN](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.ClusterLoader)** from Chiang *et al.*: [Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph Convolutional Networks](https://arxiv.org/abs/1905.07953) (KDD 2019) \\[[**Example1**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/cluster_gcn_reddit.py), [**Example2**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/cluster_gcn_ppi.py)\\]\n- **[GraphSAINT](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.GraphSAINTSampler)** from Zeng *et al.*: [GraphSAINT: Graph Sampling Based Inductive Learning Method](https://arxiv.org/abs/1907.04931) (ICLR 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_saint.py)\\]\n\n<details>\n<summary><b>Expand to see all implemented scalable GNNs...</b></summary>\n\n- **[ShaDow](https://pytorch-geometric.readthedocs.io/en/latest/modules/loader.html#torch_geometric.loader.ShaDowKHopSampler)** from Zeng *et al.*: [Decoupling the Depth and Scope of Graph Neural Networks](https://arxiv.org/abs/2201.07858) (NeurIPS 2021) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/shadow.py)\\]\n- **[SIGN](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.transforms.SIGN.html)** from Rossi *et al.*: [SIGN: Scalable Inception Graph Neural Networks](https://arxiv.org/abs/2004.11198) (CoRR 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/sign.py)\\]\n- **[HGTLoader](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.loader.HGTLoader.html)** from Hu *et al.*: [Heterogeneous Graph Transformer](https://arxiv.org/abs/2003.01332) (WWW 2020) \\[[**Example**](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/to_hetero_mag.py)\\]\n\n</details>\n\n## Installation\n\nPyG is available for Python 3.10 to Python 3.14.\n\nFrom **PyG 2.3** onwards, you can install and use PyG **without any external library** required except for PyTorch.\nFor this, simply run\n\n```\npip install torch_geometric\n```\n\n### Additional Libraries\n\nIf you want to utilize the full set of features from PyG, there exists several additional libraries you may want to install:\n\n- **[`pyg-lib`](https://github.com/pyg-team/pyg-lib)**: Heterogeneous GNN operators, graph sampling routines, and [`SplineConv`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.SplineConv.html) support\n- **[`torch-scatter`](https://github.com/rusty1s/pytorch_scatter)**: Accelerated and efficient sparse reductions\n- **[`torch-sparse`](https://github.com/rusty1s/pytorch_sparse)**: [`SparseTensor`](https://pytorch-geometric.readthedocs.io/en/latest/advanced/sparse_tensor.html) support\n- **[`torch-cluster`](https://github.com/rusty1s/pytorch_cluster)**: Graph clustering routines\n\nThese packages come with their own CPU and GPU kernel implementations based on the [PyTorch C++/CUDA/hip(ROCm) extension interface](https://github.com/pytorch/extension-cpp).\nFor a basic usage of PyG, these dependencies are **fully optional**.\nWe recommend to start with a minimal installation, and install additional dependencies once you start to actually need them.\n\nFor ease of installation of these extensions, we provide `pip` wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl).\n\n#### PyTorch 2.10\n\nTo install the binaries for PyTorch 2.10, simply run\n\n```\npip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.10.0+${CUDA}.html\n```\n\nwhere `${CUDA}` should be replaced by either `cpu`, `cu126`, `cu128`, or `cu130` depending on your PyTorch installation.\n\n|             | `cpu` | `cu126` | `cu128` | `cu130` |\n| ----------- | ----- | ------- | ------- | ------- |\n| **Linux**   | ✅    | ✅      | ✅      | ✅      |\n| **Windows** | ✅    | ✅      | ✅      | ✅      |\n| **macOS**   | ✅    |         |         |         |\n\n#### PyTorch 2.9\n\nTo install the binaries for PyTorch 2.9, simply run\n\n```\npip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.9.0+${CUDA}.html\n```\n\nwhere `${CUDA}` should be replaced by either `cpu`, `cu126`, `cu128`, or `cu130` depending on your PyTorch installation.\n\n|             | `cpu` | `cu118` | `cu126` | `cu128` |\n| ----------- | ----- | ------- | ------- | ------- |\n| **Linux**   | ✅    | ✅      | ✅      | ✅      |\n| **Windows** | ✅    | ✅      | ✅      | ✅      |\n| **macOS**   | ✅    |         |         |         |\n\n#### PyTorch 2.8\n\nTo install the binaries for PyTorch 2.8, simply run\n\n```\npip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.8.0+${CUDA}.html\n```\n\nwhere `${CUDA}` should be replaced by either `cpu`, `cu126`, `cu128`, or `cu129` depending on your PyTorch installation.\n\n|             | `cpu` | `cu126` | `cu128` | `cu129` |\n| ----------- | ----- | ------- | ------- | ------- |\n| **Linux**   | ✅    | ✅      | ✅      | ✅      |\n| **Windows** | ✅    | ✅      | ✅      | ✅      |\n| **macOS**   | ✅    |         |         |         |\n\n**Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, PyTorch 2.3.0/2.3.1, PyTorch 2.4.0/2.4.1, PyTorch 2.5.0/2.5.1, PyTorch 2.6.0, and PyTorch 2.7.0/2.7.1 (following the same procedure).\n**For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source.\nYou can look up the latest supported version number [here](https://data.pyg.org/whl).\n\n### NVIDIA PyG Container\n\nNVIDIA provides a PyG docker container for effortlessly training and deploying GPU accelerated GNNs with PyG, see [here](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg).\n\n### Nightly and Master\n\nIn case you want to experiment with the latest PyG features which are not fully released yet, either install the **nightly version** of PyG via\n\n```\npip install pyg-nightly\n```\n\nor install PyG **from master** via\n\n```\npip install git+https://github.com/pyg-team/pytorch_geometric.git\n```\n\n### ROCm Wheels\n\nThe external [`pyg-rocm-build` repository](https://github.com/Looong01/pyg-rocm-build) provides wheels and detailed instructions on how to install PyG for ROCm.\nIf you have any questions about it, please open an issue [here](https://github.com/Looong01/pyg-rocm-build/issues).\n\n## Cite\n\nPlease cite our [PyG 1.0](https://arxiv.org/abs/1903.02428) and [PyG 2.0](https://www.arxiv.org/abs/2507.16991) papers if you use this code in your own work:\n\n```\n@inproceedings{Fey/Lenssen/2019,\n  title={Fast Graph Representation Learning with {PyTorch Geometric}},\n  author={Fey, Matthias and Lenssen, Jan E.},\n  booktitle={ICLR Workshop on Representation Learning on Graphs and Manifolds},\n  year={2019},\n}\n\n@inproceedings{Fey/etal/2025,\n  title={{PyG} 2.0: Scalable Learning on Real World Graphs},\n  author={Fey, Matthias and Sunil, Jinu and Nitta, Akihiro and Puri, Rishi and Shah, Manan, and Stojanovi{\\v{c}, Bla{\\v{z} and Bendias, Ramona and Alexandria, Barghi and Kocijan, Vid and Zhang, Zecheng and He, Xinwei and Lenssen, Jan E. and Leskovec, Jure},\n  booktitle={Temporal Graph Learning Workshop @ KDD},\n  year={2025},\n}\n```\n\nFeel free to [email us](mailto:matthias.fey@tu-dortmund.de) if you wish your work to be listed in the [external resources](https://pytorch-geometric.readthedocs.io/en/latest/external/resources.html).\nIf you notice anything unexpected, please open an [issue](https://github.com/pyg-team/pytorch_geometric/issues) and let us know.\nIf you have any questions or are missing a specific feature, feel free [to discuss them with us](https://github.com/pyg-team/pytorch_geometric/discussions).\nWe are motivated to constantly make PyG even better.\n\n[contributing-image]: https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat&color=4B26A4\n[contributing-url]: https://github.com/pyg-team/pytorch_geometric/blob/master/.github/CONTRIBUTING.md\n[pypi-download-image]: https://img.shields.io/pypi/dm/torch_geometric?color=4B26A4\n[pypi-download-url]: https://pepy.tech/projects/torch_geometric\n[pypi-image]: https://img.shields.io/pypi/pyversions/torch-geometric?color=4B26A4\n[pypi-url]: https://pypi.python.org/pypi/torch-geometric\n[slack-image]: https://img.shields.io/badge/slack-join-white.svg?logo=slack&color=4B26A4\n[slack-url]: https://data.pyg.org/slack.html\n"
  },
  {
    "path": "benchmark/README.md",
    "content": "# PyG Benchmark Suite\n\nThis benchmark suite provides evaluation scripts for **[semi-supervised node classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/citation)**, **[graph classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/kernel)**, and **[point cloud classification](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/points)** and **[runtimes](https://github.com/pyg-team/pytorch_geometric/tree/master/benchmark/runtime)** in order to compare various methods in homogeneous evaluation scenarios.\nIn particular, we take care to avoid to perform hyperparameter and model selection on the test set and instead use an additional validation set.\n\n## Installation\n\n```\n$ pip install -e .\n```\n"
  },
  {
    "path": "benchmark/citation/README.md",
    "content": "# Semi-supervised Node Classification\n\nEvaluation scripts for various methods on the Cora, CiteSeer and PubMed citation networks.\nEach experiment is repeated 100 times on either a fixed train/val/test split or on multiple random splits:\n\n- **[GCN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/gcn.py)**: `python gcn.py`\n- **[GAT](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/gat.py)**: `python gat.py`\n- **[Cheby](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/cheb.py)**: `python cheb.py`\n- **[SGC](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/sgc.py)**: `python sgc.py`\n- **[ARMA](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/arma.py)**: `python arma.py`\n- **[APPNP](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/citation/appnp.py)**: `python appnp.py`\n\nRun the whole test suite via\n\n```\n$ ./run.sh\n```\n"
  },
  {
    "path": "benchmark/citation/__init__.py",
    "content": "from .datasets import get_planetoid_dataset\nfrom .train_eval import random_planetoid_splits, run\n\n__all__ = [\n    'get_planetoid_dataset',\n    'random_planetoid_splits',\n    'run',\n]\n"
  },
  {
    "path": "benchmark/citation/appnp.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom citation import get_planetoid_dataset, random_planetoid_splits, run\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import APPNP\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, required=True)\nparser.add_argument('--random_splits', action='store_true')\nparser.add_argument('--runs', type=int, default=100)\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--lr', type=float, default=0.01)\nparser.add_argument('--weight_decay', type=float, default=0.0005)\nparser.add_argument('--early_stopping', type=int, default=10)\nparser.add_argument('--hidden', type=int, default=64)\nparser.add_argument('--dropout', type=float, default=0.5)\nparser.add_argument('--no_normalize_features', action='store_true')\nparser.add_argument('--K', type=int, default=10)\nparser.add_argument('--alpha', type=float, default=0.1)\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, dataset):\n        super().__init__()\n        self.lin1 = Linear(dataset.num_features, args.hidden)\n        self.lin2 = Linear(args.hidden, dataset.num_classes)\n        self.prop1 = APPNP(args.K, args.alpha)\n\n    def reset_parameters(self):\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index = data.x, data.edge_index\n        x = F.dropout(x, p=args.dropout, training=self.training)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=args.dropout, training=self.training)\n        x = self.lin2(x)\n        x = self.prop1(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)\npermute_masks = random_planetoid_splits if args.random_splits else None\nrun(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,\n    args.early_stopping, args.inference, args.profile, args.bf16, args.compile,\n    permute_masks)\n\nif args.profile:\n    rename_profile_file('citation', APPNP.__name__, args.dataset,\n                        str(args.random_splits),\n                        'inference' if args.inference else 'train')\n"
  },
  {
    "path": "benchmark/citation/arma.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom citation import get_planetoid_dataset, random_planetoid_splits, run\n\nfrom torch_geometric.nn import ARMAConv\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, required=True)\nparser.add_argument('--random_splits', action='store_true')\nparser.add_argument('--runs', type=int, default=100)\nparser.add_argument('--epochs', type=int, default=1000)\nparser.add_argument('--lr', type=float, default=0.01)\nparser.add_argument('--weight_decay', type=float, default=0.0005)\nparser.add_argument('--early_stopping', type=int, default=100)\nparser.add_argument('--hidden', type=int, default=16)\nparser.add_argument('--dropout', type=float, default=0.5)\nparser.add_argument('--no_normalize_features', action='store_true')\nparser.add_argument('--num_stacks', type=int, default=1)\nparser.add_argument('--num_layers', type=int, default=1)\nparser.add_argument('--shared_weights', action='store_true')\nparser.add_argument('--skip_dropout', type=float, default=0.75)\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, dataset):\n        super().__init__()\n        self.conv1 = ARMAConv(dataset.num_features, args.hidden,\n                              args.num_stacks, args.num_layers,\n                              args.shared_weights, dropout=args.skip_dropout)\n        self.conv2 = ARMAConv(args.hidden, dataset.num_classes,\n                              args.num_stacks, args.num_layers,\n                              args.shared_weights, dropout=args.skip_dropout)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        self.conv2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index = data.x, data.edge_index\n        x = F.relu(self.conv1(x, edge_index))\n        x = F.dropout(x, p=args.dropout, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)\npermute_masks = random_planetoid_splits if args.random_splits else None\nrun(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,\n    args.early_stopping, args.inference, args.profile, args.bf16, args.compile,\n    permute_masks)\n\nif args.profile:\n    rename_profile_file('citation', ARMAConv.__name__, args.dataset,\n                        str(args.random_splits),\n                        'inference' if args.inference else 'train')\n"
  },
  {
    "path": "benchmark/citation/cheb.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom citation import get_planetoid_dataset, random_planetoid_splits, run\n\nfrom torch_geometric.nn import ChebConv\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, required=True)\nparser.add_argument('--random_splits', action='store_true')\nparser.add_argument('--runs', type=int, default=100)\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--lr', type=float, default=0.01)\nparser.add_argument('--weight_decay', type=float, default=0.0005)\nparser.add_argument('--early_stopping', type=int, default=10)\nparser.add_argument('--hidden', type=int, default=16)\nparser.add_argument('--dropout', type=float, default=0.5)\nparser.add_argument('--no_normalize_features', action='store_true')\nparser.add_argument('--num_hops', type=int, default=3)\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, dataset):\n        super().__init__()\n        self.conv1 = ChebConv(dataset.num_features, args.hidden, args.num_hops)\n        self.conv2 = ChebConv(args.hidden, dataset.num_classes, args.num_hops)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        self.conv2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index = data.x, data.edge_index\n        x = F.relu(self.conv1(x, edge_index))\n        x = F.dropout(x, p=args.dropout, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)\npermute_masks = random_planetoid_splits if args.random_splits else None\nrun(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,\n    args.early_stopping, args.inference, args.profile, args.bf16, args.compile,\n    permute_masks)\n\nif args.profile:\n    rename_profile_file('citation', ChebConv.__name__, args.dataset,\n                        str(args.random_splits),\n                        'inference' if args.inference else 'train')\n"
  },
  {
    "path": "benchmark/citation/datasets.py",
    "content": "import os.path as osp\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\n\n\ndef get_planetoid_dataset(name, normalize_features=False, transform=None):\n    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)\n    dataset = Planetoid(path, name)\n\n    if transform is not None and normalize_features:\n        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])\n    elif normalize_features:\n        dataset.transform = T.NormalizeFeatures()\n    elif transform is not None:\n        dataset.transform = transform\n\n    return dataset\n"
  },
  {
    "path": "benchmark/citation/gat.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom citation import get_planetoid_dataset, random_planetoid_splits, run\n\nfrom torch_geometric.nn import GATConv\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, required=True)\nparser.add_argument('--random_splits', action='store_true')\nparser.add_argument('--runs', type=int, default=100)\nparser.add_argument('--epochs', type=int, default=1000)\nparser.add_argument('--lr', type=float, default=0.005)\nparser.add_argument('--weight_decay', type=float, default=0.0005)\nparser.add_argument('--early_stopping', type=int, default=100)\nparser.add_argument('--hidden', type=int, default=8)\nparser.add_argument('--dropout', type=float, default=0.6)\nparser.add_argument('--no_normalize_features', action='store_true')\nparser.add_argument('--heads', type=int, default=8)\nparser.add_argument('--output_heads', type=int, default=1)\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, dataset):\n        super().__init__()\n        self.conv1 = GATConv(dataset.num_features, args.hidden,\n                             heads=args.heads, dropout=args.dropout)\n        self.conv2 = GATConv(args.hidden * args.heads, dataset.num_classes,\n                             heads=args.output_heads, concat=False,\n                             dropout=args.dropout)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        self.conv2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index = data.x, data.edge_index\n        x = F.dropout(x, p=args.dropout, training=self.training)\n        x = F.elu(self.conv1(x, edge_index))\n        x = F.dropout(x, p=args.dropout, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)\npermute_masks = random_planetoid_splits if args.random_splits else None\nrun(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,\n    args.early_stopping, args.inference, args.profile, args.bf16, args.compile,\n    permute_masks)\n\nif args.profile:\n    rename_profile_file('citation', GATConv.__name__, args.dataset,\n                        str(args.random_splits),\n                        'inference' if args.inference else 'train')\n"
  },
  {
    "path": "benchmark/citation/gcn.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom citation import get_planetoid_dataset, random_planetoid_splits, run\n\nfrom torch_geometric.nn import GCNConv\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, required=True)\nparser.add_argument('--random_splits', action='store_true')\nparser.add_argument('--runs', type=int, default=100)\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--lr', type=float, default=0.01)\nparser.add_argument('--weight_decay', type=float, default=0.0005)\nparser.add_argument('--early_stopping', type=int, default=10)\nparser.add_argument('--hidden', type=int, default=16)\nparser.add_argument('--dropout', type=float, default=0.5)\nparser.add_argument('--no_normalize_features', action='store_true')\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, dataset):\n        super().__init__()\n        self.conv1 = GCNConv(dataset.num_features, args.hidden)\n        self.conv2 = GCNConv(args.hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        self.conv2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index = data.x, data.edge_index\n        x = F.relu(self.conv1(x, edge_index))\n        x = F.dropout(x, p=args.dropout, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)\npermute_masks = random_planetoid_splits if args.random_splits else None\nrun(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,\n    args.early_stopping, args.inference, args.profile, args.bf16, args.compile,\n    permute_masks)\n\nif args.profile:\n    rename_profile_file('citation', GCNConv.__name__, args.dataset,\n                        str(args.random_splits),\n                        'inference' if args.inference else 'train')\n"
  },
  {
    "path": "benchmark/citation/inference.sh",
    "content": "#!/bin/sh\n\necho \"Cora\"\necho \"====\"\n\necho \"GCN\"\npython gcn.py --dataset=Cora --inference\npython gcn.py --dataset=Cora --random_splits --inference\npython gcn.py --dataset=Cora --inference --profile\npython gcn.py --dataset=Cora --random_splits --inference --profile\n\necho \"GAT\"\npython gat.py --dataset=Cora --inference\npython gat.py --dataset=Cora --random_splits --inference\npython gat.py --dataset=Cora --inference --profile\npython gat.py --dataset=Cora --random_splits --inference --profile\n\necho \"Cheby\"\npython cheb.py --dataset=Cora --num_hops=3 --inference\npython cheb.py --dataset=Cora --num_hops=3 --random_splits --inference\npython cheb.py --dataset=Cora --num_hops=3 --inference --profile\npython cheb.py --dataset=Cora --num_hops=3 --random_splits --inference --profile\n\necho \"SGC\"\npython sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 --inference\npython sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 --random_splits --inference\npython sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 --inference --profile\npython sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 --random_splits --inference --profile\n\necho \"ARMA\"\npython arma.py --dataset=Cora --num_stacks=2 --num_layers=1 --shared_weights=True --inference\npython arma.py --dataset=Cora --num_stacks=3 --num_layers=1 --shared_weights=True --random_splits --inference\npython arma.py --dataset=Cora --num_stacks=2 --num_layers=1 --shared_weights=True --inference --profile\npython arma.py --dataset=Cora --num_stacks=3 --num_layers=1 --shared_weights=True --random_splits --inference --profile\n\necho \"APPNP\"\npython appnp.py --dataset=Cora --alpha=0.1 --inference\npython appnp.py --dataset=Cora --alpha=0.1 --random_splits --inference\npython appnp.py --dataset=Cora --alpha=0.1 --inference --profile\npython appnp.py --dataset=Cora --alpha=0.1 --random_splits --inference --profile\n\necho \"CiteSeer\"\necho \"========\"\n\necho \"GCN\"\npython gcn.py --dataset=CiteSeer --inference\npython gcn.py --dataset=CiteSeer --random_splits --inference\npython gcn.py --dataset=CiteSeer --inference --profile\npython gcn.py --dataset=CiteSeer --random_splits --inference --profile\n\necho \"GAT\"\npython gat.py --dataset=CiteSeer --inference\npython gat.py --dataset=CiteSeer --random_splits --inference\npython gat.py --dataset=CiteSeer --inference --profile\npython gat.py --dataset=CiteSeer --random_splits --inference --profile\n\necho \"Cheby\"\npython cheb.py --dataset=CiteSeer --num_hops=2 --inference\npython cheb.py --dataset=CiteSeer --num_hops=3 --random_splits --inference\npython cheb.py --dataset=CiteSeer --num_hops=2 --inference --profile\npython cheb.py --dataset=CiteSeer --num_hops=3 --random_splits --inference --profile\n\necho \"SGC\"\npython sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 --inference\npython sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 --random_splits --inference\npython sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 --inference --profile\npython sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 --random_splits --inference --profile\n\necho \"ARMA\"\npython arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights=True --inference\npython arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights=True --random_splits --inference\npython arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights=True --inference --profile\npython arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights=True --random_splits --inference --profile\n\necho \"APPNP\"\npython appnp.py --dataset=CiteSeer --alpha=0.1 --inference\npython appnp.py --dataset=CiteSeer --alpha=0.1 --random_splits --inference\npython appnp.py --dataset=CiteSeer --alpha=0.1 --inference --profile\npython appnp.py --dataset=CiteSeer --alpha=0.1 --random_splits --inference --profile\n\necho \"PubMed\"\necho \"======\"\n\necho \"GCN\"\npython gcn.py --dataset=PubMed --inference\npython gcn.py --dataset=PubMed --random_splits --inference\npython gcn.py --dataset=PubMed --inference --profile\npython gcn.py --dataset=PubMed --random_splits --inference --profile\n\necho \"GAT\"\npython gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 --inference\npython gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 --random_splits --inference\npython gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 --inference --profile\npython gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 --random_splits --inference --profile\n\necho \"Cheby\"\npython cheb.py --dataset=PubMed --num_hops=2 --inference\npython cheb.py --dataset=PubMed --num_hops=2 --random_splits --inference\npython cheb.py --dataset=PubMed --num_hops=2 --inference --profile\npython cheb.py --dataset=PubMed --num_hops=2 --random_splits --inference --profile\n\necho \"SGC\"\npython sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 --inference\npython sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 --random_splits --inference\npython sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 --inference --profile\npython sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 --random_splits --inference --profile\n\necho \"ARMA\"\npython arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0 --inference\npython arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0.5 --random_splits --inference\npython arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0 --inference --profile\npython arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0.5 --random_splits --inference --profile\n\necho \"APPNP\"\npython appnp.py --dataset=PubMed --alpha=0.1 --inference\npython appnp.py --dataset=PubMed --alpha=0.1 --random_splits --inference\npython appnp.py --dataset=PubMed --alpha=0.1 --inference --profile\npython appnp.py --dataset=PubMed --alpha=0.1 --random_splits --inference --profile\n"
  },
  {
    "path": "benchmark/citation/run.sh",
    "content": "#!/bin/sh\n\necho \"Cora\"\necho \"====\"\n\necho \"GCN\"\npython gcn.py --dataset=Cora\npython gcn.py --dataset=Cora --random_splits\n\necho \"GAT\"\npython gat.py --dataset=Cora\npython gat.py --dataset=Cora --random_splits\n\necho \"Cheby\"\npython cheb.py --dataset=Cora --num_hops=3\npython cheb.py --dataset=Cora --num_hops=3 --random_splits\n\necho \"SGC\"\npython sgc.py --dataset=Cora --K=3 --weight_decay=0.0005\npython sgc.py --dataset=Cora --K=3 --weight_decay=0.0005 --random_splits\n\necho \"ARMA\"\npython arma.py --dataset=Cora --num_stacks=2 --num_layers=1 --shared_weights\npython arma.py --dataset=Cora --num_stacks=3 --num_layers=1 --shared_weights --random_splits\n\necho \"APPNP\"\npython appnp.py --dataset=Cora --alpha=0.1\npython appnp.py --dataset=Cora --alpha=0.1 --random_splits\n\necho \"CiteSeer\"\necho \"========\"\n\necho \"GCN\"\npython gcn.py --dataset=CiteSeer\npython gcn.py --dataset=CiteSeer --random_splits\n\necho \"GAT\"\npython gat.py --dataset=CiteSeer\npython gat.py --dataset=CiteSeer --random_splits\n\necho \"Cheby\"\npython cheb.py --dataset=CiteSeer --num_hops=2\npython cheb.py --dataset=CiteSeer --num_hops=3 --random_splits\n\necho \"SGC\"\npython sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005\npython sgc.py --dataset=CiteSeer --K=2 --weight_decay=0.005 --random_splits\n\necho \"ARMA\"\npython arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights\npython arma.py --dataset=CiteSeer --num_stacks=3 --num_layers=1 --shared_weights --random_splits\n\necho \"APPNP\"\npython appnp.py --dataset=CiteSeer --alpha=0.1\npython appnp.py --dataset=CiteSeer --alpha=0.1 --random_splits\n\necho \"PubMed\"\necho \"======\"\n\necho \"GCN\"\npython gcn.py --dataset=PubMed\npython gcn.py --dataset=PubMed --random_splits\n\necho \"GAT\"\npython gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8\npython gat.py --dataset=PubMed --lr=0.01 --weight_decay=0.001 --output_heads=8 --random_splits\n\necho \"Cheby\"\npython cheb.py --dataset=PubMed --num_hops=2\npython cheb.py --dataset=PubMed --num_hops=2 --random_splits\n\necho \"SGC\"\npython sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005\npython sgc.py --dataset=PubMed --K=2 --weight_decay=0.0005 --random_splits\n\necho \"ARMA\"\npython arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0\npython arma.py --dataset=PubMed --num_stacks=2 --num_layers=1 --skip_dropout=0.5 --random_splits\n\necho \"APPNP\"\npython appnp.py --dataset=PubMed --alpha=0.1\npython appnp.py --dataset=PubMed --alpha=0.1 --random_splits\n"
  },
  {
    "path": "benchmark/citation/sgc.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom citation import get_planetoid_dataset, random_planetoid_splits, run\n\nfrom torch_geometric.nn import SGConv\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, required=True)\nparser.add_argument('--random_splits', action='store_true')\nparser.add_argument('--runs', type=int, default=100)\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--lr', type=float, default=0.1)\nparser.add_argument('--weight_decay', type=float, default=0.0005)\nparser.add_argument('--early_stopping', type=int, default=10)\nparser.add_argument('--no_normalize_features', action='store_true')\nparser.add_argument('--K', type=int, default=2)\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, dataset):\n        super().__init__()\n        self.conv1 = SGConv(dataset.num_features, dataset.num_classes,\n                            K=args.K, cached=True)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index = data.x, data.edge_index\n        x = self.conv1(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndataset = get_planetoid_dataset(args.dataset, not args.no_normalize_features)\npermute_masks = random_planetoid_splits if args.random_splits else None\nrun(dataset, Net(dataset), args.runs, args.epochs, args.lr, args.weight_decay,\n    args.early_stopping, args.inference, args.profile, args.bf16, args.compile,\n    permute_masks)\n\nif args.profile:\n    rename_profile_file('citation', SGConv.__name__, args.dataset,\n                        str(args.random_splits),\n                        'inference' if args.inference else 'train')\n"
  },
  {
    "path": "benchmark/citation/statistics.py",
    "content": "from citation import get_planetoid_dataset\n\n\ndef print_dataset(dataset):\n    data = dataset[0]\n    print('Name', dataset)\n    print('Nodes', data.num_nodes)\n    print('Edges', data.num_edges // 2)\n    print('Features', dataset.num_features)\n    print('Classes', dataset.num_classes)\n    print('Label rate', data.train_mask.sum().item() / data.num_nodes)\n    print()\n\n\nfor name in ['Cora', 'CiteSeer', 'PubMed']:\n    print_dataset(get_planetoid_dataset(name))\n"
  },
  {
    "path": "benchmark/citation/train_eval.py",
    "content": "import time\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import tensor\nfrom torch.optim import Adam\n\nfrom torch_geometric.profile import timeit, torch_profile\nfrom torch_geometric.utils import index_to_mask\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\n\ndef random_planetoid_splits(data, num_classes):\n    # Set new random planetoid splits:\n    # * 20 * num_classes labels for training\n    # * 500 labels for validation\n    # * 1000 labels for testing\n\n    indices = []\n    for i in range(num_classes):\n        index = (data.y == i).nonzero().view(-1)\n        index = index[torch.randperm(index.size(0))]\n        indices.append(index)\n\n    train_index = torch.cat([i[:20] for i in indices], dim=0)\n\n    rest_index = torch.cat([i[20:] for i in indices], dim=0)\n    rest_index = rest_index[torch.randperm(rest_index.size(0))]\n\n    data.train_mask = index_to_mask(train_index, size=data.num_nodes)\n    data.val_mask = index_to_mask(rest_index[:500], size=data.num_nodes)\n    data.test_mask = index_to_mask(rest_index[500:1500], size=data.num_nodes)\n\n    return data\n\n\ndef run_train(dataset, model, runs, epochs, lr, weight_decay, early_stopping,\n              profiling, use_compile, permute_masks=None, logger=None):\n    val_losses, accs, durations = [], [], []\n    if use_compile:\n        model = torch.compile(model)\n\n    for run in range(runs):\n        data = dataset[0]\n        if permute_masks is not None:\n            data = permute_masks(data, dataset.num_classes)\n        data = data.to(device)\n\n        model.to(device).reset_parameters()\n        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        elif hasattr(torch.backends,\n                     'mps') and torch.backends.mps.is_available():\n            try:\n                torch.mps.synchronize()\n            except ImportError:\n                pass\n\n        t_start = time.perf_counter()\n\n        best_val_loss = float('inf')\n        test_acc = 0\n        val_loss_history = []\n\n        for epoch in range(1, epochs + 1):\n            if run == runs - 1 and epoch == epochs:\n                with timeit():\n                    train(model, optimizer, data)\n            else:\n                train(model, optimizer, data)\n            eval_info = evaluate(model, data)\n            eval_info['epoch'] = epoch\n\n            if logger is not None:\n                logger(eval_info)\n\n            if eval_info['val_loss'] < best_val_loss:\n                best_val_loss = eval_info['val_loss']\n                test_acc = eval_info['test_acc']\n\n            val_loss_history.append(eval_info['val_loss'])\n            if early_stopping > 0 and epoch > epochs // 2:\n                tmp = tensor(val_loss_history[-(early_stopping + 1):-1])\n                if eval_info['val_loss'] > tmp.mean().item():\n                    break\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        elif hasattr(torch.backends,\n                     'mps') and torch.backends.mps.is_available():\n            try:\n                torch.mps.synchronize()\n            except ImportError:\n                pass\n\n        t_end = time.perf_counter()\n\n        val_losses.append(best_val_loss)\n        accs.append(test_acc)\n        durations.append(t_end - t_start)\n    loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations)\n\n    print(f'Val Loss: {float(loss.mean()):.4f}, '\n          f'Test Accuracy: {float(acc.mean()):.3f} ± {float(acc.std()):.3f}, '\n          f'Duration: {float(duration.mean()):.3f}s')\n\n    if profiling:\n        with torch_profile():\n            train(model, optimizer, data)\n\n\n@torch.no_grad()\ndef run_inference(dataset, model, epochs, profiling, bf16, use_compile,\n                  permute_masks=None, logger=None):\n    data = dataset[0]\n    if permute_masks is not None:\n        data = permute_masks(data, dataset.num_classes)\n    data = data.to(device)\n\n    model.to(device).reset_parameters()\n    if use_compile:\n        model = torch.compile(model)\n\n    if torch.cuda.is_available():\n        amp = torch.amp.autocast('cuda', enabled=False)\n    else:\n        amp = torch.cpu.amp.autocast(enabled=bf16)\n    if bf16:\n        data.x = data.x.to(torch.bfloat16)\n\n    with amp:\n        for epoch in range(1, epochs + 1):\n            if epoch == epochs:\n                with timeit():\n                    inference(model, data)\n            else:\n                inference(model, data)\n\n        if profiling:\n            with torch_profile():\n                inference(model, data)\n\n\ndef run(dataset, model, runs, epochs, lr, weight_decay, early_stopping,\n        inference, profiling, bf16, use_compile, permute_masks=None,\n        logger=None):\n    if not inference:\n        run_train(dataset, model, runs, epochs, lr, weight_decay,\n                  early_stopping, profiling, use_compile, permute_masks,\n                  logger)\n    else:\n        run_inference(dataset, model, epochs, profiling, bf16, use_compile,\n                      permute_masks, logger)\n\n\ndef train(model, optimizer, data):\n    model.train()\n    optimizer.zero_grad()\n    out = model(data)\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n\n\n@torch.no_grad()\ndef evaluate(model, data):\n    model.eval()\n\n    out = model(data)\n\n    outs = {}\n    for key in ['train', 'val', 'test']:\n        mask = data[f'{key}_mask']\n        loss = float(F.nll_loss(out[mask], data.y[mask]))\n        pred = out[mask].argmax(1)\n        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n\n        outs[f'{key}_loss'] = loss\n        outs[f'{key}_acc'] = acc\n\n    return outs\n\n\n@torch.no_grad()\ndef inference(model, data):\n    model.eval()\n    model(data)\n"
  },
  {
    "path": "benchmark/inference/README.md",
    "content": "# Inference Benchmark\n\n## Environment setup\n\n1. Confirm that PyG is properly installed.\n1. Install dataset package:\n   ```bash\n   pip install ogb\n   ```\n1. Install `autoconf` required for `jemalloc` setup:\n   ```bash\n   sudo apt-get install autoconf\n   ```\n1. Install `jemalloc` for performance benchmark:\n   ```bash\n   cd ${workspace}\n   git clone https://github.com/jemalloc/jemalloc.git\n   cd jemalloc\n   git checkout 5.2.1\n   ./autogen.sh\n   ./configure --prefix=${workspace}/jemalloc-bin\n   make\n   make install\n   ```\n\n## Running benchmark\n\n1. Set environment variables:\n   ```bash\n   source activate env_name\n   export DNNL_PRIMITIVE_CACHE_CAPACITY=1024\n   export KMP_BLOCKTIME=1\n   export KMP_AFFINITY=granularity=fine,compact,1,0\n\n   jemalloc_lib=${workspace}/jemalloc-bin/lib/libjemalloc.so\n   export LD_PRELOAD=\"$jemalloc_lib\"\n   export MALLOC_CONF=\"oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000\"\n   ```\n1. Core binding, *e.g.*, single socket / single core / 4 cores per instance:\n   ```bash\n   OMP_NUM_THREADS=${CORES} numactl -C 0-${LAST_CORE} -m 0 CMD......\n   ```\n1. Execute benchmarks, *e.g.*:\n   ```bash\n   python -u inference_benchmark.py --datasets=Reddit --models=gcn --eval-batch-sizes=512 --num-layers=2 --num-hidden-channels=64\n   python -u inference_benchmark.py --datasets=Reddit --models=gcn --eval-batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --use-sparse-tensor\n   python -u inference_benchmark.py --datasets=ogbn-products --models=sage --eval-batch-sizes=512 --num-layers=2 --num-hidden-channels=64\n   python -u inference_benchmark.py --datasets=ogbn-products --models=sage --eval-batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --use-sparse-tensor\n   ```\n"
  },
  {
    "path": "benchmark/inference/inference_benchmark.py",
    "content": "import argparse\nimport warnings\nfrom collections import defaultdict\nfrom contextlib import nullcontext\n\nimport torch\n\nfrom benchmark.utils import (\n    emit_itt,\n    get_dataset_with_transformation,\n    get_model,\n    get_split_masks,\n    save_benchmark_data,\n    test,\n    write_to_csv,\n)\nfrom torch_geometric.io import fs\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import PNAConv\nfrom torch_geometric.profile import (\n    rename_profile_file,\n    timeit,\n    torch_profile,\n    xpu_profile,\n)\n\nsupported_sets = {\n    'ogbn-mag': ['rgat', 'rgcn'],\n    'ogbn-products': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'],\n    'Reddit': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'],\n}\n\n\n@torch.no_grad()\ndef full_batch_inference(model, data):\n    model.eval()\n    if hasattr(data, 'adj_t'):\n        edge_index = data.adj_t\n    else:\n        edge_index = data.edge_index\n    return model(data.x, edge_index)\n\n\ndef run(args: argparse.ArgumentParser):\n    csv_data = defaultdict(list)\n\n    if args.write_csv == 'prof' and not args.profile:\n        warnings.warn(\n            \"Cannot write profile data to CSV because profiling is \"\n            \"disabled\", stacklevel=2)\n\n    if args.device == 'xpu':\n        try:\n            import intel_extension_for_pytorch as ipex\n        except ImportError as e:\n            raise RuntimeError(\n                'XPU device requires IPEX to be installed') from e\n\n    if ((args.device == 'cuda' and not torch.cuda.is_available())\n            or (args.device == 'xpu' and not torch.xpu.is_available())):\n        raise RuntimeError(f'{args.device.upper()} is not available')\n\n    if args.device == 'cuda' and args.full_batch:\n        raise RuntimeError('CUDA device is not suitable for full batch mode')\n\n    device = torch.device(args.device)\n\n    print('BENCHMARK STARTS')\n    print(f'Running on {args.device.upper()}')\n    for dataset_name in args.datasets:\n        assert dataset_name in supported_sets.keys(\n        ), f\"Dataset {dataset_name} isn't supported.\"\n        print(f'Dataset: {dataset_name}')\n        load_time = timeit() if args.measure_load_time else nullcontext()\n        with load_time:\n            result = get_dataset_with_transformation(dataset_name, args.root,\n                                                     args.use_sparse_tensor,\n                                                     args.bf16)\n            dataset, num_classes, transformation = result\n        data = dataset.to(device)\n        hetero = True if dataset_name == 'ogbn-mag' else False\n        mask = ('paper', None) if dataset_name == 'ogbn-mag' else None\n        _, _, test_mask = get_split_masks(data, dataset_name)\n        degree = None\n\n        if hetero and args.cached_loader:\n            args.cached_loader = False\n            print('Disabling CachedLoader, not supported in Hetero models')\n        if args.num_layers != [1] and not hetero and args.num_steps != -1:\n            raise ValueError(\"Layer-wise inference requires `steps=-1`\")\n\n        if args.device == 'cuda':\n            amp = torch.amp.autocast('cuda', enabled=False)\n        elif args.device == 'xpu':\n            amp = torch.xpu.amp.autocast(enabled=False)\n        else:\n            amp = torch.cpu.amp.autocast(enabled=args.bf16)\n\n        if args.device == 'xpu' and args.warmup < 1:\n            print('XPU device requires warmup - setting warmup=1')\n            args.warmup = 1\n\n        inputs_channels = data[\n            'paper'].num_features if dataset_name == 'ogbn-mag' \\\n            else dataset.num_features\n\n        for model_name in args.models:\n            if model_name not in supported_sets[dataset_name]:\n                print(f'Configuration of {dataset_name} + {model_name} '\n                      f'not supported. Skipping.')\n                continue\n            with_loader = not args.full_batch or (model_name == 'pna'\n                                                  and degree is None)\n            print(f'Evaluation bench for {model_name}:')\n\n            for batch_size in args.eval_batch_sizes:\n                num_nodes = data[\n                    'paper'].num_nodes if hetero else data.num_nodes\n                sampler = torch.utils.data.RandomSampler(\n                    range(num_nodes), num_samples=args.num_steps * batch_size\n                ) if args.num_steps != -1 and with_loader else None\n                kwargs = {\n                    'batch_size': batch_size,\n                    'shuffle': False,\n                    'num_workers': args.num_workers,\n                }\n                if not hetero:\n                    subgraph_loader = NeighborLoader(\n                        data,\n                        num_neighbors=[-1],  # layer-wise inference\n                        input_nodes=mask,\n                        sampler=sampler,\n                        **kwargs,\n                    ) if with_loader else None\n                    if args.evaluate and not args.full_batch:\n                        test_loader = NeighborLoader(\n                            data,\n                            num_neighbors=[-1],  # layer-wise inference\n                            input_nodes=test_mask,\n                            sampler=None,\n                            **kwargs,\n                        )\n\n                for layers in args.num_layers:\n                    num_neighbors = [args.hetero_num_neighbors] * layers\n                    if hetero:\n                        # batch-wise inference\n                        subgraph_loader = NeighborLoader(\n                            data,\n                            num_neighbors=num_neighbors,\n                            input_nodes=mask,\n                            sampler=sampler,\n                            **kwargs,\n                        ) if with_loader else None\n                        if args.evaluate and not args.full_batch:\n                            test_loader = NeighborLoader(\n                                data,\n                                num_neighbors=num_neighbors,\n                                input_nodes=test_mask,\n                                sampler=None,\n                                **kwargs,\n                            )\n\n                    for hidden_channels in args.num_hidden_channels:\n                        print('----------------------------------------------')\n                        print(f'Batch size={batch_size}, '\n                              f'Layers amount={layers}, '\n                              f'Num_neighbors={num_neighbors}, '\n                              f'Hidden features size={hidden_channels}, '\n                              f'Sparse tensor={args.use_sparse_tensor}')\n                        params = {\n                            'inputs_channels': inputs_channels,\n                            'hidden_channels': hidden_channels,\n                            'output_channels': num_classes,\n                            'num_heads': args.num_heads,\n                            'num_layers': layers,\n                        }\n\n                        if model_name == 'pna':\n                            if degree is None:\n                                degree = PNAConv.get_degree_histogram(\n                                    subgraph_loader)\n                                print(f'Calculated degree for {dataset_name}.')\n                            params['degree'] = degree\n\n                        model = get_model(\n                            model_name, params,\n                            metadata=data.metadata() if hetero else None)\n                        model = model.to(device)\n                        # TODO: Migrate to ModelHubMixin.\n                        if args.ckpt_path:\n                            state_dict = fs.torch_load(args.ckpt_path)\n                            model.load_state_dict(state_dict)\n                        model.eval()\n                        if args.device == 'xpu':\n                            model = ipex.optimize(model)\n\n                        # Define context manager parameters:\n                        if args.cpu_affinity and with_loader:\n                            cpu_affinity = subgraph_loader.enable_cpu_affinity(\n                                args.loader_cores)\n                        else:\n                            cpu_affinity = nullcontext()\n                        if args.profile and args.device == 'xpu':\n                            profile = xpu_profile(args.export_chrome_trace)\n                        elif args.profile:\n                            profile = torch_profile(args.export_chrome_trace,\n                                                    csv_data, args.write_csv)\n                        else:\n                            profile = nullcontext()\n                        itt = emit_itt(\n                        ) if args.vtune_profile else nullcontext()\n\n                        if args.full_batch and args.use_sparse_tensor:\n                            data = transformation(data)\n\n                        with cpu_affinity, amp, timeit() as time:\n                            inference_kwargs = dict(cache=args.cached_loader)\n                            if args.reuse_device_for_embeddings and not hetero:\n                                inference_kwargs['embedding_device'] = device\n                            for _ in range(args.warmup):\n                                if args.full_batch:\n                                    full_batch_inference(model, data)\n                                else:\n                                    model.inference(\n                                        subgraph_loader,\n                                        device,\n                                        progress_bar=True,\n                                        **inference_kwargs,\n                                    )\n                            if args.warmup > 0:\n                                time.reset()\n                            with itt, profile:\n                                if args.full_batch:\n                                    y = full_batch_inference(model, data)\n                                    if args.evaluate:\n                                        mask = data.test_mask\n                                        pred = y[mask].argmax(1)\n                                        test_acc = pred.eq(data.y[mask]).sum(\n                                        ).item() / mask.sum().item()\n                                        print(f'Full Batch Test Accuracy: \\\n                                            {test_acc:.4f}')\n                                else:\n                                    y = model.inference(\n                                        subgraph_loader,\n                                        device,\n                                        progress_bar=True,\n                                        **inference_kwargs,\n                                    )\n                                    if args.evaluate:\n                                        test_acc = test(\n                                            model,\n                                            test_loader,\n                                            device,\n                                            hetero,\n                                            progress_bar=True,\n                                        )\n                                        print(f'Mini Batch Test Accuracy: \\\n                                            {test_acc:.4f}')\n\n                        if args.profile and args.export_chrome_trace:\n                            rename_profile_file(model_name, dataset_name,\n                                                str(batch_size), str(layers),\n                                                str(hidden_channels),\n                                                str(num_neighbors))\n                        total_time = time.duration\n                        if args.num_steps != -1:\n                            total_num_samples = args.num_steps * batch_size\n                        else:\n                            total_num_samples = num_nodes\n                        throughput = total_num_samples / total_time\n                        latency = total_time / total_num_samples * 1000\n                        print(f'Throughput: {throughput:.3f} samples/s')\n                        print(f'Latency: {latency:.3f} ms')\n\n                        num_records = 1\n                        if args.write_csv == 'prof':\n                            # For profiling with PyTorch, we save the top-5\n                            # most time consuming operations. Therefore, the\n                            # same data should be entered for each of them.\n                            num_records = 5\n                        for _ in range(num_records):\n                            save_benchmark_data(\n                                csv_data,\n                                batch_size,\n                                layers,\n                                num_neighbors,\n                                hidden_channels,\n                                total_time,\n                                model_name,\n                                dataset_name,\n                                args.use_sparse_tensor,\n                            )\n    if args.write_csv:\n        write_to_csv(csv_data, args.write_csv)\n\n\nif __name__ == '__main__':\n    argparser = argparse.ArgumentParser('GNN inference benchmark')\n    add = argparser.add_argument\n\n    add('--device', choices=['cpu', 'cuda', 'xpu'], default='cpu',\n        help='Device to run benchmark on')\n    add('--reuse-device-for-embeddings', action='store_true',\n        help='Use the same device for embeddings as specified in \"--device\"')\n    add('--datasets', nargs='+',\n        default=['ogbn-mag', 'ogbn-products', 'Reddit'], type=str)\n    add('--use-sparse-tensor', action='store_true',\n        help='use torch_sparse.SparseTensor as graph storage format')\n    add('--models', nargs='+',\n        default=['edge_cnn', 'gat', 'gcn', 'pna', 'rgat', 'rgcn'], type=str)\n    add('--root', default='../../data', type=str,\n        help='relative path to look for the datasets')\n    add('--eval-batch-sizes', nargs='+', default=[512, 1024, 2048, 4096, 8192],\n        type=int)\n    add('--num-layers', nargs='+', default=[2, 3], type=int)\n    add('--num-hidden-channels', nargs='+', default=[64, 128, 256], type=int)\n    add('--num-heads', default=2, type=int,\n        help='number of hidden attention heads, applies only for gat and rgat')\n    add('--hetero-num-neighbors', default=10, type=int,\n        help='number of neighbors to sample per layer for hetero workloads')\n    add('--num-workers', default=0, type=int)\n    add('--num-steps', default=-1, type=int,\n        help='number of steps, -1 means iterating through all the data')\n    add('--warmup', default=1, type=int)\n    add('--profile', action='store_true')\n    add('--vtune-profile', action='store_true')\n    add('--bf16', action='store_true')\n    add('--cpu-affinity', action='store_true',\n        help='Use DataLoader affinitzation.')\n    add('--loader-cores', nargs='+', default=[], type=int,\n        help=\"List of CPU core IDs to use for DataLoader workers\")\n    add('--measure-load-time', action='store_true')\n    add('--full-batch', action='store_true', help='Use full batch mode')\n    add('--evaluate', action='store_true')\n    add('--ckpt_path', type=str, help='Checkpoint path for loading a model')\n    add('--write-csv', choices=[None, 'bench', 'prof'], default=None,\n        help='Write benchmark or PyTorch profile data to CSV')\n    add('--export-chrome-trace', default=True, type=bool,\n        help='Export chrome trace file. Works only with PyTorch profiler')\n    add('--cached-loader', action='store_true', help='Use CachedLoader')\n    run(argparser.parse_args())\n"
  },
  {
    "path": "benchmark/kernel/README.md",
    "content": "# Graph Classification\n\nEvaluation script for various methods on [common benchmark datasets](https://chrsmrrs.github.io/datasets/) via 10-fold cross validation, where a training fold is randomly sampled to serve as a validation set.\nHyperparameter selection is performed for the number of hidden units and the number of layers with respect to the validation set:\n\n- **[GCN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/gcn.py)**\n- **[GraphSAGE](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/graph_sage.py)**\n- **[GIN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/gin.py)**\n- **[Graclus](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/graclus.py)**\n- **[Top-K Pooling](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/top_k.py)**\n- **[SAG Pooling](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sag_pool.py)**\n- **[DiffPool](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/diff_pool.py)**\n- **[EdgePool](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/edge_pool.py)**\n- **[GlobalAttention](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/global_attention.py)**\n- **[Set2Set](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/set2set.py)**\n- **[SortPool](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/sort_pool.py)**\n- **[ASAPool](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/kernel/asap.py)**\n\nRun (or modify) the whole test suite via\n\n```\n$ python main.py\n```\n\nFor more comprehensive time-measurement and memory usage information, you may use\n\n```\n$ python main_performance.py\n```\n"
  },
  {
    "path": "benchmark/kernel/__init__.py",
    "content": "from .datasets import get_dataset\nfrom .train_eval import cross_validation_with_val_set\n\n__all__ = [\n    'get_dataset',\n    'cross_validation_with_val_set',\n]\n"
  },
  {
    "path": "benchmark/kernel/asap.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import (\n    ASAPooling,\n    GraphConv,\n    JumpingKnowledge,\n    global_mean_pool,\n)\n\n\nclass ASAP(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden, ratio=0.8, dropout=0):\n        super().__init__()\n        self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')\n        self.convs = torch.nn.ModuleList()\n        self.pools = torch.nn.ModuleList()\n        self.convs.extend([\n            GraphConv(hidden, hidden, aggr='mean')\n            for i in range(num_layers - 1)\n        ])\n        self.pools.extend([\n            ASAPooling(hidden, ratio, dropout=dropout)\n            for i in range((num_layers) // 2)\n        ])\n        self.jump = JumpingKnowledge(mode='cat')\n        self.lin1 = Linear(num_layers * hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        for pool in self.pools:\n            pool.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        edge_weight = None\n        x = F.relu(self.conv1(x, edge_index))\n        xs = [global_mean_pool(x, batch)]\n        for i, conv in enumerate(self.convs):\n            x = conv(x=x, edge_index=edge_index, edge_weight=edge_weight)\n            x = F.relu(x)\n            xs += [global_mean_pool(x, batch)]\n            if i % 2 == 0 and i < len(self.convs) - 1:\n                pool = self.pools[i // 2]\n                x, edge_index, edge_weight, batch, _ = pool(\n                    x=x, edge_index=edge_index, edge_weight=edge_weight,\n                    batch=batch)\n        x = self.jump(xs)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/datasets.py",
    "content": "import os.path as osp\n\nimport torch\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.utils import degree\n\n\nclass NormalizedDegree:\n    def __init__(self, mean, std):\n        self.mean = mean\n        self.std = std\n\n    def __call__(self, data):\n        deg = degree(data.edge_index[0], dtype=torch.float)\n        deg = (deg - self.mean) / self.std\n        data.x = deg.view(-1, 1)\n        return data\n\n\ndef get_dataset(name, sparse=True, cleaned=False):\n    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)\n    dataset = TUDataset(path, name, cleaned=cleaned)\n    dataset.data.edge_attr = None\n\n    if dataset.data.x is None:\n        max_degree = 0\n        degs = []\n        for data in dataset:\n            degs += [degree(data.edge_index[0], dtype=torch.long)]\n            max_degree = max(max_degree, degs[-1].max().item())\n\n        if max_degree < 1000:\n            dataset.transform = T.OneHotDegree(max_degree)\n        else:\n            deg = torch.cat(degs, dim=0).to(torch.float)\n            mean, std = deg.mean().item(), deg.std().item()\n            dataset.transform = NormalizedDegree(mean, std)\n\n    if not sparse:\n        num_nodes = max_num_nodes = 0\n        for data in dataset:\n            num_nodes += data.num_nodes\n            max_num_nodes = max(data.num_nodes, max_num_nodes)\n\n        # Filter out a few really large graphs in order to apply DiffPool.\n        if name == 'REDDIT-BINARY':\n            num_nodes = min(int(num_nodes / len(dataset) * 1.5), max_num_nodes)\n        else:\n            num_nodes = min(int(num_nodes / len(dataset) * 5), max_num_nodes)\n\n        indices = []\n        for i, data in enumerate(dataset):\n            if data.num_nodes <= num_nodes:\n                indices.append(i)\n        dataset = dataset.copy(torch.tensor(indices))\n\n        if dataset.transform is None:\n            dataset.transform = T.ToDense(num_nodes)\n        else:\n            dataset.transform = T.Compose(\n                [dataset.transform, T.ToDense(num_nodes)])\n\n    return dataset\n"
  },
  {
    "path": "benchmark/kernel/diff_pool.py",
    "content": "from math import ceil\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import DenseSAGEConv, JumpingKnowledge, dense_diff_pool\n\n\nclass Block(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, mode='cat'):\n        super().__init__()\n\n        self.conv1 = DenseSAGEConv(in_channels, hidden_channels)\n        self.conv2 = DenseSAGEConv(hidden_channels, out_channels)\n        self.jump = JumpingKnowledge(mode)\n        if mode == 'cat':\n            self.lin = Linear(hidden_channels + out_channels, out_channels)\n        else:\n            self.lin = Linear(out_channels, out_channels)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        self.conv2.reset_parameters()\n        self.lin.reset_parameters()\n\n    def forward(self, x, adj, mask=None):\n        x1 = F.relu(self.conv1(x, adj, mask))\n        x2 = F.relu(self.conv2(x1, adj, mask))\n        return self.lin(self.jump([x1, x2]))\n\n\nclass DiffPool(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden, ratio=0.25):\n        super().__init__()\n\n        num_nodes = ceil(ratio * dataset[0].num_nodes)\n        self.embed_block1 = Block(dataset.num_features, hidden, hidden)\n        self.pool_block1 = Block(dataset.num_features, hidden, num_nodes)\n\n        self.embed_blocks = torch.nn.ModuleList()\n        self.pool_blocks = torch.nn.ModuleList()\n        for _ in range((num_layers // 2) - 1):\n            num_nodes = ceil(ratio * num_nodes)\n            self.embed_blocks.append(Block(hidden, hidden, hidden))\n            self.pool_blocks.append(Block(hidden, hidden, num_nodes))\n\n        self.jump = JumpingKnowledge(mode='cat')\n        self.lin1 = Linear((len(self.embed_blocks) + 1) * hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.embed_block1.reset_parameters()\n        self.pool_block1.reset_parameters()\n        for embed_block, pool_block in zip(self.embed_blocks,\n                                           self.pool_blocks):\n            embed_block.reset_parameters()\n            pool_block.reset_parameters()\n        self.jump.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, adj, mask = data.x, data.adj, data.mask\n\n        s = self.pool_block1(x, adj, mask)\n        x = F.relu(self.embed_block1(x, adj, mask))\n        xs = [x.mean(dim=1)]\n        x, adj, _, _ = dense_diff_pool(x, adj, s, mask)\n\n        for i, (embed_block, pool_block) in enumerate(\n                zip(self.embed_blocks, self.pool_blocks)):\n            s = pool_block(x, adj)\n            x = F.relu(embed_block(x, adj))\n            xs.append(x.mean(dim=1))\n            if i < len(self.embed_blocks) - 1:\n                x, adj, _, _ = dense_diff_pool(x, adj, s)\n\n        x = self.jump(xs)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/edge_pool.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import (\n    EdgePooling,\n    GraphConv,\n    JumpingKnowledge,\n    global_mean_pool,\n)\n\n\nclass EdgePool(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden):\n        super().__init__()\n        self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')\n        self.convs = torch.nn.ModuleList()\n        self.pools = torch.nn.ModuleList()\n        self.convs.extend([\n            GraphConv(hidden, hidden, aggr='mean')\n            for i in range(num_layers - 1)\n        ])\n        self.pools.extend(\n            [EdgePooling(hidden) for i in range((num_layers) // 2)])\n        self.jump = JumpingKnowledge(mode='cat')\n        self.lin1 = Linear(num_layers * hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        for pool in self.pools:\n            pool.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        xs = [global_mean_pool(x, batch)]\n        for i, conv in enumerate(self.convs):\n            x = F.relu(conv(x, edge_index))\n            xs += [global_mean_pool(x, batch)]\n            if i % 2 == 0 and i < len(self.convs) - 1:\n                pool = self.pools[i // 2]\n                x, edge_index, batch, _ = pool(x, edge_index, batch=batch)\n        x = self.jump(xs)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/gcn.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import GCNConv, JumpingKnowledge, global_mean_pool\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden):\n        super().__init__()\n        self.conv1 = GCNConv(dataset.num_features, hidden)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(GCNConv(hidden, hidden))\n        self.lin1 = Linear(hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        for conv in self.convs:\n            x = F.relu(conv(x, edge_index))\n        x = global_mean_pool(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\nclass GCNWithJK(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden, mode='cat'):\n        super().__init__()\n        self.conv1 = GCNConv(dataset.num_features, hidden)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(GCNConv(hidden, hidden))\n        self.jump = JumpingKnowledge(mode)\n        if mode == 'cat':\n            self.lin1 = Linear(num_layers * hidden, hidden)\n        else:\n            self.lin1 = Linear(hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.jump.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        xs = [x]\n        for conv in self.convs:\n            x = F.relu(conv(x, edge_index))\n            xs += [x]\n        x = self.jump(xs)\n        x = global_mean_pool(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/gin.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import BatchNorm1d as BN\nfrom torch.nn import Linear, ReLU, Sequential\n\nfrom torch_geometric.nn import GINConv, JumpingKnowledge, global_mean_pool\n\n\nclass GIN0(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden):\n        super().__init__()\n        self.conv1 = GINConv(\n            Sequential(\n                Linear(dataset.num_features, hidden),\n                ReLU(),\n                BN(hidden),\n                Linear(hidden, hidden),\n                ReLU(),\n                BN(hidden),\n            ), train_eps=False)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(\n                GINConv(\n                    Sequential(\n                        Linear(hidden, hidden),\n                        ReLU(),\n                        BN(hidden),\n                        Linear(hidden, hidden),\n                        ReLU(),\n                        BN(hidden),\n                    ), train_eps=False))\n        self.lin1 = Linear(hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = self.conv1(x, edge_index)\n        for conv in self.convs:\n            x = conv(x, edge_index)\n        x = global_mean_pool(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\nclass GIN0WithJK(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden, mode='cat'):\n        super().__init__()\n        self.conv1 = GINConv(\n            Sequential(\n                Linear(dataset.num_features, hidden),\n                ReLU(),\n                BN(hidden),\n                Linear(hidden, hidden),\n                ReLU(),\n                BN(hidden),\n            ), train_eps=False)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(\n                GINConv(\n                    Sequential(\n                        Linear(hidden, hidden),\n                        ReLU(),\n                        BN(hidden),\n                        Linear(hidden, hidden),\n                        ReLU(),\n                        BN(hidden),\n                    ), train_eps=False))\n        self.jump = JumpingKnowledge(mode)\n        if mode == 'cat':\n            self.lin1 = Linear(num_layers * hidden, hidden)\n        else:\n            self.lin1 = Linear(hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.jump.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = self.conv1(x, edge_index)\n        xs = [x]\n        for conv in self.convs:\n            x = conv(x, edge_index)\n            xs += [x]\n        x = self.jump(xs)\n        x = global_mean_pool(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\nclass GIN(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden):\n        super().__init__()\n        self.conv1 = GINConv(\n            Sequential(\n                Linear(dataset.num_features, hidden),\n                ReLU(),\n                BN(hidden),\n                Linear(hidden, hidden),\n                ReLU(),\n                BN(hidden),\n            ), train_eps=True)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(\n                GINConv(\n                    Sequential(\n                        Linear(hidden, hidden),\n                        ReLU(),\n                        BN(hidden),\n                        Linear(hidden, hidden),\n                        ReLU(),\n                        BN(hidden),\n                    ), train_eps=True))\n        self.lin1 = Linear(hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = self.conv1(x, edge_index)\n        for conv in self.convs:\n            x = conv(x, edge_index)\n        x = global_mean_pool(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\nclass GINWithJK(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden, mode='cat'):\n        super().__init__()\n        self.conv1 = GINConv(\n            Sequential(\n                Linear(dataset.num_features, hidden),\n                ReLU(),\n                BN(hidden),\n                Linear(hidden, hidden),\n                ReLU(),\n                BN(hidden),\n            ), train_eps=True)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(\n                GINConv(\n                    Sequential(\n                        Linear(hidden, hidden),\n                        ReLU(),\n                        BN(hidden),\n                        Linear(hidden, hidden),\n                        ReLU(),\n                        BN(hidden),\n                    ), train_eps=True))\n        self.jump = JumpingKnowledge(mode)\n        if mode == 'cat':\n            self.lin1 = Linear(num_layers * hidden, hidden)\n        else:\n            self.lin1 = Linear(hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.jump.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = self.conv1(x, edge_index)\n        xs = [x]\n        for conv in self.convs:\n            x = conv(x, edge_index)\n            xs += [x]\n        x = self.jump(xs)\n        x = global_mean_pool(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/global_attention.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import AttentionalAggregation, SAGEConv\n\n\nclass GlobalAttentionNet(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden):\n        super().__init__()\n        self.conv1 = SAGEConv(dataset.num_features, hidden)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(SAGEConv(hidden, hidden))\n        self.att = AttentionalAggregation(Linear(hidden, 1))\n        self.lin1 = Linear(hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.att.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        for conv in self.convs:\n            x = F.relu(conv(x, edge_index))\n        x = self.att(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/graclus.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.data import Batch\nfrom torch_geometric.nn import (\n    GraphConv,\n    JumpingKnowledge,\n    global_mean_pool,\n    graclus,\n    max_pool,\n)\n\n\nclass Graclus(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden):\n        super().__init__()\n        self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(GraphConv(hidden, hidden, aggr='mean'))\n        self.jump = JumpingKnowledge(mode='cat')\n        self.lin1 = Linear(num_layers * hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.jump.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        xs = [global_mean_pool(x, batch)]\n        for i, conv in enumerate(self.convs):\n            x = F.relu(conv(x, edge_index))\n            xs += [global_mean_pool(x, batch)]\n            if i % 2 == 0 and i < len(self.convs) - 1:\n                cluster = graclus(edge_index, num_nodes=x.size(0))\n                data = Batch(x=x, edge_index=edge_index, batch=batch)\n                data = max_pool(cluster, data)\n                x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = self.jump(xs)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/graph_sage.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import JumpingKnowledge, SAGEConv, global_add_pool\n\n\nclass GraphSAGE(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden):\n        super().__init__()\n        self.conv1 = SAGEConv(dataset.num_features, hidden)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(SAGEConv(hidden, hidden))\n        self.lin1 = Linear(hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        for conv in self.convs:\n            x = F.relu(conv(x, edge_index))\n        x = global_add_pool(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\nclass GraphSAGEWithJK(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden, mode='cat'):\n        super().__init__()\n        self.conv1 = SAGEConv(dataset.num_features, hidden)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(SAGEConv(hidden, hidden))\n        self.jump = JumpingKnowledge(mode)\n        if mode == 'cat':\n            self.lin1 = Linear(num_layers * hidden, hidden)\n        else:\n            self.lin1 = Linear(hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.jump.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        xs = [x]\n        for conv in self.convs:\n            x = F.relu(conv(x, edge_index))\n            xs += [x]\n        x = self.jump(xs)\n        x = global_add_pool(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/main.py",
    "content": "import argparse\nfrom itertools import product\n\nfrom asap import ASAP\nfrom datasets import get_dataset\nfrom diff_pool import DiffPool\nfrom edge_pool import EdgePool\nfrom gcn import GCN, GCNWithJK\nfrom gin import GIN, GIN0, GIN0WithJK, GINWithJK\nfrom global_attention import GlobalAttentionNet\nfrom graclus import Graclus\nfrom graph_sage import GraphSAGE, GraphSAGEWithJK\nfrom sag_pool import SAGPool\nfrom set2set import Set2SetNet\nfrom sort_pool import SortPool\nfrom top_k import TopK\nfrom train_eval import cross_validation_with_val_set\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--epochs', type=int, default=100)\nparser.add_argument('--batch_size', type=int, default=128)\nparser.add_argument('--lr', type=float, default=0.01)\nparser.add_argument('--lr_decay_factor', type=float, default=0.5)\nparser.add_argument('--lr_decay_step_size', type=int, default=50)\nargs = parser.parse_args()\n\nlayers = [1, 2, 3, 4, 5]\nhiddens = [16, 32, 64, 128]\ndatasets = ['MUTAG', 'PROTEINS', 'IMDB-BINARY', 'REDDIT-BINARY']  # , 'COLLAB']\nnets = [\n    GCNWithJK,\n    GraphSAGEWithJK,\n    GIN0WithJK,\n    GINWithJK,\n    Graclus,\n    TopK,\n    SAGPool,\n    DiffPool,\n    EdgePool,\n    GCN,\n    GraphSAGE,\n    GIN0,\n    GIN,\n    GlobalAttentionNet,\n    Set2SetNet,\n    SortPool,\n    ASAP,\n]\n\n\ndef logger(info):\n    fold, epoch = info['fold'] + 1, info['epoch']\n    val_loss, test_acc = info['val_loss'], info['test_acc']\n    print(f'{fold:02d}/{epoch:03d}: Val Loss: {val_loss:.4f}, '\n          f'Test Accuracy: {test_acc:.3f}')\n\n\nresults = []\nfor dataset_name, Net in product(datasets, nets):\n    best_result = (float('inf'), 0, 0)  # (loss, acc, std)\n    print(f'--\\n{dataset_name} - {Net.__name__}')\n    for num_layers, hidden in product(layers, hiddens):\n        dataset = get_dataset(dataset_name, sparse=Net != DiffPool)\n        model = Net(dataset, num_layers, hidden)\n        loss, acc, std = cross_validation_with_val_set(\n            dataset,\n            model,\n            folds=10,\n            epochs=args.epochs,\n            batch_size=args.batch_size,\n            lr=args.lr,\n            lr_decay_factor=args.lr_decay_factor,\n            lr_decay_step_size=args.lr_decay_step_size,\n            weight_decay=0,\n            logger=None,\n        )\n        if loss < best_result[0]:\n            best_result = (loss, acc, std)\n\n    desc = f'{best_result[1]:.3f} ± {best_result[2]:.3f}'\n    print(f'Best result - {desc}')\n    results += [f'{dataset_name} - {model}: {desc}']\nresults = '\\n'.join(results)\nprint(f'--\\n{results}')\n"
  },
  {
    "path": "benchmark/kernel/main_performance.py",
    "content": "import argparse\nfrom itertools import product\n\nimport torch\nfrom datasets import get_dataset\nfrom gcn import GCN\nfrom gin import GIN\nfrom graph_sage import GraphSAGE\nfrom train_eval import eval_acc, inference_run, train\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.profile import rename_profile_file, timeit, torch_profile\n\nseed_everything(0)\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    '--datasets', type=str, nargs='+',\n    default=['MUTAG', 'PROTEINS', 'IMDB-BINARY', 'REDDIT-BINARY'])\nparser.add_argument('--models', type=str, nargs='+',\n                    default=['GCN', 'GraphSAGE', 'GIN'])\nparser.add_argument('--layers', type=int, nargs='+', default=[1, 2, 3])\nparser.add_argument('--hiddens', type=int, nargs='+', default=[16, 32])\nparser.add_argument('--epochs', type=int, default=100)\nparser.add_argument('--batch_size', type=int, default=128)\nparser.add_argument('--lr', type=float, default=0.01)\nparser.add_argument('--warmup_profile', type=int, default=1,\n                    help='Skip the first few runs')\nparser.add_argument('--goal_accuracy', type=int, default=1,\n                    help='The goal test accuracy')\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nif torch.cuda.is_available():\n    amp = torch.amp.autocast('cuda', enabled=False)\nelse:\n    amp = torch.cpu.amp.autocast(enabled=args.bf16)\n\nMODELS = {\n    'GCN': GCN,\n    'GraphSAGE': GraphSAGE,\n    'GIN': GIN,\n}\n\n\ndef prepare_dataloader(dataset_name):\n    dataset = get_dataset(dataset_name, sparse=True)\n    num_train = int(len(dataset) * 0.8)\n    num_val = int(len(dataset) * 0.1)\n\n    train_dataset = dataset[:num_train]\n    val_dataset = dataset[num_train:num_train + num_val]\n    test_dataset = dataset[num_train + num_val:]\n\n    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,\n                              shuffle=True)\n    val_loader = DataLoader(val_dataset, batch_size=args.batch_size,\n                            shuffle=False)\n    test_loader = DataLoader(test_dataset, batch_size=args.batch_size,\n                             shuffle=False)\n    return dataset, train_loader, val_loader, test_loader\n\n\ndef run_train():\n    for dataset_name, model_name in product(args.datasets, args.models):\n        dataset, train_loader, val_loader, test_loader = prepare_dataloader(\n            dataset_name)\n        Model = MODELS[model_name]\n\n        for num_layers, hidden in product(args.layers, args.hiddens):\n            print('--')\n            print(f'{dataset_name} - {model_name}- {num_layers} - {hidden}')\n\n            model = Model(dataset, num_layers, hidden).to(device)\n            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n            if args.compile:\n                model = torch.compile(model)\n            loss_list = []\n            acc_list = []\n            for epoch in range(1, args.epochs + 1):\n                if epoch == args.epochs:\n                    with timeit():\n                        loss = train(model, optimizer, train_loader)\n                else:\n                    loss = train(model, optimizer, train_loader)\n\n                with timeit(log=False) as t:\n                    val_acc = eval_acc(model, val_loader)\n                val_time = t.duration\n                with timeit(log=False) as t:\n                    test_acc = eval_acc(model, test_loader)\n                test_time = t.duration\n\n                if epoch >= args.warmup_profile:\n                    loss_list.append(loss)\n                    acc_list.append([val_acc, val_time, test_acc, test_time])\n\n            if args.profile:\n                with torch_profile():\n                    train(model, optimizer, train_loader)\n                rename_profile_file(model_name, dataset_name, str(num_layers),\n                                    str(hidden), 'train')\n\n\n@torch.no_grad()\ndef run_inference():\n    for dataset_name, model_name in product(args.datasets, args.models):\n        dataset, _, _, test_loader = prepare_dataloader(dataset_name)\n        Model = MODELS[model_name]\n\n        for num_layers, hidden in product(args.layers, args.hiddens):\n            print('--')\n            print(f'{dataset_name} - {model_name}- {num_layers} - {hidden}')\n\n            model = Model(dataset, num_layers, hidden).to(device)\n            if args.compile:\n                model = torch.compile(model)\n            with amp:\n                for epoch in range(1, args.epochs + 1):\n                    if epoch == args.epochs:\n                        with timeit():\n                            inference_run(model, test_loader, args.bf16)\n                    else:\n                        inference_run(model, test_loader, args.bf16)\n\n                if args.profile:\n                    with torch_profile():\n                        inference_run(model, test_loader, args.bf16)\n                    rename_profile_file(model_name, dataset_name,\n                                        str(num_layers), str(hidden),\n                                        'inference')\n\n\nif not args.inference:\n    run_train()\nelse:\n    run_inference()\n"
  },
  {
    "path": "benchmark/kernel/sag_pool.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import (\n    GraphConv,\n    JumpingKnowledge,\n    SAGPooling,\n    global_mean_pool,\n)\n\n\nclass SAGPool(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden, ratio=0.8):\n        super().__init__()\n        self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')\n        self.convs = torch.nn.ModuleList()\n        self.pools = torch.nn.ModuleList()\n        self.convs.extend([\n            GraphConv(hidden, hidden, aggr='mean')\n            for i in range(num_layers - 1)\n        ])\n        self.pools.extend(\n            [SAGPooling(hidden, ratio) for i in range((num_layers) // 2)])\n        self.jump = JumpingKnowledge(mode='cat')\n        self.lin1 = Linear(num_layers * hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        for pool in self.pools:\n            pool.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        xs = [global_mean_pool(x, batch)]\n        for i, conv in enumerate(self.convs):\n            x = F.relu(conv(x, edge_index))\n            xs += [global_mean_pool(x, batch)]\n            if i % 2 == 0 and i < len(self.convs) - 1:\n                pool = self.pools[i // 2]\n                x, edge_index, _, batch, _, _ = pool(x, edge_index,\n                                                     batch=batch)\n        x = self.jump(xs)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/set2set.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import SAGEConv, Set2Set\n\n\nclass Set2SetNet(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden):\n        super().__init__()\n        self.conv1 = SAGEConv(dataset.num_features, hidden)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(SAGEConv(hidden, hidden))\n        self.set2set = Set2Set(hidden, processing_steps=4)\n        self.lin1 = Linear(2 * hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.set2set.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        for conv in self.convs:\n            x = F.relu(conv(x, edge_index))\n        x = self.set2set(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/sort_pool.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Conv1d, Linear\n\nfrom torch_geometric.nn import SAGEConv, SortAggregation\n\n\nclass SortPool(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden):\n        super().__init__()\n        self.conv1 = SAGEConv(dataset.num_features, hidden)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(SAGEConv(hidden, hidden))\n        self.pool = SortAggregation(k=30)\n        self.conv1d = Conv1d(hidden, 32, 5)\n        self.lin1 = Linear(32 * (30 - 5 + 1), hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.conv1d.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        for conv in self.convs:\n            x = F.relu(conv(x, edge_index))\n        x = self.pool(x, batch)\n        x = x.view(len(x), self.k, -1).permute(0, 2, 1)\n        x = F.relu(self.conv1d(x))\n        x = x.view(len(x), -1)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/statistics.py",
    "content": "from kernel.datasets import get_dataset\n\n\ndef print_dataset(dataset):\n    num_nodes = num_edges = 0\n    for data in dataset:\n        num_nodes += data.num_nodes\n        num_edges += data.num_edges\n\n    print('Name', dataset)\n    print('Graphs', len(dataset))\n    print('Nodes', num_nodes / len(dataset))\n    print('Edges', (num_edges // 2) / len(dataset))\n    print('Features', dataset.num_features)\n    print('Classes', dataset.num_classes)\n    print()\n\n\nfor name in ['MUTAG', 'PROTEINS', 'COLLAB', 'IMDB-BINARY', 'REDDIT-BINARY']:\n    print_dataset(get_dataset(name))\n"
  },
  {
    "path": "benchmark/kernel/top_k.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import (\n    GraphConv,\n    JumpingKnowledge,\n    TopKPooling,\n    global_mean_pool,\n)\n\n\nclass TopK(torch.nn.Module):\n    def __init__(self, dataset, num_layers, hidden, ratio=0.8):\n        super().__init__()\n        self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean')\n        self.convs = torch.nn.ModuleList()\n        self.pools = torch.nn.ModuleList()\n        self.convs.extend([\n            GraphConv(hidden, hidden, aggr='mean')\n            for i in range(num_layers - 1)\n        ])\n        self.pools.extend(\n            [TopKPooling(hidden, ratio) for i in range((num_layers) // 2)])\n        self.jump = JumpingKnowledge(mode='cat')\n        self.lin1 = Linear(num_layers * hidden, hidden)\n        self.lin2 = Linear(hidden, dataset.num_classes)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        for pool in self.pools:\n            pool.reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n        x = F.relu(self.conv1(x, edge_index))\n        xs = [global_mean_pool(x, batch)]\n        for i, conv in enumerate(self.convs):\n            x = F.relu(conv(x, edge_index))\n            xs += [global_mean_pool(x, batch)]\n            if i % 2 == 0 and i < len(self.convs) - 1:\n                pool = self.pools[i // 2]\n                x, edge_index, _, batch, _, _ = pool(x, edge_index,\n                                                     batch=batch)\n        x = self.jump(xs)\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n    def __repr__(self):\n        return self.__class__.__name__\n"
  },
  {
    "path": "benchmark/kernel/train_eval.py",
    "content": "import time\n\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.model_selection import StratifiedKFold\nfrom torch import tensor\nfrom torch.optim import Adam\n\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.loader import DenseDataLoader as DenseLoader\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\n\ndef cross_validation_with_val_set(dataset, model, folds, epochs, batch_size,\n                                  lr, lr_decay_factor, lr_decay_step_size,\n                                  weight_decay, logger=None):\n\n    val_losses, accs, durations = [], [], []\n    for fold, (train_idx, test_idx,\n               val_idx) in enumerate(zip(*k_fold(dataset, folds))):\n\n        train_dataset = dataset[train_idx]\n        test_dataset = dataset[test_idx]\n        val_dataset = dataset[val_idx]\n\n        if 'adj' in train_dataset[0]:\n            train_loader = DenseLoader(train_dataset, batch_size, shuffle=True)\n            val_loader = DenseLoader(val_dataset, batch_size, shuffle=False)\n            test_loader = DenseLoader(test_dataset, batch_size, shuffle=False)\n        else:\n            train_loader = DataLoader(train_dataset, batch_size, shuffle=True)\n            val_loader = DataLoader(val_dataset, batch_size, shuffle=False)\n            test_loader = DataLoader(test_dataset, batch_size, shuffle=False)\n\n        model.to(device).reset_parameters()\n        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        elif hasattr(torch.backends,\n                     'mps') and torch.backends.mps.is_available():\n            try:\n                torch.mps.synchronize()\n            except ImportError:\n                pass\n\n        t_start = time.perf_counter()\n\n        for epoch in range(1, epochs + 1):\n            train_loss = train(model, optimizer, train_loader)\n            val_losses.append(eval_loss(model, val_loader))\n            accs.append(eval_acc(model, test_loader))\n            eval_info = {\n                'fold': fold,\n                'epoch': epoch,\n                'train_loss': train_loss,\n                'val_loss': val_losses[-1],\n                'test_acc': accs[-1],\n            }\n\n            if logger is not None:\n                logger(eval_info)\n\n            if epoch % lr_decay_step_size == 0:\n                for param_group in optimizer.param_groups:\n                    param_group['lr'] = lr_decay_factor * param_group['lr']\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        elif hasattr(torch.backends,\n                     'mps') and torch.backends.mps.is_available():\n            torch.mps.synchronize()\n\n        t_end = time.perf_counter()\n        durations.append(t_end - t_start)\n\n    loss, acc, duration = tensor(val_losses), tensor(accs), tensor(durations)\n    loss, acc = loss.view(folds, epochs), acc.view(folds, epochs)\n    loss, argmin = loss.min(dim=1)\n    acc = acc[torch.arange(folds, dtype=torch.long), argmin]\n\n    loss_mean = loss.mean().item()\n    acc_mean = acc.mean().item()\n    acc_std = acc.std().item()\n    duration_mean = duration.mean().item()\n    print(f'Val Loss: {loss_mean:.4f}, Test Accuracy: {acc_mean:.3f} '\n          f'± {acc_std:.3f}, Duration: {duration_mean:.3f}')\n\n    return loss_mean, acc_mean, acc_std\n\n\ndef k_fold(dataset, folds):\n    skf = StratifiedKFold(folds, shuffle=True, random_state=12345)\n\n    test_indices, train_indices = [], []\n    for _, idx in skf.split(torch.zeros(len(dataset)), dataset.data.y):\n        test_indices.append(torch.from_numpy(idx).to(torch.long))\n\n    val_indices = [test_indices[i - 1] for i in range(folds)]\n\n    for i in range(folds):\n        train_mask = torch.ones(len(dataset), dtype=torch.bool)\n        train_mask[test_indices[i]] = 0\n        train_mask[val_indices[i]] = 0\n        train_indices.append(train_mask.nonzero(as_tuple=False).view(-1))\n\n    return train_indices, test_indices, val_indices\n\n\ndef num_graphs(data):\n    if hasattr(data, 'num_graphs'):\n        return data.num_graphs\n    else:\n        return data.x.size(0)\n\n\ndef train(model, optimizer, loader):\n    model.train()\n\n    total_loss = 0\n    for data in loader:\n        optimizer.zero_grad()\n        data = data.to(device)\n        out = model(data)\n        loss = F.nll_loss(out, data.y.view(-1))\n        loss.backward()\n        total_loss += loss.item() * num_graphs(data)\n        optimizer.step()\n    return total_loss / len(loader.dataset)\n\n\ndef eval_acc(model, loader):\n    model.eval()\n\n    correct = 0\n    for data in loader:\n        data = data.to(device)\n        with torch.no_grad():\n            pred = model(data).max(1)[1]\n        correct += pred.eq(data.y.view(-1)).sum().item()\n    return correct / len(loader.dataset)\n\n\ndef eval_loss(model, loader):\n    model.eval()\n\n    loss = 0\n    for data in loader:\n        data = data.to(device)\n        with torch.no_grad():\n            out = model(data)\n        loss += F.nll_loss(out, data.y.view(-1), reduction='sum').item()\n    return loss / len(loader.dataset)\n\n\n@torch.no_grad()\ndef inference_run(model, loader, bf16):\n    model.eval()\n    for data in loader:\n        data = data.to(device)\n        if bf16:\n            data.x = data.x.to(torch.bfloat16)\n        model(data)\n"
  },
  {
    "path": "benchmark/loader/neighbor_loader.py",
    "content": "import argparse\nimport ast\nimport os.path as osp\nfrom contextlib import nullcontext\nfrom timeit import default_timer\n\nimport tqdm\nfrom ogb.nodeproppred import PygNodePropPredDataset\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import OGB_MAG\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.profile import torch_profile\n\n\ndef run(args: argparse.ArgumentParser):\n    for dataset_name in args.datasets:\n        print(f\"Dataset: {dataset_name}\")\n        root = osp.join(args.root, dataset_name)\n        transform = T.ToSparseTensor(\n            remove_edge_index=False) if args.use_sparse_tensor else None\n        if dataset_name == 'mag':\n            transform = (T.ToUndirected(merge=True) if transform is None else\n                         T.Compose([T.ToUndirected(merge=True), transform]))\n            dataset = OGB_MAG(root=root, transform=transform)\n            train_idx = ('paper', dataset[0]['paper'].train_mask)\n            eval_idx = ('paper', None)\n            neighbor_sizes = (args.hetero_neighbor_sizes\n                              if args.hetero_neighbor_sizes else None)\n        else:\n            dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root)\n            split_idx = dataset.get_idx_split()\n            train_idx = split_idx['train']\n            eval_idx = None\n            neighbor_sizes = (args.homo_neighbor_sizes\n                              if args.homo_neighbor_sizes else None)\n\n        data = dataset[0].to(args.device)\n        average_times = []\n        profile = torch_profile() if args.profile else nullcontext()\n        # run dataloader iteration\n        if neighbor_sizes is not None:\n            for num_neighbors in neighbor_sizes:\n                print(f'Training sampling with {num_neighbors} neighbors')\n                for batch_size in args.batch_sizes:\n                    train_loader = NeighborLoader(\n                        data,\n                        num_neighbors=num_neighbors,\n                        input_nodes=train_idx,\n                        batch_size=batch_size,\n                        shuffle=True,\n                        num_workers=args.num_workers,\n                        subgraph_type=args.subgraph_type,\n                    )\n                    cpu_affinity = train_loader.enable_cpu_affinity(\n                        args.loader_cores\n                    ) if args.cpu_affinity else nullcontext()\n                    runtimes = []\n                    num_iterations = 0\n                    with profile, cpu_affinity:\n                        for _ in range(args.runs):\n                            start = default_timer()\n                            for _ in tqdm.tqdm(train_loader):\n                                num_iterations += 1\n                            stop = default_timer()\n                            runtimes.append(round(stop - start, 3))\n                        average_time = round(sum(runtimes) / args.runs, 3)\n                        print(f'batch size={batch_size}, '\n                              f'iterations={num_iterations}, '\n                              f'runtimes={runtimes}, '\n                              f'average runtime={average_time}')\n                        average_times.append(average_time)\n        eval_batch_sizes = (args.eval_batch_sizes\n                            if args.eval_batch_sizes else None)\n        if eval_batch_sizes is not None:\n            print('Evaluation sampling with all neighbors')\n            for batch_size in eval_batch_sizes:\n                subgraph_loader = NeighborLoader(\n                    data,\n                    num_neighbors=[-1],\n                    input_nodes=eval_idx,\n                    batch_size=batch_size,\n                    shuffle=False,\n                    num_workers=args.num_workers,\n                )\n                cpu_affinity = subgraph_loader.enable_cpu_affinity(\n                    args.loader_cores) if args.cpu_affinity else nullcontext()\n                runtimes = []\n                num_iterations = 0\n                with profile, cpu_affinity:\n                    for _ in range(args.runs):\n                        start = default_timer()\n                        for _ in tqdm.tqdm(subgraph_loader):\n                            num_iterations += 1\n                        stop = default_timer()\n                        runtimes.append(round(stop - start, 3))\n                    average_time = round(sum(runtimes) / args.runs, 3)\n                    print(f'batch size={batch_size}, '\n                          f'iterations={num_iterations}, '\n                          f'runtimes={runtimes}, '\n                          f'average runtime={average_time}')\n                    average_times.append(average_time)\n        print(f\"Total time averages: {average_times}\")\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser('NeighborLoader Sampling Benchmarking')\n    add = parser.add_argument\n\n    add('--device', default='cpu')\n    add('--datasets', nargs=\"+\", default=['arxiv', 'products', 'mag'])\n    add('--root', default='../../data')\n    add('--batch-sizes', default=[8192, 4096, 2048, 1024, 512],\n        type=ast.literal_eval)\n    add('--eval-batch-sizes', default=[16384, 8192, 4096, 2048, 1024, 512],\n        type=ast.literal_eval)\n    add('--homo-neighbor_sizes', default=[[10, 5], [15, 10, 5], [20, 15, 10]],\n        type=ast.literal_eval)\n    add('--hetero-neighbor_sizes', default=[[5], [10], [10, 5]],\n        type=ast.literal_eval)\n    add('--use-sparse-tensor', action='store_true',\n        help='use torch_sparse.SparseTensor as graph storage format')\n    add('--num-workers', type=int, default=0,\n        help=\"Number of DataLoader workers to use.\")\n    add('--runs', type=int, default=3,\n        help=\"Number of iterations for each test setting.\")\n    add('--profile', default=False, action='store_true',\n        help=\"Run torch.profiler.\")\n    add('--cpu-affinity', default=False, action='store_true',\n        help=\"Use DataLoader affinitzation.\")\n    add('--loader-cores', nargs='+', default=[], type=int,\n        help=\"List of CPU core IDs to use for DataLoader workers.\")\n    add('--subgraph-type', type=str, default='directional',\n        help=\"The type of the returned subgraph (directional, bidirectional)\")\n    run(parser.parse_args())\n"
  },
  {
    "path": "benchmark/multi_gpu/training/README.md",
    "content": "# Training Benchmark\n\n## Running benchmark on CUDA GPU\n\nRun benchmark, e.g. assuming you have `n` NVIDIA GPUs:\n\n```\npython training_benchmark_cuda.py --dataset ogbn-products --model edge_cnn --num-epochs 3 --n_gpus <n>\n```\n\n## Running benchmark on Intel GPU\n\n### Environment setup\n\n### Prerequisites\n\n- Intel Data Center GPU Max Series. You could try it through [Intel DevCloud](https://www.intel.com/content/www/us/en/developer/tools/devcloud/services.html).\n- Verify the Intel GPU Driver is installed, refer to the [guide](https://dgpu-docs.intel.com/driver/installation.html).\n\n### docker setup\n\nIf you want to run your scripts inside a docker image, you could refer to the [dockerfile](https://github.com/pyg-team/pytorch_geometric/blob/master/docker/Dockerfile.xpu) and the corresponding [guide](https://github.com/pyg-team/pytorch_geometric/blob/master/docker).\n\n### bare-metal setup\n\nIf you prefer to run your scripts directly on the bare-metal server. We recommend the installation guidance provided by [Intel® Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu&version=v2.1.30%2bxpu&os=linux%2fwsl2&package=pip). The following are some key steps:\n\n- Install [Intel® oneAPI Base Toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html), indluding [Intel® oneAPI DPC++ Compiler](https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compiler.html), [Intel® oneAPI Math Kernel Library (oneMKL)](https://www.intel.com/content/www/us/en/docs/oneapi/programming-guide/2024-1/intel-oneapi-math-kernel-library-onemkl.html), [Intel® oneAPI Collective Communications Library (oneCCL)](https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html), and [Intel® oneCCL Bindings for PyTorch](https://github.com/intel/torch-ccl).\n\n```bash\n# Install oneCCL package on Ubuntu\nsudo apt install -y intel-oneapi-dpcpp-cpp-2024.1=2024.1.0-963 intel-oneapi-mkl-devel=2024.1.0-691 intel-oneapi-ccl-devel=2021.12.0-309\n# Install oneccl_bindings_for_pytorch\npip install oneccl_bind_pt==2.1.300+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/\n# Runtime Dynamic Linking\nsource /opt/intel/oneapi/setvars.sh\n```\n\n- Install [Intel® Extension for PyTorch](https://github.com/intel/intel-extension-for-pytorch) and the corresponding version of PyTorch\n\n```bash\npip install torch==2.1.0.post2 intel-extension-for-pytorch==2.1.30+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/\n```\n\n### Running benchmark\n\nThis [guide](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/features/DDP.html) is helpful for you to launch DDP training on intel GPU.\n\nTo Run benchmark, e.g. assuming you have `n` XPUs:\n\n```\nmpirun -np <n> python training_benchmark_xpu.py --dataset ogbn-products --model edge_cnn --num-epochs 3\n```\n"
  },
  {
    "path": "benchmark/multi_gpu/training/common.py",
    "content": "import argparse\nimport ast\nfrom time import perf_counter\nfrom typing import Any, Callable, Tuple, Union\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom torch.nn.parallel import DistributedDataParallel as DDP\n\nfrom benchmark.utils import get_model, get_split_masks, test\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import PNAConv\n\nsupported_sets = {\n    'ogbn-mag': ['rgat', 'rgcn'],\n    'ogbn-products': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'],\n    'Reddit': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'],\n}\n\ndevice_conditions = {\n    'xpu': (lambda: torch.xpu.is_available()),\n    'cuda': (lambda: torch.cuda.is_available()),\n}\n\n\ndef train_homo(model: Any, loader: NeighborLoader, optimizer: torch.optim.Adam,\n               device: torch.device) -> torch.Tensor:\n    for batch in loader:\n        optimizer.zero_grad()\n        batch = batch.to(device)\n        out = model(batch.x, batch.edge_index)\n        batch_size = batch.batch_size\n        out = out[:batch_size]\n        target = batch.y[:batch_size]\n        loss = F.cross_entropy(out, target)\n        loss.backward()\n        optimizer.step()\n\n    return loss\n\n\ndef train_hetero(model: Any, loader: NeighborLoader,\n                 optimizer: torch.optim.Adam,\n                 device: torch.device) -> torch.Tensor:\n    for batch in loader:\n        optimizer.zero_grad()\n        batch = batch.to(device)\n        out = model(batch.x_dict, batch.edge_index_dict)\n        batch_size = batch['paper'].batch_size\n        out = out['paper'][:batch_size]\n        target = batch['paper'].y[:batch_size]\n        loss = F.cross_entropy(out, target)\n        loss.backward()\n        optimizer.step()\n\n    return loss\n\n\ndef maybe_synchronize(device: str):\n    if device == 'xpu' and torch.xpu.is_available():\n        torch.xpu.synchronize()\n    if device == 'cuda' and torch.cuda.is_available():\n        torch.cuda.synchronize()\n\n\ndef create_mask_per_rank(\n        global_mask: Union[torch.Tensor,\n                           Tuple[str,\n                                 torch.Tensor]], rank: int, world_size: int,\n        hetero: bool = False) -> Union[torch.Tensor, Tuple[str, torch.Tensor]]:\n    mask = global_mask[-1] if hetero else global_mask\n    nonzero = mask.nonzero().reshape(-1)\n    rank_indices = nonzero.split(nonzero.size(0) // world_size,\n                                 dim=0)[rank].clone()\n    mask_per_rank = torch.full_like(mask, False)\n    mask_per_rank[rank_indices] = True\n\n    if hetero:\n        return tuple((global_mask[0], mask_per_rank))\n    else:\n        return mask_per_rank\n\n\ndef run(rank: int, world_size: int, args: argparse.ArgumentParser,\n        num_classes: int, data: Union[Data, HeteroData],\n        custom_optimizer: Callable[[Any, Any], Tuple[Any, Any]] = None):\n    if not device_conditions[args.device]():\n        raise RuntimeError(f'{args.device.upper()} is not available')\n\n    device = torch.device(f'{args.device}:{rank}')\n\n    if rank == 0:\n        print('BENCHMARK STARTS')\n        print(f'Running on {args.device.upper()}')\n        print(f'Dataset: {args.dataset}')\n\n    hetero = True if args.dataset == 'ogbn-mag' else False\n    mask, val_mask, test_mask = get_split_masks(data, args.dataset)\n    mask = create_mask_per_rank(mask, rank, world_size, hetero)\n    degree = None\n\n    inputs_channels = data[\n        'paper'].num_features if args.dataset == 'ogbn-mag' \\\n        else data.num_features\n\n    if args.model not in supported_sets[args.dataset]:\n        err_msg = (f'Configuration of {args.dataset} + {args.model}'\n                   'not supported')\n        raise RuntimeError(err_msg)\n    if rank == 0:\n        print(f'Training bench for {args.model}:')\n\n    num_nodes = int(mask[-1].sum()) if hetero else int(mask.sum())\n    num_neighbors = args.num_neighbors\n\n    if type(num_neighbors) is list:\n        if len(num_neighbors) == 1:\n            num_neighbors = num_neighbors * args.num_layers\n    elif type(num_neighbors) is int:\n        num_neighbors = [num_neighbors] * args.num_layers\n\n    if len(num_neighbors) != args.num_layers:\n        err_msg = (f'num_neighbors={num_neighbors} length != num of'\n                   'layers={args.num_layers}')\n\n    kwargs = {\n        'num_neighbors': num_neighbors,\n        'batch_size': args.batch_size,\n        'num_workers': args.num_workers,\n    }\n    subgraph_loader = NeighborLoader(\n        data,\n        input_nodes=mask,\n        sampler=None,\n        **kwargs,\n    )\n    if rank == 0 and args.evaluate:\n        val_loader = NeighborLoader(\n            data,\n            input_nodes=val_mask,\n            sampler=None,\n            **kwargs,\n        )\n        test_loader = NeighborLoader(\n            data,\n            input_nodes=test_mask,\n            sampler=None,\n            **kwargs,\n        )\n\n    if rank == 0:\n        print('----------------------------------------------')\n        print(\n            f'Batch size={args.batch_size}, '\n            f'Layers amount={args.num_layers}, '\n            f'Num_neighbors={num_neighbors}, '\n            f'Hidden features size={args.num_hidden_channels}', flush=True)\n\n    params = {\n        'inputs_channels': inputs_channels,\n        'hidden_channels': args.num_hidden_channels,\n        'output_channels': num_classes,\n        'num_heads': args.num_heads,\n        'num_layers': args.num_layers,\n    }\n\n    if args.model == 'pna' and degree is None:\n        degree = PNAConv.get_degree_histogram(subgraph_loader)\n        print(f'Rank: {rank}, calculated degree for {args.dataset}.',\n              flush=True)\n        params['degree'] = degree\n    dist.barrier()\n\n    torch.manual_seed(12345)\n    model = get_model(args.model, params,\n                      metadata=data.metadata() if hetero else None)\n    model = model.to(device)\n\n    if hetero:\n        model.eval()\n        x_keys = data.metadata()[0]\n        edge_index_keys = data.metadata()[1]\n        fake_x_dict = {\n            k: torch.rand((32, inputs_channels), device=device)\n            for k in x_keys\n        }\n        fake_edge_index_dict = {\n            k: torch.randint(0, 32, (2, 8), device=device)\n            for k in edge_index_keys\n        }\n        model.forward(fake_x_dict, fake_edge_index_dict)\n\n    model = DDP(model, device_ids=[device], find_unused_parameters=hetero)\n    model.train()\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n    if custom_optimizer:\n        model, optimizer = custom_optimizer(model, optimizer)\n\n    train = train_hetero if hetero else train_homo\n\n    maybe_synchronize(args.device)\n    dist.barrier()\n    if rank == 0:\n        beg = perf_counter()\n\n    for epoch in range(args.num_epochs):\n        loss = train(\n            model,\n            subgraph_loader,\n            optimizer,\n            device,\n        )\n\n        dist.barrier()\n\n        if rank == 0:\n            print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}', flush=True)\n\n        if rank == 0 and args.evaluate:\n            # In evaluate, throughput and\n            # latency are not accurate.\n            val_acc = test(model, val_loader, device, hetero,\n                           progress_bar=False)\n            print(f'Val Accuracy: {val_acc:.4f}')\n\n        dist.barrier()\n\n    maybe_synchronize(args.device)\n    dist.barrier()\n    if rank == 0:\n        end = perf_counter()\n        duration = end - beg\n\n    if rank == 0 and args.evaluate:\n        test_acc = test(model, test_loader, device, hetero, progress_bar=False)\n        print(f'Test Accuracy: {test_acc:.4f}')\n\n    dist.barrier()\n\n    if rank == 0:\n        num_nodes_total = num_nodes * world_size\n        duration_per_epoch = duration / args.num_epochs\n        throughput = num_nodes_total / duration_per_epoch\n        latency = duration_per_epoch / num_nodes_total * 1000\n        print(f'Time: {duration_per_epoch:.4f}s')\n        print(f'Throughput: {throughput:.3f} samples/s')\n        print(f'Latency: {latency:.3f} ms', flush=True)\n\n    dist.destroy_process_group()\n\n\ndef get_predefined_args() -> argparse.ArgumentParser:\n    argparser = argparse.ArgumentParser(\n        'GNN distributed (DDP) training benchmark')\n    add = argparser.add_argument\n\n    add('--dataset', choices=['ogbn-mag', 'ogbn-products', 'Reddit'],\n        default='Reddit', type=str)\n    add('--model',\n        choices=['edge_cnn', 'gat', 'gcn', 'pna', 'rgat', 'rgcn',\n                 'sage'], default='sage', type=str)\n    add('--root', default='../../data', type=str,\n        help='relative path to look for the datasets')\n    add('--batch-size', default=4096, type=int)\n    add('--num-layers', default=3, type=int)\n    add('--num-hidden-channels', default=128, type=int)\n    add('--num-heads', default=2, type=int,\n        help='number of hidden attention heads, applies only for gat and rgat')\n    add('--num-neighbors', default=[10], type=ast.literal_eval,\n        help='number of neighbors to sample per layer')\n    add('--num-workers', default=0, type=int)\n    add('--num-epochs', default=1, type=int)\n    add('--evaluate', action='store_true')\n\n    return argparser\n"
  },
  {
    "path": "benchmark/multi_gpu/training/training_benchmark_cuda.py",
    "content": "import argparse\nimport os\nfrom typing import Union\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\n\nfrom benchmark.multi_gpu.training.common import (\n    get_predefined_args,\n    run,\n    supported_sets,\n)\nfrom benchmark.utils import get_dataset\nfrom torch_geometric.data import Data, HeteroData\n\n\ndef run_cuda(rank: int, world_size: int, args: argparse.ArgumentParser,\n             num_classes: int, data: Union[Data, HeteroData]):\n    os.environ['MASTER_ADDR'] = 'localhost'\n    os.environ['MASTER_PORT'] = '12355'\n    dist.init_process_group('nccl', rank=rank, world_size=world_size)\n    run(rank, world_size, args, num_classes, data)\n\n\nif __name__ == '__main__':\n    argparser = get_predefined_args()\n    argparser.add_argument('--n-gpus', default=1, type=int)\n    args = argparser.parse_args()\n    args.device = 'cuda'\n\n    assert args.dataset in supported_sets.keys(), \\\n        f\"Dataset {args.dataset} isn't supported.\"\n    data, num_classes = get_dataset(args.dataset, args.root)\n\n    max_world_size = torch.cuda.device_count()\n    chosen_world_size = args.n_gpus\n    if chosen_world_size <= max_world_size:\n        world_size = chosen_world_size\n    else:\n        print(f'User selected {chosen_world_size} GPUs '\n              f'but only {max_world_size} GPUs are available')\n        world_size = max_world_size\n    print(f'Let\\'s use {world_size} GPUs!')\n\n    mp.spawn(\n        run_cuda,\n        args=(world_size, args, num_classes, data),\n        nprocs=world_size,\n        join=True,\n    )\n"
  },
  {
    "path": "benchmark/multi_gpu/training/training_benchmark_xpu.py",
    "content": "import os\nfrom typing import Any, Tuple\n\nimport intel_extension_for_pytorch as ipex\nimport oneccl_bindings_for_pytorch  # noqa\nimport torch.distributed as dist\n\nfrom benchmark.multi_gpu.training.common import (\n    get_predefined_args,\n    run,\n    supported_sets,\n)\nfrom benchmark.utils import get_dataset\n\n\ndef get_dist_params() -> Tuple[int, int, str]:\n    master_addr = \"127.0.0.1\"\n    master_port = \"29500\"\n    os.environ[\"MASTER_ADDR\"] = master_addr\n    os.environ[\"MASTER_PORT\"] = master_port\n\n    mpi_rank = int(os.environ.get(\"PMI_RANK\", -1))\n    mpi_world_size = int(os.environ.get(\"PMI_SIZE\", -1))\n    rank = mpi_rank if mpi_world_size > 0 else os.environ.get(\"RANK\", 0)\n    world_size = (mpi_world_size if mpi_world_size > 0 else os.environ.get(\n        \"WORLD_SIZE\", 1))\n\n    os.environ[\"RANK\"] = str(rank)\n    os.environ[\"WORLD_SIZE\"] = str(world_size)\n\n    init_method = f\"tcp://{master_addr}:{master_port}\"\n\n    return rank, world_size, init_method\n\n\ndef custom_optimizer(model: Any, optimizer: Any) -> Tuple[Any, Any]:\n    return ipex.optimize(model, optimizer=optimizer)\n\n\nif __name__ == '__main__':\n    rank, world_size, init_method = get_dist_params()\n    dist.init_process_group(backend=\"ccl\", init_method=init_method,\n                            world_size=world_size, rank=rank)\n\n    argparser = get_predefined_args()\n    args = argparser.parse_args()\n    args.device = 'xpu'\n\n    assert args.dataset in supported_sets.keys(), \\\n        f\"Dataset {args.dataset} isn't supported.\"\n\n    # if the dataset is not present, it will be downloaded\n    # only by process with rank=0,\n    # other process will use the dataset cache from rank=0,\n    # and will not re-download and process it\n    if rank == 0:\n        data, num_classes = get_dataset(args.dataset, args.root)\n    dist.barrier()\n    if rank != 0:\n        data, num_classes = get_dataset(args.dataset, args.root)\n\n    run(rank, world_size, args, num_classes, data, custom_optimizer)\n"
  },
  {
    "path": "benchmark/points/README.md",
    "content": "# Point Cloud classification\n\nEvaluation scripts for various methods on the ModelNet10 dataset:\n\n- **[MPNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/mpnn.py)**: `python mpnn.py`\n- **[PointNet++](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/point_net.py)**: `python point_net.py`\n- **[EdgeCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/edge_cnn.py)**: `python edge_cnn.py`\n- **[SplineCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/spline_cnn.py)**: `python spline_cnn.py`\n- **[PointCNN](https://github.com/pyg-team/pytorch_geometric/blob/master/benchmark/points/point_cnn.py)**: `python point_cnn.py`\n"
  },
  {
    "path": "benchmark/points/__init__.py",
    "content": "from .datasets import get_dataset\nfrom .train_eval import run\n\n__all__ = [\n    'get_dataset',\n    'run',\n]\n"
  },
  {
    "path": "benchmark/points/datasets.py",
    "content": "import os.path as osp\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import ModelNet\n\n\ndef get_dataset(num_points):\n    name = 'ModelNet10'\n    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)\n    pre_transform = T.NormalizeScale()\n    transform = T.SamplePoints(num_points)\n\n    train_dataset = ModelNet(path, name='10', train=True, transform=transform,\n                             pre_transform=pre_transform)\n    test_dataset = ModelNet(path, name='10', train=False, transform=transform,\n                            pre_transform=pre_transform)\n\n    return train_dataset, test_dataset\n"
  },
  {
    "path": "benchmark/points/edge_cnn.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom points.datasets import get_dataset\nfrom points.train_eval import run\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nfrom torch_geometric.nn import DynamicEdgeConv, global_max_pool\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--batch_size', type=int, default=8)\nparser.add_argument('--lr', type=float, default=0.001)\nparser.add_argument('--lr_decay_factor', type=float, default=0.5)\nparser.add_argument('--lr_decay_step_size', type=int, default=50)\nparser.add_argument('--weight_decay', type=float, default=0)\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, num_classes):\n        super().__init__()\n\n        nn = Seq(Lin(6, 64), ReLU(), Lin(64, 64), ReLU(), Lin(64, 64), ReLU())\n        self.conv1 = DynamicEdgeConv(nn, k=20, aggr='max')\n\n        nn = Seq(Lin(128, 128), ReLU(), Lin(128, 128), ReLU(), Lin(128, 256),\n                 ReLU())\n        self.conv2 = DynamicEdgeConv(nn, k=20, aggr='max')\n\n        self.lin0 = Lin(256, 512)\n\n        self.lin1 = Lin(512, 256)\n        self.lin2 = Lin(256, 256)\n        self.lin3 = Lin(256, num_classes)\n\n    def forward(self, pos, batch):\n        x = self.conv1(pos, batch)\n        x = self.conv2(x, batch)\n\n        x = F.relu(self.lin0(x))\n\n        x = global_max_pool(x, batch)\n\n        x = F.relu(self.lin1(x))\n        x = F.relu(self.lin2(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin3(x)\n        return F.log_softmax(x, dim=-1)\n\n\ntrain_dataset, test_dataset = get_dataset(num_points=1024)\nmodel = Net(train_dataset.num_classes)\nrun(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr,\n    args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay,\n    args.inference, args.profile, args.bf16, args.compile)\n\nif args.profile:\n    rename_profile_file('points', DynamicEdgeConv.__name__)\n"
  },
  {
    "path": "benchmark/points/mpnn.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom points.datasets import get_dataset\nfrom points.train_eval import run\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nfrom torch_geometric.nn import NNConv, fps, global_mean_pool, radius_graph\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--batch_size', type=int, default=8)\nparser.add_argument('--lr', type=float, default=0.001)\nparser.add_argument('--lr_decay_factor', type=float, default=0.5)\nparser.add_argument('--lr_decay_step_size', type=int, default=50)\nparser.add_argument('--weight_decay', type=float, default=0)\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, num_classes):\n        super().__init__()\n\n        nn = Seq(Lin(3, 25), ReLU(), Lin(25, 1 * 64))\n        self.conv1 = NNConv(1, 64, nn, aggr='mean')\n\n        nn = Seq(Lin(3, 25), ReLU(), Lin(25, 64 * 64))\n        self.conv2 = NNConv(64, 64, nn, aggr='mean')\n\n        nn = Seq(Lin(3, 25), ReLU(), Lin(25, 64 * 128))\n        self.conv3 = NNConv(64, 128, nn, aggr='mean')\n\n        self.lin1 = torch.nn.Linear(128, 256)\n        self.lin2 = torch.nn.Linear(256, 256)\n        self.lin3 = torch.nn.Linear(256, num_classes)\n\n    def forward(self, pos, batch):\n        x = pos.new_ones((pos.size(0), 1))\n\n        radius = 0.2\n        edge_index = radius_graph(pos, r=radius, batch=batch)\n        pseudo = pos[edge_index[1]] - pos[edge_index[0]]\n        x = F.relu(self.conv1(x, edge_index, pseudo))\n\n        idx = fps(pos, batch, ratio=0.5)\n        x, pos, batch = x[idx], pos[idx], batch[idx]\n\n        radius = 0.4\n        edge_index = radius_graph(pos, r=radius, batch=batch)\n        pseudo = pos[edge_index[1]] - pos[edge_index[0]]\n        x = F.relu(self.conv2(x, edge_index, pseudo))\n\n        idx = fps(pos, batch, ratio=0.25)\n        x, pos, batch = x[idx], pos[idx], batch[idx]\n\n        radius = 1\n        edge_index = radius_graph(pos, r=radius, batch=batch)\n        pseudo = pos[edge_index[1]] - pos[edge_index[0]]\n        x = F.relu(self.conv3(x, edge_index, pseudo))\n\n        x = global_mean_pool(x, batch)\n        x = F.relu(self.lin1(x))\n        x = F.relu(self.lin2(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin3(x)\n        return F.log_softmax(x, dim=-1)\n\n\ntrain_dataset, test_dataset = get_dataset(num_points=1024)\nmodel = Net(train_dataset.num_classes)\nrun(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr,\n    args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay,\n    args.inference, args.profile, args.bf16, args.compile)\n\nif args.profile:\n    rename_profile_file('points', NNConv.__name__)\n"
  },
  {
    "path": "benchmark/points/point_cnn.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom points.datasets import get_dataset\nfrom points.train_eval import run\nfrom torch.nn import Linear as Lin\n\nfrom torch_geometric.nn import XConv, fps, global_mean_pool\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--batch_size', type=int, default=32)\nparser.add_argument('--lr', type=float, default=0.001)\nparser.add_argument('--lr_decay_factor', type=float, default=0.5)\nparser.add_argument('--lr_decay_step_size', type=int, default=50)\nparser.add_argument('--weight_decay', type=float, default=0)\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, num_classes):\n        super().__init__()\n\n        self.conv1 = XConv(0, 48, dim=3, kernel_size=8, hidden_channels=32)\n        self.conv2 = XConv(48, 96, dim=3, kernel_size=12, hidden_channels=64,\n                           dilation=2)\n        self.conv3 = XConv(96, 192, dim=3, kernel_size=16, hidden_channels=128,\n                           dilation=2)\n        self.conv4 = XConv(192, 384, dim=3, kernel_size=16,\n                           hidden_channels=256, dilation=2)\n\n        self.lin1 = Lin(384, 256)\n        self.lin2 = Lin(256, 128)\n        self.lin3 = Lin(128, num_classes)\n\n    def forward(self, pos, batch):\n        x = F.relu(self.conv1(None, pos, batch))\n\n        idx = fps(pos, batch, ratio=0.375)\n        x, pos, batch = x[idx], pos[idx], batch[idx]\n\n        x = F.relu(self.conv2(x, pos, batch))\n\n        idx = fps(pos, batch, ratio=0.334)\n        x, pos, batch = x[idx], pos[idx], batch[idx]\n\n        x = F.relu(self.conv3(x, pos, batch))\n        x = F.relu(self.conv4(x, pos, batch))\n\n        x = global_mean_pool(x, batch)\n\n        x = F.relu(self.lin1(x))\n        x = F.relu(self.lin2(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin3(x)\n        return F.log_softmax(x, dim=-1)\n\n\ntrain_dataset, test_dataset = get_dataset(num_points=1024)\nmodel = Net(train_dataset.num_classes)\nrun(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr,\n    args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay,\n    args.inference, args.profile, args.bf16, args.compile)\n\nif args.profile:\n    rename_profile_file('points', XConv.__name__)\n"
  },
  {
    "path": "benchmark/points/point_net.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom points.datasets import get_dataset\nfrom points.train_eval import run\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nfrom torch_geometric.nn import PointNetConv, fps, global_max_pool, radius_graph\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--batch_size', type=int, default=8)\nparser.add_argument('--lr', type=float, default=0.001)\nparser.add_argument('--lr_decay_factor', type=float, default=0.5)\nparser.add_argument('--lr_decay_step_size', type=int, default=50)\nparser.add_argument('--weight_decay', type=float, default=0)\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, num_classes):\n        super().__init__()\n\n        nn = Seq(Lin(3, 64), ReLU(), Lin(64, 64))\n        self.conv1 = PointNetConv(local_nn=nn)\n\n        nn = Seq(Lin(67, 128), ReLU(), Lin(128, 128))\n        self.conv2 = PointNetConv(local_nn=nn)\n\n        nn = Seq(Lin(131, 256), ReLU(), Lin(256, 256))\n        self.conv3 = PointNetConv(local_nn=nn)\n\n        self.lin1 = Lin(256, 256)\n        self.lin2 = Lin(256, 256)\n        self.lin3 = Lin(256, num_classes)\n\n    def forward(self, pos, batch):\n        radius = 0.2\n        edge_index = radius_graph(pos, r=radius, batch=batch)\n        x = F.relu(self.conv1(None, pos, edge_index))\n\n        idx = fps(pos, batch, ratio=0.5)\n        x, pos, batch = x[idx], pos[idx], batch[idx]\n\n        radius = 0.4\n        edge_index = radius_graph(pos, r=radius, batch=batch)\n        x = F.relu(self.conv2(x, pos, edge_index))\n\n        idx = fps(pos, batch, ratio=0.25)\n        x, pos, batch = x[idx], pos[idx], batch[idx]\n\n        radius = 1\n        edge_index = radius_graph(pos, r=radius, batch=batch)\n        x = F.relu(self.conv3(x, pos, edge_index))\n\n        x = global_max_pool(x, batch)\n\n        x = F.relu(self.lin1(x))\n        x = F.relu(self.lin2(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin3(x)\n        return F.log_softmax(x, dim=-1)\n\n\ntrain_dataset, test_dataset = get_dataset(num_points=1024)\nmodel = Net(train_dataset.num_classes)\nrun(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr,\n    args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay,\n    args.inference, args.profile, args.bf16, args.compile)\n\nif args.profile:\n    rename_profile_file('points', PointNetConv.__name__)\n"
  },
  {
    "path": "benchmark/points/spline_cnn.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom points.datasets import get_dataset\nfrom points.train_eval import run\nfrom torch.nn import Linear as Lin\n\nfrom torch_geometric.nn import SplineConv, fps, global_mean_pool, radius_graph\nfrom torch_geometric.profile import rename_profile_file\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--batch_size', type=int, default=8)\nparser.add_argument('--lr', type=float, default=0.001)\nparser.add_argument('--lr_decay_factor', type=float, default=0.5)\nparser.add_argument('--lr_decay_step_size', type=int, default=50)\nparser.add_argument('--weight_decay', type=float, default=0)\nparser.add_argument('--inference', action='store_true')\nparser.add_argument('--profile', action='store_true')\nparser.add_argument('--bf16', action='store_true')\nparser.add_argument('--compile', action='store_true')\nargs = parser.parse_args()\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, num_classes):\n        super().__init__()\n\n        self.conv1 = SplineConv(1, 64, dim=3, kernel_size=5)\n        self.conv2 = SplineConv(64, 64, dim=3, kernel_size=5)\n        self.conv3 = SplineConv(64, 128, dim=3, kernel_size=5)\n\n        self.lin1 = Lin(128, 256)\n        self.lin2 = Lin(256, 256)\n        self.lin3 = Lin(256, num_classes)\n\n    def forward(self, pos, batch):\n        x = pos.new_ones((pos.size(0), 1))\n\n        radius = 0.2\n        edge_index = radius_graph(pos, r=radius, batch=batch)\n        pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5\n        pseudo = pseudo.clamp(min=0, max=1)\n        x = F.elu(self.conv1(x, edge_index, pseudo))\n\n        idx = fps(pos, batch, ratio=0.5)\n        x, pos, batch = x[idx], pos[idx], batch[idx]\n\n        radius = 0.4\n        edge_index = radius_graph(pos, r=radius, batch=batch)\n        pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5\n        pseudo = pseudo.clamp(min=0, max=1)\n        x = F.elu(self.conv2(x, edge_index, pseudo))\n\n        idx = fps(pos, batch, ratio=0.25)\n        x, pos, batch = x[idx], pos[idx], batch[idx]\n\n        radius = 1\n        edge_index = radius_graph(pos, r=radius, batch=batch)\n        pseudo = (pos[edge_index[1]] - pos[edge_index[0]]) / (2 * radius) + 0.5\n        pseudo = pseudo.clamp(min=0, max=1)\n        x = F.elu(self.conv3(x, edge_index, pseudo))\n\n        x = global_mean_pool(x, batch)\n\n        x = F.elu(self.lin1(x))\n        x = F.elu(self.lin2(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin3(x)\n        return F.log_softmax(x, dim=-1)\n\n\ntrain_dataset, test_dataset = get_dataset(num_points=1024)\nmodel = Net(train_dataset.num_classes)\nrun(train_dataset, test_dataset, model, args.epochs, args.batch_size, args.lr,\n    args.lr_decay_factor, args.lr_decay_step_size, args.weight_decay,\n    args.inference, args.profile, args.bf16, args.compile)\n\nif args.profile:\n    rename_profile_file('points', SplineConv.__name__)\n"
  },
  {
    "path": "benchmark/points/statistics.py",
    "content": "from points.datasets import get_dataset\n\nfrom torch_geometric.transforms import RadiusGraph\n\n\ndef print_dataset(train_dataset, test_dataset):\n    num_nodes = num_edges = 0\n    for data in train_dataset:\n        data = RadiusGraph(0.2)(data)\n        num_nodes += data.num_nodes\n        num_edges += data.num_edges\n    for data in test_dataset:\n        data = RadiusGraph(0.2)(data)\n        num_nodes += data.num_nodes\n        num_edges += data.num_edges\n\n    num_graphs = len(train_dataset) + len(test_dataset)\n    print('Graphs', num_graphs)\n    print('Nodes', num_nodes / num_graphs)\n    print('Edges', (num_edges // 2) / num_graphs)\n    print('Label rate', len(train_dataset) / num_graphs)\n    print()\n\n\nprint_dataset(*get_dataset(num_points=1024))\n"
  },
  {
    "path": "benchmark/points/train_eval.py",
    "content": "import time\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.optim import Adam\n\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.profile import timeit, torch_profile\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\n\ndef run_train(train_dataset, test_dataset, model, epochs, batch_size,\n              use_compile, lr, lr_decay_factor, lr_decay_step_size,\n              weight_decay):\n    model = model.to(device)\n    if use_compile:\n        model = torch.compile(model)\n    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n\n    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)\n    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)\n\n    for epoch in range(1, epochs + 1):\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        elif (hasattr(torch.backends, 'mps')\n              and torch.backends.mps.is_available()):\n            torch.mps.synchronize()\n\n        t_start = time.perf_counter()\n\n        train(model, optimizer, train_loader, device)\n        test_acc = test(model, test_loader, device)\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        elif (hasattr(torch.backends, 'mps')\n              and torch.backends.mps.is_available()):\n            torch.mps.synchronize()\n\n        t_end = time.perf_counter()\n\n        print(f'Epoch: {epoch:03d}, Test: {test_acc:.4f}, '\n              f'Duration: {t_end - t_start:.2f}')\n\n        if epoch % lr_decay_step_size == 0:\n            for param_group in optimizer.param_groups:\n                param_group['lr'] = lr_decay_factor * param_group['lr']\n\n\n@torch.no_grad()\ndef run_inference(test_dataset, model, epochs, batch_size, profiling, bf16,\n                  use_compile):\n    model = model.to(device)\n    if use_compile:\n        model = torch.compile(model)\n    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)\n\n    if torch.cuda.is_available():\n        amp = torch.amp.autocast('cuda', enabled=False)\n    else:\n        amp = torch.cpu.amp.autocast(enabled=bf16)\n\n    with amp:\n        for epoch in range(1, epochs + 1):\n            print(\"Epoch: \", epoch)\n            if epoch == epochs:\n                with timeit():\n                    inference(model, test_loader, device, bf16)\n            else:\n                inference(model, test_loader, device, bf16)\n\n        if profiling:\n            with torch_profile():\n                inference(model, test_loader, device, bf16)\n\n\ndef run(train_dataset, test_dataset, model, epochs, batch_size, lr,\n        lr_decay_factor, lr_decay_step_size, weight_decay, inference,\n        profiling, bf16, use_compile):\n    if not inference:\n        run_train(train_dataset, test_dataset, model, epochs, batch_size,\n                  use_compile, lr, lr_decay_factor, lr_decay_step_size,\n                  weight_decay)\n    else:\n        run_inference(test_dataset, model, epochs, batch_size, profiling, bf16,\n                      use_compile)\n\n\ndef train(model, optimizer, train_loader, device):\n    model.train()\n\n    for data in train_loader:\n        optimizer.zero_grad()\n        data = data.to(device)\n        out = model(data.pos, data.batch)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        optimizer.step()\n\n\n@torch.no_grad()\ndef test(model, test_loader, device):\n    model.eval()\n\n    correct = 0\n    for data in test_loader:\n        data = data.to(device)\n        pred = model(data.pos, data.batch).max(1)[1]\n        correct += pred.eq(data.y).sum().item()\n    test_acc = correct / len(test_loader.dataset)\n\n    return test_acc\n\n\n@torch.no_grad()\ndef inference(model, test_loader, device, bf16):\n    model.eval()\n    for data in test_loader:\n        data = data.to(device)\n        if bf16:\n            data.pos = data.pos.to(torch.bfloat16)\n            model = model.to(torch.bfloat16)\n        model(data.pos, data.batch)\n"
  },
  {
    "path": "benchmark/runtime/README.md",
    "content": "# Runtimes\n\nRun the test suite for PyG via\n\n```\npython main.py\n```\n\nInstall `dgl` and run the test suite for DGL via\n\n```\ncd dgl\npython main.py\n```\n"
  },
  {
    "path": "benchmark/runtime/__init__.py",
    "content": "from .train import train_runtime\n\n__all__ = [\n    'train_runtime',\n]\n"
  },
  {
    "path": "benchmark/runtime/dgl/gat.py",
    "content": "import dgl.function as fn\nimport torch\nimport torch.nn.functional as F\nfrom dgl.nn.pytorch import EdgeSoftmax\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.inits import glorot, zeros\n\n\nclass GATConv(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels, heads=1,\n                 negative_slope=0.2, dropout=0):\n        super().__init__()\n\n        self.g = g\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.negative_slope = negative_slope\n        self.dropout = dropout\n\n        self.weight = Parameter(torch.empty(in_channels, heads * out_channels))\n        self.att = Parameter(torch.empty(1, heads, 2 * out_channels))\n        self.bias = Parameter(torch.empty(heads * out_channels))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot(self.weight)\n        glorot(self.att)\n        zeros(self.bias)\n\n    def gat_msg(self, edge):\n        alpha = torch.cat([edge.src['x'], edge.dst['x']], dim=-1)\n        alpha = (alpha * self.att).sum(dim=-1)\n        alpha = F.leaky_relu(alpha, self.negative_slope)\n        return {'m': edge.src['x'], 'a': alpha}\n\n    def gat_reduce(self, node):\n        alpha = torch.softmax(node.mailbox['a'], dim=1)\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n        x = (node.mailbox['m'] * alpha.unsqueeze(-1)).sum(dim=1)\n        return {'x': x}\n\n    def forward(self, x):\n        x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)\n        self.g.ndata['x'] = x\n        self.g.update_all(self.gat_msg, self.gat_reduce)\n        x = self.g.ndata.pop('x')\n        x = x.view(-1, self.heads * self.out_channels)\n        x = x + self.bias\n        return x\n\n\nclass GAT(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels):\n        super().__init__()\n        self.g = g\n        self.conv1 = GATConv(g, in_channels, 8, 8, 0.6, 0.2)\n        self.conv2 = GATConv(g, 64, out_channels, 1, 0.6, 0.2)\n\n    def forward(self, x):\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = F.elu(self.conv1(x))\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = self.conv2(x)\n        return F.log_softmax(x, dim=1)\n\n\nclass GATSPMVConv(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels, heads=1,\n                 negative_slope=0.2, dropout=0):\n        super().__init__()\n        self.g = g\n        self.out_channels = out_channels\n        self.heads = heads\n        self.negative_slope = negative_slope\n        self.dropout = dropout\n        self.weight = Parameter(torch.empty(in_channels, heads * out_channels))\n        self.att_l = Parameter(torch.empty(heads, out_channels, 1))\n        self.att_r = Parameter(torch.empty(heads, out_channels, 1))\n        self.bias = Parameter(torch.empty(heads * out_channels))\n        self.softmax = EdgeSoftmax()\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot(self.weight)\n        glorot(self.att_l)\n        glorot(self.att_r)\n        zeros(self.bias)\n\n    def forward(self, x):\n        x = torch.matmul(x, self.weight)\n        x = x.reshape((x.size(0), self.heads, -1))  # NxHxD'\n        head_x = x.transpose(0, 1)  # HxNxD'\n        a1 = torch.bmm(head_x, self.att_l).transpose(0, 1)  # NxHx1\n        a2 = torch.bmm(head_x, self.att_r).transpose(0, 1)  # NxHx1\n        self.g.ndata.update({'x': x, 'a1': a1, 'a2': a2})\n        self.g.apply_edges(self.edge_attention)\n        self.edge_softmax()\n        self.g.update_all(fn.src_mul_edge('x', 'a', 'x'), fn.sum('x', 'x'))\n        x = self.g.ndata['x'] / self.g.ndata['z']  # NxHxD'\n        return x.view(-1, self.heads * self.out_channels)\n\n    def edge_attention(self, edge):\n        a = F.leaky_relu(edge.src['a1'] + edge.dst['a2'], self.negative_slope)\n        return {'a': a}\n\n    def edge_softmax(self):\n        alpha, normalizer = self.softmax(self.g.edata['a'], self.g)\n        self.g.ndata['z'] = normalizer\n        if self.training and self.dropout > 0:\n            alpha = F.dropout(alpha, p=self.dropout, training=True)\n        self.g.edata['a'] = alpha\n\n\nclass GATSPMV(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels):\n        super().__init__()\n        self.g = g\n        self.conv1 = GATSPMVConv(g, in_channels, 8, 8, 0.6, 0.2)\n        self.conv2 = GATSPMVConv(g, 64, out_channels, 1, 0.6, 0.2)\n\n    def forward(self, x):\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = F.elu(self.conv1(x))\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = self.conv2(x)\n        return F.log_softmax(x, dim=1)\n"
  },
  {
    "path": "benchmark/runtime/dgl/gcn.py",
    "content": "import dgl.function as fn\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.inits import glorot, zeros\n\n\nclass GCNConv(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels):\n        super().__init__()\n        self.g = g\n        self.weight = Parameter(torch.empty(in_channels, out_channels))\n        self.bias = Parameter(torch.empty(out_channels))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot(self.weight)\n        zeros(self.bias)\n\n    def gcn_msg(self, edge):\n        return {'m': edge.src['x'] * edge.src['norm']}\n\n    def gcn_reduce(self, node):\n        return {'x': node.mailbox['m'].sum(dim=1) * node.data['norm']}\n\n    def forward(self, x):\n        self.g.ndata['x'] = torch.matmul(x, self.weight)\n        self.g.update_all(self.gcn_msg, self.gcn_reduce)\n        x = self.g.ndata.pop('x')\n        x = x + self.bias\n        return x\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(g, in_channels, 16)\n        self.conv2 = GCNConv(g, 16, out_channels)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(x))\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x)\n        return F.log_softmax(x, dim=1)\n\n\nclass GCNSPMVConv(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels):\n        super().__init__()\n        self.g = g\n        self.weight = Parameter(torch.empty(in_channels, out_channels))\n        self.bias = Parameter(torch.empty(out_channels))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot(self.weight)\n        zeros(self.bias)\n\n    def forward(self, x):\n        x = torch.matmul(x, self.weight)\n        self.g.ndata['x'] = x * self.g.ndata['norm']\n        self.g.update_all(fn.copy_src(src='x', out='m'),\n                          fn.sum(msg='m', out='x'))\n        x = self.g.ndata.pop('x') * self.g.ndata['norm']\n        x = x + self.bias\n        return x\n\n\nclass GCNSPMV(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNSPMVConv(g, in_channels, 16)\n        self.conv2 = GCNSPMVConv(g, 16, out_channels)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(x))\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x)\n        return F.log_softmax(x, dim=1)\n"
  },
  {
    "path": "benchmark/runtime/dgl/hidden.py",
    "content": "import os\nimport sys\nimport warnings\n\nwarnings.filterwarnings('ignore')\n\n\nclass HiddenPrint:\n    def __enter__(self):\n        self._original_stdout = sys.stdout\n        sys.stdout = open(os.devnull, 'w')\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        sys.stdout.close()\n        sys.stdout = self._original_stdout\n"
  },
  {
    "path": "benchmark/runtime/dgl/main.py",
    "content": "from itertools import product\n\nimport dgl\nimport torch\nfrom dgl import DGLGraph\nfrom dgl.contrib.data import load_data\nfrom dgl.data import citation_graph\nfrom runtime.dgl.gat import GAT, GATSPMV\nfrom runtime.dgl.gcn import GCN, GCNSPMV\nfrom runtime.dgl.hidden import HiddenPrint\nfrom runtime.dgl.rgcn import RGCN, RGCNSPMV\nfrom runtime.dgl.train import train_runtime\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nwith HiddenPrint():\n    Cora = citation_graph.load_cora()\n    CiteSeer = citation_graph.load_citeseer()\n    PubMed = citation_graph.load_pubmed()\n    MUTAG = load_data('mutag')  # fair comparison\n\n# One training run before we start tracking duration to warm up GPU.\ng = DGLGraph(Cora.graph)\ng.set_n_initializer(dgl.init.zero_initializer)\ng.add_edges(g.nodes(), g.nodes())\nnorm = torch.pow(g.in_degrees().float(), -0.5)\nnorm[torch.isinf(norm)] = 0\ng.ndata['norm'] = norm.unsqueeze(1).to(device)\nmodel = GCNSPMV(g, Cora.features.shape[1], Cora.num_labels).to(device)\ntrain_runtime(model, Cora, epochs=200, device=device)\n\nfor d, Net in product([Cora, CiteSeer, PubMed], [GCN, GCNSPMV, GAT, GATSPMV]):\n    g = DGLGraph(d.graph)\n    g.set_n_initializer(dgl.init.zero_initializer)\n    g.add_edges(g.nodes(), g.nodes())\n    norm = torch.pow(g.in_degrees().float(), -0.5)\n    norm[torch.isinf(norm)] = 0\n    g.ndata['norm'] = norm.unsqueeze(1).to(device)\n    model = Net(g, d.features.shape[1], d.num_labels).to(device)\n    t = train_runtime(model, d, epochs=200, device=device)\n    print(f'{d.name} - {Net.__name__}: {t:.2f}s')\n\nfor d, Net in product([MUTAG], [RGCN, RGCNSPMV]):\n    g = DGLGraph()\n    g.add_nodes(d.num_nodes)\n    g.add_edges(d.edge_src, d.edge_dst)\n    edge_type = torch.from_numpy(d.edge_type).to(device)\n    edge_norm = torch.from_numpy(d.edge_norm).to(device)\n    g.edata.update({'type': edge_type, 'norm': edge_norm})\n    g.ndata['id'] = torch.arange(d.num_nodes, dtype=torch.long, device=device)\n    model = Net(g, d.num_nodes, d.num_classes, d.num_rels)\n    t = train_runtime(model, d, epochs=200, device=device)\n    print(f'{d.name} - {Net.__name__}: {t:.2f}s')\n"
  },
  {
    "path": "benchmark/runtime/dgl/rgcn.py",
    "content": "import dgl.function as fn\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Parameter as Param\n\nfrom torch_geometric.nn.inits import uniform\n\n\nclass RGCNConv(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels, num_relations, num_bases):\n        super().__init__()\n\n        self.g = g\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_relations = num_relations\n        self.num_bases = num_bases\n\n        self.basis = Param(torch.empty(num_bases, in_channels, out_channels))\n        self.att = Param(torch.empty(num_relations, num_bases))\n        self.root = Param(torch.empty(in_channels, out_channels))\n        self.bias = Param(torch.empty(out_channels))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        size = self.num_bases * self.in_channels\n        uniform(size, self.basis)\n        uniform(size, self.att)\n        uniform(size, self.root)\n        uniform(size, self.bias)\n\n    def rgcn_reduce(self, node):\n        return {'x': node.mailbox['m'].sum(dim=1)}\n\n    def forward(self, x):\n        self.w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))\n        self.w = self.w.view(self.num_relations, self.in_channels,\n                             self.out_channels)\n\n        if x is None:\n\n            def msg_func(edge):\n                w = self.w.view(-1, self.out_channels)\n                index = edge.data['type'] * self.in_channels + edge.src['id']\n                m = w.index_select(0, index) * edge.data['norm'].unsqueeze(1)\n                return {'m': m}\n        else:\n            self.g.ndata['x'] = x\n\n            def msg_func(edge):\n                w = self.w.index_select(0, edge.data['type'])\n                m = torch.bmm(edge.src['x'].unsqueeze(1), w).squeeze()\n                m = m * edge.data['norm'].unsqueeze(1)\n                return {'m': m}\n\n        self.g.update_all(msg_func, self.rgcn_reduce)\n        out = self.g.ndata.pop('x')\n\n        if x is None:\n            out = out + self.root\n        else:\n            out = out + torch.matmul(x, self.root)\n\n        out = out + self.bias\n        return out\n\n\nclass RGCN(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels, num_relations):\n        super().__init__()\n        self.conv1 = RGCNConv(g, in_channels, 16, num_relations, num_bases=30)\n        self.conv2 = RGCNConv(g, 16, out_channels, num_relations, num_bases=30)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(None))\n        x = self.conv2(x)\n        return F.log_softmax(x, dim=1)\n\n\nclass RGCNSPMVConv(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels, num_relations, num_bases):\n        super().__init__()\n\n        self.g = g\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_relations = num_relations\n        self.num_bases = num_bases\n\n        self.basis = Param(torch.empty(num_bases, in_channels, out_channels))\n        self.att = Param(torch.empty(num_relations, num_bases))\n        self.root = Param(torch.empty(in_channels, out_channels))\n        self.bias = Param(torch.empty(out_channels))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        size = self.num_bases * self.in_channels\n        uniform(size, self.basis)\n        uniform(size, self.att)\n        uniform(size, self.root)\n        uniform(size, self.bias)\n\n    def forward(self, x):\n        self.w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))\n        self.w = self.w.view(self.num_relations, self.in_channels,\n                             self.out_channels)\n\n        if x is None:\n\n            def msg_func(edge):\n                w = self.w.view(-1, self.out_channels)\n                index = edge.data['type'] * self.in_channels + edge.src['id']\n                m = w.index_select(0, index) * edge.data['norm'].unsqueeze(1)\n                return {'m': m}\n        else:\n            self.g.ndata['x'] = x\n\n            def msg_func(edge):\n                w = self.w.index_select(0, edge.data['type'])\n                m = torch.bmm(edge.src['x'].unsqueeze(1), w).squeeze()\n                m = m * edge.data['norm'].unsqueeze(1)\n                return {'m': m}\n\n        self.g.update_all(msg_func, fn.sum(msg='m', out='x'))\n        out = self.g.ndata.pop('x')\n\n        if x is None:\n            out = out + self.root\n        else:\n            out = out + torch.matmul(x, self.root)\n\n        out = out + self.bias\n        return out\n\n\nclass RGCNSPMV(torch.nn.Module):\n    def __init__(self, g, in_channels, out_channels, num_relations):\n        super().__init__()\n        self.conv1 = RGCNSPMVConv(g, in_channels, 16, num_relations,\n                                  num_bases=30)\n        self.conv2 = RGCNSPMVConv(g, 16, out_channels, num_relations,\n                                  num_bases=30)\n\n    def forward(self, x):\n        x = F.relu(self.conv1(None))\n        x = self.conv2(x)\n        return F.log_softmax(x, dim=1)\n"
  },
  {
    "path": "benchmark/runtime/dgl/train.py",
    "content": "import time\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef train_runtime(model, data, epochs, device):\n    if hasattr(data, 'features'):\n        x = torch.tensor(data.features, dtype=torch.float, device=device)\n    else:\n        x = None\n    mask = data.train_mask if hasattr(data, 'train_mask') else data.train_idx\n    y = torch.tensor(data.labels, dtype=torch.long, device=device)[mask]\n\n    model = model.to(device)\n    model.train()\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n    if torch.cuda.is_available():\n        torch.cuda.synchronize()\n    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n        torch.mps.synchronize()\n    t_start = time.perf_counter()\n\n    for _ in range(epochs):\n        optimizer.zero_grad()\n        out = model(x)\n        loss = F.nll_loss(out[mask], y.view(-1))\n        loss.backward()\n        optimizer.step()\n\n    if torch.cuda.is_available():\n        torch.cuda.synchronize()\n    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n        torch.mps.synchronize()\n    t_end = time.perf_counter()\n\n    return t_end - t_start\n"
  },
  {
    "path": "benchmark/runtime/gat.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.nn import GATConv\n\n\nclass GAT(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv1 = GATConv(in_channels, 8, heads=8, dropout=0.6)\n        self.conv2 = GATConv(8 * 8, out_channels, dropout=0.6)\n\n    def forward(self, data):\n        x, edge_index = data.x, data.edge_index\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = F.elu(self.conv1(x, edge_index))\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n"
  },
  {
    "path": "benchmark/runtime/gcn.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.nn import GCNConv\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, 16, cached=True)\n        self.conv2 = GCNConv(16, out_channels, cached=True)\n\n    def forward(self, data):\n        x, edge_index = data.x, data.edge_index\n        x = F.relu(self.conv1(x, edge_index))\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n"
  },
  {
    "path": "benchmark/runtime/main.py",
    "content": "import os.path as osp\nfrom itertools import product\n\nimport torch\nfrom runtime.gat import GAT\nfrom runtime.gcn import GCN\nfrom runtime.rgcn import RGCN\nfrom runtime.train import train_runtime\n\nfrom torch_geometric.datasets import Entities, Planetoid\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nroot = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')\nCora = Planetoid(osp.join(root, 'Cora'), 'Cora')\nCiteSeer = Planetoid(osp.join(root, 'CiteSeer'), 'CiteSeer')\nPubMed = Planetoid(osp.join(root, 'PubMed'), 'PubMed')\nMUTAG = Entities(osp.join(root, 'EntitiesMUTAG'), 'MUTAG')\n\n# One training run before we start tracking duration to warm up GPU.\nmodel = GCN(Cora.num_features, Cora.num_classes)\ntrain_runtime(model, Cora[0], epochs=200, device=device)\n\nfor d, Net in product([Cora, CiteSeer, PubMed], [GCN, GAT]):\n    model = Net(d.num_features, d.num_classes)\n    t = train_runtime(model, d[0], epochs=200, device=device)\n    print(f'{str(d)[:-2]} - {Net.__name__}: {t:.2f}s')\n\nfor d, Net in product([MUTAG], [RGCN]):\n    model = Net(d[0].num_nodes, d.num_classes, d.num_relations)\n    t = train_runtime(model, d[0], epochs=200, device=device)\n    print(f'{str(d)[:-2]} - {Net.__name__}: {t:.2f}s')\n"
  },
  {
    "path": "benchmark/runtime/rgcn.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.nn import FastRGCNConv\n\n\nclass RGCN(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, num_relations):\n        super().__init__()\n        self.conv1 = FastRGCNConv(in_channels, 16, num_relations, num_bases=30)\n        self.conv2 = FastRGCNConv(16, out_channels, num_relations,\n                                  num_bases=30)\n\n    def forward(self, data):\n        edge_index, edge_type = data.edge_index, data.edge_type\n        x = F.relu(self.conv1(None, edge_index, edge_type))\n        x = self.conv2(x, edge_index, edge_type)\n        return F.log_softmax(x, dim=1)\n"
  },
  {
    "path": "benchmark/runtime/train.py",
    "content": "import time\n\nimport torch\nimport torch.nn.functional as F\n\n\ndef train_runtime(model, data, epochs, device):\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n    model = model.to(device)\n    data = data.to(device)\n    model.train()\n    mask = data.train_mask if 'train_mask' in data else data.train_idx\n    y = data.y[mask] if 'train_mask' in data else data.train_y\n\n    if torch.cuda.is_available():\n        torch.cuda.synchronize()\n    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n        torch.mps.synchronize()\n    t_start = time.perf_counter()\n\n    for _ in range(epochs):\n        optimizer.zero_grad()\n        out = model(data)\n        loss = F.nll_loss(out[mask], y)\n        loss.backward()\n        optimizer.step()\n\n    if torch.cuda.is_available():\n        torch.cuda.synchronize()\n    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n        torch.mps.synchronize()\n    t_end = time.perf_counter()\n\n    return t_end - t_start\n"
  },
  {
    "path": "benchmark/setup.py",
    "content": "from setuptools import find_packages, setup\n\nsetup(\n    name='torch_geometric_benchmark',\n    version='0.1.0',\n    description='PyG Benchmark Suite',\n    author='Matthias Fey',\n    author_email='matthias.fey@tu-dortmund.de',\n    url='https://github.com/pyg-team/pytorch_geometric_benchmark',\n    install_requires=['scikit-learn'],\n    packages=find_packages(),\n)\n"
  },
  {
    "path": "benchmark/training/README.md",
    "content": "# Training Benchmark\n\n## Environment setup\n\n1. Confirm that PyG is properly installed.\n1. Install dataset package:\n   ```bash\n   pip install ogb\n   ```\n1. Install `jemalloc` for performance benchmark:\n   ```bash\n   cd ${workspace}\n   git clone https://github.com/jemalloc/jemalloc.git\n   cd jemalloc\n   git checkout 5.2.1\n   ./autogen.sh\n   ./configure --prefix=${workspace}/jemalloc-bin\n   make\n   make install\n   ```\n\n## Running benchmark\n\n1. Set environment variables:\n   ```bash\n   source activate env_name\n   export DNNL_PRIMITIVE_CACHE_CAPACITY=1024\n   export KMP_BLOCKTIME=1\n   export KMP_AFFINITY=granularity=fine,compact,1,0\n\n   jemalloc_lib=${workspace}/jemalloc-bin/lib/libjemalloc.so\n   export LD_PRELOAD=\"$jemalloc_lib\"\n   export MALLOC_CONF=\"oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000\"\n   ```\n1. Core binding, *e.g.*, single socket / single core / 4 cores per instance:\n   ```bash\n   OMP_NUM_THREADS=${CORES} numactl -C 0-${LAST_CORE} -m 0 CMD......\n   ```\n1. Execute benchmarks, *e.g.*:\n   ```bash\n   python training_benchmark.py --models=gcn --datasets=Reddit --num-workers=0 --batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --num-steps=50\n   python training_benchmark.py --models=gcn --datasets=Reddit --num-workers=0 --batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --num-steps=50 --use-sparse-tensor\n   python training_benchmark.py --models=sage --datasets=ogbn-products --num-workers=0 --batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --num-steps=50\n   python training_benchmark.py --models=sage --datasets=ogbn-products --num-workers=0 --batch-sizes=512 --num-layers=2 --num-hidden-channels=64 --num-steps=50 --use-sparse-tensor\n   ```\n"
  },
  {
    "path": "benchmark/training/training_benchmark.py",
    "content": "import argparse\nimport ast\nimport warnings\nfrom collections import defaultdict\nfrom contextlib import nullcontext\n\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nfrom benchmark.utils import (\n    emit_itt,\n    get_dataset,\n    get_model,\n    get_split_masks,\n    save_benchmark_data,\n    test,\n    write_to_csv,\n)\nfrom torch_geometric import compile\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import PNAConv\nfrom torch_geometric.profile import (\n    rename_profile_file,\n    timeit,\n    torch_profile,\n    xpu_profile,\n)\n\nsupported_sets = {\n    'ogbn-mag': ['rgat', 'rgcn'],\n    'ogbn-products': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'],\n    'Reddit': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'],\n}\n\ndevice_conditions = {\n    'cpu': (lambda: True),\n    'cuda': (lambda: torch.cuda.is_available()),\n    'mps':\n    (lambda:\n     (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available())),\n    'xpu': (lambda: torch.xpu.is_available()),\n}\n\n\ndef train_homo(model, loader, optimizer, device, progress_bar=True, desc=\"\",\n               trim=False):\n    if progress_bar:\n        loader = tqdm(loader, desc=desc)\n    for batch in loader:\n        optimizer.zero_grad()\n        batch = batch.to(device)\n        if 'adj_t' in batch:\n            edge_index = batch.adj_t\n        else:\n            edge_index = batch.edge_index\n        if not trim:\n            out = model(batch.x, edge_index)\n        else:\n            out = model(\n                batch.x,\n                edge_index,\n                num_sampled_nodes_per_hop=batch.num_sampled_nodes,\n                num_sampled_edges_per_hop=batch.num_sampled_edges,\n            )\n        batch_size = batch.batch_size\n        out = out[:batch_size]\n        target = batch.y[:batch_size]\n        loss = F.cross_entropy(out, target)\n        loss.backward()\n        optimizer.step()\n\n\ndef train_hetero(model, loader, optimizer, device, progress_bar=True, desc=\"\",\n                 trim=False):\n    if trim:\n        warnings.warn(\"Trimming not yet implemented for heterogeneous graphs\",\n                      stacklevel=2)\n\n    if progress_bar:\n        loader = tqdm(loader, desc=desc)\n    for batch in loader:\n        optimizer.zero_grad()\n        batch = batch.to(device)\n        if 'adj_t' in batch:\n            edge_index_dict = batch.adj_t_dict\n        else:\n            edge_index_dict = batch.edge_index_dict\n        out = model(batch.x_dict, edge_index_dict)\n        batch_size = batch['paper'].batch_size\n        out = out['paper'][:batch_size]\n        target = batch['paper'].y[:batch_size]\n        loss = F.cross_entropy(out, target)\n        loss.backward()\n        optimizer.step()\n\n\ndef run(args: argparse.ArgumentParser):\n    csv_data = defaultdict(list)\n\n    if args.write_csv == 'prof' and not args.profile:\n        warnings.warn(\n            \"Cannot write profile data to CSV because profiling is \"\n            \"disabled\", stacklevel=2)\n\n    if args.device == 'xpu':\n        try:\n            import intel_extension_for_pytorch as ipex\n        except ImportError as e:\n            raise RuntimeError(\n                'XPU device requires IPEX to be installed') from e\n\n    if not device_conditions[args.device]():\n        raise RuntimeError(f'{args.device.upper()} is not available')\n    device = torch.device(args.device)\n\n    # If we use a custom number of steps, then we need to use RandomSampler,\n    # which already does shuffle.\n    shuffle = False if args.num_steps != -1 else True\n\n    print('BENCHMARK STARTS')\n    print(f'Running on {args.device.upper()}')\n    for dataset_name in args.datasets:\n        assert dataset_name in supported_sets.keys(\n        ), f\"Dataset {dataset_name} isn't supported.\"\n        print(f'Dataset: {dataset_name}')\n        load_time = timeit() if args.measure_load_time else nullcontext()\n        with load_time:\n            data, num_classes = get_dataset(dataset_name, args.root,\n                                            args.use_sparse_tensor, args.bf16)\n        hetero = True if dataset_name == 'ogbn-mag' else False\n        mask, val_mask, test_mask = get_split_masks(data, dataset_name)\n        degree = None\n\n        if args.device == 'cpu':\n            amp = torch.cpu.amp.autocast(enabled=args.bf16)\n        elif args.device == 'cuda':\n            amp = torch.amp.autocast('cuda', enabled=False)\n        elif args.device == 'xpu':\n            amp = torch.xpu.amp.autocast(enabled=False)\n        else:\n            amp = nullcontext()\n\n        if args.device == 'xpu' and args.warmup < 1:\n            print('XPU device requires warmup - setting warmup=1')\n            args.warmup = 1\n\n        inputs_channels = data[\n            'paper'].num_features if dataset_name == 'ogbn-mag' \\\n            else data.num_features\n\n        for model_name in args.models:\n            if model_name not in supported_sets[dataset_name]:\n                print(f'Configuration of {dataset_name} + {model_name} '\n                      f'not supported. Skipping.')\n                continue\n            print(f'Training bench for {model_name}:')\n\n            for batch_size in args.batch_sizes:\n                num_nodes = int(mask[-1].sum()) if hetero else int(mask.sum())\n                sampler = torch.utils.data.RandomSampler(\n                    range(num_nodes), num_samples=args.num_steps *\n                    batch_size) if args.num_steps != -1 else None\n\n                for layers in args.num_layers:\n                    num_neighbors = args.num_neighbors\n                    if type(num_neighbors) is list:\n                        if len(num_neighbors) == 1:\n                            num_neighbors = num_neighbors * layers\n                    elif type(num_neighbors) is int:\n                        num_neighbors = [num_neighbors] * layers\n\n                    assert len(\n                        num_neighbors) == layers, \\\n                        f'''num_neighbors={num_neighbors} length\n                        != num of layers={layers}'''\n\n                    kwargs = {\n                        'num_neighbors': num_neighbors,\n                        'batch_size': batch_size,\n                        'shuffle': shuffle,\n                        'num_workers': args.num_workers,\n                    }\n                    subgraph_loader = NeighborLoader(\n                        data,\n                        input_nodes=mask,\n                        sampler=sampler,\n                        **kwargs,\n                    )\n                    if args.evaluate:\n                        val_loader = NeighborLoader(\n                            data,\n                            input_nodes=val_mask,\n                            sampler=None,\n                            **kwargs,\n                        )\n                        test_loader = NeighborLoader(\n                            data,\n                            input_nodes=test_mask,\n                            sampler=None,\n                            **kwargs,\n                        )\n                    for hidden_channels in args.num_hidden_channels:\n                        print('----------------------------------------------')\n                        print(f'Batch size={batch_size}, '\n                              f'Layers amount={layers}, '\n                              f'Num_neighbors={num_neighbors}, '\n                              f'Hidden features size={hidden_channels}, '\n                              f'Sparse tensor={args.use_sparse_tensor}')\n\n                        params = {\n                            'inputs_channels': inputs_channels,\n                            'hidden_channels': hidden_channels,\n                            'output_channels': num_classes,\n                            'num_heads': args.num_heads,\n                            'num_layers': layers,\n                        }\n\n                        if model_name == 'pna':\n                            if degree is None:\n                                degree = PNAConv.get_degree_histogram(\n                                    subgraph_loader)\n                                print(f'Calculated degree for {dataset_name}.')\n                            params['degree'] = degree\n\n                        model = get_model(\n                            model_name, params,\n                            metadata=data.metadata() if hetero else None)\n                        model = model.to(device)\n                        model.train()\n\n                        if args.compile:\n                            model = compile(model, dynamic=True)\n\n                        optimizer = torch.optim.Adam(model.parameters(),\n                                                     lr=0.001)\n\n                        if args.device == 'xpu':\n                            model, optimizer = ipex.optimize(\n                                model, optimizer=optimizer)\n\n                        progress_bar = False if args.no_progress_bar else True\n                        train = train_hetero if hetero else train_homo\n\n                        # Define context manager parameters:\n                        cpu_affinity = subgraph_loader.enable_cpu_affinity(\n                            args.loader_cores\n                        ) if args.cpu_affinity else nullcontext()\n\n                        with amp, cpu_affinity:\n                            for _ in range(args.warmup):\n                                train(\n                                    model,\n                                    subgraph_loader,\n                                    optimizer,\n                                    device,\n                                    progress_bar=progress_bar,\n                                    desc=\"Warmup\",\n                                    trim=args.trim,\n                                )\n                            with timeit(avg_time_divisor=args.num_epochs) as t:\n                                # becomes a no-op if vtune_profile == False\n                                with emit_itt(args.vtune_profile):\n                                    for epoch in range(args.num_epochs):\n                                        train(\n                                            model,\n                                            subgraph_loader,\n                                            optimizer,\n                                            device,\n                                            progress_bar=progress_bar,\n                                            desc=f\"Epoch={epoch}\",\n                                            trim=args.trim,\n                                        )\n                                        if args.evaluate:\n                                            # In evaluate, throughput and\n                                            # latency are not accurate.\n                                            val_acc = test(\n                                                model, val_loader, device,\n                                                hetero,\n                                                progress_bar=progress_bar)\n                                            print(\n                                                f'Val Accuracy: {val_acc:.4f}')\n\n                            if args.evaluate:\n                                test_acc = test(model, test_loader, device,\n                                                hetero,\n                                                progress_bar=progress_bar)\n                                print(f'Test Accuracy: {test_acc:.4f}')\n\n                            if args.profile:\n                                if args.device == 'xpu':\n                                    profile = xpu_profile(\n                                        args.export_chrome_trace)\n                                else:\n                                    profile = torch_profile(\n                                        args.export_chrome_trace, csv_data,\n                                        args.write_csv)\n                                with profile:\n                                    train(model, subgraph_loader, optimizer,\n                                          device, progress_bar=progress_bar,\n                                          desc=\"Profile training\")\n                                if args.export_chrome_trace:\n                                    rename_profile_file(\n                                        model_name, dataset_name,\n                                        str(batch_size), str(layers),\n                                        str(hidden_channels),\n                                        str(num_neighbors))\n\n                        total_time = t.duration\n                        if args.num_steps != -1:\n                            total_num_samples = args.num_steps * batch_size\n                        else:\n                            total_num_samples = num_nodes\n                        throughput = total_num_samples / total_time\n                        latency = total_time / total_num_samples * 1000\n                        print(f'Throughput: {throughput:.3f} samples/s')\n                        print(f'Latency: {latency:.3f} ms')\n\n                        num_records = 1\n                        if args.write_csv == 'prof':\n                            # For profiling with PyTorch, we save the top-5\n                            # most time consuming operations. Therefore, the\n                            # same data should be entered for each of them.\n                            num_records = 5\n                        for _ in range(num_records):\n                            save_benchmark_data(\n                                csv_data,\n                                batch_size,\n                                layers,\n                                num_neighbors,\n                                hidden_channels,\n                                total_time,\n                                model_name,\n                                dataset_name,\n                                args.use_sparse_tensor,\n                            )\n    if args.write_csv:\n        write_to_csv(csv_data, args.write_csv, training=True)\n\n\nif __name__ == '__main__':\n    argparser = argparse.ArgumentParser('GNN training benchmark')\n    add = argparser.add_argument\n\n    add('--device', choices=['cpu', 'cuda', 'mps', 'xpu'], default='cpu',\n        help='Device to run benchmark on')\n    add('--datasets', nargs='+',\n        default=['ogbn-mag', 'ogbn-products', 'Reddit'], type=str)\n    add('--use-sparse-tensor', action='store_true',\n        help='use torch_sparse.SparseTensor as graph storage format')\n    add('--models', nargs='+',\n        default=['edge_cnn', 'gat', 'gcn', 'pna', 'rgat', 'rgcn'], type=str)\n    add('--root', default='../../data', type=str,\n        help='relative path to look for the datasets')\n    add('--batch-sizes', nargs='+', default=[512, 1024, 2048, 4096, 8192],\n        type=int)\n    add('--num-layers', nargs='+', default=[2, 3], type=int)\n    add('--num-hidden-channels', nargs='+', default=[64, 128, 256], type=int)\n    add('--num-heads', default=2, type=int,\n        help='number of hidden attention heads, applies only for gat and rgat')\n    add('--num-neighbors', default=[10], type=ast.literal_eval,\n        help='number of neighbors to sample per layer')\n    add('--num-workers', default=2, type=int)\n    add('--warmup', default=1, type=int)\n    add('--profile', action='store_true')\n    add('--vtune-profile', action='store_true')\n    add('--bf16', action='store_true')\n    add('--no-progress-bar', action='store_true', default=False,\n        help='turn off using progress bar')\n    add('--num-epochs', default=1, type=int)\n    add('--num-steps', default=-1, type=int,\n        help='number of steps, -1 means iterating through all the data')\n    add('--cpu-affinity', action='store_true',\n        help=\"Use DataLoader affinitzation.\")\n    add('--loader-cores', nargs='+', default=[], type=int,\n        help=\"List of CPU core IDs to use for DataLoader workers.\")\n    add('--measure-load-time', action='store_true')\n    add('--evaluate', action='store_true')\n    add('--write-csv', choices=[None, 'bench', 'prof'], default=None,\n        help='Write benchmark or PyTorch profile data to CSV')\n    add('--export-chrome-trace', default=True, type=bool,\n        help='Export chrome trace file. Works only with PyTorch profiler')\n    add('--trim', action='store_true', help=\"Use `trim_to_layer` optimization\")\n    add('--compile', action='store_true')\n    args = argparser.parse_args()\n\n    run(args)\n"
  },
  {
    "path": "benchmark/utils/__init__.py",
    "content": "from .utils import emit_itt\nfrom .utils import get_dataset, get_dataset_with_transformation\nfrom .utils import get_model\nfrom .utils import get_split_masks\nfrom .utils import save_benchmark_data, write_to_csv\nfrom .utils import test\n\n__all__ = [\n    'emit_itt',\n    'get_dataset',\n    'get_dataset_with_transformation',\n    'get_model',\n    'get_split_masks',\n    'save_benchmark_data',\n    'write_to_csv',\n    'test',\n]\n"
  },
  {
    "path": "benchmark/utils/hetero_gat.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.nn import GAT, to_hetero\n\n\nclass HeteroGAT(torch.nn.Module):\n    def __init__(self, metadata, hidden_channels, num_layers, output_channels,\n                 num_heads):\n        super().__init__()\n        self.model = to_hetero(\n            GAT((-1, -1), hidden_channels, num_layers, output_channels,\n                add_self_loops=False, heads=num_heads), metadata)\n\n    def forward(self, x_dict, edge_index_dict):\n        return self.model(x_dict, edge_index_dict)\n\n    @torch.no_grad()\n    def inference(self, loader, device, progress_bar=False, **kwargs):\n        self.model.eval()\n        if progress_bar:\n            loader = tqdm(loader, desc=\"Inference\")\n        for batch in loader:\n            batch = batch.to(device)\n            if 'adj_t' in batch:\n                self.model(batch.x_dict, batch.adj_t_dict)\n            else:\n                self.model(batch.x_dict, batch.edge_index_dict)\n\n    @torch.no_grad()\n    def test(self, x, loader, device, progress_bar=False):\n        self.model.eval()\n        total_examples = total_correct = 0\n        if progress_bar:\n            loader = tqdm(loader, desc=\"Evaluate\")\n        for batch in loader:\n            batch = batch.to(device)\n            if 'adj_t' in batch:\n                out = self.model(batch.x_dict, batch.adj_t_dict)\n            else:\n                out = self.model(batch.x_dict, batch.edge_index_dict)\n            batch_size = batch['paper'].batch_size\n            out = out['paper'][:batch_size]\n            pred = out.argmax(dim=-1)\n\n            total_examples += batch_size\n            total_correct += int((pred == batch['paper'].y[:batch_size]).sum())\n\n        return total_correct / total_examples\n"
  },
  {
    "path": "benchmark/utils/hetero_sage.py",
    "content": "import torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.nn import GraphSAGE, to_hetero\n\n\nclass HeteroGraphSAGE(torch.nn.Module):\n    def __init__(self, metadata, hidden_channels, num_layers, output_channels):\n        super().__init__()\n        self.model = to_hetero(\n            GraphSAGE((-1, -1), hidden_channels, num_layers, output_channels),\n            metadata)\n\n    def forward(self, x_dict, edge_index_dict):\n        return self.model(x_dict, edge_index_dict)\n\n    @torch.no_grad()\n    def inference(self, loader, device, progress_bar=False, **kwargs):\n        self.model.eval()\n        if progress_bar:\n            loader = tqdm(loader, desc=\"Inference\")\n        for batch in loader:\n            batch = batch.to(device)\n            if 'adj_t' in batch:\n                self.model(batch.x_dict, batch.adj_t_dict)\n            else:\n                self.model(batch.x_dict, batch.edge_index_dict)\n\n    @torch.no_grad()\n    def test(self, loader, device, progress_bar=False):\n        self.model.eval()\n        total_examples = total_correct = 0\n        if progress_bar:\n            loader = tqdm(loader, desc=\"Evaluate\")\n        for batch in loader:\n            batch = batch.to(device)\n            if 'adj_t' in batch:\n                out = self.model(batch.x_dict, batch.adj_t_dict)\n            else:\n                out = self.model(batch.x_dict, batch.edge_index_dict)\n            batch_size = batch['paper'].batch_size\n            out = out['paper'][:batch_size]\n            pred = out.argmax(dim=-1)\n\n            total_examples += batch_size\n            total_correct += int((pred == batch['paper'].y[:batch_size]).sum())\n\n        return total_correct / total_examples\n"
  },
  {
    "path": "benchmark/utils/utils.py",
    "content": "import os\nimport os.path as osp\nfrom datetime import datetime\n\nimport torch\nfrom ogb.nodeproppred import PygNodePropPredDataset\nfrom tqdm import tqdm\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.datasets import OGB_MAG, Reddit\nfrom torch_geometric.nn import GAT, GCN, PNA, EdgeCNN, GraphSAGE\nfrom torch_geometric.utils import index_to_mask\n\nfrom .hetero_gat import HeteroGAT\nfrom .hetero_sage import HeteroGraphSAGE\n\ntry:\n    from torch.autograd.profiler import emit_itt\nexcept ImportError:\n    from contextlib import contextmanager\n\n    @contextmanager\n    def emit_itt(*args, **kwargs):\n        yield\n\n\nmodels_dict = {\n    'edge_cnn': EdgeCNN,\n    'gat': GAT,\n    'gcn': GCN,\n    'pna': PNA,\n    'sage': GraphSAGE,\n    'rgat': HeteroGAT,\n    'rgcn': HeteroGraphSAGE,\n}\n\n\ndef get_dataset_with_transformation(name, root, use_sparse_tensor=False,\n                                    bf16=False):\n    path = osp.join(osp.dirname(osp.realpath(__file__)), root, name)\n    transform = T.ToSparseTensor(\n        remove_edge_index=False) if use_sparse_tensor else None\n    if name == 'ogbn-mag':\n        if transform is None:\n            transform = T.ToUndirected(merge=True)\n        else:\n            transform = T.Compose([T.ToUndirected(merge=True), transform])\n        dataset = OGB_MAG(root=path, preprocess='metapath2vec',\n                          transform=transform)\n    elif name == 'ogbn-products':\n        if transform is None:\n            transform = T.RemoveDuplicatedEdges()\n        else:\n            transform = T.Compose([T.RemoveDuplicatedEdges(), transform])\n\n        dataset = PygNodePropPredDataset('ogbn-products', root=path,\n                                         transform=transform)\n\n    elif name == 'Reddit':\n        dataset = Reddit(root=path, transform=transform)\n\n    data = dataset[0]\n\n    if name == 'ogbn-products':\n        split_idx = dataset.get_idx_split()\n        data.train_mask = index_to_mask(split_idx['train'],\n                                        size=data.num_nodes)\n        data.val_mask = index_to_mask(split_idx['valid'], size=data.num_nodes)\n        data.test_mask = index_to_mask(split_idx['test'], size=data.num_nodes)\n        data.y = data.y.squeeze()\n\n    if bf16:\n        if isinstance(data, HeteroData):\n            for node_type in data.node_types:\n                data[node_type].x = data[node_type].x.to(torch.bfloat16)\n        else:\n            data.x = data.x.to(torch.bfloat16)\n\n    return data, dataset.num_classes, transform\n\n\ndef get_dataset(name, root, use_sparse_tensor=False, bf16=False):\n    data, num_classes, _ = get_dataset_with_transformation(\n        name, root, use_sparse_tensor, bf16)\n    return data, num_classes\n\n\ndef get_model(name, params, metadata=None):\n    Model = models_dict.get(name, None)\n    assert Model is not None, f'Model {name} not supported!'\n\n    if name == 'rgat':\n        return Model(metadata, params['hidden_channels'], params['num_layers'],\n                     params['output_channels'], params['num_heads'])\n\n    if name == 'rgcn':\n        return Model(metadata, params['hidden_channels'], params['num_layers'],\n                     params['output_channels'])\n\n    if name == 'gat':\n        return Model(params['inputs_channels'], params['hidden_channels'],\n                     params['num_layers'], params['output_channels'],\n                     heads=params['num_heads'])\n\n    if name == 'pna':\n        return Model(params['inputs_channels'], params['hidden_channels'],\n                     params['num_layers'], params['output_channels'],\n                     aggregators=['mean', 'min', 'max', 'std'],\n                     scalers=['identity', 'amplification',\n                              'attenuation'], deg=params['degree'])\n\n    return Model(params['inputs_channels'], params['hidden_channels'],\n                 params['num_layers'], params['output_channels'])\n\n\ndef get_split_masks(data, dataset_name):\n    if dataset_name == 'ogbn-mag':\n        train_mask = ('paper', data['paper'].train_mask)\n        test_mask = ('paper', data['paper'].test_mask)\n        val_mask = ('paper', data['paper'].val_mask)\n    else:\n        train_mask = data.train_mask\n        val_mask = data.val_mask\n        test_mask = data.test_mask\n    return train_mask, val_mask, test_mask\n\n\ndef save_benchmark_data(csv_data, batch_size, layers, num_neighbors,\n                        hidden_channels, total_time, model_name, dataset_name,\n                        use_sparse_tensor):\n    config = f'Batch size={batch_size}, ' \\\n             f'#Layers={layers}, ' \\\n             f'#Neighbors={num_neighbors}, ' \\\n             f'#Hidden features={hidden_channels}'\n    csv_data['DATE'].append(datetime.now().date())\n    csv_data['TIME (s)'].append(round(total_time, 2))\n    csv_data['MODEL'].append(model_name)\n    csv_data['DATASET'].append(dataset_name)\n    csv_data['CONFIG'].append(config)\n    csv_data['SPARSE'].append(use_sparse_tensor)\n\n\ndef write_to_csv(csv_data, write_csv='bench', training=False):\n    import pandas as pd\n    results_path = osp.join(osp.dirname(osp.realpath(__file__)), '../results/')\n    os.makedirs(results_path, exist_ok=True)\n\n    name = 'training' if training else 'inference'\n    if write_csv == 'bench':\n        csv_file_name = f'TOTAL_{name}_benchmark.csv'\n    else:\n        csv_file_name = f'TOTAL_prof_{name}_benchmark.csv'\n    csv_path = osp.join(results_path, csv_file_name)\n    index_label = 'TEST_ID' if write_csv == 'bench' else 'ID'\n\n    with_header = not osp.exists(csv_path)\n    df = pd.DataFrame(csv_data)\n    df.to_csv(csv_path, mode='a', index_label=index_label, header=with_header)\n\n\n@torch.no_grad()\ndef test(model, loader, device, hetero, progress_bar=True,\n         desc=\"Evaluation\") -> None:\n    if progress_bar:\n        loader = tqdm(loader, desc=desc)\n    total_examples = total_correct = 0\n    if hetero:\n        for batch in loader:\n            batch = batch.to(device)\n            if 'adj_t' in batch:\n                edge_index_dict = batch.adj_t_dict\n            else:\n                edge_index_dict = batch.edge_index_dict\n            out = model(batch.x_dict, edge_index_dict)\n            batch_size = batch['paper'].batch_size\n            out = out['paper'][:batch_size]\n            pred = out.argmax(dim=-1)\n\n            total_examples += batch_size\n            total_correct += int((pred == batch['paper'].y[:batch_size]).sum())\n    else:\n        for batch in loader:\n            batch = batch.to(device)\n            if 'adj_t' in batch:\n                edge_index = batch.adj_t\n            else:\n                edge_index = batch.edge_index\n            out = model(batch.x, edge_index)\n            batch_size = batch.batch_size\n            out = out[:batch_size]\n            pred = out.argmax(dim=-1)\n\n            total_examples += batch_size\n            total_correct += int((pred == batch.y[:batch_size]).sum())\n    return total_correct / total_examples\n"
  },
  {
    "path": "codecov.yml",
    "content": "# See: https://docs.codecov.io/docs/codecov-yaml\ncoverage:\n  range: 80..100\n  round: down\n  precision: 2\n  status:\n    project:\n      default:\n        target: 80%\n        threshold: 1%\n    patch:\n      default:\n        target: 80%\n        threshold: 1%\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "FROM nvcr.io/nvidia/cuda-dl-base:24.09-cuda12.6-devel-ubuntu22.04\n\n# Based on NGC PyG 24.09 image:\n# https://docs.nvidia.com/deeplearning/frameworks/pyg-release-notes/rel-24-09.html#rel-24-09\n\n# install pip\nRUN apt-get update && apt-get install -y python3-pip\n\n# install PyTorch - latest stable version\nRUN pip install torch torchvision torchaudio\n\n# install graphviz - latest stable version\nRUN apt-get install -y graphviz graphviz-dev\nRUN pip install pygraphviz\n\n# install python packages with NGC PyG 24.09 image versions\nRUN pip install torch_geometric==2.6.0\nRUN pip install triton==3.0.0 numba==0.59.0 requests==2.32.3 opencv-python==4.7.0.72 scipy==1.14.0 jupyterlab==4.2.5\n\n# install cugraph\nRUN pip install cugraph-cu12 cugraph-pyg-cu12 --extra-index-url=https://pypi.nvidia.com\n"
  },
  {
    "path": "docker/Dockerfile.xpu",
    "content": "ARG BASE_IMAGE=\"intel/intel-extension-for-pytorch\"\nARG BASE_TAG=\"2.1.30-xpu\"\n\nFROM ${BASE_IMAGE}:${BASE_TAG}\n\n# meta information\nLABEL org.opencontainers.image.version = \"2.3.1\"\nLABEL org.opencontainers.image.authors = \"PyG authors\"\nLABEL org.opencontainers.image.source = \"https://github.com/pyg-team/pytorch_geometric\"\nLABEL org.opencontainers.image.licenses = \"MIT\"\nLABEL org.opencontainers.image.base.name=${BASE_IMAGE}:${BASE_TAG}\n\n# Create a working directory\nRUN mkdir /app\nWORKDIR /app\n\n# Add the XPU-related package repository for the LTS releases\nRUN . /etc/os-release && \\\n    wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \\\n    sudo gpg --yes --dearmor --output /usr/share/keyrings/intel-graphics.gpg && \\\n    echo \"deb [arch=amd64 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu ${VERSION_CODENAME}/lts/2350 unified\" | \\\n    sudo tee /etc/apt/sources.list.d/intel-gpu-${VERSION_CODENAME}.list\n\n# Install oneCCL\nRUN sudo apt update && apt install -y intel-oneapi-ccl-devel=2021.12.0-309 python3-dev cmake vim\nRUN echo \"source /opt/intel/oneapi/setvars.sh --force\" >> /root/.bashrc\n\n# Install PyG\nRUN pip install ninja wheel ogb && pip install git+https://github.com/pyg-team/pyg-lib.git && \\\n    pip install torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.5.0+cpu.html && \\\n    pip install torch_geometric\n"
  },
  {
    "path": "docker/README.md",
    "content": "# Docker on NVIDIA GPU\n\nThe recommended way to use Docker for NVIDIA hardware is described [here](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg).\n\nYou can also run PyG with CUDA 12.1 inside a docker image. This method is deprecated and we highly recommend the above mentioned official NVIDIA docker containers instead.\n\nThe creation of [our dockerfile](https://github.com/pyg-team/pytorch_geometric/blob/master/docker/Dockerfile) refers to the dockerfiles provided by [NVIDIA](https://gitlab.com/nvidia/cuda/tree/ubuntu18.04) and [PyTorch](https://github.com/anibali/docker-pytorch).\n\n1. Download the dockerfile to your host server.\n1. `$ docker build -t \"custom image name\"`\n1. `$ docker run --rm -it --init --runtime=nvidia --ipc=host --network=host --volume=$PWD:/app -e NVIDIA_VISIBLE_DEVICES=0 \"custom image name\" /bin/bash`\n\nIf you encounter any problems, please feel free to create a GitHub issue.\n\n# Docker on Intel GPU\n\nYou can also run PyG with Intel GPU inside a docker image.\nThe creation of [our dockerfile](https://github.com/pyg-team/pytorch_geometric/blob/master/docker/Dockerfile.xpu) refers to the dockerfiles provided by [Intel](https://github.com/intel/intel-extension-for-pytorch/blob/xpu-main/docker/Dockerfile.prebuilt) and the installation guidance provided by [Intel® Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu&version=v2.1.30%2bxpu&os=linux%2fwsl2&package=pip).\n\n1. Download the dockerfile to your host server.\n1. `$ docker build -f docker/Dockerfile.xpu -t \"custom image name\"`\n1. `$ docker run --rm -it --ipc=host -v /dev/dri:/dev/dri --volume=$PWD:/app \"custom image name\" /bin/bash`\n\n# Singularity\n\nYou can run PyG inside a singularity image. An example singularity file can be found in this folder.\n\nYou might have to modify the script; depending on your needs, modify the following:\n\n- **cuda version 10.1**: If you need another version, change `From: nvidia/cuda:10.1-cudnn7-devel-ubuntu16.04` to the corresponding tag from <https://hub.docker.com/r/nvidia/cuda>. Same if you want to use anything but Ubuntu 16.04. Your host has to have at least this cuda version!\n- **python version 3.7.2**: If you need another version, change `pyenv install 3.7.2` and the following lines to the corresponding version.\n- **pytorch version 1.3.0**: If you need another version, change `pip install torch==1.3.0`.\n- **pytorch_geometric versions**: This uses specific versions for each of the `pytorch_geometric` requirements (scatter: 1.4.0, sparse: 0.4.3, cluster: 1.4.5, geometric 1.3.2). To change these, change the corresponding `git checkout` lines near the bottom.\n- **cuda compute capability 5.0 and 6.1**: If you run it on multiply systems (likely, with singularity), ensure that all compute capabilities are listed here. If you have the cuda samples installed, check it with (for example) `/usr/local/cuda-10.1/extras/demo_suite/deviceQuery | grep 'CUDA Capability'`; if not, check [here](https://en.wikipedia.org/wiki/CUDA#GPUs_supported).\n\nNote: If your harddisk runs full after multiple builds, this is known and apparently working as intended; delete the `/tmp/sbuild-XXXXXXXXX` files.\n\n## Building and Using the Container\n\nTo build the container, run\n\n`sudo singularity build geometric.sif singularity`\n\nthen wait. Once finished, you can run the GAT example in the folder you built the image in by calling\n\n```\nwget https://raw.githubusercontent.com/pyg-team/pytorch_geometric/master/examples/gat.py\n```\n\n(to download the sample),\n\nthen\n\n```\nsingularity exec geometric.sif python3 gat.py\n```\n\nto run on the CPU, or\n\n```\nsingularity exec --nv geometric.sif python3 gat.py\n```\n\nto run on the GPU.\n"
  },
  {
    "path": "docker/singularity",
    "content": "Bootstrap: docker\nFrom: nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04\n\n%post\n  CURDIR=$(pwd)\n\n  # Set timezone to Etc/UTC for tzdata. See issue #4365 for more details.\n  TZ=Etc/UTC && \\\n    ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && \\\n    echo $TZ > /etc/timezone\n\n  apt-get update -y\n  apt-get install -y tmux nano git wget\n  apt-get install -y --no-install-recommends \\\n    build-essential \\\n    gfortran \\\n    libssl-dev \\\n    zlib1g-dev \\\n    libbz2-dev \\\n    libreadline-dev \\\n    libsqlite3-dev \\\n    wget \\\n    curl \\\n    llvm \\\n    libncurses5-dev \\\n    xz-utils \\\n    tk-dev \\\n    libxml2-dev \\\n    libxmlsec1-dev \\\n    libffi-dev \\\n    liblzma-dev \\\n    liblapack-dev \\\n    libopenblas-dev \\\n    libhdf5-dev\n\n  export PYENV_ROOT=/opt/pyenv\n  export PATH=\"/opt/pyenv/bin:$PATH\"\n  curl -L https://github.com/pyenv/pyenv-installer/raw/master/bin/pyenv-installer | bash\n  pyenv install 3.7.2\n  echo 'export PATH=/opt/pyenv/versions/3.7.2/bin/:$PATH' >> $SINGULARITY_ENVIRONMENT\n  export PATH=/opt/pyenv/versions/3.7.2/bin/:$PATH\n\n  pip install torch==1.3.0\n\n  mkdir -p $SINGULARITY_ROOTFS/tmp/sing_build_cuda\n  cd $SINGULARITY_ROOTFS/tmp/sing_build_cuda\n\n  export TORCH_CUDA_ARCH_LIST=\"5.0 6.1\"\n\n  git clone https://github.com/rusty1s/pytorch_scatter.git && \\\n    cd ./pytorch_scatter && \\\n    git checkout 1.4.0 && \\\n    python3 -m pip install . && \\\n    cd ..\n\n  git clone https://github.com/rusty1s/pytorch_sparse.git && \\\n    cd ./pytorch_sparse && \\\n    git checkout 0.4.3 && \\\n    python3 -m pip install . && \\\n    cd ..\n\n  git clone https://github.com/rusty1s/pytorch_cluster.git && \\\n    cd ./pytorch_cluster && \\\n    git checkout 1.4.5 && \\\n    python3 -m pip install . && \\\n    cd ..\n\n  git clone https://github.com/pyg-team/pytorch_geometric.git && \\\n    cd ./pytorch_geometric && \\\n    git checkout 1.3.2 && \\\n    python3 -m pip install . && \\\n    cd ..\n\n  cd $CURDIR\n  rm -rf $SINGULARITY_ROOTFS/tmp/sing_build_cuda\n"
  },
  {
    "path": "docs/Makefile",
    "content": "SPHINXBUILD   = sphinx-build\nSPHINXPROJ    = pytorch_geometric\nSOURCEDIR     = source\nBUILDDIR      = build\n\n.PHONY: help Makefile\n\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(0)\n"
  },
  {
    "path": "docs/README.md",
    "content": "# Building Documentation\n\nTo build the documentation:\n\n1. [Build and install](https://github.com/pyg-team/pytorch_geometric/blob/master/.github/CONTRIBUTING.md#developing-pytorch-geometric) PyG from source.\n1. Install [Sphinx](https://www.sphinx-doc.org/en/master/) theme via\n   ```\n   pip install git+https://github.com/pyg-team/pyg_sphinx_theme.git\n   ```\n1. Generate the documentation file via:\n   ```\n   cd docs\n   make html\n   ```\n\nThe documentation is now available to view by opening `docs/build/html/index.html`.\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "https://download.pytorch.org/whl/cpu/torch-2.8.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl\nnumpy>=1.19.5\ngit+https://github.com/pyg-team/pyg_sphinx_theme.git\n"
  },
  {
    "path": "docs/source/.gitignore",
    "content": "generated/\n"
  },
  {
    "path": "docs/source/_figures/.gitignore",
    "content": "*.aux\n*.log\n*.pdf\n"
  },
  {
    "path": "docs/source/_figures/build.sh",
    "content": "#!/bin/sh\n\nfor filename in *.tex; do\n  basename=$(basename $filename .tex)\n  pdflatex \"$basename.tex\"\n  pdf2svg \"$basename.pdf\" \"$basename.svg\"\ndone\n"
  },
  {
    "path": "docs/source/_figures/graph.tex",
    "content": "\\documentclass{standalone}\n\n\\usepackage{tikz}\n\n\\begin{document}\n\n\\begin{tikzpicture}\n  \\node[draw,circle,label= left:{$x_1=-1$}] (0) at (0, 0) {0};\n  \\node[draw,circle,label=above:{$x_1=0$}] (1) at (1, 1) {1};\n  \\node[draw,circle,label=right:{$x_1=1$}] (2) at (2, 0) {2};\n\n  \\path[draw] (0) -- (1);\n  \\path[draw] (1) -- (2);\n\\end{tikzpicture}\n\n\\end{document}\n"
  },
  {
    "path": "docs/source/_figures/hg_example.tex",
    "content": "\\documentclass{standalone}\n\n\\usepackage{tikz}\n\n\\begin{document}\n\n\\begin{tikzpicture}\n  \\node[draw,rectangle, align=center] (0) at (0, 0) {\\textbf{Author}\\\\ $1,134,649$ nodes};\n  \\node[draw,rectangle, align=center] (1) at (4, 2) {\\textbf{Paper}\\\\ $736,389$ nodes};\n  \\node[draw,rectangle, align=center] (2) at (8, 0) {\\textbf{Institution}\\\\ $8,740$ nodes};\n  \\node[draw,rectangle, align=center] (3) at (4, 4) {\\textbf{Field of Study}\\\\ $59,965$ nodes};\n\n  \\path[->,>=stealth] (0) edge [above left] node[align=center] {\\textbf{writes}\\\\$7,145,660$ edges} (1.south);\n  \\path[->,>=stealth] (0) edge [below] node[align=center] {\\textbf{affiliated with}\\\\$1,043,998$ edges} (2);\n  \\path[->,>=stealth,every loop/.style={looseness=3}] (1) edge [out=350, in=10, loop, right] node[align=center] {\\textbf{cites}\\\\$5,416,271$ edges} (1);\n  \\path[->,>=stealth] (1) edge [left] node[align=center] {\\textbf{has topic}\\\\$7,505,078$ edges} (3);\n\\end{tikzpicture}\n\n\\end{document}\n"
  },
  {
    "path": "docs/source/_figures/to_hetero.tex",
    "content": "\\documentclass{standalone}\n\n\\usepackage{tikz}\n\n\\definecolor{green}{RGB}{159,213,179}\n\\definecolor{blue}{RGB}{10,153,201}\n\n\\begin{document}\n\n\\begin{tikzpicture}\n  \\tikzset{rect/.style={draw,rectangle,inner sep=0pt,minimum width=2cm,minimum height=0.6cm,rounded corners=2pt}}\n  \\tikzset{arrow/.style={draw,->,>=stealth}}\n\n  \\def\\offset{1.2}\n\n  \\node[inner sep=0pt] at (0,0.7) {\\strut\\textbf{Homogeneous Model}};\n  \\node[rect] (x) at (0,0) {\\strut\\texttt{x}};\n  \\node[rect,fill=green!20!white] (conv1) at (0,-1*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=blue!20!white] (relu1) at (0,-2*\\offset) {\\texttt{ReLU}};\n  \\node[rect,fill=green!20!white] (conv2) at (0,-3*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=blue!20!white] (relu2) at (0,-4*\\offset) {\\texttt{ReLU}};\n  \\node[rect] (out) at (0,-5*\\offset) {\\strut\\texttt{out}};\n\n  \\draw[arrow] (x) -- (conv1);\n  \\draw[arrow] (conv1) -- (relu1);\n  \\draw[arrow] (relu1) -- (conv2);\n  \\draw[arrow] (conv2) -- (relu2);\n  \\draw[arrow] (relu2) -- (out);\n\n  \\node[inner sep=0pt] at (6,0.7) {\\strut\\textbf{Heterogeneous Model}};\n  \\node[rect] (xpaper) at (3.5,-0) {\\strut\\texttt{x\\_paper}};\n  \\node[rect,fill=green!20!white] (conv1paper) at (3.5,-1*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=blue!20!white]  (relu1paper) at (3.5,-2*\\offset) {\\texttt{ReLU}};\n  \\node[rect,fill=green!20!white] (conv2paper) at (3.5,-3*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=blue!20!white]  (relu2paper) at (3.5,-4*\\offset) {\\texttt{ReLU}};\n  \\node[rect] (outpaper) at (3.5,-5*\\offset) {\\strut\\texttt{out\\_paper}};\n\n  \\node[rect,fill=green!20!white] (conv1middle) at (6,-1*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=green!20!white] (conv2middle) at (6,-3*\\offset) {\\texttt{SAGEConv}};\n\n  \\node[rect] (xauthor) at (8.5,-0) {\\strut\\texttt{x\\_author}};\n  \\node[rect,fill=green!20!white] (conv1author) at (8.5,-1*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=blue!20!white]  (relu1author) at (8.5,-2*\\offset) {\\texttt{ReLU}};\n  \\node[rect,fill=green!20!white] (conv2author) at (8.5,-3*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=blue!20!white]  (relu2author) at (8.5,-4*\\offset) {\\texttt{ReLU}};\n  \\node[rect] (outauthor) at (8.5,-5*\\offset) {\\strut\\texttt{out\\_author}};\n\n  \\draw[arrow] (xpaper) -- (conv1paper);\n  \\draw[arrow,out=270+45,in=90+45] (xpaper.south) to (conv1middle.north);\n  \\draw[arrow] (xauthor) -- (conv1author);\n\n  \\draw[arrow] (conv1paper) -- (relu1paper);\n  \\draw[arrow,out=270+45,in=90+45] (conv1middle.south) to (relu1author.north);\n  \\draw[arrow,out=270-60,in=0] (conv1author.south) to (relu1paper.east);\n\n  \\draw[arrow] (relu1paper) -- (conv2paper);\n  \\draw[arrow,out=270+45,in=90+45] (relu1paper.south) to (conv2middle.north);\n  \\draw[arrow] (relu1author) -- (conv2author);\n\n  \\draw[arrow] (conv2paper) -- (relu2paper);\n  \\draw[arrow,out=270+45,in=90+45] (conv2middle.south) to (relu2author.north);\n  \\draw[arrow,out=270-60,in=0] (conv2author.south) to (relu2paper.east);\n\n  \\draw[arrow] (relu2paper) -- (outpaper);\n  \\draw[arrow] (relu2author) -- (outauthor);\n\n\\end{tikzpicture}\n\n\\end{document}\n"
  },
  {
    "path": "docs/source/_figures/to_hetero_with_bases.tex",
    "content": "\\documentclass{standalone}\n\n\\usepackage{tikz}\n\n\\definecolor{green}{RGB}{159,213,179}\n\\definecolor{blue}{RGB}{10,153,201}\n\n\\begin{document}\n\n\\begin{tikzpicture}\n  \\tikzset{rect/.style={draw,rectangle,inner sep=0pt,minimum width=2cm,minimum height=0.6cm,rounded corners=2pt}}\n  \\tikzset{arrow/.style={draw,->,>=stealth}}\n\n  \\def\\offset{1.2}\n\n  \\node[inner sep=0pt] at (0,0.7) {\\strut\\textbf{Homogeneous Model}};\n  \\node[rect] (x) at (0,0) {\\strut\\texttt{x}};\n  \\node[rect,fill=green!20!white] (conv1) at (0,-1*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=blue!20!white] (relu1) at (0,-2*\\offset) {\\texttt{ReLU}};\n  \\node[rect,fill=green!20!white] (conv2) at (0,-3*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=blue!20!white] (relu2) at (0,-4*\\offset) {\\texttt{ReLU}};\n  \\node[rect] (out) at (0,-5*\\offset) {\\texttt{out}};\n\n  \\draw[arrow] (x) -- (conv1);\n  \\draw[arrow] (conv1) -- (relu1);\n  \\draw[arrow] (relu1) -- (conv2);\n  \\draw[arrow] (conv2) -- (relu2);\n  \\draw[arrow] (relu2) -- (out);\n\n  \\node[inner sep=0pt] at (6,2*\\offset+0.7) {\\strut\\textbf{Heterogeneous Model}};\n  \\node[rect] (xpaper) at (3.5,2*\\offset) {\\strut\\texttt{x\\_paper}};\n  \\node[rect] (xauthor) at (8.5,2*\\offset) {\\strut\\texttt{x\\_author}};\n  \\node[rect] (linpaper) at (3.5,1*\\offset) {\\strut\\texttt{Linear}};\n  \\node[rect] (linauthor) at (8.5,1*\\offset) {\\strut\\texttt{Linear}};\n  \\node[rect] (x) at (6,0*\\offset) {\\strut\\texttt{x}};\n  \\node[rect,fill=green!20!white] (conv11) at (3.5,-1*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=green!20!white] (conv12) at (6,-1*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=green!20!white] (conv13) at (8.5,-1*\\offset) {\\texttt{SAGEConv}};\n  \\node[inner sep=1pt] (aggr1) at (6,-1.5*\\offset) {\\footnotesize$+$};\n  \\node[rect,fill=blue!20!white]  (relu1) at (6,-2*\\offset) {\\texttt{ReLU}};\n  \\node[rect,fill=green!20!white] (conv21) at (3.5,-3*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=green!20!white] (conv22) at (6,-3*\\offset) {\\texttt{SAGEConv}};\n  \\node[rect,fill=green!20!white] (conv23) at (8.5,-3*\\offset) {\\texttt{SAGEConv}};\n  \\node[inner sep=1pt] (aggr2) at (6,-3.5*\\offset) {\\footnotesize$+$};\n  \\node[rect,fill=blue!20!white]  (relu2) at (6,-4*\\offset) {\\texttt{ReLU}};\n  \\node[rect] (outpaper) at (3.5,-5*\\offset) {\\strut\\texttt{out\\_paper}};\n  \\node[rect] (outauthor) at (8.5,-5*\\offset) {\\strut\\texttt{out\\_author}};\n\n  \\draw[arrow] (xpaper) -- (linpaper);\n  \\draw[arrow] (xauthor) -- (linauthor);\n  \\draw[arrow] (linpaper) -- (x);\n  \\draw[arrow] (linauthor) -- (x);\n  \\draw[arrow] (x) -- node[fill=white,inner sep=1pt] {\\footnotesize $\\mathbf{a}_{\\mathcal{R}, 1}$} (conv11) ;\n  \\draw[arrow] (x) -- node[fill=white,inner sep=1pt] {\\footnotesize $\\mathbf{a}_{\\mathcal{R}, 2}$} (conv12) ;\n  \\draw[arrow] (x) -- node[fill=white,inner sep=1pt] {\\footnotesize $\\mathbf{a}_{\\mathcal{R}, 3}$} (conv13) ;\n  \\draw[arrow] (conv11) -- (aggr1);\n  \\draw[arrow] (conv12) -- (aggr1);\n  \\draw[arrow] (conv13) -- (aggr1);\n  \\draw[arrow] (aggr1) -- (relu1);\n  \\draw[arrow] (relu1) -- node[fill=white,inner sep=0pt] {\\footnotesize $\\mathbf{a}_{\\mathcal{R}, 1}$} (conv21) ;\n  \\draw[arrow] (relu1) -- node[fill=white,inner sep=0pt] {\\footnotesize $\\mathbf{a}_{\\mathcal{R}, 2}$} (conv22) ;\n  \\draw[arrow] (relu1) -- node[fill=white,inner sep=0pt] {\\footnotesize $\\mathbf{a}_{\\mathcal{R}, 3}$} (conv23) ;\n  \\draw[arrow] (conv21) -- (aggr2);\n  \\draw[arrow] (conv22) -- (aggr2);\n  \\draw[arrow] (conv23) -- (aggr2);\n  \\draw[arrow] (aggr2) -- (relu2);\n  \\draw[arrow] (relu2) -- (outpaper);\n  \\draw[arrow] (relu2) -- (outauthor);\n\n\\end{tikzpicture}\n\n\\end{document}\n"
  },
  {
    "path": "docs/source/_static/js/version_alert.js",
    "content": "function warnOnLatestVersion() {\n  if (!window.READTHEDOCS_DATA || window.READTHEDOCS_DATA.version !== \"latest\") {\n    return;  // not on ReadTheDocs and not latest.\n  }\n\n  var note = document.createElement('div');\n  note.setAttribute('class', 'admonition note');\n  note.innerHTML = \"<p class='first admonition-title'>Note</p> \" +\n    \"<p> \" +\n    \"This documentation is for an <b>unreleased development version</b>. \" +\n    \"Click <a href='/en/stable'><b>here</b></a> to access the documentation of the current stable release.\" +\n    \"</p>\";\n\n  var parent = document.querySelector('#pyg-documentation');\n  if (parent)\n    parent.insertBefore(note, parent.querySelector('h1'));\n}\n\ndocument.addEventListener('DOMContentLoaded', warnOnLatestVersion);\n"
  },
  {
    "path": "docs/source/_templates/autosummary/class.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. autoclass:: {{ objname }}\n   :show-inheritance:\n   :members:\n"
  },
  {
    "path": "docs/source/_templates/autosummary/inherited_class.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. autoclass:: {{ objname }}\n   :show-inheritance:\n   :members:\n   :inherited-members:\n   :special-members: __cat_dim__, __inc__\n"
  },
  {
    "path": "docs/source/_templates/autosummary/metrics.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. autoclass:: {{ objname }}\n   :show-inheritance:\n   :members: update, compute, reset\n"
  },
  {
    "path": "docs/source/_templates/autosummary/nn.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n{% if objname != \"MessagePassing\" %}\n.. autoclass:: {{ objname }}\n   :show-inheritance:\n   :members:\n   :exclude-members: forward, reset_parameters, message, message_and_aggregate, edge_update, aggregate, update\n\n   .. automethod:: forward\n   .. automethod:: reset_parameters\n{% else %}\n.. autoclass:: {{ objname }}\n   :show-inheritance:\n   :members:\n{% endif %}\n"
  },
  {
    "path": "docs/source/_templates/autosummary/only_class.rst",
    "content": "{{ fullname | escape | underline}}\n\n.. currentmodule:: {{ module }}\n\n.. autoclass:: {{ objname }}\n   :show-inheritance:\n"
  },
  {
    "path": "docs/source/advanced/batching.rst",
    "content": "Advanced Mini-Batching\n======================\n\nThe creation of mini-batching is crucial for letting the training of a deep learning model scale to huge amounts of data.\nInstead of processing examples one-by-one, a mini-batch groups a set of examples into a unified representation where it can efficiently be processed in parallel.\nIn the image or language domain, this procedure is typically achieved by rescaling or padding each example into a set to equally-sized shapes, and examples are then grouped in an additional dimension.\nThe length of this dimension is then equal to the number of examples grouped in a mini-batch and is typically referred to as the :obj:`batch_size`.\n\nSince graphs are one of the most general data structures that can hold *any* number of nodes or edges, the two approaches described above are either not feasible or may result in a lot of unnecessary memory consumption.\nIn :pyg:`PyG`, we opt for another approach to achieve parallelization across a number of examples.\nHere, adjacency matrices are stacked in a diagonal fashion (creating a giant graph that holds multiple isolated subgraphs), and node and target features are simply concatenated in the node dimension, *i.e.*\n\n.. math::\n\n    \\mathbf{A} = \\begin{bmatrix} \\mathbf{A}_1 & & \\\\ & \\ddots & \\\\ & & \\mathbf{A}_n \\end{bmatrix}, \\qquad \\mathbf{X} = \\begin{bmatrix} \\mathbf{X}_1 \\\\ \\vdots \\\\ \\mathbf{X}_n \\end{bmatrix}, \\qquad \\mathbf{Y} = \\begin{bmatrix} \\mathbf{Y}_1 \\\\ \\vdots \\\\ \\mathbf{Y}_n \\end{bmatrix}.\n\nThis procedure has some crucial advantages over other batching procedures:\n\n1. GNN operators that rely on a message passing scheme do not need to be modified since messages still cannot be exchanged between two nodes that belong to different graphs.\n\n2. There is no computational or memory overhead.\n   For example, this batching procedure works completely without any padding of node or edge features.\n   Note that there is no additional memory overhead for adjacency matrices since they are saved in a sparse fashion holding only non-zero entries, *i.e.*, the edges.\n\n:pyg:`PyG` automatically takes care of batching multiple graphs into a single giant graph with the help of the :class:`torch_geometric.loader.DataLoader` class.\nInternally, :class:`~torch_geometric.loader.DataLoader` is just a regular :pytorch:`PyTorch` :class:`torch.utils.data.DataLoader` that overwrites its :func:`collate` functionality, *i.e.*, the definition of how a list of examples should be grouped together.\nTherefore, all arguments that can be passed to a :pytorch:`PyTorch` :class:`~torch.utils.data.DataLoader` can also be passed to a :pyg:`PyG` :class:`~torch_geometric.loader.DataLoader`, *e.g.*, the number of workers :obj:`num_workers`.\n\nIn its most general form, the :pyg:`PyG` :class:`~torch_geometric.loader.DataLoader` will automatically increment the :obj:`edge_index` tensor by the cumulated number of nodes of all graphs that got collated before the currently processed graph, and will concatenate :obj:`edge_index` tensors (that are of shape :obj:`[2, num_edges]`) in the second dimension.\nThe same is true for :obj:`face` tensors, *i.e.*, face indices in meshes.\nAll other tensors will just get concatenated in the first dimension without any further increasement of their values.\n\nHowever, there are a few special use-cases (as outlined below) where the user actively wants to modify this behavior to its own needs.\n:pyg:`PyG` allows modification to the underlying batching procedure by overwriting the :meth:`torch_geometric.data.Data.__inc__` and :meth:`torch_geometric.data.Data.__cat_dim__` functionalities.\nWithout any modifications, these are defined as follows in the :class:`~torch_geometric.data.Data` class:\n\n.. code-block:: python\n\n    def __inc__(self, key, value, *args, **kwargs):\n        if 'index' in key:\n            return self.num_nodes\n        else:\n            return 0\n\n    def __cat_dim__(self, key, value, *args, **kwargs):\n        if 'index' in key:\n            return 1\n        else:\n            return 0\n\nWe can see that :meth:`~torch_geometric.data.Data.__inc__` defines the incremental count between two consecutive graph attributes.\nBy default, :pyg:`PyG` increments attributes by the number of nodes whenever their attribute names contain the substring :obj:`index` (for historical reasons), which comes in handy for attributes such as :obj:`edge_index` or :obj:`node_index`.\nHowever, note that this may lead to unexpected behavior for attributes whose names contain the substring :obj:`index` but should not be incremented.\nTo make sure, it is best practice to always double-check the output of batching.\nFurthermore, :meth:`~torch_geometric.data.Data.__cat_dim__` defines in which dimension graph tensors of the same attribute should be concatenated together.\nBoth functions are called for each attribute stored in the :class:`~torch_geometric.data.Data` class, and get passed their specific :obj:`key` and value :obj:`item` as arguments.\n\nIn what follows, we present a few use-cases where the modification of :meth:`~torch_geometric.data.Data.__inc__` and :meth:`~torch_geometric.data.Data.__cat_dim__` might be absolutely necessary.\n\nPairs of Graphs\n---------------\n\nIn case you want to store multiple graphs in a single :class:`~torch_geometric.data.Data` object, *e.g.*, for applications such as graph matching, you need to ensure correct batching behavior across all those graphs.\nFor example, consider storing two graphs, a source graph :math:`\\mathcal{G}_s` and a target graph :math:`\\mathcal{G}_t` in a :class:`~torch_geometric.data.Data`, *e.g.*:\n\n.. code-block:: python\n\n    from torch_geometric.data import Data\n\n    class PairData(Data):\n        pass\n\n    data = PairData(x_s=x_s, edge_index_s=edge_index_s,  # Source graph.\n                    x_t=x_t, edge_index_t=edge_index_t)  # Target graph.\n\nIn this case, :obj:`edge_index_s` should be increased by the number of nodes in the source graph :math:`\\mathcal{G}_s`, *e.g.*, :obj:`x_s.size(0)`, and :obj:`edge_index_t` should be increased by the number of nodes in the target graph :math:`\\mathcal{G}_t`, *e.g.*, :obj:`x_t.size(0)`:\n\n.. code-block:: python\n\n    class PairData(Data):\n        def __inc__(self, key, value, *args, **kwargs):\n            if key == 'edge_index_s':\n                return self.x_s.size(0)\n            if key == 'edge_index_t':\n                return self.x_t.size(0)\n            return super().__inc__(key, value, *args, **kwargs)\n\nWe can test our :class:`PairData` batching behavior by setting up a simple test script:\n\n.. code-block:: python\n\n   from torch_geometric.loader import DataLoader\n\n    x_s = torch.randn(5, 16)  # 5 nodes.\n    edge_index_s = torch.tensor([\n        [0, 0, 0, 0],\n        [1, 2, 3, 4],\n    ])\n\n    x_t = torch.randn(4, 16)  # 4 nodes.\n    edge_index_t = torch.tensor([\n        [0, 0, 0],\n        [1, 2, 3],\n    ])\n\n    data = PairData(x_s=x_s, edge_index_s=edge_index_s,\n                    x_t=x_t, edge_index_t=edge_index_t)\n\n    data_list = [data, data]\n    loader = DataLoader(data_list, batch_size=2)\n    batch = next(iter(loader))\n\n    print(batch)\n    >>> PairDataBatch(x_s=[10, 16], edge_index_s=[2, 8],\n                      x_t=[8, 16], edge_index_t=[2, 6])\n\n    print(batch.edge_index_s)\n    >>> tensor([[0, 0, 0, 0, 5, 5, 5, 5],\n                [1, 2, 3, 4, 6, 7, 8, 9]])\n\n    print(batch.edge_index_t)\n    >>> tensor([[0, 0, 0, 4, 4, 4],\n                [1, 2, 3, 5, 6, 7]])\n\nEverything looks good so far!\n:obj:`edge_index_s` and :obj:`edge_index_t` get correctly batched together, even when using a different numbers of nodes for :math:`\\mathcal{G}_s` and :math:`\\mathcal{G}_t`.\nHowever, the :obj:`batch` attribute (that maps each node to its respective graph) is missing since :pyg:`PyG` fails to identify the actual graph in the :class:`PairData` object.\nThat is where the :obj:`follow_batch` argument of the :class:`~torch_geometric.loader.DataLoader` comes into play.\nHere, we can specify for which attributes we want to maintain the batch information:\n\n.. code-block:: python\n\n    loader = DataLoader(data_list, batch_size=2, follow_batch=['x_s', 'x_t'])\n    batch = next(iter(loader))\n\n    print(batch)\n    >>> PairDataBatch(x_s=[10, 16], edge_index_s=[2, 8], x_s_batch=[10],\n                      x_t=[8, 16], edge_index_t=[2, 6], x_t_batch=[8])\n\n    print(batch.x_s_batch)\n    >>> tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])\n\n    print(batch.x_t_batch)\n    >>> tensor([0, 0, 0, 0, 1, 1, 1, 1])\n\nAs one can see, :obj:`follow_batch=['x_s', 'x_t']` now successfully creates assignment vectors :obj:`x_s_batch` and :obj:`x_t_batch` for the node features :obj:`x_s` and :obj:`x_t`, respectively.\nThat information can now be used to perform reduce operations, *e.g.*, global pooling, on multiple graphs in a single :class:`Batch` object.\n\nBipartite Graphs\n----------------\n\nThe adjacency matrix of a bipartite graph defines the relationship between nodes of two different node types.\nIn general, the number of nodes for each node type do not need to match, resulting in a non-quadratic adjacency matrix of shape :math:`\\mathbf{A} \\in \\{ 0, 1 \\}^{N \\times M}` with :math:`N \\neq M` potentially.\nIn a mini-batching procedure of bipartite graphs, the source nodes of edges in :obj:`edge_index` should get increased differently than the target nodes of edges in :obj:`edge_index`.\nTo achieve this, consider a bipartite graph between two node types with corresponding node features :obj:`x_s` and :obj:`x_t`, respectively:\n\n.. code-block:: python\n\n    from torch_geometric.data import Data\n\n    class BipartiteData(Data):\n        pass\n\n    data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index)\n\nFor a correct mini-batching procedure in bipartite graphs, we need to tell :pyg:`PyG` that it should increment source and target nodes of edges in :obj:`edge_index` independently:\n\n.. code-block:: python\n\n    class BipartiteData(Data):\n        def __inc__(self, key, value, *args, **kwargs):\n            if key == 'edge_index':\n                return torch.tensor([[self.x_s.size(0)], [self.x_t.size(0)]])\n            return super().__inc__(key, value, *args, **kwargs)\n\nHere, :obj:`edge_index[0]` (the source nodes of edges) get incremented by :obj:`x_s.size(0)` while :obj:`edge_index[1]` (the target nodes of edges) get incremented by :obj:`x_t.size(0)`.\nWe can again test our implementation by running a simple test script:\n\n.. code-block:: python\n\n    from torch_geometric.loader import DataLoader\n\n    x_s = torch.randn(2, 16)  # 2 nodes.\n    x_t = torch.randn(3, 16)  # 3 nodes.\n    edge_index = torch.tensor([\n        [0, 0, 1, 1],\n        [0, 1, 1, 2],\n    ])\n\n    data = BipartiteData(x_s=x_s, x_t=x_t, edge_index=edge_index)\n\n    data_list = [data, data]\n    loader = DataLoader(data_list, batch_size=2)\n    batch = next(iter(loader))\n\n    print(batch)\n    >>> BipartiteDataBatch(x_s=[4, 16], x_t=[6, 16], edge_index=[2, 8])\n\n    print(batch.edge_index)\n    >>> tensor([[0, 0, 1, 1, 2, 2, 3, 3],\n                [0, 1, 1, 2, 3, 4, 4, 5]])\n\nAgain, this is exactly the behavior we aimed for!\n\nBatching Along New Dimensions\n-----------------------------\n\nSometimes, attributes of :obj:`data` objects should be batched by gaining a new batch dimension (as in classical mini-batching), *e.g.*, for graph-level properties or targets.\nSpecifically, a list of attributes of shape :obj:`[num_features]` should be returned as :obj:`[num_examples, num_features]` rather than :obj:`[num_examples * num_features]`.\n:pyg:`PyG` achieves this by returning a concatenation dimension of :obj:`None` in :meth:`~torch_geometric.data.Data.__cat_dim__`:\n\n.. code-block:: python\n\n    from torch_geometric.data import Data\n    from torch_geometric.loader import DataLoader\n\n    class MyData(Data):\n        def __cat_dim__(self, key, value, *args, **kwargs):\n            if key == 'foo':\n                return None\n            return super().__cat_dim__(key, value, *args, **kwargs)\n\n    edge_index = torch.tensor([\n       [0, 1, 1, 2],\n       [1, 0, 2, 1],\n    ])\n    foo = torch.randn(16)\n\n    data = MyData(num_nodes=3, edge_index=edge_index, foo=foo)\n\n    data_list = [data, data]\n    loader = DataLoader(data_list, batch_size=2)\n    batch = next(iter(loader))\n\n    print(batch)\n    >>> MyDataBatch(num_nodes=6, edge_index=[2, 8], foo=[2, 16])\n\nAs desired, :obj:`batch.foo` is now described by two dimensions: The batch dimension and the feature dimension.\n"
  },
  {
    "path": "docs/source/advanced/compile.rst",
    "content": "Compiled Graph Neural Networks\n==============================\n\n:meth:`torch.compile` is the latest method to speed up your :pytorch:`PyTorch` code in :obj:`torch >= 2.0.0`!\n:meth:`torch.compile` makes PyTorch code run faster by JIT-compiling it into optimized kernels, all while required minimal code changes.\n\nUnder the hood, :meth:`torch.compile` captures :pytorch:`PyTorch` programs via :obj:`TorchDynamo`, canonicalizes over 2,000 :pytorch:`PyTorch` operators via :obj:`PrimTorch`, and finally generates fast code out of it across multiple accelerators and backends via the deep learning compiler :obj:`TorchInductor`.\n\n.. note::\n    See `here <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__ for a general tutorial on how to leverage :meth:`torch.compile`, and `here <https://pytorch.org/docs/stable/generated/torch.compile.html>`__ for a description of its interface.\n\nIn this tutorial, we show how to optimize your custom :pyg:`PyG` model via :meth:`torch.compile`.\n\n.. note::\n    From :pyg:`PyG` 2.5 (and onwards), :meth:`torch.compile` is now fully compatible with all :pyg:`PyG` GNN layers.\n    If you are on an earlier version of :pyg:`PyG`, consider using :meth:`torch_geometric.compile` instead.\n\nBasic Usage\n-----------\n\nOnce you have a :pyg:`PyG` model defined, simply wrap it with :meth:`torch.compile` to obtain its optimized version:\n\n.. code-block:: python\n\n    import torch\n    from torch_geometric.nn import GraphSAGE\n\n    model = GraphSAGE(in_channels, hidden_channels, num_layers, out_channels)\n    model = model.to(device)\n\n    model = torch.compile(model)\n\nand execute it as usual:\n\n.. code-block:: python\n\n    from torch_geometric.datasets import Planetoid\n\n    dataset = Planetoid(root, name=\"Cora\")\n    data = dataset[0].to(device)\n\n    out = model(data.x, data.edge_index)\n\nMaximizing Performance\n----------------------\n\nThe :meth:`torch.compile` method provides two important arguments to be aware of:\n\n* Most of the mini-batches observed in :pyg:`PyG` are dynamic by nature, meaning that their shape varies across different mini-batches.\n  For these scenarios, we can enforce dynamic shape tracing in :pytorch:`PyTorch` via the :obj:`dynamic=True` argument:\n\n  .. code-block:: python\n\n      torch.compile(model, dynamic=True)\n\n  With this, :pytorch:`PyTorch` will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change across mini-batches.\n  Note that when :obj:`dynamic` is set to :obj:`False`, :pytorch:`PyTorch` will *never* generate dynamic kernels, and thus only work when graph sizes are guaranteed to never change (*e.g.*, in full-batch training on small graphs).\n  By default, :obj:`dynamic` is set to :obj:`None` in :pytorch:`PyTorch` :obj:`>= 2.1.0`, and :pytorch:`PyTorch` will automatically detect if dynamism has occurred.\n  Note that support for dynamic shape tracing requires :pytorch:`PyTorch` :obj:`>= 2.1.0` to be installed.\n\n* In order to maximize speedup, graph breaks in the compiled model should be limited.\n  We can force compilation to raise an error upon the first graph break encountered by using the :obj:`fullgraph=True` argument:\n\n  .. code-block:: python\n\n      torch.compile(model, fullgraph=True)\n\n  It is generally a good practice to confirm that your written model does not contain any graph breaks.\n  Importantly, there exist a few operations in :pyg:`PyG` that will currently lead to graph breaks (but workarounds exist), *e.g.*:\n\n  1. :meth:`~torch_geometric.nn.pool.global_mean_pool` (and other pooling operators) perform device synchronization in case the batch size :obj:`size` is not passed, leading to a graph break.\n\n  2. :meth:`~torch_geometric.utils.remove_self_loops` and :meth:`~torch_geometric.utils.add_remaining_self_loops` mask the given :obj:`edge_index`, leading to a device synchronization to compute its final output shape.\n     As such, we recommend augmenting your graph *before* inputting it into your GNN, *e.g.*, via the :class:`~torch_geometric.transforms.AddSelfLoops` or :class:`~torch_geometric.transforms.GCNNorm` transformations, and setting :obj:`add_self_loops=False`/:obj:`normalize=False` when initializing layers such as :class:`~torch_geometric.nn.conv.GCNConv`.\n\nExample Scripts\n---------------\n\nWe have incorporated multiple examples in :obj:`examples/compile` that further show the practical usage of :meth:`torch.compile`:\n\n#. `Node Classification <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/compile/gcn.py>`__ via :class:`~torch_geometric.nn.models.GCN` (:obj:`dynamic=False`)\n#. `Graph Classification <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/compile/gin.py>`__ via :class:`~torch_geometric.nn.models.GIN` (:obj:`dynamic=True`)\n\nIf you notice that :meth:`torch.compile` fails for a certain :pyg:`PyG` model, do not hesitate to reach out either on :github:`null` `GitHub <https://github.com/pyg-team/pytorch_geometric/issues>`_ or :slack:`null` `Slack <https://data.pyg.org/slack.html>`_.\nWe are very eager to improve :meth:`torch.compile` support across the whole :pyg:`PyG` code base.\n\nBenchmark\n---------\n\n:meth:`torch.compile` works **fantastically well** for many :pyg:`PyG` models.\n**Overall, we observe runtime improvements of up to 300%.**\n\nSpecifically, we benchmark :class:`~torch_geometric.nn.models.GCN`, :class:`~torch_geometric.nn.models.GraphSAGE` and :class:`~torch_geometric.nn.models.GIN` and compare runtimes obtained from traditional eager mode and :meth:`torch.compile`.\nWe use a synthetic graph with 10,000 nodes and 200,000 edges, and a hidden feature dimensionality of 64.\nWe report runtimes over 500 optimization steps:\n\n.. list-table::\n   :widths: 15 15 15 15 15 15\n   :header-rows: 1\n\n   * - Model\n     - Mode\n     - Forward\n     - Backward\n     - Total\n     - Speedup\n   * - :class:`~torch_geometric.nn.models.GCN`\n     - Eager\n     - 2.6396s\n     - 2.1697s\n     - 4.8093s\n     -\n   * - :class:`~torch_geometric.nn.models.GCN`\n     - **Compiled**\n     - **1.1082s**\n     - **0.5896s**\n     - **1.6978s**\n     - **2.83x**\n   * - :class:`~torch_geometric.nn.models.GraphSAGE`\n     - Eager\n     - 1.6023s\n     - 1.6428s\n     - 3.2451s\n     -\n   * - :class:`~torch_geometric.nn.models.GraphSAGE`\n     - **Compiled**\n     - **0.7033s**\n     - **0.7465s**\n     - **1.4498s**\n     - **2.24x**\n   * - :class:`~torch_geometric.nn.models.GIN`\n     - Eager\n     - 1.6701s\n     - 1.6990s\n     - 3.3690s\n     -\n   * - :class:`~torch_geometric.nn.models.GIN`\n     - **Compiled**\n     - **0.7320s**\n     - **0.7407s**\n     - **1.4727s**\n     - **2.29x**\n\nTo reproduce these results, run\n\n.. code-block:: console\n\n    python test/nn/models/test_basic_gnn.py\n\nfrom the root folder of your checked out :pyg:`PyG` repository from :github:`GitHub`.\n"
  },
  {
    "path": "docs/source/advanced/cpu_affinity.rst",
    "content": "CPU Affinity for PyG Workloads\n==============================\n\nThe performance of :pyg:`PyG` workloads using CPU can be significantly improved by setting a proper affinity mask.\nProcessor affinity, or core binding, is a modification of the native OS queue scheduling algorithm that enables an application to assign a specific set of cores to processes or threads launched during its execution on the CPU.\nIn consequence, it increases the overall effective hardware utilisation by minimizing core stalls and memory bounds.\nIt also secures CPU resources to critical processes or threads, even if the system is under heavy load.\n\nCPU affinity targets the two main performance-critical regions:\n\n* **Execution bind:** Indicates a core where process/thread will run.\n* **Memory bind:** Indicates a preferred memory area where memory pages will be bound (local areas in NUMA machine).\n\nThe following article discusses readily available tools and environment settings that one can use to maximize the performance of Intel CPUs with :pyg:`PyG`.\n\n.. note::\n    Overall, CPU affinity can be a useful tool for improving the performance and predictability of certain types of applications, but one configuration does not necessarily fit all cases: it is important to carefully consider whether CPU affinity is appropriate for your use case, and to test and measure the impact of any changes you make.\n\nUsing CPU affinity\n------------------\n\nEach :pyg:`PyG` workload can be parallelized using the :pytorch:`PyTorch` iterator class :class:`MultiProcessingDataLoaderIter`, which is automatically enabled in case :obj:`num_workers > 0` is passed to a :class:`torch.utils.data.DataLoader`.\nUnder the hood, it creates :obj:`num_workers` many sub-processes that will run in parallel to the main process.\nSetting a CPU affinity mask for the data loading processes places :class:`~torch.utils.data.DataLoader` worker threads on specific CPU cores.\nIn effect, it allows for more efficient data batch preparation by allocating pre-fetched batches in local memory.\nEvery time a process or thread moves from one core to another, registers and caches need to be flushed and reloaded.\nThis can become very costly if it happens often, and threads may also no longer be close to their data, or be able to share data in a cache.\n\nSince :pyg:`PyG` (2.3 and beyond), :class:`~torch_geometric.loader.NodeLoader` and :class:`~torch_geometric.loader.LinkLoader` classes officially support a native solution for CPU affinity using the :class:`torch_geometric.loader.AffinityMixin` context manager.\nCPU affinity can be enabled via the :meth:`~torch_geometric.loader.AffinityMixin.enable_cpu_affinity` method for :obj:`num_workers > 0` use-cases,\nand will guarantee that a separate core is assigned to each worker at initialization.\nA user-defined list of core IDs may be assigned using the :attr:`loader_cores` argument.\nOtherwise, cores will be assigned automatically, starting at core ID 0.\nAs of now, only a single core can be assigned to a worker, hence multi-threading is disabled in workers' processes by default.\nThe recommended number of workers to start with lies between :obj:`[2, 4]`, and the optimum may vary based on workload characteristics:\n\n.. code-block:: python\n\n    loader = NeigborLoader(\n        data,\n        num_workers=3,\n        ...,\n    )\n\n    with loader.enable_cpu_affinity(loader_cores=[0, 1, 2]):\n        for batch in loader:\n            pass\n\nIt is generally advisable to use :obj:`filter_per_worker=True` for any multi-process CPU workloads (:obj:`True` by default).\nThe workers then prepare each mini-batch: first by sampling the node indices using pre-defined a sampler, and secondly filtering node and edge features according to sampled nodes and edges.\nThe filtering function selects node feature vectors from the complete input :class:`~torch_geometric.data.Data` tensor loaded into DRAM.\nWhen :attr:`filter_per_worker` is set to :attr:`True`, each worker's subprocess performs the filtering within it's CPU resource.\nHence, main process resources are relieved and can be secured only for GNN computation.\n\nBinding processes to physical cores\n-----------------------------------\n\nFollowing general performance tuning principles, it is advisable to use only physical cores for deep learning workloads.\nFor example, while two logical threads run :obj:`GEMM` at the same time, they will be sharing the same core resources causing front end bound, such that the overhead from this front end bound is greater than the gain from running both logical threads at the same time.\nThis is because OpenMP threads will contend for the same :obj:`GEMM` execution units, see `here <https://pytorch.org/tutorials/intermediate/torchserve_with_ipex.html>`__.\n\nThe binding can be done in many ways, however the most common tools are:\n\n* :obj:`numactl` (only on Linux):\n\n  .. code-block:: console\n\n     --physcpubind=<cpus>, -C <cpus>  or --cpunodebind=<nodes>, -N <nodes>\n\n* `Intel OMP <https://www.intel.com/content/www/us/en/developer/articles/technical/how-to-get-better-performance-on-pytorchcaffe2-with-intel-acceleration.html>`__ :obj:`libiomp`:\n\n  .. code-block:: console\n\n     export KMP_AFFINITY=granularity=fine,proclist=[0-<physical_cores_num-1>],explicit\n\n* GNU :obj:`libgomp`:\n\n  .. code-block:: console\n\n     export GOMP_CPU_AFFINITY=\"0-<physical_cores_num-1>\"\n\nIsolating the :class:`~torch.utils.data.DataLoader` process\n-----------------------------------------------------------\n\nFor best performance, it is required combine main process affinity using the tools listed above, with the multi-process :class:`~torch.utils.data.DataLoader` affinity settings.\nIn each parallelized :pyg:`PyG` workload execution, the main process performs message passing updates over GNN layers, while the :class:`~torch.utils.data.DataLoader` workers sub-processes take care of fetching and pre-processing data to be passed to a GNN model.\nIt is advisable to isolate the CPU resources made available to these two processes to achieve the best results.\nTo do this, CPUs assigned to each affinity mask should be mutually exclusive.\nFor example, if four :class:`~torch.utils.data.DataLoader` workers are assigned to CPUs :obj:`[0, 1, 2, 3]`, the main process should use the rest of available cores, *i.e.* by calling:\n\n.. code-block:: console\n\n   numactl -C 4-(N-1) --localalloc python …\n\nwhere :obj:`N` is the total number of physical cores, with the last CPU having core ID :obj:`N-1`.\nAdding :obj:`--localalloc` improves local memory allocation and keeps the cache closer to active cores.\n\nDual socket CPU separation\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nWith dual-socket CPUs, it might be beneficial to further isolate the processes between the sockets.\nThis leads to decreased frequency of remote memory calls for the main process.\nThe goal is to `utilize high-speed cache on local memory and reduces memory bound caused by migrating cached data between NUMA nodes <https://pytorch.org/tutorials/intermediate/torchserve_with_ipex.html>`__.\nThis can be achieved by using :class:`~torch.utils.data.DataLoader` affinity, and launching main process on the cores of the second socket, *i.e.* with:\n\n.. code-block:: console\n\n   numactl -C M-(N-1) -m 1 python …\n\nwhere :obj:`M` is the :obj:`cpuid` of the first core of the second CPU socket.\nAdding a complementary memory-allocation flag :obj:`-m 1` prioritizes cache allocation on the same NUMA node, where the main process is running (alternatively for less strict memory allocation use :obj:`--preferred 1`).\nThis makes the data readily available on the same socket where the computation takes place.\nUsing this setting is very workload-specific and may require some fine-tuning, as one needs to manage a trade-off between using more OMP threads vs. limiting the number of remote memory calls.\n\nImproving memory bounds\n-----------------------\n\nFollowing the CPU performance optimization guidelines for :pytorch:`PyTorch`, it is also advised for :pyg:`PyG` to use :obj:`jemalloc` or :obj:`TCMalloc`.\nThese generally can reach better memory usage than the default :pytorch:`PyTorch` `memory allocator <https://pytorch.org/tutorials/intermediate/torchserve_with_ipex_2.html>`__ :obj:`PTMalloc`.\nA `non-default memory allocator <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`__ can be specified using :obj:`LD_PRELOAD` prior to script execution.\n\nQuick start guidelines\n----------------------\n\nThe general guidelines for achieving the best performance with CPU affinity can be summarized in the following steps:\n\n#. Test if your dataset benefits from using parallel data loaders.\n   For some datasets, it might be more beneficial to use a plain serial data loader, especially when the dimensions of the input :class:`~torch_geometric.data.Data` are relatively small.\n#. Enable multi-process data loaders by setting :attr:`num_workers > 0`.\n   A good estimate for :obj:`num_workers` lies in the range :obj:`[2, 4]`.\n   However, for more complex datasets you might want to experiment with larger number of workers.\n   Use the :meth:`~torch_geometric.loader.AffinityMixin.enable_cpu_affinity` feature to affinitize :class:`~torch.utils.data.DataLoader` cores.\n#. Bind execution to physical cores.\n   Alternatively, hyperthreading can be disabled completely at a system-level.\n#. Separate the cores used for main process from the data loader workers' cores by using :obj:`numactl`, :obj:`KMP_AFFINITY` of the :obj:`libiomp5` library, or :obj:`GOMP_CPU_AFFINITY` of the :obj:`libgomp` library.\n#. Find the optimum number of OMP threads for your workload.\n   A good starting point is :obj:`N - num_workers`.\n   Generally, well-parallelized models will benefit from many OMP threads.\n   However, if your model computation flow has interlaced parallel and serial regions, the performance will decrease due to resource allocation needed for spawning and maintaining threads between parallel regions.\n#. When using a dual-socket CPU, you might want to experiment with assigning data loading to one socket and main process to another socket with memory allocation (:obj:`numactl -m`) on the same socket where the main process is executed.\n   This leads to best cache-allocation and often overweighs the benefit of using more OMP threads.\n#. An additional boost in performance can be obtained by using non-default memory allocator, such as :obj:`jemalloc` or :obj:`TCMalloc`.\n#. Finding an optimal setup for the CPU affinity mask is a problem of managing the proportion of CPU time spent in each iteration for loading and preparing the data vs. time spent during GNN execution.\n   Different results may be obtained by changing model hyperparameters, such as the batch size, number of sampled neighbors, and the number of layers.\n   As a general rule, workloads which require sampling a complex graph may benefit more from reserving some CPU resources just for the data preparation step.\n\nExample results\n---------------\n\nThe figure below presents the outcome of applying CPU affinity mask to :obj:`benchmark/training/training_benchmark.py`.\nMeasurements were taken for a variable number of workers, while other hyperparameters for each benchmark were constant: :obj:`--warmup 0 --use-sparse-tensor --num-layers 3 --num-hidden-channels 128 --batch-sizes 2048`.\nThree different affinity configurations are presented:\n\n* **Baseline** - only :obj:`OMP_NUM_THREADS` changes:\n\n.. code-block:: console\n\n   OMP_NUM_THREADS=(N-num_workers) python training_benchmark.py --num-workers …\n\n* **Aff** - data loader process on first socket, main process on first and second socket, 98-110 threads:\n\n.. code-block:: console\n\n   LD_PRELOAD=(path)/libjemalloc.so (path)/libiomp5.so MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto OMP_NUM_THREADS=(N-num_workers) KMP_AFFINITY=granularity=fine,compact,1,0 KMP_BLOCKTIME=0 numactl -C <num_workers-(N-1)> --localalloc python training_benchmark.py --cpu-affinity --num-workers …\n\n\n* **Aff+SocketSep** - data loader process on first socket, main process on second socket, 60 threads:\n\n.. code-block:: console\n\n   LD_PRELOAD=(path)/libjemalloc.so (path)/libiomp5.so MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto OMP_NUM_THREADS=(N-M) KMP_AFFINITY=granularity=fine,compact,1,0 KMP_BLOCKTIME=0 numactl -C <M-(N-1)> -m 1 python training_benchmark.py --cpu-affinity --num-workers ...\n\nTraining times for each model/dataset combination were obtained by taking a mean of results at a variable number of dataloader workers: :obj:`[0, 2, 4, 8, 16]` for the baseline and :obj:`[2, 4, 8, 16]` workers for each affinity configuration.\nThen, the affinity means were normalized with respect to the mean baseline measurement.\nThis value is denoted on the :math:`y`-axis.\nThe labels above each result indicate the end-to-end performance gain from using the discussed configuration.\nOver all model/dataset samples, the average training time is decreased by **1.53x** for plain affinity and **1.85x** for the affinity with socket separation.\n\n.. figure:: ../_figures/training_affinity.png\n    :width: 100%\n\n    Pre-production dual-socket Intel(R) Xeon(R) Platinum 8481C @ 2.0Ghz (2 x 56) cores CPU.\n"
  },
  {
    "path": "docs/source/advanced/graphgym.rst",
    "content": "Managing Experiments with GraphGym\n==================================\n\nGraphGym is a platform for **designing and evaluating Graph Neural Networks (GNNs)**, as originally proposed in the `\"Design Space for Graph Neural Networks\" <https://arxiv.org/abs/2011.08843>`__ paper.\nWe now officially support GraphGym as part of of :pyg:`PyG`.\n\n.. warning::\n\n    GraphGym API may change in the future as we are continuously working on better and deeper integration with :pyg:`PyG`.\n\nHighlights\n----------\n\n#. **Highly modularized pipeline for GNN:**\n\n   - **Data:** Data loading and data splitting\n   - **Model:** Modularized GNN implementations\n   - **Tasks:** Node-level, edge-level and graph-level tasks\n   - **Evaluation:** Accuracy, ROC AUC, ...\n\n#. **Reproducible experiment configuration:**\n\n   - Each experiment is *fully described by a configuration file*\n\n#. **Scalable experiment management:**\n\n   - Easily launch *thousands of GNN experiments in parallel*\n   - *Auto-generate* experiment analyses and figures across random seeds and experiments\n\n#. **Flexible user customization:**\n\n   - Easily *register your own modules*, such as data loaders, GNN layers, loss functions, etc\n\nWhy GraphGym?\n-------------\n\n**TL;DR:** GraphGym is great for GNN beginners, domain experts and GNN researchers.\n\n**Scenario 1:** You are a beginner to graph representation learning and want to understand how GNNs work:\n\nYou probably have read many exciting papers on GNNs, and try to write your own GNN implementation.\nEven if using raw :pyg:`PyG`, you still have to code up the essential pipeline on your own.\nGraphGym is a perfect place for your to start learning about *standardized GNN implementation and evaluation*.\n\n.. figure:: ../_figures/graphgym_design_space.png\n  :align: center\n  :width: 450px\n\n  **Figure 1:** Modularized GNN implementation.\n\n**Scenario 2:** You want to apply GNNs to your exciting application:\n\nYou probably know that there are hundreds of possible GNN models, and selecting the best model is notoriously hard.\nEven worse, the `GraphGym paper <https://arxiv.org/abs/2011.08843>`__ shows that the best GNN designs for different tasks differ drastically.\nGraphGym provides a *simple interface to try out thousands of GNNs in parallel* and understand the best designs for your specific task.\nGraphGym also recommends a \"go-to\" GNN design space, after investigating 10 million GNN model-task combinations.\n\n.. figure:: ../_figures/graphgym_results.png\n  :align: center\n  :width: 100%\n\n  **Figure 2:** A guideline for desirable GNN design choices.\n\n**Scenario 3:** You are a GNN researcher, who wants to innovate new GNN models or propose new GNN tasks:\n\nSay you have proposed a new GNN layer :class:`ExampleConv`.\nGraphGym can help you convincingly argue that :class:`ExampleConv` is better than, *e.g.*, :class:`~torch_geometric.nn.conv.GCNConv`:\nWhen randomly sampling from 10 million possible model-task combinations, how often will :class:`ExampleConv` will outperform :class:`~torch_geometric.nn.conv.GCNConv` when everything else is fixed (including computational costs)?\nMoreover, GraphGym can help you easily do hyper-parameter search, and *visualize* what design choices are better.\nIn sum, GraphGym can greatly facilitate your GNN research.\n\n.. figure:: ../_figures/graphgym_evaluation.png\n  :align: center\n  :width: 100%\n\n  **Figure 3:** Evaluation of a given GNN design dimension, *e.g.*, :obj:`BatchNorm`.\n\nBasic Usage\n-----------\n\n.. note::\n   For using GraphGym, :pyg:`PyG` requires additional dependencies.\n   You can install those by running :obj:`pip install torch-geometric[graphgym]`.\n\nTo use GraphGym, you need to clone :pyg:`PyG` from :github:`GitHub`, then change to the :obj:`graphgym/` directory.\n\n.. code-block:: bash\n\n    git clone https://github.com/pyg-team/pytorch_geometric.git\n    cd pytorch_geometric/graphgym\n\n#. **Run a single experiment:**\n   Run an experiment using GraphGym via :obj:`run_single.sh`.\n   Configurations are specified in :obj:`configs/pyg/example_node.yaml`.\n   The default experiment is about node classification on the :class:`~torch_geometric.datasets.Planetoid` datasets (using a random 80/20 train/validation split).\n\n   .. code-block:: bash\n\n       bash run_single.sh # run a single experiment\n\n#. **Run a batch of experiments:**\n   Run a batch of experiments using GraphGym via :obj:`run_batch.sh`.\n   Configurations are specified in :obj:`configs/pyg/example_node.yaml` (controls the basic architecture) and :obj:`grids/example.txt` (controls how to do grid search).\n   The experiment examines 96 models in the recommended GNN design space, on 2 graph classification datasets.\n   Each experiment is repeated 3 times, and we set up that 8 jobs can be concurrently run.\n   Depending on your infrastructure, finishing all the experiments may take a long time;\n   you can quit the experiment via :obj:`Ctrl-C` (GraphGym will properly kill all the processes).\n\n   .. code-block:: bash\n\n       bash run_batch.sh # run a batch of experiments\n\n#. **Run GraphGym with CPU backend:**\n   GraphGym supports CPU backend as well -- you only need to add the line :obj:`accelerator: cpu` to the :obj:`*.yaml` file.\n\nIn-Depth Usage\n--------------\n\nTo use GraphGym, you need to clone :pyg:`PyG` from :github:`GitHub`, then change to the :obj:`graphgym/` directory.\n\n.. code-block:: bash\n\n    git clone https://github.com/pyg-team/pytorch_geometric.git\n    cd pytorch_geometric/graphgym\n\n#. **Run a single experiment:**\n   A full example is specified in :obj:`run_single.sh`.\n\n   #. **Specify a configuration file:**\n      In GraphGym, an experiment is fully specified by a :obj:`*.yaml` file.\n      Unspecified configurations in the :obj:`*.yaml` file will be populated by the default values in :meth:`torch_geometric.graphgym.set_cfg`.\n      For example, in :obj:`configs/pyg/example_node.yaml`, there are configurations for the dataset, training procedure, model, etc.\n      Concrete description for each configuration is described in :meth:`~torch_geometric.graphgym.set_cfg`.\n\n   #. **Launch an experiment:**\n      For example, in :obj:`run_single.sh`:\n\n      .. code-block:: bash\n\n          python main.py --cfg configs/pyg/example_node.yaml --repeat 3\n\n      You can specify the number of different random seeds to repeat via :obj:`--repeat`.\n\n   #. **Understand the results:**\n      Experimental results will be automatically saved in :obj:`results/${CONFIG_NAME}/`.\n      In the example above, this amounts to :obj:`results/pyg/example_node/`.\n      Results for different random seeds will be saved in different subdirectories, *e.g.*, :obj:`results/pyg/example_node/2`.\n      The aggregated results over all the random seeds are *automatically* generated into :obj:`results/example/agg`, including the mean and standard deviation :obj:`_std` for each metric.\n      Train/validation/test results are further saved into subdirectories, such as :obj:`results/example/agg/val`.\n      Here, :obj:`stats.json` stores the results after each epoch aggregated across random seeds, and :obj:`best.json` stores the results of *the epoch with the highest validation accuracy*.\n\n#. **Run a batch of experiments:**\n   A full example is specified in :obj:`run_batch.sh`.\n\n   #. **Specify a base file:**\n      GraphGym supports running a batch of experiments.\n      To start, a user needs to select a base architecture via :obj:`--config`.\n      The batch of experiments will be created by perturbing certain configurations of the base architecture.\n\n   #. **(Optionally) specify a base file for computational budget:**\n      Additionally, GraphGym allows a user to select a base architecture to *control the computational budget* for the grid search via :obj:`--config_budget`.\n      The computational budget is currently measured by the number of trainable parameters, and the control is achieved by auto-adjusting the hidden dimensionality of the underlying GNN.\n      If no :obj:`--config_budget` is provided, GraphGym will not control the computational budget.\n\n   #. **Specify a grid file:**\n      A grid file describes how to perturb the base file, in order to generate the batch of experiments.\n      For example, the base file could specify an experiment of a 3-layer GCN for node classification on the :class:`~torch_geometric.datasets.Planetoid` datasets.\n      Then, the grid file specifies how to perturb the experiment along different dimensions, such as the number of layers, the model architecture, the dataset, the level of task, etc.\n\n   #. **Generate configuration files for the batch of experiments** based on the information specified above:\n      For example, in :obj:`run_batch.sh`:\n\n      .. code-block:: bash\n\n          python configs_gen.py --config configs/${DIR}/${CONFIG}.yaml \\\n            --config_budget configs/${DIR}/${CONFIG}.yaml \\\n            --grid grids/${DIR}/${GRID}.txt \\\n            --out_dir configs\n\n   #. **Launch the batch of experiments:**\n      For example, in :obj:`run_batch.sh`:\n\n      .. code-block:: bash\n\n          bash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $SLEEP\n\n      Each experiment will be repeated for :obj:`$REPEAT` times.\n      We implemented a queue system to sequentially launch all the jobs, with :obj:`$MAX_JOBS` concurrent jobs running at the same time.\n      In practice, our system works great when handling thousands of jobs.\n\n   #. **Understand the results:**\n      Experimental results will be automatically saved in directory :obj:`results/${CONFIG_NAME}_grid_${GRID_NAME}/`.\n      In the example above, this amounts to :obj:`results/pyg/example_grid_example/`.\n      After running each experiment, GraphGym additionally automatically averages across different models, saved in :obj:`results/pyg/example_grid_example/agg`.\n      There, :obj:`val.csv` represents the validation accuracy for each model configuration at the *final* epoch,\n      :obj:`val_best.csv` represents the results at the epoch with the highest average validation accuracy, and\n      :obj:`val_best_epoch.csv` represents the results at the epoch with the highest validation accuracy averaged over different random seeds.\n      When a test set split is provided, :obj:`test.csv` represents the test accuracy for each model configuration at the *final* epoch,\n      :obj:`test_best.csv` represents the test set results at the epoch with the highest average validation accuracy, and\n      :obj:`test_best_epoch.csv` represents the test set results at the epoch with the highest validation accuracy averaged over different random seeds.\n\nCustomizing GraphGym\n--------------------\n\nA highlight of GraphGym is that it allows you to easily register customized modules.\nFor each project, you can have a unique GraphGym copy with different customized modules.\nFor example, the `\"Design Space for Graph Neural Networks\" <https://arxiv.org/abs/2011.08843>`__ and `\"Identity-aware Graph Neural Networks\" <https://arxiv.org/abs/2101.10320>`__ papers represent two successful projects using customized GraphGym, and you may find more details about them `here <https://github.com/snap-stanford/GraphGym#use-case-design-space-for-graph-neural-networks-neurips-2020-spotlight>`__.\nEventually, every GraphGym-powered project will be unique.\n\nThere are two ways for customizing GraphGym:\n\n#. Use the :obj:`graphgym/custom_graphgym` directory outside the :pyg:`PyG` package:\n   You can register your customized modules here without touching :pyg:`PyG`. This use case will be great for your own customized project.\n\n#. Use the :obj:`torch_geometric/graphgym/contrib` directory inside the :pyg:`PyG` package:\n   If you have come up with a nice customized module, you can directly copy your files into :obj:`torch_geometric/graphgym/contrib`, and **create a pull request** to :pyg:`PyG`.\n   This way, your idea can ship with :pyg:`PyG` installations, and will have a much higher visibility and impact.\n\nConcretely, the supported customized modules includes\n\n- Activations: :obj:`custom_graphgym/act/`\n- Customized configurations: :obj:`custom_graphgym/config/`\n- Feature augmentations: :obj:`custom_graphgym/feature_augment/`\n- Feature encoders: :obj:`custom_graphgym/feature_encoder/`\n- GNN heads: :obj:`custom_graphgym/head/`\n- GNN layers: :obj:`custom_graphgym/layer/`\n- Data loaders: :obj:`custom_graphgym/loader/`\n- Loss functions: :obj:`custom_graphgym/loss/`\n- GNN network architectures: :obj:`custom_graphgym/network/`\n- Optimizers: :obj:`custom_graphgym/optimizer/`\n- GNN global pooling layers (for graph classification only): :obj:`custom_graphgym/pooling/`\n- GNN stages: :obj:`custom_graphgym/stage/`\n- GNN training pipelines: :obj:`custom_graphgym/train/`\n- Data transformations: :obj:`custom_graphgym/transform/`\n\nWithin each directory, at least one example is provided that shows how to register customized modules via :meth:`torch_geometric.graphgym.register`.\nNote that new customized modules may result in new configurations.\nIn these cases, new configuration fields can be registered via :obj:`custom_graphgym/config/`.\n"
  },
  {
    "path": "docs/source/advanced/hgam.rst",
    "content": "Hierarchical Neighborhood Sampling\n==================================\n\nOne of the design principles of :pyg:`PyG` is that models and data loading routines should be exchangeable to allow for flexible GNN and data loading experimentation.\nAs such, models can usually be written in a data loading agnostic fashion, independent of whether one applies full-batch or mini-batch training strategies via, *e.g.*, :class:`~torch_geometric.loader.DataLoader`, :class:`~torch_geometric.loader.NeighborLoader` or :class:`~torch_geometric.loader.ClusterLoader`.\nHowever, in some scenarios, this flexibility comes at the cost of performance, as the model cannot exploit special characteristics of the underlying data loading routine.\nOne such limitation is that a GNN trained with the :class:`~torch_geometric.loader.NeighborLoader` routine iteratively builds representations for *all* nodes at *all* depths of the network, although nodes sampled in later hops do not contribute to the node representations of seed nodes in later GNN layers anymore, thus performing useless computation.\n\n*Hierarchical Neighborhood Sampling* or *Hierarchical Graph Adjacency Matrix (HGAM)* is a technique available in :pyg:`PyG` to eliminate this overhead and speeds up training and inference in mini-batch GNNs.\nIts main idea is to progressively trim the adjacency matrix of the returned subgraph before inputting it to each GNN layer.\nIt works seamlessly across several models, basically reducing the amount of compute necessary to generate the representations for the seed node of the given mini-batch.\n\nCrucially, HGAM recognizes that the computation of the final node representations is only necessary for the seed nodes (which are the real target of the batch computation).\nThus, HGAM allows for every layer of the GNN to compute only the representations of the nodes that are necessary for that layer, leading to a reduction of the computation and a speed up of the training process that grows with the depth of the GNN being considered.\nIn practice, this is achieved by **trimming the adjacency matrix** and the various **features matrices** as the computation proceeds throughout the GNN layers.\nThis is in line with the fact that in order to compute the representation for the seed/target nodes (from which the mini-batch was build via sampling methods), the depth of the relevant neighborhood shrinks as we proceed through the layers of the GNN.\nThe trimming applied by HGAM is possible as the nodes of the subgraph built via sampling are ordered according to a *Breadth First Search (BFS)* strategy, meaning that the rows and columns of the adjacency matrix refer to a node ordering that starts with the seed nodes (in any order) followed by the 1-hop neighbors of the first seed node, followed by the 1-hop sampled neighbors of the second seed node and so on.\nThe BFS ordering of nodes in a mini-batch allows for incremental trimming (reduction) of the adjacency matrix of the subgraph.\nThis progressive trimming is done in a computational convenient manner thanks to the BFS ordering that causes the nodes more distant from the seed nodes to be appear farther away in the list of ordered nodes.\n\nTo support this trimming and implement it effectively, the :class:`~torch_geometric.loader.NeighborLoader` implementation in :pyg:`PyG` and in :pyg:`pyg-lib` additionally return the number of nodes and edges sampled in hop.\nThis information allows for fast manipulation of the adjacency matrix, which in turns lead to great computation reduction.\nThe :class:`~torch_geometric.loader.NeighborLoader` prepares this metadata via the dedicated attributes :obj:`num_sampled_nodes` and :obj:`num_sampled_edges`.\nIt can be accessed from the :class:`~torch_geometric.data.Batch` object returned for both homogeneous and heterogeneous graphs.\n\nTo sum up, HGAM is special data structure that enables efficient message passing computation in :class:`~torch_geometric.loader.NeighborLoader` scenarios.\nHGAM is implemented in :pyg:`PyG` and can be utilized via the special :meth:`~torch_geometric.utils.trim_to_layer` functionality.\nHGAM is currently an option that :pyg:`PyG` users are free to switch on, or leave it off *(current default)*.\n\nUsage\n-----\n\nHere, we show examples of how to use the HGAM functionality in combination with :class:`~torch_geometric.loader.NeighborLoader`:\n\n* **Homogeneous data example:**\n\n  .. code-block:: python\n\n      from torch_geometric.datasets import Planetoid\n      from torch_geometric.loader import NeighborLoader\n\n      data = Planetoid(path, name='Cora')[0]\n\n      loader = NeighborLoader(\n          data,\n          num_neighbors=[10] * 3,\n          batch_size=128,\n      )\n\n      batch = next(iter(loader))\n      print(batch)\n      >>> Data(x=[1883, 1433], edge_index=[2, 5441], y=[1883], train_mask=[1883],\n               val_mask=[1883], test_mask=[1883], batch_size=128,\n               num_sampled_nodes=[4], num_sampled_edges=[3])\n\n      print(batch.num_sampled_nodes)\n      >>> [128, 425, 702, 628]  # Number of sampled nodes per hop/layer.\n      print(batch.num_sampled_edges)\n      >>> [520, 2036, 2885]  # Number of sampled edges per hop/layer.\n\n* **Heterogeneous data example:**\n\n  .. code-block:: python\n\n      from torch_geometric.datasets import OGB_MAG\n      from torch_geometric.loader import NeighborLoader\n\n      data = OGB_MAG(path)[0]\n\n      loader = NeighborLoader(\n          data,\n          num_neighbors=[10] * 3,\n          batch_size=128,\n          input_nodes='paper',\n      )\n\n      batch = next(iter(loader))\n      print(batch)\n      >>> HeteroData(\n          paper={\n              x=[2275, 128],\n              num_sampled_nodes=[3],\n              batch_size=128,\n          },\n          author={\n              num_nodes=2541,\n              num_sampled_nodes=[3],\n          },\n          institution={\n              num_nodes=0,\n              num_sampled_nodes=[3],\n          },\n          field_of_study={\n              num_nodes=0,\n              num_sampled_nodes=[3],\n          },\n          (author, affiliated_with, institution)={\n              edge_index=[2, 0],\n              num_sampled_edges=[2],\n          },\n          (author, writes, paper)={\n              edge_index=[2, 3255],\n              num_sampled_edges=[2],\n          },\n          (paper, cites, paper)={\n              edge_index=[2, 2691],\n              num_sampled_edges=[2],\n          },\n          (paper, has_topic, field_of_study)={\n              edge_index=[2, 0],\n              num_sampled_edges=[2],\n          }\n          )\n      print(batch['paper'].num_sampled_nodes)\n      >>> [128, 508, 1598]  # Number of sampled paper nodes per hop/layer.\n\n      print(batch['author', 'writes', 'paper'].num_sampled_edges)\n      >>>> [629, 2621]  # Number of sampled author<>paper edges per hop/layer.\n\nThe attributes :obj:`num_sampled_nodes` and :obj:`num_sampled_edges` can be used by the :meth:`~torch_geometric.utils.trim_to_layer` function inside the GNN:\n\n.. code-block::  python\n\n    from torch_geometric.datasets import Reddit\n    from torch_geometric.loader import NeighborLoader\n    from torch_geometric.nn import SAGEConv\n    from torch_geometric.utils import trim_to_layer\n\n    dataset = Reddit(path)\n    loader = NeighborLoader(data, num_neighbors=[10, 5, 5], ...)\n\n    class GNN(torch.nn.Module):\n        def __init__(self, in_channels: int, out_channels: int, num_layers: int):\n            super().__init__()\n\n            self.convs = ModuleList([SAGEConv(in_channels, 64)])\n            for _ in range(num_layers - 1):\n                self.convs.append(SAGEConv(hidden_channels, hidden_channels))\n            self.lin = Linear(hidden_channels, out_channels)\n\n        def forward(\n            self,\n            x: Tensor,\n            edge_index: Tensor,\n            num_sampled_nodes_per_hop: List[int],\n            num_sampled_edges_per_hop: List[int],\n        ) -> Tensor:\n\n            for i, conv in enumerate(self.convs):\n                # Trim edge and node information to the current layer `i`.\n                x, edge_index, _ = trim_to_layer(\n                    i, num_sampled_nodes_per_hop, num_sampled_edges_per_hop,\n                    x, edge_index)\n\n                x = conv(x, edge_index).relu()\n\n            return self.lin(x)\n\nExamples\n--------\n\nWe provide full examples of HGAM in the :pyg:`PyG` :obj:`examples/` folder:\n\n* :obj:`examples/hierarchical_sampling.py`: An `example <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hierarchical_sampling.py>`__ to show-case the basic usage of HGAM.\n* :obj:`examples/hetero/hierarchical_sage.py`: An `example <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hierarchical_sage.py>`__ of HGAM on heterogeneous graphs.\n"
  },
  {
    "path": "docs/source/advanced/jit.rst",
    "content": "TorchScript Support\n===================\n\nTorchScript is a way to create serializable and optimizable models from :pytorch:`PyTorch` code.\nAny TorchScript program can be saved from a :python:`Python` process and loaded in a process where there is no :python:`Python` dependency.\nIf you are unfamilar with TorchScript, we recommend to read the official \"`Introduction to TorchScript <https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html>`_\" tutorial first.\n\nConverting GNN Models\n---------------------\n\n.. note::\n    From :pyg:`PyG` 2.5 (and onwards), GNN layers are now fully compatible with :meth:`torch.jit.script` without any modification needed.\n    If you are on an earlier version of :pyg:`PyG`, consider to convert your GNN layers into \"jittable\" instances first by calling :meth:`~torch_geometric.nn.conv.MessagePassing.jittable`.\n\nConverting your :pyg:`PyG` model to a TorchScript program is straightforward and requires only a few code changes.\nLet's consider the following model:\n\n.. code-block:: python\n\n    import torch\n    import torch.nn.functional as F\n    from torch_geometric.nn import GCNConv\n\n    class GNN(torch.nn.Module):\n        def __init__(self, in_channels, out_channels):\n            super().__init__()\n            self.conv1 = GCNConv(in_channels, 64)\n            self.conv2 = GCNConv(64, out_channels)\n\n        def forward(self, x, edge_index):\n            x = self.conv1(x, edge_index)\n            x = F.relu(x)\n            x = self.conv2(x, edge_index)\n            return F.log_softmax(x, dim=1)\n\n    model = GNN(dataset.num_features, dataset.num_classes)\n\nThe instantiated model can now be directly passed into :meth:`torch.jit.script`:\n\n.. code-block:: python\n\n    model = torch.jit.script(model)\n\nThat is all you need to know on how to convert your :pyg:`PyG` models to TorchScript programs.\nYou can have a further look at our JIT examples that show-case how to obtain TorchScript programs for `node <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/jit/gat.py>`_ and `graph classification <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/jit/gin.py>`_ models.\n\nCreating Jittable GNN Operators\n--------------------------------\n\nAll :pyg:`PyG` :class:`~torch_geometric.nn.conv.MessagePassing` operators are tested to be convertible to a TorchScript program.\nHowever, if you want your own GNN module to be compatible with :meth:`torch.jit.script`, you need to account for the following two things:\n\n1. As one would expect, your :meth:`forward` code may need to be adjusted so that it passes the TorchScript compiler requirements, *e.g.*, by adding type notations.\n2. You need to tell the :class:`~torch_geometric.nn.conv.MessagePassing` module the types that you pass to its :meth:`~torch_geometric.nn.conv.MessagePassing.propagate` function.\n   This can be achieved in two different ways:\n\n   1. Declaring the type of propagation arguments in a dictionary called :obj:`propagate_type`:\n\n    .. code-block:: python\n\n        from typing import Optional\n        from torch import Tensor\n        from torch_geometric.nn import MessagePassing\n\n        class MyConv(MessagePassing):\n            propagate_type = {'x': Tensor, 'edge_weight': Optional[Tensor] }\n\n            def forward(\n                self,\n                x: Tensor,\n                edge_index: Tensor,\n                edge_weight: Optional[Tensor] = None,\n            ) -> Tensor:\n                return self.propagate(edge_index, x=x, edge_weight=edge_weight)\n\n   2. Declaring the type of propagation arguments as a comment inside your module:\n\n    .. code-block:: python\n\n        from typing import Optional\n        from torch import Tensor\n        from torch_geometric.nn import MessagePassing\n\n        class MyConv(MessagePassing):\n            def forward(\n                self,\n                x: Tensor,\n                edge_index: Tensor,\n                edge_weight: Optional[Tensor] = None,\n            ) -> Tensor:\n                # propagate_type: (x: Tensor, edge_weight: Optional[Tensor])\n                return self.propagate(edge_index, x=x, edge_weight=edge_weight)\n\n   If none of these options are given, the :class:`~torch_geometric.nn.conv.MessagePassing` module will infer the arguments of :meth:`~torch_geometric.nn.conv.MessagePassing.propagate` to be of type :class:`torch.Tensor` (mimicking the default type that TorchScript is inferring for non-annotated arguments).\n"
  },
  {
    "path": "docs/source/advanced/remote.rst",
    "content": "Scaling Up GNNs via Remote Backends\n===================================\n\n:pyg:`PyG` (2.2 and beyond) includes numerous primitives to easily integrate with simple paradigms for scalable graph machine learning, enabling users to train GNNs on graphs far larger than the size of their machine's available memory.\nIt does so by introducing simple, easy-to-use, and extensible abstractions of a :class:`torch_geometric.data.FeatureStore` and a :class:`torch_geometric.data.GraphStore` that plug directly into existing familiar :pyg:`PyG` interfaces.\nDefining a :class:`~torch_geometric.data.FeatureStore` allows users to leverage node (and soon, edge) features stored remotely, and defining a :class:`~torch_geometric.data.GraphStore` allows users to leverage graph structure information stored remotely.\nTogether, they allow for powerful GNN scalability with low developer friction.\n\n.. warning::\n\n    The remote backend APIs discussed here may change in the future as we continuously work to improve their ease-of-use and generalizability.\n\n.. note::\n\n    Currently, the :class:`~torch_geometric.data.FeatureStore` and :class:`~torch_geometric.data.GraphStore` only support *heterogeneous graphs*, and do not support edge features.\n    Homogeneous graph and edge feature support is coming soon.\n\nBackground\n----------\n\nAn instantiated Graph Neural Network consists of two types of data:\n\n- **Node and/or edge feature information:** Dense vectors corresponding to attributes of the nodes and edges in a graph\n- **Graph structure information:** The nodes in the graph and the edges that connect them\n\nAn immediate observation of GNNs is that scaling to data larger than the available memory of a chosen accelerator requires training on sampled subgraphs (which form mini-batches), instead of the full graph at once (full-batch training).\nWhile this method adds stochasticity to the learning process, it reduces the memory requirements of the accelerator to those of the sampled subgraphs.\n\n.. figure:: ../_figures/remote_1.png\n  :align: center\n  :width: 100%\n\n  **Figure 1:** The classical mini-batch GNN training paradigm.\n\nHowever, while mini-batch training reduces the memory requirements of the chosen accelerator, it is not a silver bullet for all graph learning scalability problems.\nIn particular, since one must sample subgraphs to pass to the accelerator at each iteration of the learning process, the graph and features are traditionally required to be stored in the CPU DRAM of a user's machine.\nAt large scale, this requirement can become quite burdensome:\n\n- Acquiring instances with enough CPU DRAM to store a graph and features is challenging\n- Training with data parallelism requires replicating the graph and features in each compute node\n- Graphs and features can easily be much larger than the memory of a single machine\n\nScalability to very large graphs and features beyond the memory requirements of a single machine thus requires moving these data structures out-of-core and only processing sampled subgraphs on a node that performs computation.\nIn order to achieve this goal, :pyg:`PyG` relies on two primary abstractions to store feature information and graph structure:\nFeatures are stored in a key-value :class:`~torch_geometric.data.FeatureStore`, which must support efficient random access.\nGraph information is stored in a :class:`~torch_geometric.data.GraphStore`, which must support efficient sampling for the samplers defined to operate on the :class:`~torch_geometric.data.GraphStore` instance.\n\n.. figure:: ../_figures/remote_2.png\n  :align: center\n  :width: 100%\n\n  **Figure 2:** Graph data storage layout between remote storage and a training instance.\n\nIn :pyg:`PyG` (2.2 and beyond), the separation of graph data into its features and structure information, the storage of this information in locations potentially remote to the actual training node, and the interactions between these components, are all completely abstracted from the end user.\nAs long as the :class:`~torch_geometric.data.FeatureStore` and :class:`~torch_geometric.data.GraphStore` are defined appropriately (keeping in mind the aforementioned performance requirements), :pyg:`PyG` handles the rest!\n\nFeature Store\n-------------\n\nA :class:`torch_geometric.data.FeatureStore` holds features for the nodes and edges of a graph.\nFeature storage is often the primary storage bottleneck in graph learning applications, as storing a graph's layout information (*i.e.* the :obj:`edge_index`) is relatively cheap (~32 bytes per edge).\n:pyg:`PyG` provides a common interface for various :class:`~torch_geometric.data.FeatureStore` implementations to interface with its core learning API.\n\nThe implementation details of a :class:`~torch_geometric.data.FeatureStore` are abstracted from :pyg:`PyG` through a CRUD-like interface.\nIn particular, implementors of the :class:`~torch_geometric.data.FeatureStore` abstraction are expected to primarily override :meth:`~torch_geometric.data.FeatureStore.put_tensor`, :meth:`~torch_geometric.data.FeatureStore.get_tensor`, and :meth:`~torch_geometric.data.FeatureStore.remove_tensor` functionalities.\nDoing so both enables :pyg:`PyG` to leverage the features stored in the implementation and allows a user to employ a pythonic interface to inspect and modify the :class:`~torch_geometric.data.FeatureStore` elements:\n\n.. code-block:: python\n\n    feature_store = CustomFeatureStore()\n\n    paper_features = ...  # [num_papers, num_paper_features]\n    author_features = ...  # [num_authors, num_author_features]\n\n    # Add features:\n    feature_store['paper', 'x', None] = paper_features\n    feature_store['author', 'x', None] = author_features\n\n    # Access features:\n    assert torch.equal(feature_store['paper', 'x'], paper_features)\n    assert torch.equal(feature_store['paper'].x, paper_features)\n    assert torch.equal(feature_store['author', 'x', 0:20], author_features[0:20])\n\nCommon implementations of the :class:`~torch_geometric.data.FeatureStore` abstractions are key-value stores, *e.g.*, backends such as :obj:`memcached`, :obj:`LevelDB`, :obj:`RocksDB` are all viable performant options.\n\nGraph Store and Sampler\n-----------------------\n\nA :class:`torch_geometric.data.GraphStore` holds the edge indices that define relationships between nodes in a graph.\nThe goal of the :class:`~torch_geometric.data.GraphStore` is to store graph information in a manner that allows for efficient sampling from root nodes, according to a sampling algorithm of the developer's choice.\n\nSimilar to the :class:`~torch_geometric.data.FeatureStore`, :pyg:`PyG` provides a common interface for various :class:`~torch_geometric.data.GraphStore` implementations to interface with its core learning API.\nHowever, unlike the :class:`~torch_geometric.data.FeatureStore`, the :class:`~torch_geometric.data.GraphStore` does not need to provide efficient random access for all its elements; rather, it needs to define a representation that provides efficient subgraph sampling.\nAn example usage of the interface is shown below:\n\n.. code-block:: python\n\n    graph_store = CustomGraphStore()\n\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    # Put edges:\n    graph_store['edge', 'coo'] = coo\n\n    # Access edges:\n    row, col = graph_store['edge', 'coo']\n    assert torch.equal(row, edge_index[0])\n    assert torch.equal(col, edge_index[1])\n\nCommon implementations of the :class:`~torch_geometric.data.GraphStore` are graph databases, *e.g.*, :obj:`Neo4j`, :obj:`TigerGraph`, :obj:`ArangoDB`, :obj:`Kùzu` are all viable performant options.\nWe provide an example of using :pyg:`PyG` in combination with the :obj:`Kùzu` database `here <https://github.com/pyg-team/pytorch_geometric/tree/master/examples/distributed/kuzu>`__.\n\nA graph sampler is tightly coupled to the given :class:`~torch_geometric.data.GraphStore`, and operates on the :class:`~torch_geometric.data.GraphStore` to produce sampled subgraphs from input nodes.\nDifferent sampling algorithms are implemented behind the :class:`torch_geometric.sampler.BaseSampler` interface.\nBy default, :pyg:`PyG's` default in-memory sampler pulls all edge indices from the :class:`~torch_geometric.data.GraphStore` into the training node memory, converts them to compressed sparse column (CSC) format, and leverages pre-built in-memory sampling routines.\nHowever, custom sampler implementations may choose to call specialized :class:`~torch_geometric.data.GraphStore` methods by implementing the :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes` and/or :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges` of the :class:`~torch_geometric.sampler.BaseSampler` class for efficiency reasons (*e.g.*, for performing sampling directly on the remote :class:`~torch_geometric.data.GraphStore`):\n\n.. code-block:: python\n\n    # `CustomGraphSampler` knows how to sample on `CustomGraphStore`:\n    node_sampler = CustomGraphSampler(\n        graph_store=graph_store,\n        num_neighbors=[10, 20],\n        ...\n    )\n\nData Loader\n-----------\n\n:pyg:`PyG` does not define a domain-specific language for sampling that must be implemented by the :class:`~torch_geometric.data.GraphStore`; rather, the sampler and the :class:`~torch_geometric.data.GraphStore` are tightly coupled together through a data loader.\n\n:pyg:`PyG` provides two data loaders out-of-the-box: a :class:`torch_geometric.loader.NodeLoader` that samples subgraphs from input nodes for use in node classification tasks, and a :class:`torch_geometric.loader.LinkLoader` that samples subgraphs from either side of an edge for use in link prediction tasks.\nThese data loaders require a :class:`~torch_geometric.data.FeatureStore`, a :class:`~torch_geometric.data.GraphStore`, and a graph sampler as input, and internally call the sampler's :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes` or :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges` method to perform subgraph sampling:\n\n.. code-block:: python\n\n    # Instead of passing PyG data objects, we now pass a tuple\n    # of the `FeatureStore` and `GraphStore as input data:\n    loader = NodeLoader(\n        data=(feature_store, graph_store),\n        node_sampler=node_sampler,\n        batch_size=20,\n        input_nodes='paper',\n    )\n\n    for batch in loader:\n        pass\n\nPutting it All Together\n-----------------------\n\nAt a high level, the components listed above all work together to provide support for scaling up GNNs within :pyg:`PyG`.\n\n- The **data loader** (precisely, each worker) leverages a :class:`~torch_geometric.sampler.BaseSampler` to make a sampling request to the :class:`~torch_geometric.data.GraphStore`.\n- Upon receipt of a response, the data loader subsequently queries the :class:`~torch_geometric.data.FeatureStore` for features associated with the nodes and edges of the sampled subgraphs.\n- The data loader subsequently constructs a final mini-batch from graph structure and feature information to send to the accelerator for forward/backward passes.\n- Repeat until convergence.\n\nAll of the outlined classes speak through common interfaces, making them extensible, generalizable, and easy to integrate with the :pyg:`PyG` you use today:\n\n.. figure:: ../_figures/remote_3.png\n  :align: center\n  :width: 80%\n\n  **Figure 3:** The common interfaces (and data flow) uniting the :class:~torch_geometric.data.`FeatureStore`, :class:`~torch_geometric.data.GraphStore`, graph sampler, and data loader.\n\nTo get started with scalability, we recommend inspecting the interfaces listed above and defining your own :class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.data.GraphStore`, and :class:`~torch_geometric.sampler.BaseSampler` implementations behind them.\nOnce a :class:`~torch_geometric.data.FeatureStore`, a :class:`~torch_geometric.data.GraphStore`, and a :class:`~torch_geometric.sampler.BaseSampler` are correctly implemented, simply pass them as parameters to a :class:`~torch_geometric.loader.NodeLoader` or a :class:`~torch_geometric.loader.LinkLoader`, and the rest of :pyg:`PyG` will work seamlessly and similar to any pure in-memory application.\n\nSince this feature is still undergoing heavy development, please feel free to reach out to the :pyg:`PyG` core team either on :github:`null` `GitHub <https://github.com/pyg-team/pytorch_geometric/discussions>`_ or :slack:`null` `Slack <https://data.pyg.org/slack.html>`_ if you have any questions, comments or concerns.\n"
  },
  {
    "path": "docs/source/advanced/sparse_tensor.rst",
    "content": "Memory-Efficient Aggregations\n=============================\n\nThe :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface of :pyg:`PyG` relies on a gather-scatter scheme to aggregate messages from neighboring nodes.\nFor example, consider the message passing layer\n\n.. math::\n\n    \\mathbf{x}^{\\prime}_i = \\sum_{j \\in \\mathcal{N}(i)} \\textrm{MLP}(\\mathbf{x}_j - \\mathbf{x}_i),\n\nthat can be implemented as:\n\n.. code-block:: python\n\n    from torch_geometric.nn import MessagePassing\n\n    x = ...           # Node features of shape [num_nodes, num_features]\n    edge_index = ...  # Edge indices of shape [2, num_edges]\n\n    class MyConv(MessagePassing):\n        def __init__(self):\n            super().__init__(aggr=\"add\")\n\n        def forward(self, x, edge_index):\n            return self.propagate(edge_index, x=x)\n\n        def message(self, x_i, x_j):\n            return MLP(x_j - x_i)\n\nUnder the hood, the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` implementation produces a code that looks as follows:\n\n.. code-block:: python\n\n    from torch_geometric.utils import scatter\n\n    x = ...           # Node features of shape [num_nodes, num_features]\n    edge_index = ...  # Edge indices of shape [2, num_edges]\n\n    x_j = x[edge_index[0]]  # Source node features [num_edges, num_features]\n    x_i = x[edge_index[1]]  # Target node features [num_edges, num_features]\n\n    msg = MLP(x_j - x_i)  # Compute message for each edge\n\n    # Aggregate messages based on target node indices\n    out = scatter(msg, edge_index[1], dim=0, dim_size=x.size(0), reduce='sum')\n\nWhile the gather-scatter formulation generalizes to a lot of useful GNN implementations, it has the disadvantage of explicitly materalizing :obj:`x_j` and :obj:`x_i`, resulting in a high memory footprint on large and dense graphs.\n\nLuckily, not all GNNs need to be implemented by explicitly materalizing :obj:`x_j` and/or :obj:`x_i`.\nIn some cases, GNNs can also be implemented as a simple-sparse matrix multiplication.\nAs a general rule of thumb, this holds true for GNNs that do not make use of the central node features :obj:`x_i` or multi-dimensional edge features when computing messages.\nFor example, the :class:`~torch_geometric.nn.conv.GINConv` layer\n\n.. math::\n\n    \\mathbf{x}^{\\prime}_i = \\textrm{MLP} \\left( (1 + \\epsilon) \\cdot \\mathbf{x}_i + \\sum_{j \\in \\mathcal{N}(i)} \\mathbf{x}_j \\right),\n\nis equivalent to computing\n\n.. math::\n\n    \\mathbf{X}^{\\prime} = \\textrm{MLP} \\left( (1 + \\epsilon) \\cdot \\mathbf{X} + \\mathbf{A}\\mathbf{X} \\right),\n\nwhere :math:`\\mathbf{A}` denotes a sparse adjacency matrix of shape :obj:`[num_nodes, num_nodes]`.\nThis formulation allows to leverage dedicated and fast sparse-matrix multiplication implementations.\n\nIn :pyg:`null` **PyG >= 1.6.0**, we officially introduce better support for sparse-matrix multiplication GNNs, resulting in a **lower memory footprint** and a **faster execution time**.\nAs a result, we introduce the :class:`SparseTensor` class (from the :obj:`torch_sparse` package), which implements fast forward and backward passes for sparse-matrix multiplication based on the `\"Design Principles for Sparse Matrix Multiplication on the GPU\" <https://arxiv.org/abs/1803.08601>`_ paper.\n\nUsing the :class:`SparseTensor` class is straightforward and similar to the way :obj:`scipy` treats sparse matrices:\n\n.. code-block:: python\n\n    from torch_sparse import SparseTensor\n\n    adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=...,\n                       sparse_sizes=(num_nodes, num_nodes))\n    # value is optional and can be None\n\n    # Obtain different representations (COO, CSR, CSC):\n    row,    col, value = adj.coo()\n    rowptr, col, value = adj.csr()\n    colptr, row, value = adj.csc()\n\n    adj = adj[:100, :100]  # Slicing, indexing and masking support\n    adj = adj.set_diag()   # Add diagonal entries\n    adj_t = adj.t()        # Transpose\n    out = adj.matmul(x)    # Sparse-dense matrix multiplication\n    adj = adj.matmul(adj)  # Sparse-sparse matrix multiplication\n\n    # Creating SparseTensor instances:\n    adj = SparseTensor.from_dense(mat)\n    adj = SparseTensor.eye(100, 100)\n    adj = SparseTensor.from_scipy(mat)\n\nOur :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface can handle both :obj:`torch.Tensor` and :class:`SparseTensor` as input for propagating messages.\nHowever, when holding a directed graph in :class:`SparseTensor`, you need to make sure to input the **transposed sparse matrix** to :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`:\n\n.. code-block:: python\n\n    conv = GCNConv(16, 32)\n    out1 = conv(x, edge_index)\n    out2 = conv(x, adj.t())\n    assert torch.allclose(out1, out2)\n\n    conv = GINConv(nn=Sequential(Linear(16, 32), ReLU(), Linear(32, 32)))\n    out1 = conv(x, edge_index)\n    out2 = conv(x, adj.t())\n    assert torch.allclose(out1, out2)\n\nTo leverage sparse-matrix multiplications, the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface introduces the :func:`~torch_geometric.nn.conv.message_passing.message_and_aggregate` function (which fuses the :func:`~torch_geometric.nn.conv.message_passing.message` and :func:`~torch_geometric.nn.conv.message_passing.aggregate` functions into a single computation step), which gets called whenever it is implemented and receives a :class:`SparseTensor` as input for :obj:`edge_index`.\nWith it, the :class:`~torch_geometric.nn.conv.GINConv` layer can now be implemented as follows:\n\n.. code-block:: python\n\n    import torch_sparse\n\n    class GINConv(MessagePassing):\n        def __init__(self):\n            super().__init__(aggr=\"add\")\n\n        def forward(self, x, edge_index):\n            out = self.propagate(edge_index, x=x)\n            return MLP((1 + eps) x + out)\n\n        def message(self, x_j):\n            return x_j\n\n        def message_and_aggregate(self, adj_t, x):\n            return torch_sparse.matmul(adj_t, x, reduce=self.aggr)\n\nPlaying around with the new :class:`SparseTensor` format is straightforward since all of our GNNs work with it out-of-the-box.\nTo convert the :obj:`edge_index` format to the newly introduced :class:`SparseTensor` format, you can make use of the :class:`torch_geometric.transforms.ToSparseTensor` transform:\n\n.. code-block:: python\n\n    import torch\n    import torch.nn.functional as F\n\n    from torch_geometric.nn import GCNConv\n    import torch_geometric.transforms as T\n    from torch_geometric.datasets import Planetoid\n\n    dataset = Planetoid(\"Planetoid\", name=\"Cora\", transform=T.ToSparseTensor())\n    data = dataset[0]\n    >>> Data(adj_t=[2708, 2708, nnz=10556], x=[2708, 1433], y=[2708], ...)\n\n\n    class GNN(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.conv1 = GCNConv(dataset.num_features, 16, cached=True)\n            self.conv2 = GCNConv(16, dataset.num_classes, cached=True)\n\n        def forward(self, x, adj_t):\n            x = self.conv1(x, adj_t)\n            x = F.relu(x)\n            x = self.conv2(x, adj_t)\n            return F.log_softmax(x, dim=1)\n\n    model = GNN()\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n    def train(data):\n        model.train()\n        optimizer.zero_grad()\n        out = model(data.x, data.adj_t)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        optimizer.step()\n        return float(loss)\n\n    for epoch in range(1, 201):\n        loss = train(data)\n\nAll code remains the same as before, except for the :obj:`data` transform via :obj:`T.ToSparseTensor()`.\nAs an additional advantage, :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` implementations that utilize the :class:`SparseTensor` class are deterministic on the GPU since aggregations no longer rely on atomic operations.\n\nNotably, the GNN layer execution slightly changes in case GNNs incorporate single or multi-dimensional edge information :obj:`edge_weight` or :obj:`edge_attr` into their message passing formulation, respectively.\nIn particular, it is now expected that these attributes are directly added as values to the :class:`SparseTensor` object.\nInstead of calling the GNN as\n\n.. code-block:: python\n\n    conv = GMMConv(16, 32, dim=3)\n    out = conv(x, edge_index, edge_attr)\n\nwe now execute our GNN operator as\n\n.. code-block:: python\n\n    conv = GMMConv(16, 32, dim=3)\n    adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_attr)\n    out = conv(x, adj.t())\n\n.. note::\n\n    Since this feature is still experimental, some operations, *e.g.*, graph pooling methods, may still require you to input the :obj:`edge_index` format.\n    You can convert :obj:`adj_t` back to :obj:`(edge_index, edge_attr)` via:\n\n    .. code-block:: python\n\n        row, col, edge_attr = adj_t.t().coo()\n        edge_index = torch.stack([row, col], dim=0)\n\nPlease let us know what you think of :class:`SparseTensor`, how we can improve it, and whenever you encounter any unexpected behavior.\n"
  },
  {
    "path": "docs/source/cheatsheet/data_cheatsheet.rst",
    "content": "Dataset Cheatsheet\n==================\n\n.. note::\n\n    This dataset statistics table is a **work in progress**.\n    Please consider helping us filling its content by providing statistics for individual datasets.\n    See `here <https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/datasets/karate.py#L25-L37>`__ and `here <https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/datasets/tu_dataset.py#L56-L108>`__ for examples on how to do so.\n\nHomogeneous Datasets\n--------------------\n\n.. list-table::\n    :widths: 50 10 10 10 10 10\n    :header-rows: 1\n\n    * - Name\n      - #graphs\n      - #nodes\n      - #edges\n      - #features\n      - #classes/#tasks\n{% for cls in torch_geometric.datasets.homo_datasets %}\n    * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %}\n      - {%if torch_geometric.datasets.utils.has_stats(cls) %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default=1) }}{% else %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default='') }}{% endif %}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', default='') }}\n    {% for child in torch_geometric.datasets.utils.get_children(cls) %}\n    * - └─ {{ child }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', child, default=1) }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }}\n    {% endfor %}\n{% endfor %}\n\nHeterogeneous Datasets\n----------------------\n\n.. list-table::\n    :widths: 50 30 10 10\n    :header-rows: 1\n\n    * - Name\n      - #nodes/#edges\n      - #features\n      - #classes/#tasks\n{% for cls in torch_geometric.datasets.hetero_datasets %}\n    * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %}\n      -\n      -\n      -\n    {% for child in torch_geometric.datasets.utils.get_children(cls) %}\n    * - └─ **{{torch_geometric.datasets.utils.get_type(child)}} Type**: {{ child }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes/#edges', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }}\n    {% endfor %}\n{% endfor %}\n\nSynthetic Datasets\n------------------\n\n.. list-table::\n    :widths: 50 10 10 10 10 10\n    :header-rows: 1\n\n    * - Name\n      - #graphs\n      - #nodes\n      - #edges\n      - #features\n      - #classes/#tasks\n{% for cls in torch_geometric.datasets.synthetic_datasets %}\n    * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %}\n      - {%if torch_geometric.datasets.utils.has_stats(cls) %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default=1) }}{% else %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default='') }}{% endif %}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', default='') }}\n    {% for child in torch_geometric.datasets.utils.get_children(cls) %}\n    * - └─ {{ child }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', child, default=1) }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }}\n    {% endfor %}\n{% endfor %}\n"
  },
  {
    "path": "docs/source/cheatsheet/gnn_cheatsheet.rst",
    "content": "GNN Cheatsheet\n==============\n\n* :class:`~torch_sparse.SparseTensor`: If checked (✓), supports message passing based on :class:`torch_sparse.SparseTensor`, *e.g.*, :obj:`GCNConv(...).forward(x, adj_t)`. See `here <../advanced/sparse_tensor.html>`__ for the accompanying tutorial.\n\n* :obj:`edge_weight`: If checked (✓), supports message passing with one-dimensional edge weight information, *e.g.*, :obj:`GraphConv(...).forward(x, edge_index, edge_weight)`.\n\n* :obj:`edge_attr`: If checked (✓), supports message passing with multi-dimensional edge feature information, *e.g.*, :obj:`GINEConv(...).forward(x, edge_index, edge_attr)`.\n\n* **bipartite**: If checked (✓), supports message passing in bipartite graphs with potentially different feature dimensionalities for source and destination nodes, *e.g.*, :obj:`SAGEConv(in_channels=(16, 32), out_channels=64)`.\n\n* **static**: If checked (✓), supports message passing in static graphs, *e.g.*, :obj:`GCNConv(...).forward(x, edge_index)` with :obj:`x` having shape :obj:`[batch_size, num_nodes, in_channels]`.\n\n* **lazy**: If checked (✓), supports lazy initialization of message passing layers, *e.g.*, :obj:`SAGEConv(in_channels=-1, out_channels=64)`.\n\nGraph Neural Network Operators\n------------------------------\n\n.. list-table::\n    :widths: 40 10 10 10 10 10 10\n    :header-rows: 1\n\n    * - Name\n      - :class:`~torch_sparse.SparseTensor`\n      - :obj:`edge_weight`\n      - :obj:`edge_attr`\n      - bipartite\n      - static\n      - lazy\n{% for cls in torch_geometric.nn.conv.classes[1:] %}\n{% if not torch_geometric.nn.conv.utils.processes_heterogeneous_graphs(cls) and\n      not torch_geometric.nn.conv.utils.processes_hypergraphs(cls) and\n      not torch_geometric.nn.conv.utils.processes_point_clouds(cls) %}\n    * - :class:`~torch_geometric.nn.conv.{{ cls }}` {% if torch_geometric.nn.conv.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__){% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %}\n{% endif %}\n{% endfor %}\n\nHeterogeneous Graph Neural Network Operators\n--------------------------------------------\n\n.. list-table::\n    :widths: 40 10 10 10 10 10 10\n    :header-rows: 1\n\n    * - Name\n      - :class:`~torch_sparse.SparseTensor`\n      - :obj:`edge_weight`\n      - :obj:`edge_attr`\n      - bipartite\n      - static\n      - lazy\n{% for cls in torch_geometric.nn.conv.classes[1:] %}\n{% if torch_geometric.nn.conv.utils.processes_heterogeneous_graphs(cls) %}\n    * - :class:`~torch_geometric.nn.conv.{{ cls }}` {% if torch_geometric.nn.conv.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__){% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %}\n{% endif %}\n{% endfor %}\n\nHypergraph Neural Network Operators\n-----------------------------------\n\n.. list-table::\n    :widths: 40 10 10 10 10 10 10\n    :header-rows: 1\n\n    * - Name\n      - :class:`~torch_sparse.SparseTensor`\n      - :obj:`edge_weight`\n      - :obj:`edge_attr`\n      - bipartite\n      - static\n      - lazy\n{% for cls in torch_geometric.nn.conv.classes[1:] %}\n{% if torch_geometric.nn.conv.utils.processes_hypergraphs(cls) %}\n    * - :class:`~torch_geometric.nn.conv.{{ cls }}` {% if torch_geometric.nn.conv.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__){% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %}\n{% endif %}\n{% endfor %}\n\nPoint Cloud Neural Network Operators\n------------------------------------\n\n.. list-table::\n    :widths: 80 10 10\n    :header-rows: 1\n\n    * - Name\n      - bipartite\n      - lazy\n{% for cls in torch_geometric.nn.conv.classes[1:] %}\n{% if torch_geometric.nn.conv.utils.processes_point_clouds(cls) %}\n    * - :class:`~torch_geometric.nn.conv.{{ cls }}` {% if torch_geometric.nn.conv.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__){% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %}\n{% endif %}\n{% endfor %}\n"
  },
  {
    "path": "docs/source/conf.py",
    "content": "import datetime\nimport os.path as osp\nimport sys\n\nimport pyg_sphinx_theme\n\nimport torch_geometric\n\nauthor = 'PyG Team'\nproject = 'pytorch_geometric'\nversion = torch_geometric.__version__\ncopyright = f'{datetime.datetime.now().year}, {author}'\n\nsys.path.append(osp.join(osp.dirname(pyg_sphinx_theme.__file__), 'extension'))\n\nextensions = [\n    'sphinx.ext.autodoc',\n    'sphinx.ext.autosummary',\n    'sphinx.ext.intersphinx',\n    'sphinx.ext.mathjax',\n    'sphinx.ext.napoleon',\n    'sphinx.ext.viewcode',\n    'sphinx_autodoc_typehints',\n    'sphinx_copybutton',\n    'nbsphinx',\n    'pyg',\n]\n\nhtml_theme = 'pyg_sphinx_theme'\nhtml_logo = ('https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/'\n             'master/pyg_sphinx_theme/static/img/pyg_logo.png')\nhtml_favicon = ('https://raw.githubusercontent.com/pyg-team/pyg_sphinx_theme/'\n                'master/pyg_sphinx_theme/static/img/favicon.png')\nhtml_static_path = ['_static']\ntemplates_path = ['_templates']\n\nadd_module_names = False\nautodoc_member_order = 'bysource'\n\nsuppress_warnings = ['autodoc.import_object']\n\nintersphinx_mapping = {\n    'python': ('https://docs.python.org/', None),\n    # 'numpy': ('http://docs.scipy.org/doc/numpy', None),\n    'pandas': ('https://pandas.pydata.org/docs', None),\n    'torch': ('https://pytorch.org/docs/main', None),\n}\n\ntypehints_use_rtype = False\ntypehints_defaults = 'comma'\n\nnbsphinx_thumbnails = {\n    'tutorial/create_gnn': '_static/thumbnails/create_gnn.png',\n    'tutorial/heterogeneous': '_static/thumbnails/heterogeneous.png',\n    'tutorial/create_dataset': '_static/thumbnails/create_dataset.png',\n    'tutorial/load_csv': '_static/thumbnails/load_csv.png',\n    'tutorial/dataset_splitting': '_static/thumbnails/dataset_splitting.png',\n    'tutorial/neighbor_loader': '_static/thumbnails/neighbor_loader.png',\n    'tutorial/point_cloud': '_static/thumbnails/point_cloud.png',\n    'tutorial/explain': '_static/thumbnails/explain.png',\n    'tutorial/shallow_node_embeddings':\n    '_static/thumbnails/shallow_node_embeddings.png',\n    'tutorial/distributed_pyg': '_static/thumbnails/distributed_pyg.png',\n    'tutorial/multi_gpu_vanilla': '_static/thumbnails/multi_gpu_vanilla.png',\n    'tutorial/multi_node_multi_gpu_vanilla':\n    '_static/thumbnails/multi_gpu_vanilla.png',\n    'tutorial/graph_transformer': '_static/thumbnails/graph_transformer.png',\n}\n\n\ndef rst_jinja_render(app, _, source):\n    if hasattr(app.builder, 'templates'):\n        rst_context = {'torch_geometric': torch_geometric}\n        source[0] = app.builder.templates.render_string(source[0], rst_context)\n\n\ndef setup(app):\n    r\"\"\"Setup sphinx application.\"\"\"\n    app.connect('source-read', rst_jinja_render)\n    app.add_js_file('js/version_alert.js')\n\n    # Do not drop type hints in signatures:\n    del app.events.listeners['autodoc-process-signature']\n"
  },
  {
    "path": "docs/source/external/resources.rst",
    "content": "External Resources\n==================\n\n* Fey *et al.*: **PyG 2.0: Scalable Learning on Real World Graphs** [`Paper <https://arxiv.org/abs/2507.16991>`__]\n\n* Matthias Fey and Jan E. Lenssen: **Fast Graph Representation Learning with** :pyg:`null` **PyTorch Geometric** [`Paper <https://arxiv.org/abs/1903.02428>`_, `Slides (3.3MB) <http://rusty1s.github.io/pyg_slides.pdf>`__, `Poster (2.3MB) <http://rusty1s.github.io/pyg_poster.pdf>`__, `Notebook <http://htmlpreview.github.io/?https://github.com/rusty1s/rusty1s.github.io/blob/master/pyg_notebook.html>`__]\n\n* :stanford:`Stanford CS224W: Machine Learning with Graphs`: **Graph Machine Learning lectures** [:youtube:`null` `Youtube <https://www.youtube.com/watch?v=JAB_plj2rbA>`__]\n\n* :stanford:`Stanford University`: **A collection of graph machine learning tutorial blog posts**, fully realized with :pyg:`null` **PyG** [`Website <https://medium.com/stanford-cs224w>`__]\n\n* Soumith Chintala: **Automatic Differentiation,** :pytorch:`null` **PyTorch and Graph Neural Networks** [`Talk (starting from 26:15) <http://www.ipam.ucla.edu/abstract/?tid=15592&pcode=GLWS4>`__]\n\n* Stanford University: **Graph Neural Networks using** :pyg:`null` **PyTorch Geometric** [:youtube:`null` `YouTube (starting from 33:33) <https://www.youtube.com/watch?v=-UjytpbqX4A&feature=youtu.be>`__]\n\n* Antonio Longa, Gabriele Santin and Giovanni Pellegrini: :pyg:`null` **PyTorch Geometric Tutorial** [`Website <https://antoniolonga.github.io/Pytorch_geometric_tutorials>`__, :github:`null` `GitHub <https://github.com/AntonioLonga/PytorchGeometricTutorial>`__]\n\n* DAIR.AI | elvis: **Introduction to GNNs with** :pyg:`null` **PyTorch Geometric** [`Website <https://github.com/dair-ai/GNNs-Recipe>`__, :colab:`null` `Colab <https://colab.research.google.com/drive/1d0jLDwgNBtjBVQOFe8lO_1WrqTVeVZx9?usp=sharing>`__]\n\n* Nicolas Chaulet *et al.*: **PyTorch Points 3D** - A framework for running common deep learning models for point cloud analysis tasks that heavily relies on Pytorch Geometric [:github:`null` `GitHub <https://github.com/nicolas-chaulet/torch-points3d>`__, `Documentation <https://torch-points3d.readthedocs.io/en/latest/>`__]\n\n* Weihua Hu *et al.*: :ogb:`null` **Open Graph Benchmark** - A collection of large-scale benchmark datasets, data loaders, and evaluators for graph machine learning, including :pyg:`PyG` support and examples [`Website <https://ogb.stanford.edu>`__, :github:`null` `GitHub <https://github.com/snap-stanford/ogb>`__]\n\n* **DeepSNAP** - A :pytorch:`PyTorch` library that bridges between graph libraries such as NetworkX and :pyg:`PyG` [:github:`null` `GitHub <https://github.com/snap-stanford/deepsnap>`__, `Documentation <https://snap.stanford.edu/deepsnap/>`__]\n\n* **Quiver** - A distributed graph learning library for :pyg:`PyG` [:github:`null` `GitHub <https://github.com/quiver-team/torch-quiver>`__]\n\n* Benedek Rozemberczki: **PyTorch Geometric Temporal** - A temporal GNN library built upon :pyg:`PyG` [:github:`null` `GitHub <https://github.com/benedekrozemberczki/pytorch_geometric_temporal>`__, `Documentation <https://pytorch-geometric-temporal.readthedocs.io/en/latest/>`__]\n\n* Yixuan He: **PyTorch Geometric Signed Directed** - A signed and directed GNN library built upon :pyg:`PyG` [:github:`null` `GitHub <https://github.com/SherylHYX/pytorch_geometric_signed_directed>`__, `Documentation <https://pytorch-geometric-signed-directed.readthedocs.io/en/latest/>`__]\n\n* Steeve Huang: **Hands-on Graph Neural Networks with** :pytorch:`null` **PyTorch &** :pyg:`null` **PyTorch Geometric** [`Tutorial <https://towardsdatascience.com/hands-on-graph-neural-networks-with-pytorch-pytorch-geometric-359487e221a8>`__, `Code <https://github.com/khuangaf/Pytorch-Geometric-YooChoose>`__]\n\n* Francesco Landolfi: :pyg:`null` **PyTorch Geometric Tutorial** [`PDF (0.4MB) <http://pages.di.unipi.it/citraro/files/slides/Landolfi_tutorial.pdf>`__]\n\n* Sachin Sharma: **How to Deploy (almost) any** :pyg:`null` **PyTorch Geometric Model on Nvidia's Triton Inference Server with an Application to Amazon Product Recommendation and ArangoDB** [`Blog <https://sachinsharma9780.medium.com/how-to-deploy-almost-any-pytorch-geometric-model-on-nvidias-triton-inference-server-with-an-218d0c0c679c>`__]\n\n* Amitoz Azad: **torch_pdegraph** - Solving PDEs on Graphs with :pyg:`PyG` [`Devpost <https://devpost.com/software/gdfgddfd>`__, :github:`null` `GitHub <https://github.com/aGIToz/Pytorch_pdegraph>`__]\n\n* Amitoz Azad: **Primal-Dual Algorithm for Total Variation Processing on Graphs** [`Jupyter <https://nbviewer.jupyter.org/github/aGIToz/Graph_Signal_Processing/tree/main>`__]\n\n* Manan Goel: **Recommending Amazon Products using Graph Neural Networks in** :pyg:`null` **PyTorch Geometric** [:wandb:`null` `W&B Report <https://wandb.ai/manan-goel/gnn-recommender/reports/Recommending-Amazon-Products-using-Graph-Neural-Networks-in-PyTorch-Geometric--VmlldzozMTA3MzYw>`__]\n\n* Kùzu: **Remote Backend for** :pyg:`null` **PyTorch Geometric** [:colab:`null` `Colab <https://colab.research.google.com/drive/12fOSqPm1HQTz_m9caRW7E_92vaeD9xq6>`__]\n\n* Aniket Saxena: **Graph Neural Networks-based Explanation App using** :pyg:`null` **PyTorch Geometric** [`Website <https://graph-explainability.streamlit.app/>`__, :github:`null` `GitHub <https://github.com/fork123aniket/End-to-End-Node-and-Graph-Classification-and-Explanation-App>`__]\n\n* Mashaan Alshammari: **Graph Attention in** :pyg:`null` **PyTorch Geometric** [:youtube:`null` `Youtube <https://youtu.be/AWkPjrZshug>`__, :github:`null` `GitHub <https://github.com/mashaan14/YouTube-channel/blob/main/notebooks/2024_02_05_GAT.ipynb>`__]\n\n* Mashaan Alshammari: **Graph Convolutional Networks (GCNs) in** :pytorch:`null` **PyTorch** [:youtube:`null` `Youtube <https://youtu.be/G6c6zk0RhRM>`__, :github:`null` `GitHub <https://github.com/mashaan14/YouTube-channel/blob/main/notebooks/2023_12_04_GCN_introduction.ipynb>`__]\n\n* Mashaan Alshammari: **GCN and SGC in** :pytorch:`null` **PyTorch** [:youtube:`null` `Youtube <https://youtu.be/PQT2QblNegY>`__, :github:`null` `GitHub <https://github.com/mashaan14/YouTube-channel/blob/main/notebooks/2023_12_13_GCN_and_SGC.ipynb>`__],\n\n* Mashaan Alshammari: **GCN Variants SGC and ASGC in** :pytorch:`null` **PyTorch** [:youtube:`null` `Youtube <https://youtu.be/ZNMV5i84fmM>`__, :github:`null` `GitHub <https://github.com/mashaan14/YouTube-channel/blob/main/notebooks/2024_01_31_SGC_and_ASGC.ipynb>`__]\n"
  },
  {
    "path": "docs/source/get_started/colabs.rst",
    "content": "Colab Notebooks and Video Tutorials\n===================================\n\nOfficial Examples\n-----------------\n\nWe have prepared a list of :colab:`Colab` notebooks that practically introduces you to the world of **Graph Neural Networks** with :pyg:`PyG`:\n\n1. `Introduction: Hands-on Graph Neural Networks <https://colab.research.google.com/drive/1h3-vJGRVloF5zStxL5I0rSy4ZUPNsjy8?usp=sharing>`__\n2. `Node Classification with Graph Neural Networks <https://colab.research.google.com/drive/14OvFnAXggxB8vM4e8vSURUp1TaKnovzX?usp=sharing>`__\n3. `Graph Classification with Graph Neural Networks <https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing>`__\n4. `Scaling Graph Neural Networks <https://colab.research.google.com/drive/1XAjcjRHrSR_ypCk_feIWFbcBKyT4Lirs?usp=sharing>`__\n5. `Point Cloud Classification with Graph Neural Networks <https://colab.research.google.com/drive/1D45E5bUK3gQ40YpZo65ozs7hg5l-eo_U?usp=sharing>`__\n6. `Explaining GNN Model Predictions using <https://colab.research.google.com/drive/1fLJbFPz0yMCQg81DdCP5I8jXw9LoggKO?usp=sharing>`__ :captum:`null` `Captum <https://colab.research.google.com/drive/1fLJbFPz0yMCQg81DdCP5I8jXw9LoggKO?usp=sharing>`__\n7. `Customizing Aggregations within Message Passing <https://colab.research.google.com/drive/1KKw-VUDQuHhMo7sCd7ZaRROza3leBjRR?usp=sharing>`__\n8. `Node Classification Instrumented with <https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pyg/8_Node_Classification_(with_W&B).ipynb>`__ :wandb:`null` `Weights&Biases <https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pyg/8_Node_Classification_(with_W&B).ipynb>`__\n9. `Graph Classification Instrumented with <https://colab.research.google.com/github/wandb/examples/blob/pyg/graph-classification/colabs/pyg/Graph_Classification_with_PyG_and_W%26B.ipynb>`__ :wandb:`null` `Weights&Biases <https://colab.research.google.com/github/wandb/examples/blob/pyg/graph-classification/colabs/pyg/Graph_Classification_with_PyG_and_W%26B.ipynb>`__\n10. `Link Prediction on MovieLens <https://colab.research.google.com/drive/1xpzn1Nvai1ygd_P5Yambc_oe4VBPK_ZT?usp=sharing>`__\n11. `Link Regression on MovieLens <https://colab.research.google.com/drive/1N3LvAO0AXV4kBPbTMX866OwJM9YS6Ji2?usp=sharing>`__\n12. `Pooling in Graph Neural Networks with <https://colab.research.google.com/github/tgp-team/torch-geometric-pool/blob/main/docs/source/tutorials/hierarchical_gnns.ipynb>`__ :tgp:`null` `tgp <https://colab.research.google.com/github/tgp-team/torch-geometric-pool/blob/main/docs/source/tutorials/hierarchical_gnns.ipynb>`__\n\nAll :colab:`Colab` notebooks are released under the MIT license.\n\nStanford CS224W Tutorials\n-------------------------\n\n.. image:: https://data.pyg.org/img/cs224w_tutorials.png\n  :align: center\n  :width: 941px\n  :target: https://medium.com/stanford-cs224w\n\n.. raw:: html\n\n   <br/>\n\nThe :stanford:`null` `Stanford CS224W <http://web.stanford.edu/class/cs224w/>`__ course has collected a set of `graph machine learning tutorial blog posts <https://medium.com/stanford-cs224w>`__, fully realized with :pyg:`PyG`.\nStudents worked on projects spanning all kinds of tasks, model architectures and applications.\nAll tutorials also link to a :colab:`Colab` with the code in the tutorial for you to follow along with as you read it!\n\nPyTorch Geometric Tutorial Project\n----------------------------------\n\nThe :pyg:`null` `PyTorch Geometric Tutorial <https://github.com/AntonioLonga/PytorchGeometricTutorial>`__ project provides **video tutorials and** :colab:`null` **Colab notebooks** for a variety of different methods in :pyg:`PyG`:\n\n1. Introduction [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=JtDgmmQ60x8>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial1/Tutorial1.ipynb>`__]\n2. :pytorch:`PyTorch` basics [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=UHrhp2l_knU>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial2/Tutorial2.ipynb>`__]\n3. Graph Attention Networks (GATs) [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=CwsPoa7z2c8>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial3/Tutorial3.ipynb>`__]\n4. Spectral Graph Convolutional Layers [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=Ghw-fp_2HFM>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial4/Tutorial4.ipynb>`__]\n5. Aggregation Functions in GNNs [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=tGXovxQ7hKU>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial5/Aggregation%20Tutorial.ipynb>`__]\n6. (Variational) Graph Autoencoders (GAE and VGAE) [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=qA6U4nIK62E>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial6/Tutorial6.ipynb>`__]\n7. Adversarially Regularized Graph Autoencoders (ARGA and ARGVA) [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=hZkLu2OaHD0>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial7/Tutorial7.ipynb>`__]\n8. Graph Generation [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=embpBq1gHAE>`__]\n9. Recurrent Graph Neural Networks [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=v7TQ2DUoaBY>`__, :colab:`null` `Colab (Part 1) <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial9/Tutorial9.ipynb>`__, :colab:`null` `Colab (Part 2) <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial9/RecGNN_tutorial.ipynb>`__]\n10. DeepWalk and Node2Vec [:youtube:`null` `YouTube (Theory) <https://www.youtube.com/watch?v=QZQBnl1QbCQ>`__, :youtube:`null` `YouTube (Practice) <https://youtu.be/5YOcpI3dB7I>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial11/Tutorial11.ipynb>`__]\n11. Edge analysis [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=m1G7oS9hmwE>`__, :colab:`null` `Colab (Link Prediction) <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial12/Tutorial12%20GAE%20for%20link%20prediction.ipynb>`__, :colab:`null` `Colab (Label Prediction) <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial12/Tutorial12%20Node2Vec%20for%20label%20prediction.ipynb>`__]\n12. Data handling in :pyg:`PyG` (Part 1) [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=Vz5bT8Xw6Dc>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial14/Tutorial14.ipynb>`__]\n13. Data handling in :pyg:`PyG` (Part 2) [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=Q5T-JdyVCfs>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial15/Tutorial15.ipynb>`__]\n14. MetaPath2vec [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=GtPoGehuKYY>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial13/Tutorial13.ipynb>`__]\n15. Graph pooling (DiffPool) [:youtube:`null` `YouTube <https://www.youtube.com/watch?v=Uqc3O3-oXxM>`__, :colab:`null` `Colab <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial16/Tutorial16.ipynb>`__]\n"
  },
  {
    "path": "docs/source/get_started/introduction.rst",
    "content": "Introduction by Example\n=======================\n\nWe shortly introduce the fundamental concepts of :pyg:`PyG` through self-contained examples.\n\nFor an introduction to Graph Machine Learning, we refer the interested reader to the :stanford:`null` `Stanford CS224W: Machine Learning with Graphs <https://www.youtube.com/watch?v=JAB_plj2rbA>`__ lectures.\nFor an interactive introduction to :pyg:`PyG`, we recommend our carefully curated :colab:`null` `Google Colab <colabs.html>`__ notebooks.\n\nAt its core, :pyg:`PyG` provides the following main features:\n\n.. contents::\n    :local:\n\nData Handling of Graphs\n-----------------------\n\nA graph is used to model pairwise relations (edges) between objects (nodes).\nA single graph in :pyg:`PyG` is described by an instance of :class:`torch_geometric.data.Data`, which holds the following attributes by default:\n\n- :obj:`data.x`: Node feature matrix with shape :obj:`[num_nodes, num_node_features]`\n- :obj:`data.edge_index`: Graph connectivity in `COO format <https://pytorch.org/docs/stable/sparse.html#sparse-coo-docs>`_ with shape :obj:`[2, num_edges]` and type :obj:`torch.long`\n- :obj:`data.edge_attr`: Edge feature matrix with shape :obj:`[num_edges, num_edge_features]`\n- :obj:`data.y`: Target to train against (may have arbitrary shape), *e.g.*, node-level targets of shape :obj:`[num_nodes, *]` or graph-level targets of shape :obj:`[1, *]`\n- :obj:`data.pos`: Node position matrix with shape :obj:`[num_nodes, num_dimensions]`\n\nNone of these attributes are required.\nIn fact, the :class:`~torch_geometric.data.Data` object is not even restricted to these attributes.\nWe can, *e.g.*, extend it by :obj:`data.face` to save the connectivity of triangles from a 3D mesh in a tensor with shape :obj:`[3, num_faces]` and type :obj:`torch.long`.\n\n.. Note::\n    :pytorch:`PyTorch` and :obj:`torchvision` define an example as a tuple of an image and a target.\n    We omit this notation in :pyg:`PyG` to allow for various data structures in a clean and understandable way.\n\nWe show a simple example of an unweighted and undirected graph with three nodes and four edges.\nEach node contains exactly one feature:\n\n.. code-block:: python\n\n    import torch\n    from torch_geometric.data import Data\n\n    edge_index = torch.tensor([[0, 1, 1, 2],\n                               [1, 0, 2, 1]], dtype=torch.long)\n    x = torch.tensor([[-1], [0], [1]], dtype=torch.float)\n\n    data = Data(x=x, edge_index=edge_index)\n    >>> Data(edge_index=[2, 4], x=[3, 1])\n\n.. image:: ../_figures/graph.svg\n  :align: center\n  :width: 300px\n\n|\n\nNote that :obj:`edge_index`, *i.e.* the tensor defining the source and target nodes of all edges, is **not** a list of index tuples.\nIf you want to write your indices this way, you should transpose and call :obj:`contiguous` on it before passing them to the data constructor:\n\n.. code-block:: python\n\n    import torch\n    from torch_geometric.data import Data\n\n    edge_index = torch.tensor([[0, 1],\n                               [1, 0],\n                               [1, 2],\n                               [2, 1]], dtype=torch.long)\n    x = torch.tensor([[-1], [0], [1]], dtype=torch.float)\n\n    data = Data(x=x, edge_index=edge_index.t().contiguous())\n    >>> Data(edge_index=[2, 4], x=[3, 1])\n\nAlthough the graph has only two edges, we need to define four index tuples to account for both directions of a edge.\n\n.. Note::\n    You can print out your data object anytime and receive a short information about its attributes and their shapes.\n\nNote that it is necessary that the elements in :obj:`edge_index` only hold indices in the range :obj:`{ 0, ..., num_nodes - 1}`.\nThis is needed as we want our final data representation to be as compact as possible, *e.g.*, we want to index the source and destination node features of the first edge :obj:`(0, 1)` via :obj:`x[0]` and :obj:`x[1]`, respectively.\nYou can always check that your final :class:`~torch_geometric.data.Data` objects fulfill these requirements by running :meth:`~torch_geometric.data.Data.validate`:\n\n.. code-block:: python\n\n   data.validate(raise_on_error=True)\n\nBesides holding a number of node-level, edge-level or graph-level attributes, :class:`~torch_geometric.data.Data` provides a number of useful utility functions, *e.g.*:\n\n.. code-block:: python\n\n    print(data.keys())\n    >>> ['x', 'edge_index']\n\n    print(data['x'])\n    >>> tensor([[-1.0],\n                [0.0],\n                [1.0]])\n\n    for key, item in data:\n        print(f'{key} found in data')\n    >>> x found in data\n    >>> edge_index found in data\n\n    'edge_attr' in data\n    >>> False\n\n    data.num_nodes\n    >>> 3\n\n    data.num_edges\n    >>> 4\n\n    data.num_node_features\n    >>> 1\n\n    data.has_isolated_nodes()\n    >>> False\n\n    data.has_self_loops()\n    >>> False\n\n    data.is_directed()\n    >>> False\n\n    # Transfer data object to GPU.\n    device = torch.device('cuda')\n    data = data.to(device)\n\nYou can find a complete list of all methods at :class:`torch_geometric.data.Data`.\n\nCommon Benchmark Datasets\n-------------------------\n\n:pyg:`PyG` contains a large number of common benchmark datasets, *e.g.*, all Planetoid datasets (Cora, Citeseer, Pubmed), all graph classification datasets from `TUDatasets <https://chrsmrrs.github.io/datasets/>`_ and their `cleaned versions <https://github.com/nd7141/graph_datasets>`_, the QM7 and QM9 dataset, and a handful of 3D mesh/point cloud datasets like FAUST, ModelNet10/40 and ShapeNet.\n\nInitializing a dataset is straightforward.\nAn initialization of a dataset will automatically download its raw files and process them to the previously described :class:`~torch_geometric.data.Data` format.\n*E.g.*, to load the ENZYMES dataset (consisting of 600 graphs within 6 classes), type:\n\n.. code-block:: python\n\n    from torch_geometric.datasets import TUDataset\n\n    dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')\n    >>> ENZYMES(600)\n\n    len(dataset)\n    >>> 600\n\n    dataset.num_classes\n    >>> 6\n\n    dataset.num_node_features\n    >>> 3\n\nWe now have access to all 600 graphs in the dataset:\n\n.. code-block:: python\n\n    data = dataset[0]\n    >>> Data(edge_index=[2, 168], x=[37, 3], y=[1])\n\n    data.is_undirected()\n    >>> True\n\nWe can see that the first graph in the dataset contains 37 nodes, each one having 3 features.\nThere are 168/2 = 84 undirected edges and the graph is assigned to exactly one class.\nIn addition, the data object is holding exactly one graph-level target.\n\nWe can even use slices, long or bool tensors to split the dataset.\n*E.g.*, to create a 90/10 train/test split, type:\n\n.. code-block:: python\n\n    train_dataset = dataset[:540]\n    >>> ENZYMES(540)\n\n    test_dataset = dataset[540:]\n    >>> ENZYMES(60)\n\nIf you are unsure whether the dataset is already shuffled before you split, you can randomly permute it by running:\n\n.. code-block:: python\n\n    dataset = dataset.shuffle()\n    >>> ENZYMES(600)\n\nThis is equivalent of doing:\n\n.. code-block:: python\n\n    perm = torch.randperm(len(dataset))\n    dataset = dataset[perm]\n    >> ENZYMES(600)\n\nLet's try another one! Let's download Cora, the standard benchmark dataset for semi-supervised graph node classification:\n\n.. code-block:: python\n\n    from torch_geometric.datasets import Planetoid\n\n    dataset = Planetoid(root='/tmp/Cora', name='Cora')\n    >>> Cora()\n\n    len(dataset)\n    >>> 1\n\n    dataset.num_classes\n    >>> 7\n\n    dataset.num_node_features\n    >>> 1433\n\nHere, the dataset contains only a single, undirected citation graph:\n\n.. code-block:: python\n\n    data = dataset[0]\n    >>> Data(edge_index=[2, 10556], test_mask=[2708],\n             train_mask=[2708], val_mask=[2708], x=[2708, 1433], y=[2708])\n\n    data.is_undirected()\n    >>> True\n\n    data.train_mask.sum().item()\n    >>> 140\n\n    data.val_mask.sum().item()\n    >>> 500\n\n    data.test_mask.sum().item()\n    >>> 1000\n\nThis time, the :class:`~torch_geometric.data.Data` objects holds a label for each node, and additional node-level attributes: :obj:`train_mask`, :obj:`val_mask` and :obj:`test_mask`, where\n\n- :obj:`train_mask` denotes against which nodes to train (140 nodes),\n- :obj:`val_mask` denotes which nodes to use for validation, *e.g.*, to perform early stopping (500 nodes),\n- :obj:`test_mask` denotes against which nodes to test (1000 nodes).\n\nMini-batches\n------------\n\nNeural networks are usually trained in a batch-wise fashion.\n:pyg:`PyG` achieves parallelization over a mini-batch by creating sparse block diagonal adjacency matrices (defined by :obj:`edge_index`) and concatenating feature and target matrices in the node dimension.\nThis composition allows differing number of nodes and edges over examples in one batch:\n\n.. math::\n\n    \\mathbf{A} = \\begin{bmatrix} \\mathbf{A}_1 & & \\\\ & \\ddots & \\\\ & & \\mathbf{A}_n \\end{bmatrix}, \\qquad \\mathbf{X} = \\begin{bmatrix} \\mathbf{X}_1 \\\\ \\vdots \\\\ \\mathbf{X}_n \\end{bmatrix}, \\qquad \\mathbf{Y} = \\begin{bmatrix} \\mathbf{Y}_1 \\\\ \\vdots \\\\ \\mathbf{Y}_n \\end{bmatrix}\n\n:pyg:`PyG` contains its own :class:`torch_geometric.loader.DataLoader`, which already takes care of this concatenation process.\nLet's learn about it in an example:\n\n.. code-block:: python\n\n    from torch_geometric.datasets import TUDataset\n    from torch_geometric.loader import DataLoader\n\n    dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)\n    loader = DataLoader(dataset, batch_size=32, shuffle=True)\n\n    for batch in loader:\n        batch\n        >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])\n\n        batch.num_graphs\n        >>> 32\n\n:class:`torch_geometric.data.Batch` inherits from :class:`torch_geometric.data.Data` and contains an additional attribute called :obj:`batch`.\n\n:obj:`batch` is a column vector which maps each node to its respective graph in the batch:\n\n.. math::\n\n    \\mathrm{batch} = {\\begin{bmatrix} 0 & \\cdots & 0 & 1 & \\cdots & n - 2 & n -1 & \\cdots & n - 1 \\end{bmatrix}}^{\\top}\n\nYou can use it to, *e.g.*, average node features in the node dimension for each graph individually:\n\n.. code-block:: python\n\n    from torch_geometric.utils import scatter\n    from torch_geometric.datasets import TUDataset\n    from torch_geometric.loader import DataLoader\n\n    dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)\n    loader = DataLoader(dataset, batch_size=32, shuffle=True)\n\n    for data in loader:\n        data\n        >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])\n\n        data.num_graphs\n        >>> 32\n\n        x = scatter(data.x, data.batch, dim=0, reduce='mean')\n        x.size()\n        >>> torch.Size([32, 21])\n\nYou can learn more about the internal batching procedure of :pyg:`PyG`, *e.g.*, how to modify its behavior, `here <../advanced/batching.html>`__.\nFor documentation of scatter operations, we refer the interested reader to the :obj:`torch_scatter` `documentation <https://pytorch-scatter.readthedocs.io>`_.\n\nData Transforms\n---------------\n\nTransforms are a common way in :obj:`torchvision` to transform images and perform augmentation.\n:pyg:`PyG` comes with its own transforms, which expect a :class:`~torch_geometric.data.Data` object as input and return a new transformed :class:`~torch_geometric.data.Data` object.\nTransforms can be chained together using :class:`torch_geometric.transforms.Compose` and are applied before saving a processed dataset on disk (:obj:`pre_transform`) or before accessing a graph in a dataset (:obj:`transform`).\n\nLet's look at an example, where we apply transforms on the ShapeNet dataset (containing 17,000 3D shape point clouds and per point labels from 16 shape categories).\n\n.. code-block:: python\n\n    from torch_geometric.datasets import ShapeNet\n\n    dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'])\n\n    dataset[0]\n    >>> Data(pos=[2518, 3], y=[2518])\n\nWe can convert the point cloud dataset into a graph dataset by generating nearest neighbor graphs from the point clouds via transforms:\n\n.. code-block:: python\n\n    import torch_geometric.transforms as T\n    from torch_geometric.datasets import ShapeNet\n\n    dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],\n                        pre_transform=T.KNNGraph(k=6))\n\n    dataset[0]\n    >>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])\n\n.. note::\n    We use the :obj:`pre_transform` to convert the data before saving it to disk (leading to faster loading times).\n    Note that the next time the dataset is initialized it will already contain graph edges, even if you do not pass any transform.\n    If the :obj:`pre_transform` does not match with the one from the already processed dataset, you will be given a warning.\n\nIn addition, we can use the :obj:`transform` argument to randomly augment a :class:`~torch_geometric.data.Data` object, *e.g.*, translating each node position by a small number:\n\n.. code-block:: python\n\n    import torch_geometric.transforms as T\n    from torch_geometric.datasets import ShapeNet\n\n    dataset = ShapeNet(root='/tmp/ShapeNet', categories=['Airplane'],\n                        pre_transform=T.KNNGraph(k=6),\n                        transform=T.RandomJitter(0.01))\n\n    dataset[0]\n    >>> Data(edge_index=[2, 15108], pos=[2518, 3], y=[2518])\n\nYou can find a complete list of all implemented transforms at :mod:`torch_geometric.transforms`.\n\nLearning Methods on Graphs\n--------------------------\n\nAfter learning about data handling, datasets, loader and transforms in :pyg:`PyG`, it's time to implement our first graph neural network!\n\nWe will use a simple GCN layer and replicate the experiments on the Cora citation dataset.\nFor a high-level explanation on GCN, have a look at its `blog post <http://tkipf.github.io/graph-convolutional-networks/>`_.\n\nWe first need to load the Cora dataset:\n\n.. code-block:: python\n\n    from torch_geometric.datasets import Planetoid\n\n    dataset = Planetoid(root='/tmp/Cora', name='Cora')\n    >>> Cora()\n\nNote that we do not need to use transforms or a dataloader.\nNow let's implement a two-layer GCN:\n\n.. code-block:: python\n\n    import torch\n    import torch.nn.functional as F\n    from torch_geometric.nn import GCNConv\n\n    class GCN(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.conv1 = GCNConv(dataset.num_node_features, 16)\n            self.conv2 = GCNConv(16, dataset.num_classes)\n\n        def forward(self, data):\n            x, edge_index = data.x, data.edge_index\n\n            x = self.conv1(x, edge_index)\n            x = F.relu(x)\n            x = F.dropout(x, training=self.training)\n            x = self.conv2(x, edge_index)\n\n            return F.log_softmax(x, dim=1)\n\nThe constructor defines two :class:`~torch_geometric.nn.conv.GCNConv` layers which get called in the forward pass of our network.\nNote that the non-linearity is not integrated in the :obj:`conv` calls and hence needs to be applied afterwards (something which is consistent across all operators in :pyg:`PyG`).\nHere, we chose to use ReLU as our intermediate non-linearity and finally output a softmax distribution over the number of classes.\nLet's train this model on the training nodes for 200 epochs:\n\n.. code-block:: python\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    model = GCN().to(device)\n    data = dataset[0].to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\n    model.train()\n    for epoch in range(200):\n        optimizer.zero_grad()\n        out = model(data)\n        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n        loss.backward()\n        optimizer.step()\n\nFinally, we can evaluate our model on the test nodes:\n\n.. code-block:: python\n\n    model.eval()\n    pred = model(data).argmax(dim=1)\n    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()\n    acc = int(correct) / int(data.test_mask.sum())\n    print(f'Accuracy: {acc:.4f}')\n    >>> Accuracy: 0.8150\n\nThis is all it takes to implement your first graph neural network.\nThe easiest way to learn more about Graph Neural Networks is to study the examples in the :obj:`examples/` directory and to browse :mod:`torch_geometric.nn`.\nHappy hacking!\n\nExercises\n---------\n\n1. What does :obj:`edge_index.t().contiguous()` do?\n\n2. Load the :obj:`\"IMDB-BINARY\"` dataset from the :class:`~torch_geometric.datasets.TUDataset` benchmark suite and randomly split it into 80%/10%/10% training, validation and test graphs.\n\n3. What does each number of the following output mean?\n\n   .. code-block:: python\n\n       print(batch)\n       >>> DataBatch(batch=[1082], edge_index=[2, 4066], x=[1082, 21], y=[32])\n"
  },
  {
    "path": "docs/source/index.rst",
    "content": ":github_url: https://github.com/pyg-team/pytorch_geometric\n\nPyG Documentation\n=================\n\n:pyg:`null` **PyG** *(PyTorch Geometric)* is a library built upon :pytorch:`null` `PyTorch <https://pytorch.org>`_ to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.\n\nIt consists of various methods for deep learning on graphs and other irregular structures, also known as `geometric deep learning <http://geometricdeeplearning.com/>`_, from a variety of published papers.\nIn addition, it consists of easy-to-use mini-batch loaders for operating on many small and single giant graphs, `multi GPU-support <https://github.com/pyg-team/pytorch_geometric/tree/master/examples/multi_gpu>`_, `torch.compile <https://pytorch-geometric.readthedocs.io/en/latest/advanced/compile.html>`_ support, `DataPipe <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/datapipe.py>`_ support, a large number of common benchmark datasets (based on simple interfaces to create your own), and helpful transforms, both for learning on arbitrary graphs as well as on 3D meshes or point clouds.\n\n.. slack_button::\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Install PyG\n\n   install/installation\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Get Started\n\n   get_started/introduction\n   get_started/colabs\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Tutorials\n\n   tutorial/gnn_design\n   tutorial/dataset\n   tutorial/application\n   tutorial/distributed\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Advanced Concepts\n\n   advanced/batching\n   advanced/sparse_tensor\n   advanced/hgam\n   advanced/compile\n   advanced/jit\n   advanced/remote\n   advanced/graphgym\n   advanced/cpu_affinity\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Package Reference\n\n   modules/root\n   modules/nn\n   modules/data\n   modules/loader\n   modules/sampler\n   modules/datasets\n   modules/llm\n   modules/transforms\n   modules/utils\n   modules/explain\n   modules/metrics\n   modules/distributed\n   modules/contrib\n   modules/graphgym\n   modules/profile\n\n.. toctree::\n   :maxdepth: 1\n   :caption: Cheatsheets\n\n   cheatsheet/gnn_cheatsheet\n   cheatsheet/data_cheatsheet\n\n.. toctree::\n   :maxdepth: 1\n   :caption: External Resources\n\n   external/resources\n"
  },
  {
    "path": "docs/source/install/installation.rst",
    "content": "Installation\n============\n\n:pyg:`PyG` is available for :python:`Python 3.10` to :python:`Python 3.14`.\n\n.. note::\n   We do not recommend installation as a root user on your system :python:`Python`.\n   Please setup a virtual environment, *e.g.*, via `venv <https://virtualenv.pypa.io/en/latest>`_, :conda:`null` `Anaconda/Miniconda <https://conda.io/projects/conda/en/latest/user-guide/install>`_, or create a `Docker image <https://www.docker.com/>`_.\n\nQuick Start\n-----------\n\n.. raw:: html\n   :file: quick-start.html\n\nInstallation via PyPI\n---------------------\n\nFrom :pyg:`null` **PyG 2.3** onwards, you can install and use :pyg:`PyG` **without any external library** required except for :pytorch:`PyTorch`.\nFor this, simply run:\n\n.. code-block:: none\n\n   pip install torch_geometric\n\nAdditional Libraries\n~~~~~~~~~~~~~~~~~~~~\n\nIf you want to utilize the full set of features from :pyg:`PyG`, there exists several additional libraries you may want to install:\n\n* `pyg-lib <https://github.com/pyg-team/pyg-lib>`__: Heterogeneous GNN operators, graph sampling routines, and :class:`~torch_geometric.nn.conv.SplineConv` support\n* `torch-scatter <https://github.com/rusty1s/pytorch_scatter>`__: Accelerated and efficient sparse reductions\n* `torch-sparse <https://github.com/rusty1s/pytorch_sparse>`__: :class:`SparseTensor` support, see `here <https://pytorch-geometric.readthedocs.io/en/latest/advanced/sparse_tensor.html>`__\n* `torch-cluster <https://github.com/rusty1s/pytorch_cluster>`__: Graph clustering routines\n\n.. note::\n   ``torch-spline-conv`` is no longer required as a separate package.\n   Its functionality has been migrated to ``pyg-lib>=0.6.0``.\n\nThese packages come with their own CPU and GPU kernel implementations based on the :pytorch:`null` `PyTorch C++/CUDA/hip(ROCm) extension interface <https://github.com/pytorch/extension-cpp/>`_.\nFor a basic usage of :pyg:`PyG`, these dependencies are **fully optional**.\nWe recommend to start with a minimal installation, and install additional dependencies once you start to actually need them.\n\nInstallation from Wheels\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nFor ease of installation of these extensions, we provide :obj:`pip` wheels for these packages for all major OS, :pytorch:`PyTorch` and CUDA combinations, see `here <https://data.pyg.org/whl>`__:\n\n#. Ensure that at least :pytorch:`PyTorch` 1.13.0 is installed:\n\n   .. code-block:: none\n\n      python -c \"import torch; print(torch.__version__)\"\n      >>> 2.10.0\n\n#. Find the CUDA version :pytorch:`PyTorch` was installed with:\n\n   .. code-block:: none\n\n      python -c \"import torch; print(torch.version.cuda)\"\n      >>> 12.8\n\n#. Install the relevant packages:\n\n   .. code-block:: none\n\n      pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html\n\n   where :obj:`${TORCH}` and :obj:`${CUDA}` should be replaced by the specific :pytorch:`PyTorch` and CUDA versions, respectively:\n\n   * :pytorch:`PyTorch` 2.10.*: :obj:`${TORCH}=2.10.0` and :obj:`${CUDA}=cpu|cu126|cu128|cu130`\n   * :pytorch:`PyTorch` 2.9.*: :obj:`${TORCH}=2.9.0` and :obj:`${CUDA}=cpu|cu126|cu128|cu130`\n   * :pytorch:`PyTorch` 2.8.*: :obj:`${TORCH}=2.8.0` and :obj:`${CUDA}=cpu|cu126|cu128|cu129`\n   * :pytorch:`PyTorch` 2.7.*: :obj:`${TORCH}=2.7.0` and :obj:`${CUDA}=cpu|cu118|cu126|cu128`\n   * :pytorch:`PyTorch` 2.6.*: :obj:`${TORCH}=2.6.0` and :obj:`${CUDA}=cpu|cu118|cu124|cu126`\n   * :pytorch:`PyTorch` 2.5.*: :obj:`${TORCH}=2.5.0` and :obj:`${CUDA}=cpu|cu118|cu121|cu124`\n   * :pytorch:`PyTorch` 2.4.*: :obj:`${TORCH}=2.4.0` and :obj:`${CUDA}=cpu|cu118|cu121|cu124`\n   * :pytorch:`PyTorch` 2.3.*: :obj:`${TORCH}=2.3.0` and :obj:`${CUDA}=cpu|cu118|cu121`\n   * :pytorch:`PyTorch` 2.2.*: :obj:`${TORCH}=2.2.0` and :obj:`${CUDA}=cpu|cu118|cu121`\n   * :pytorch:`PyTorch` 2.1.*: :obj:`${TORCH}=2.1.0` and :obj:`${CUDA}=cpu|cu118|cu121`\n   * :pytorch:`PyTorch` 2.0.*: :obj:`${TORCH}=2.0.0` and :obj:`${CUDA}=cpu|cu117|cu118`\n   * :pytorch:`PyTorch` 1.13.*: :obj:`${TORCH}=1.13.0` and :obj:`${CUDA}=cpu|cu116|cu117`\n\n   For example, for :pytorch:`PyTorch` 2.10.* and CUDA 13.0, type:\n\n   .. code-block:: none\n\n      pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.10.0+cu130.html\n\n   For example, for :pytorch:`PyTorch` 2.9.* and CUDA 12.8, type:\n\n   .. code-block:: none\n\n      pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.9.0+cu128.html\n\n**Note:** Binaries of older versions are also provided for :pytorch:`PyTorch` 1.4.0, 1.5.0, 1.6.0, 1.7.0/1.7.1, 1.8.0/1.8.1, 1.9.0, 1.10.0/1.10.1/1.10.2, 1.11.0, 1.12.0/1.12.1, 1.13.0/1.13.1, 2.0.0/2.0.1, 2.1.0/2.1.1/2.1.2, 2.2.0/2.2.1/2.2.2, 2.3.0/2.3.1, 2.4.0/2.4.1, 2.5.0/2.5.1, 2.6.0, and 2.7.0/2.7.1 (following the same procedure).\n**For older versions, you need to explicitly specify the latest supported version number** or install via :obj:`pip install --no-index` in order to prevent a manual installation from source.\nYou can look up the latest supported version number `here <https://data.pyg.org/whl>`__.\n\n**ROCm:** The external `pyg-rocm-build repository <https://github.com/Looong01/pyg-rocm-build>`__ provides wheels and detailed instructions on how to install :pyg:`PyG` for ROCm.\nIf you have any questions about it, please open an issue `here <https://github.com/Looong01/pyg-rocm-build/issues>`__.\n\nInstallation from Source\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nIn case a specific version is not supported by `our wheels <https://data.pyg.org/whl>`_, you can alternatively install them from source:\n\n#. Ensure that your CUDA is setup correctly (optional):\n\n   #. Check if :pytorch:`PyTorch` is installed with CUDA support:\n\n      .. code-block:: none\n\n         python -c \"import torch; print(torch.cuda.is_available())\"\n         >>> True\n\n   #. Add CUDA to :obj:`$PATH` and :obj:`$CPATH` (note that your actual CUDA path may vary from :obj:`/usr/local/cuda`):\n\n      .. code-block:: none\n\n         export PATH=/usr/local/cuda/bin:$PATH\n         echo $PATH\n         >>> /usr/local/cuda/bin:...\n\n         export CPATH=/usr/local/cuda/include:$CPATH\n         echo $CPATH\n         >>> /usr/local/cuda/include:...\n\n   #. Add CUDA to :obj:`$LD_LIBRARY_PATH` on Linux and to :obj:`$DYLD_LIBRARY_PATH` on macOS (note that your actual CUDA path may vary from :obj:`/usr/local/cuda`):\n\n      .. code-block:: none\n\n         export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH\n         echo $LD_LIBRARY_PATH\n         >>> /usr/local/cuda/lib64:...\n\n         export DYLD_LIBRARY_PATH=/usr/local/cuda/lib:$DYLD_LIBRARY_PATH\n         echo $DYLD_LIBRARY_PATH\n         >>> /usr/local/cuda/lib:...\n\n   #. Verify that :obj:`nvcc` is accessible from terminal:\n\n      .. code-block:: none\n\n         nvcc --version\n         >>> 11.8\n\n   #. Ensure that :pytorch:`PyTorch` and system CUDA versions match:\n\n      .. code-block:: none\n\n         python -c \"import torch; print(torch.version.cuda)\"\n         >>> 11.8\n\n         nvcc --version\n         >>> 11.8\n\n#. Install the relevant packages:\n\n   .. code-block:: none\n\n      pip install --verbose git+https://github.com/pyg-team/pyg-lib.git\n      pip install --verbose torch_scatter\n      pip install --verbose torch_sparse\n      pip install --verbose torch_cluster\n\nIn rare cases, CUDA or :python:`Python` path problems can prevent a successful installation.\n:obj:`pip` may even signal a successful installation, but execution simply crashes with :obj:`Segmentation fault (core dumped)`.\nWe collected common installation errors in the `Frequently Asked Questions <installation.html#frequently-asked-questions>`__ subsection.\nIn case the FAQ does not help you in solving your problem, please create an `issue <https://github.com/pyg-team/pytorch_geometric/issues>`_.\nBefore, please verify that your CUDA is set up correctly by following the official `installation guide <https://docs.nvidia.com/cuda>`_.\n\nInstallation via Anaconda\n-------------------------\n\n.. warning::\n   Conda packages are no longer available since :pytorch:`PyTorch` :obj:`>2.5.0`.\n   Please use :obj:`pip` instead.\n\nFor earlier :pytorch:`PyTorch` versions (:obj:`torch<=2.5.0`), you can install :pyg:`PyG` via :conda:`null` `Anaconda <https://anaconda.org/pyg/pyg>`_ for all major OS, and CUDA combinations.\nIf you have not yet installed :pytorch:`PyTorch`, install it via :conda:`null` :obj:`conda install` as described in its `official documentation <https://pytorch.org/get-started/locally/>`_.\nGiven that you have :pytorch:`PyTorch` installed, run\n\n.. code-block:: none\n\n   conda install pyg -c pyg\n\nIf :conda:`null` :obj:`conda` does not pick up the correct CUDA version of :pyg:`PyG`, you can enforce it as follows:\n\n.. code-block:: none\n\n   conda install pyg=*=*cu* -c pyg\n\n.. _install-cugraph:\n\nAccelerating PyG with NVIDIA cuGraph GNN\n----------------------------------------\n\n:pyg:`PyG` can optionally leverage NVIDIA's `cuGraph <https://github.com/rapidsai/cugraph>`_ to accelerate neighbor sampling and achieve better scalability for multi-GPU training on large-scale graphs (2x-8x data loading speedups on billion-edge graphs).\n\nNVIDIA currently recommends the `NVIDIA PyG Container <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_ on NGC as the most reliable way to use cuGraph integration with :pyg:`PyG`. For other installation methods, refer to the `cuGraph GNN repository <https://github.com/rapidsai/cugraph-gnn>`_ and/or the `RAPIDS installation guide <https://docs.rapids.ai/install>`_.\n\n.. note::\n\n   **cuGraph GNN is optional** — all :pyg:`PyG` functionality, including multi-GPU training, works without it. However, for users with NVIDIA GPUs, cuGraph can provide significant speedups and better scalability for neighbor sampling and data loading, especially on large-scale graphs.\n\n`cuGraph <https://github.com/rapidsai/cugraph>`_ is a collection of packages focused on GPU-accelerated graph analytics including support for property graphs and scaling up to thousands of GPUs. cuGraph supports the creation and manipulation of graphs followed by the execution of scalable fast graph algorithms. It is part of the `RAPIDS <https://rapids.ai>`_ accelerated data science framework.\n\n`cuGraph GNN <https://github.com/rapidsai/cugraph-gnn>`_ is a collection of GPU-accelerated plugins that support :pytorch:`PyTorch` and :pyg:`PyG` natively through the *cuGraph-PyG* and *WholeGraph* subprojects. cuGraph GNN is built on top of cuGraph, leveraging its low-level `pylibcugraph <https://github.com/rapidsai/cugraph/tree/branch-25.08/python/pylibcugraph>`_ API and C++ primitives for sampling and other GNN operations (`libcugraph <https://github.com/rapidsai/cugraph/tree/branch-25.08/python/libcugraph>`_). It also includes the :obj:`libwholegraph` and :obj:`pylibwholegraph` libraries for high-performance distributed edgelist and embedding storage. Users have the option of working with these lower-level libraries directly, or through the higher-level API in cuGraph-PyG that directly implements the :class:`~torch_geometric.data.GraphStore`, :class:`~torch_geometric.data.FeatureStore`, :class:`~torch_geometric.loader.NodeLoader`, and :class:`~torch_geometric.loader.LinkLoader` interfaces.\n\nComplete documentation on RAPIDS graph packages, including ``cugraph``, ``cugraph-pyg``, ``pylibwholegraph``, and ``pylibcugraph`` is available on the `RAPIDS docs pages <https://docs.rapids.ai/api/cugraph/nightly/graph_support>`_.\n\nSee `rapidsai/cugraph-gnn examples on GitHub <https://github.com/rapidsai/cugraph-gnn/tree/branch-25.12/python/cugraph-pyg/cugraph_pyg/examples>`_ for fully scalable PyG example workflows.\n\nFrequently Asked Questions\n--------------------------\n\n#. :obj:`undefined symbol: **make_function_schema**`: This issue signals (1) a **version conflict** between your installed :pytorch:`PyTorch` version and the :obj:`${TORCH}` version specified to install the extension packages, or (2) a version conflict between the installed CUDA version of :pytorch:`PyTorch` and the :obj:`${CUDA}` version specified to install the extension packages.\n   Please verify that your :pytorch:`PyTorch` version and its CUDA version **match** with your installation command:\n\n   .. code-block:: none\n\n      python -c \"import torch; print(torch.__version__)\"\n      python -c \"import torch; print(torch.version.cuda)\"\n      nvcc --version\n\n   For re-installation, ensure that you do not run into any caching issues by using the :obj:`pip --force-reinstall --no-cache-dir` flags.\n   In addition, the :obj:`pip --verbose` option may help to track down any issues during installation.\n   If you still do not find any success in installation, please try to install the extension packages `from source <installation.html#installation-from-source>`__.\n"
  },
  {
    "path": "docs/source/install/quick-start.html",
    "content": "<style>\n  .quick-start {\n    display: flex;\n    flex-direction: row;\n    flex-wrap: nowrap;\n    margin-bottom: 20px;\n  }\n\n  .title-column {\n    flex-grow: 0;\n  }\n\n  .content-column {\n    flex-grow: 1;\n  }\n\n  .row {\n    display: flex;\n    flex-direction: row;\n    flex-wrap: nowrap;\n  }\n\n  .title-column div, .row div {\n    white-space: nowrap;\n  }\n\n  .title-column div {\n    padding: 14px 10px 12px 0;\n    font-weight: 700;\n  }\n\n  .row div {\n    flex-grow: 1;\n    text-align: center;\n    margin: 2px;\n    padding: 12px 0 10px 0;\n    background: #e3e3e3;\n    cursor: pointer;\n  }\n\n  .row div.selected {\n    background: rgba(59,155,239,0.7);\n    color: #ffffff;\n  }\n\n  #command {\n    margin: 2px;\n    padding: 12px 10px 10px 10px;\n  }\n\n  #command pre {\n    padding: 0;\n    margin: 0;\n    white-space: pre-wrap;\n  }\n\n</style>\n\n<div class=\"quick-start\">\n  <div class=\"title-column\">\n    <div>PyTorch</div>\n    <div>Your OS</div>\n    <div>Package</div>\n    <div>CUDA</div>\n    <div>Run:</div>\n  </div>\n  <div class=\"content-column\">\n    <div class=\"row\" id=\"torch\"></div>\n    <div class=\"row\" id=\"os\"></div>\n    <div class=\"row\" id=\"package\"></div>\n    <div class=\"row\" id=\"cuda\"></div>\n    <div class=\"row\" id=\"command\"><pre></pre></div>\n  </div>\n</div>\n\n<script>\n  var torchList = [\n    ['torch-2.10.0', 'PyTorch 2.10.*'],\n    ['torch-2.9.0', 'PyTorch 2.9.*'],\n    ['torch-2.8.0', 'PyTorch 2.8.*'],\n  ];\n\n  var osList = [\n    ['linux', 'Linux'],\n    ['mac', 'Mac'],\n    ['windows', 'Windows'],\n  ];\n\n  var packageList = [\n    ['pip', 'Pip'],\n    ['conda', 'Conda'],\n  ];\n\n  var cudaList = [\n    ['cu126', '12.6'],\n    ['cu128', '12.8'],\n    ['cu129', '12.9'],\n    ['cu130', '13.0'],\n    ['cpu', 'CPU'],\n  ];\n\n  torchList.forEach(x => $(\"#torch\").append(`<div id=\"${x[0]}\">${x[1]}</div>`));\n  osList.forEach(x => $(\"#os\").append(`<div id=\"${x[0]}\">${x[1]}</div>`));\n  packageList.forEach(x => $(\"#package\").append(`<div id=\"${x[0]}\">${x[1]}</div>`));\n  cudaList.forEach(x => $(\"#cuda\").append(`<div id=\"${x[0]}\">${x[1]}</div>`));\n\n  function updateCommand() {\n    var torch = $(\"#command\").attr(\"torch\");\n    var os = $(\"#command\").attr(\"os\");\n    var package = $(\"#command\").attr(\"package\");\n    var cuda = $(\"#command\").attr(\"cuda\");\n\n    if (os == \"mac\" && cuda != \"cpu\") {\n      $(\"#command pre\").text('# macOS binaries do not support CUDA');\n    }\n\n    else if (torch == \"torch-2.8.0\" && cuda == \"cu130\") {\n      $(\"#command pre\").text('# PyTorch version does not support CUDA 13.0');\n    }\n\n    else if (torch == \"torch-2.9.0\" && cuda == \"cu129\") {\n      $(\"#command pre\").text('# PyTorch version does not support CUDA 12.9');\n    }\n\n    else if (torch == \"torch-2.10.0\" && cuda == \"cu129\") {\n      $(\"#command pre\").text('# PyTorch version does not support CUDA 12.9');\n    }\n\n    else if (package == \"conda\") {\n      $(\"#command pre\").text('# Conda packages are no longer available since PyTorch >2.5.0. Please use pip instead.');\n    }\n\n    else {\n      $(\"#command pre\").text(`pip install torch_geometric\\n\\n# Optional dependencies:\\npip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/${$(\"#command\").attr(\"torch\")}+${$(\"#command\").attr(\"cuda\")}.html`);\n    }\n  }\n\n  $(\".quick-start .content-column .row div\").click(function() {\n    $(this).parent().children().removeClass(\"selected\");\n    $(this).addClass(\"selected\");\n    $(\"#command\").attr($(this).parent().attr(\"id\"), $(this).attr(\"id\"));\n    updateCommand();\n  });\n\n  $(\"#torch\").children().get(0).click();\n  $(\"#linux\").click();\n  $(\"#pip\").click();\n  $(\"#cpu\").click();\n\n</script>\n"
  },
  {
    "path": "docs/source/modules/contrib.rst",
    "content": "torch_geometric.contrib\n=======================\n\n.. currentmodule:: torch_geometric.contrib\n\n:obj:`torch_geometric.contrib` is a staging area for early stage experimental code.\nModules might be moved to the main library in the future.\n\n.. warning::\n\n    This module contains experimental code, which is not guaranteed to be stable.\n\n.. contents:: Contents\n    :local:\n\nConvolutional Layers\n--------------------\n\n.. currentmodule:: torch_geometric.contrib.nn.conv\n\n.. autosummary::\n   :nosignatures:\n   {% for cls in torch_geometric.contrib.nn.conv.classes %}\n     {{ cls }}\n   {% endfor %}\n\n.. automodule:: torch_geometric.contrib.nn.conv\n   :members:\n   :undoc-members:\n   :exclude-members: message, aggregate, message_and_aggregate, update, MessagePassing, training, initialize_parameters\n\nModels\n------\n\n.. currentmodule:: torch_geometric.contrib.nn.models\n\n.. autosummary::\n   :nosignatures:\n   {% for cls in torch_geometric.contrib.nn.models.classes %}\n     {{ cls }}\n   {% endfor %}\n\n.. automodule:: torch_geometric.contrib.nn.models\n   :members:\n   :undoc-members:\n   :exclude-members: message, aggregate, message_and_aggregate, update, MessagePassing, training, init_conv\n\nDatasets\n--------\n\n.. currentmodule:: torch_geometric.contrib.datasets\n\n.. autosummary::\n    :nosignatures:\n    {% for cls in torch_geometric.contrib.datasets.classes %}\n      {{ cls }}\n    {% endfor %}\n\n.. automodule:: torch_geometric.contrib.datasets\n    :members:\n    :exclude-members: download, process, processed_file_names, raw_file_names, num_classes, get\n\nTransforms\n----------\n\n.. currentmodule:: torch_geometric.contrib.transforms\n\n.. autosummary::\n    :nosignatures:\n    {% for cls in torch_geometric.contrib.transforms.classes %}\n      {{ cls }}\n    {% endfor %}\n\n.. automodule:: torch_geometric.contrib.transforms\n    :members:\n\n\nExplainer\n---------\n\n.. currentmodule:: torch_geometric.contrib.explain\n\n.. autosummary::\n    :nosignatures:\n    {% for cls in torch_geometric.contrib.explain.classes %}\n      {{ cls }}\n    {% endfor %}\n\n.. automodule:: torch_geometric.contrib.explain\n    :members:\n"
  },
  {
    "path": "docs/source/modules/data.rst",
    "content": "torch_geometric.data\n====================\n\n.. contents:: Contents\n    :local:\n\nData Objects\n------------\n\n.. currentmodule:: torch_geometric.data\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/inherited_class.rst\n\n   {% for name in torch_geometric.data.data_classes %}\n     {{ name }}\n   {% endfor %}\n\nRemote Backend Interfaces\n-------------------------\n\n.. currentmodule:: torch_geometric.data\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.data.remote_backend_classes %}\n     {{ name }}\n   {% endfor %}\n\nDatabases\n---------\n\n.. currentmodule:: torch_geometric.data\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/inherited_class.rst\n\n   {% for name in torch_geometric.data.database_classes %}\n     {{ name }}\n   {% endfor %}\n\nPyTorch Lightning Wrappers\n--------------------------\n\n.. currentmodule:: torch_geometric.data.lightning\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/only_class.rst\n\n   {% for name in torch_geometric.data.lightning.classes %}\n     {{ name }}\n   {% endfor %}\n\nHelper Functions\n----------------\n\n.. currentmodule:: torch_geometric.data\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.data.helper_functions %}\n     {{ name }}\n   {% endfor %}\n"
  },
  {
    "path": "docs/source/modules/datasets.rst",
    "content": "torch_geometric.datasets\n========================\n\n.. contents:: Contents\n    :local:\n\nHomogeneous Datasets\n--------------------\n\n.. currentmodule:: torch_geometric.datasets\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/only_class.rst\n\n   {% for name in torch_geometric.datasets.homo_datasets %}\n     {{ name }}\n   {% endfor %}\n\nHeterogeneous Datasets\n----------------------\n\n.. currentmodule:: torch_geometric.datasets\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/only_class.rst\n\n   {% for name in torch_geometric.datasets.hetero_datasets %}\n     {{ name }}\n   {% endfor %}\n\nHypergraph Datasets\n-------------------\n\n.. currentmodule:: torch_geometric.datasets\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/only_class.rst\n\n   {% for name in torch_geometric.datasets.hyper_datasets %}\n     {{ name }}\n   {% endfor %}\n\nSynthetic Datasets\n------------------\n\n.. currentmodule:: torch_geometric.datasets\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/only_class.rst\n\n   {% for name in torch_geometric.datasets.synthetic_datasets %}\n     {{ name }}\n   {% endfor %}\n\nGraph Generators\n----------------\n\n.. currentmodule:: torch_geometric.datasets.graph_generator\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/only_class.rst\n\n   {% for name in torch_geometric.datasets.graph_generator.classes %}\n     {{ name }}\n   {% endfor %}\n\nMotif Generators\n----------------\n\n.. currentmodule:: torch_geometric.datasets.motif_generator\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/only_class.rst\n\n   {% for name in torch_geometric.datasets.motif_generator.classes %}\n     {{ name }}\n   {% endfor %}\n"
  },
  {
    "path": "docs/source/modules/distributed.rst",
    "content": "torch_geometric.distributed\n===========================\n\n.. warning::\n    ``torch_geometric.distributed`` has been deprecated since 2.7.0 and will\n    no longer be maintained. For distributed training, refer to :ref:`our\n    tutorials on distributed training <distributed_tutorials>` or `cuGraph\n    examples <https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples>`_.\n\n.. currentmodule:: torch_geometric.distributed\n\n.. autosummary::\n   :nosignatures:\n   {% for cls in torch_geometric.distributed.classes %}\n     {{ cls }}\n   {% endfor %}\n\n.. automodule:: torch_geometric.distributed\n    :members:\n"
  },
  {
    "path": "docs/source/modules/explain.rst",
    "content": "torch_geometric.explain\n=======================\n\n.. currentmodule:: torch_geometric.explain\n\n.. warning::\n\n    This module is in active development and may not be stable.\n    Access requires installing :pyg:`PyG` from master.\n\n.. contents:: Contents\n    :local:\n\nPhilosophy\n----------\n\nThis module provides a set of tools to explain the predictions of a PyG model or to explain the underlying phenomenon of a dataset (see the `\"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks\" <https://arxiv.org/abs/2206.09677>`_ paper for more details).\n\nWe represent explanations using the :class:`torch_geometric.explain.Explanation` class, which is a :class:`~torch_geometric.data.Data` object containing masks for the nodes, edges, features and any attributes of the data.\n\nThe :class:`torch_geometric.explain.Explainer` class is designed to handle all explainability parameters (see the :class:`torch_geometric.explain.config.ExplainerConfig` class for more details):\n\n- which algorithm from the :class:`torch_geometric.explain.algorithm` module to use (*e.g.*, :class:`~torch_geometric.explain.algorithm.GNNExplainer`)\n- the type of explanation to compute (*e.g.*, :obj:`explanation_type=\"phenomenon\"` or :obj:`explanation_type=\"model\"`)\n- the different type of masks for node and edges (*e.g.*, :obj:`mask=\"object\"` or :obj:`mask=\"attributes\"`)\n- any postprocessing of the masks (*e.g.*, :obj:`threshold_type=\"topk\"` or :obj:`threshold_type=\"hard\"`)\n\nThis class allows the user to easily compare different explainability methods and to easily switch between different types of masks, while making sure the high-level framework stays the same.\n\nExplainer\n---------\n\n.. autoclass:: torch_geometric.explain.Explainer\n   :show-inheritance:\n   :members:\n   :special-members: __call__\n\n.. autoclass:: torch_geometric.explain.config.ExplainerConfig\n   :members:\n\n.. autoclass:: torch_geometric.explain.config.ModelConfig\n   :members:\n\n.. autoclass:: torch_geometric.explain.config.ThresholdConfig\n   :members:\n\nExplanations\n------------\n\n.. autoclass:: torch_geometric.explain.Explanation\n   :show-inheritance:\n   :members:\n\n.. autoclass:: torch_geometric.explain.HeteroExplanation\n   :show-inheritance:\n   :members:\n\nExplainer Algorithms\n--------------------\n\n.. currentmodule:: torch_geometric.explain.algorithm\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.explain.algorithm.classes %}\n     {{ name }}\n   {% endfor %}\n\nExplanation Metrics\n-------------------\n\nThe quality of an explanation can be judged by a variety of different methods.\nPyG supports the following metrics out-of-the-box:\n\n.. currentmodule:: torch_geometric.explain.metric\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.explain.metric.classes %}\n     {{ name }}\n   {% endfor %}\n"
  },
  {
    "path": "docs/source/modules/graphgym.rst",
    "content": "torch_geometric.graphgym\n========================\n\n.. contents:: Contents\n    :local:\n\nWorkflow and Register Modules\n-----------------------------\n\n.. currentmodule:: torch_geometric.graphgym\n.. autosummary::\n   :nosignatures:\n   {% for cls in torch_geometric.graphgym.classes %}\n     {{ cls }}\n   {% endfor %}\n\n.. automodule:: torch_geometric.graphgym\n   :members:\n   :exclude-members:\n\nModel Modules\n-------------\n\n.. currentmodule:: torch_geometric.graphgym.models\n.. autosummary::\n    :nosignatures:\n    {% for cls in torch_geometric.graphgym.models.classes %}\n      {{ cls }}\n    {% endfor %}\n\n.. automodule:: torch_geometric.graphgym.models\n    :members:\n    :exclude-members: forward\n\nUtility Modules\n---------------\n\n.. currentmodule:: torch_geometric.graphgym.utils\n.. autosummary::\n   :nosignatures:\n    {% for cls in torch_geometric.graphgym.utils.classes %}\n     {{ cls }}\n    {% endfor %}\n\n.. automodule:: torch_geometric.graphgym.utils\n    :members:\n    :exclude-members:\n"
  },
  {
    "path": "docs/source/modules/llm.rst",
    "content": "torch_geometric.llm\n=======================\n\n.. currentmodule:: torch_geometric.llm\n\n.. autosummary::\n   :nosignatures:\n   {% for cls in torch_geometric.llm.classes %}\n     {{ cls }}\n   {% endfor %}\n\n.. automodule:: torch_geometric.llm\n   :members:\n\n\nModels\n----------------\n\n.. currentmodule:: torch_geometric.llm.models\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.llm.models.classes %}\n     {{ name }}\n   {% endfor %}\n\nUtils\n----------------\n\n.. currentmodule:: torch_geometric.llm.utils\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.llm.utils.classes %}\n     {{ name }}\n   {% endfor %}\n"
  },
  {
    "path": "docs/source/modules/loader.rst",
    "content": "torch_geometric.loader\n======================\n\n.. currentmodule:: torch_geometric.loader\n\n.. autosummary::\n   :nosignatures:\n   {% for cls in torch_geometric.loader.classes %}\n     {{ cls }}\n   {% endfor %}\n\n.. automodule:: torch_geometric.loader\n    :members:\n"
  },
  {
    "path": "docs/source/modules/metrics.rst",
    "content": "torch_geometric.metrics\n=======================\n\n.. contents:: Contents\n    :local:\n\nLink Prediction Metrics\n-----------------------\n\n.. currentmodule:: torch_geometric.metrics\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/metrics.rst\n\n   {% for name in torch_geometric.metrics.link_pred_metrics %}\n     {{ name }}\n   {% endfor %}\n"
  },
  {
    "path": "docs/source/modules/nn.rst",
    "content": "torch_geometric.nn\n==================\n\n.. contents:: Contents\n    :local:\n\n.. autoclass:: torch_geometric.nn.sequential.Sequential\n\n.. currentmodule:: torch_geometric.nn.dense\n\n   {% for name in torch_geometric.nn.dense.lin_classes %}\n.. autoclass:: {{ name }}\n   :members:\n   {% endfor %}\n\nConvolutional Layers\n--------------------\n\n.. currentmodule:: torch_geometric.nn.conv\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/nn.rst\n\n   {% for name in torch_geometric.nn.conv.classes %}\n     {{ name }}\n   {% endfor %}\n\nAggregation Operators\n---------------------\n\n.. currentmodule:: torch_geometric.nn.aggr\n\nAggregation functions play an important role in the message passing framework and the readout functions of Graph Neural Networks.\nSpecifically, many works in the literature (`Hamilton et al. (2017) <https://arxiv.org/abs/1706.02216>`__, `Xu et al. (2018) <https://arxiv.org/abs/1810.00826>`__, `Corso et al. (2020) <https://arxiv.org/abs/2004.05718>`__, `Li et al. (2020) <https://arxiv.org/abs/2006.07739>`__, `Tailor et al. (2021) <https://arxiv.org/abs/2104.01481>`__) demonstrate that the choice of aggregation functions contributes significantly to the representational power and performance of the model.\nFor example, **mean aggregation** captures the distribution (or proportions) of elements, **max aggregation** proves to be advantageous to identify representative elements, and **sum aggregation** enables the learning of structural graph properties (`Xu et al. (2018) <https://arxiv.org/abs/1810.00826>`__).\nRecent works also show that using **multiple aggregations** (`Corso et al. (2020) <https://arxiv.org/abs/2004.05718>`__, `Tailor et al. (2021) <https://arxiv.org/abs/2104.01481>`__) and **learnable aggregations** (`Li et al. (2020) <https://arxiv.org/abs/2006.07739>`__) can potentially provide substantial improvements.\nAnother line of research studies optimization-based and implicitly-defined aggregations (`Bartunov et al. (2022) <https://arxiv.org/abs/2202.12795>`__).\nFurthermore, an interesting discussion concerns the trade-off between representational power (usually gained through learnable functions implemented as neural networks) and the formal property of permutation invariance (`Buterez et al. (2022) <https://arxiv.org/abs/2211.04952>`__).\n\nTo facilitate further experimentation and unify the concepts of aggregation within GNNs across both :class:`~torch_geometric.nn.conv.MessagePassing` and global readouts, we have made the concept of :class:`~torch_geometric.nn.aggr.Aggregation` a first-class principle in :pyg:`PyG`.\nAs of now, :pyg:`PyG` provides support for various aggregations --- from rather simple ones (*e.g.*, :obj:`mean`, :obj:`max`, :obj:`sum`), to advanced ones (*e.g.*, :obj:`median`, :obj:`var`, :obj:`std`), learnable ones (*e.g.*, :class:`~torch_geometric.nn.aggr.SoftmaxAggregation`, :class:`~torch_geometric.nn.aggr.PowerMeanAggregation`, :class:`~torch_geometric.nn.aggr.SetTransformerAggregation`), and exotic ones (*e.g.*, :class:`~torch_geometric.nn.aggr.MLPAggregation`, :class:`~torch_geometric.nn.aggr.LSTMAggregation`, :class:`~torch_geometric.nn.aggr.SortAggregation`, :class:`~torch_geometric.nn.aggr.EquilibriumAggregation`):\n\n.. code-block:: python\n\n   from torch_geometric.nn import aggr\n\n   # Simple aggregations:\n   mean_aggr = aggr.MeanAggregation()\n   max_aggr = aggr.MaxAggregation()\n\n   # Advanced aggregations:\n   median_aggr = aggr.MedianAggregation()\n\n   # Learnable aggregations:\n   softmax_aggr = aggr.SoftmaxAggregation(learn=True)\n   powermean_aggr = aggr.PowerMeanAggregation(learn=True)\n\n   # Exotic aggregations:\n   lstm_aggr = aggr.LSTMAggregation(in_channels=..., out_channels=...)\n   sort_aggr = aggr.SortAggregation(k=4)\n\nWe can then easily apply these aggregations over a batch of sets of potentially varying size.\nFor this, an :obj:`index` vector defines the mapping from input elements to their location in the output:\n\n.. code-block:: python\n\n   # Feature matrix holding 1000 elements with 64 features each:\n   x = torch.randn(1000, 64)\n\n   # Randomly assign elements to 100 sets:\n   index = torch.randint(0, 100, (1000, ))\n\n   output = mean_aggr(x, index)  #  Output shape: [100, 64]\n\nNotably, all aggregations share the same set of forward arguments, as described in detail in the :class:`torch_geometric.nn.aggr.Aggregation` base class.\n\nEach of the provided aggregations can be used within :class:`~torch_geometric.nn.conv.MessagePassing` as well as for hierarchical/global pooling to obtain graph-level representations:\n\n.. code-block:: python\n\n   import torch\n   from torch_geometric.nn import MessagePassing\n\n   class MyConv(MessagePassing):\n       def __init__(self, ...):\n           # Use a learnable softmax neighborhood aggregation:\n           super().__init__(aggr=aggr.SoftmaxAggregation(learn=True))\n\n      def forward(self, x, edge_index):\n          ....\n\n\n   class MyGNN(torch.nn.Module)\n       def __init__(self, ...):\n           super().__init__()\n\n           self.conv = MyConv(...)\n           # Use a global sort aggregation:\n           self.global_pool = aggr.SortAggregation(k=4)\n           self.classifier = torch.nn.Linear(...)\n\n        def forward(self, x, edge_index, batch):\n            x = self.conv(x, edge_index).relu()\n            x = self.global_pool(x, batch)\n            x = self.classifier(x)\n            return x\n\nIn addition, the aggregation package of :pyg:`PyG` introduces two new concepts:\nFirst, aggregations can be **resolved from pure strings** via a lookup table, following the design principles of the `class-resolver <https://github.com/cthoyt/class-resolver>`__ library, *e.g.*, by simply passing in :obj:`\"median\"` to the :class:`~torch_geometric.nn.conv.MessagePassing` module.\nThis will automatically resolve to the :obj:`~torch_geometric.nn.aggr.MedianAggregation` class:\n\n.. code-block:: python\n\n   class MyConv(MessagePassing):\n       def __init__(self, ...):\n           super().__init__(aggr=\"median\")\n\nSecondly, **multiple aggregations** can be combined and stacked via the :class:`~torch_geometric.nn.aggr.MultiAggregation` module in order to enhance the representational power of GNNs (`Corso et al. (2020) <https://arxiv.org/abs/2004.05718>`__, `Tailor et al. (2021) <https://arxiv.org/abs/2104.01481>`__):\n\n.. code-block:: python\n\n   class MyConv(MessagePassing):\n       def __init__(self, ...):\n           # Combines a set of aggregations and concatenates their results,\n           # i.e. its output will be `[num_nodes, 3 * out_channels]` here.\n           # Note that the interface also supports automatic resolution.\n           super().__init__(aggr=aggr.MultiAggregation(\n               ['mean', 'std', aggr.SoftmaxAggregation(learn=True)]))\n\nImportantly, :class:`~torch_geometric.nn.aggr.MultiAggregation` provides various options to combine the outputs of its underlying aggregations (*e.g.*, using concatenation, summation, attention, ...) via its :obj:`mode` argument.\nThe default :obj:`mode` performs concatenation (:obj:`\"cat\"`).\nFor combining via attention, we need to additionally specify the :obj:`in_channels` :obj:`out_channels`, and :obj:`num_heads`:\n\n.. code-block:: python\n\n   multi_aggr = aggr.MultiAggregation(\n       aggrs=['mean', 'std'],\n       mode='attn',\n       mode_kwargs=dict(in_channels=64, out_channels=64, num_heads=4),\n   )\n\nIf aggregations are given as a list, they will be automatically resolved to a :class:`~torch_geometric.nn.aggr.MultiAggregation`, *e.g.*, :obj:`aggr=['mean', 'std', 'median']`.\n\nFinally, we added full support for customization of aggregations into the :class:`~torch_geometric.nn.conv.SAGEConv` layer --- simply override its :obj:`aggr` argument and **utilize the power of aggregation within your GNN**.\n\n.. note::\n\n   You can read more about the :class:`torch_geometric.nn.aggr` package in this `blog post <https://medium.com/@pytorch_geometric/a-principled-approach-to-aggregations-983c086b10b3>`__.\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.nn.aggr.classes %}\n     {{ name }}\n   {% endfor %}\n\nAttention\n---------\n\n.. currentmodule:: torch_geometric.nn.attention\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/nn.rst\n\n   {% for name in torch_geometric.nn.attention.classes %}\n     {{ name }}\n   {% endfor %}\n\nNormalization Layers\n--------------------\n\n.. currentmodule:: torch_geometric.nn.norm\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.nn.norm.classes %}\n     {{ name }}\n   {% endfor %}\n\nPooling Layers\n--------------\n\n.. currentmodule:: torch_geometric.nn.pool\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.nn.pool.classes %}\n     {{ name }}\n   {% endfor %}\n\nUnpooling Layers\n----------------\n\n.. currentmodule:: torch_geometric.nn.unpool\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.nn.unpool.classes %}\n     {{ name }}\n   {% endfor %}\n\nModels\n------\n\n.. currentmodule:: torch_geometric.nn.models\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/nn.rst\n\n   {% for name in torch_geometric.nn.models.classes %}\n     {{ name }}\n   {% endfor %}\n\nKGE Models\n----------\n\n.. currentmodule:: torch_geometric.nn.kge\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.nn.kge.classes %}\n     {{ name }}\n   {% endfor %}\n\nEncodings\n---------\n\n.. currentmodule:: torch_geometric.nn.encoding\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.nn.encoding.classes %}\n     {{ name }}\n   {% endfor %}\n\nFunctional\n----------\n\n.. py:currentmodule:: torch_geometric.nn.functional\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.nn.functional.classes %}\n     {{ name }}\n   {% endfor %}\n\nDense Convolutional Layers\n--------------------------\n\n.. currentmodule:: torch_geometric.nn.dense\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.nn.dense.conv_classes %}\n     {{ name }}\n   {% endfor %}\n\nDense Pooling Layers\n--------------------\n\n.. currentmodule:: torch_geometric.nn.dense\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   {% for name in torch_geometric.nn.dense.pool_classes %}\n     {{ name }}\n   {% endfor %}\n\nModel Transformations\n---------------------\n\n.. autoclass:: torch_geometric.nn.fx.Transformer\n   :members:\n   :undoc-members:\n   :exclude-members: graph, find_by_target, find_by_name\n\n.. autofunction:: torch_geometric.nn.to_hetero_transformer.to_hetero\n\n.. autofunction:: torch_geometric.nn.to_hetero_with_bases_transformer.to_hetero_with_bases\n\nDataParallel Layers\n-------------------\n\n.. warning::\n   :class:`~torch_geometric.nn.data_parallel.DataParallel` is deprecated. Please use :class:`torch.nn.parallel.DistributedDataParallel` instead.\n\n.. automodule:: torch_geometric.nn.data_parallel\n   :members:\n\nModel Hub\n---------\n\n.. automodule:: torch_geometric.nn.model_hub\n   :members:\n\nModel Summary\n-------------\n\n.. automodule:: torch_geometric.nn.summary\n   :members:\n"
  },
  {
    "path": "docs/source/modules/profile.rst",
    "content": "torch_geometric.profile\n=======================\n\n.. currentmodule:: torch_geometric.profile\n\n.. autosummary::\n    :nosignatures:\n    {% for cls in torch_geometric.profile.classes %}\n      {{ cls }}\n    {% endfor %}\n\n.. automodule:: torch_geometric.profile\n    :members:\n    :undoc-members:\n"
  },
  {
    "path": "docs/source/modules/root.rst",
    "content": "torch_geometric\n===============\n\nTensor Objects\n--------------\n\n.. currentmodule:: torch_geometric\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n\n   Index\n   EdgeIndex\n   HashTensor\n\nFunctions\n---------\n\n.. automodule:: torch_geometric.seed\n    :members:\n\n.. automodule:: torch_geometric.home\n    :members:\n\n.. automodule:: torch_geometric._compile\n    :members:\n    :exclude-members: compile\n\n.. automodule:: torch_geometric.debug\n    :members:\n\n.. automodule:: torch_geometric.experimental\n    :members:\n"
  },
  {
    "path": "docs/source/modules/sampler.rst",
    "content": "torch_geometric.sampler\n=======================\n\n.. currentmodule:: torch_geometric.sampler\n\n.. autosummary::\n   :nosignatures:\n   {% for cls in torch_geometric.sampler.classes %}\n     {{ cls }}\n   {% endfor %}\n\n.. autoclass:: torch_geometric.sampler.BaseSampler\n   :members:\n\n.. automodule:: torch_geometric.sampler\n   :members:\n   :exclude-members: sample_from_nodes, sample_from_edges, edge_permutation, BaseSampler\n"
  },
  {
    "path": "docs/source/modules/transforms.rst",
    "content": "torch_geometric.transforms\n==========================\n\n.. contents:: Contents\n    :local:\n\nTransforms are a general way to modify and customize :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects, either by implicitly passing them as an argument to a :class:`~torch_geometric.data.Dataset`, or by applying them explicitly to individual :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` objects:\n\n.. code-block:: python\n\n   import torch_geometric.transforms as T\n   from torch_geometric.datasets import TUDataset\n\n   transform = T.Compose([T.ToUndirected(), T.AddSelfLoops()])\n\n   dataset = TUDataset(path, name='MUTAG', transform=transform)\n   data = dataset[0]  # Implicitly transform data on every access.\n\n   data = TUDataset(path, name='MUTAG')[0]\n   data = transform(data)  # Explicitly transform data.\n\nGeneral Transforms\n------------------\n\n.. currentmodule:: torch_geometric.transforms\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/only_class.rst\n\n   {% for name in torch_geometric.transforms.general_transforms %}\n     {{ name }}\n   {% endfor %}\n\nGraph Transforms\n----------------\n\n.. currentmodule:: torch_geometric.transforms\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/only_class.rst\n\n   {% for name in torch_geometric.transforms.graph_transforms %}\n     {{ name }}\n   {% endfor %}\n\nVision Transforms\n-----------------\n\n.. currentmodule:: torch_geometric.transforms\n\n.. autosummary::\n   :nosignatures:\n   :toctree: ../generated\n   :template: autosummary/only_class.rst\n\n   {% for name in torch_geometric.transforms.vision_transforms %}\n     {{ name }}\n   {% endfor %}\n"
  },
  {
    "path": "docs/source/modules/utils.rst",
    "content": "torch_geometric.utils\n=====================\n\n.. currentmodule:: torch_geometric.utils\n\n.. autosummary::\n    :nosignatures:\n    {% for cls in torch_geometric.utils.classes %}\n      {{ cls }}\n    {% endfor %}\n\n.. automodule:: torch_geometric.utils\n    :members:\n"
  },
  {
    "path": "docs/source/notes/batching.rst",
    "content": ":orphan:\n\n.. include:: ../advanced/batching.rst\n"
  },
  {
    "path": "docs/source/notes/cheatsheet.rst",
    "content": ":orphan:\n\nGNN Cheatsheet\n==============\n\n* :class:`~torch_sparse.SparseTensor`: If checked (✓), supports message passing based on :class:`torch_sparse.SparseTensor`, *e.g.*, :obj:`GCNConv(...).forward(x, adj_t)`. See `here <../advanced/sparse_tensor.html>`__ for the accompanying tutorial.\n\n* :obj:`edge_weight`: If checked (✓), supports message passing with one-dimensional edge weight information, *e.g.*, :obj:`GraphConv(...).forward(x, edge_index, edge_weight)`.\n\n* :obj:`edge_attr`: If checked (✓), supports message passing with multi-dimensional edge feature information, *e.g.*, :obj:`GINEConv(...).forward(x, edge_index, edge_attr)`.\n\n* **bipartite**: If checked (✓), supports message passing in bipartite graphs with potentially different feature dimensionalities for source and destination nodes, *e.g.*, :obj:`SAGEConv(in_channels=(16, 32), out_channels=64)`.\n\n* **static**: If checked (✓), supports message passing in static graphs, *e.g.*, :obj:`GCNConv(...).forward(x, edge_index)` with :obj:`x` having shape :obj:`[batch_size, num_nodes, in_channels]`.\n\n* **lazy**: If checked (✓), supports lazy initialization of message passing layers, *e.g.*, :obj:`SAGEConv(in_channels=-1, out_channels=64)`.\n\nGraph Neural Network Operators\n------------------------------\n\n.. list-table::\n    :widths: 40 10 10 10 10 10 10\n    :header-rows: 1\n\n    * - Name\n      - :class:`~torch_sparse.SparseTensor`\n      - :obj:`edge_weight`\n      - :obj:`edge_attr`\n      - bipartite\n      - static\n      - lazy\n{% for cls in torch_geometric.nn.conv.classes[1:] %}\n{% if not torch_geometric.nn.conv.utils.processes_heterogeneous_graphs(cls) and\n      not torch_geometric.nn.conv.utils.processes_hypergraphs(cls) and\n      not torch_geometric.nn.conv.utils.processes_point_clouds(cls) %}\n    * - :class:`~torch_geometric.nn.conv.{{ cls }}` (`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__)\n      - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %}\n{% endif %}\n{% endfor %}\n\nHeterogeneous Graph Neural Network Operators\n--------------------------------------------\n\n.. list-table::\n    :widths: 40 10 10 10 10 10 10\n    :header-rows: 1\n\n    * - Name\n      - :class:`~torch_sparse.SparseTensor`\n      - :obj:`edge_weight`\n      - :obj:`edge_attr`\n      - bipartite\n      - static\n      - lazy\n{% for cls in torch_geometric.nn.conv.classes[1:] %}\n{% if torch_geometric.nn.conv.utils.processes_heterogeneous_graphs(cls) %}\n    * - :class:`~torch_geometric.nn.conv.{{ cls }}` (`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__)\n      - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %}\n{% endif %}\n{% endfor %}\n\nHypergraph Neural Network Operators\n-----------------------------------\n\n.. list-table::\n    :widths: 40 10 10 10 10 10 10\n    :header-rows: 1\n\n    * - Name\n      - :class:`~torch_sparse.SparseTensor`\n      - :obj:`edge_weight`\n      - :obj:`edge_attr`\n      - bipartite\n      - static\n      - lazy\n{% for cls in torch_geometric.nn.conv.classes[1:] %}\n{% if torch_geometric.nn.conv.utils.processes_hypergraphs(cls) %}\n    * - :class:`~torch_geometric.nn.conv.{{ cls }}` (`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__)\n      - {% if torch_geometric.nn.conv.utils.supports_sparse_tensor(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_weights(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_edge_features(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_static_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %}\n{% endif %}\n{% endfor %}\n\nPoint Cloud Neural Network Operators\n------------------------------------\n\n.. list-table::\n    :widths: 80 10 10\n    :header-rows: 1\n\n    * - Name\n      - bipartite\n      - lazy\n{% for cls in torch_geometric.nn.conv.classes[1:] %}\n{% if torch_geometric.nn.conv.utils.processes_point_clouds(cls) %}\n    * - :class:`~torch_geometric.nn.conv.{{ cls }}` (`Paper <{{ torch_geometric.nn.conv.utils.paper_link(cls) }}>`__)\n      - {% if torch_geometric.nn.conv.utils.supports_bipartite_graphs(cls) %}✓{% endif %}\n      - {% if torch_geometric.nn.conv.utils.supports_lazy_initialization(cls) %}✓{% endif %}\n{% endif %}\n{% endfor %}\n"
  },
  {
    "path": "docs/source/notes/colabs.rst",
    "content": ":orphan:\n\n.. include:: ../get_started/colabs.rst\n"
  },
  {
    "path": "docs/source/notes/create_dataset.rst",
    "content": ":orphan:\n\n.. include:: ../tutorial/create_dataset.rst\n"
  },
  {
    "path": "docs/source/notes/create_gnn.rst",
    "content": ":orphan:\n\n.. include:: ../tutorial/create_gnn.rst\n"
  },
  {
    "path": "docs/source/notes/data_cheatsheet.rst",
    "content": ":orphan:\n\nDataset Cheatsheet\n==================\n\n.. note::\n\n    This dataset statistics table is a **work in progress**.\n    Please consider helping us filling its content by providing statistics for individual datasets.\n    See `here <https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/datasets/karate.py#L25-L37>`__ and `here <https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/datasets/tu_dataset.py#L56-L108>`__ for examples on how to do so.\n\nHomogeneous Datasets\n--------------------\n\n.. list-table::\n    :widths: 50 10 10 10 10 10\n    :header-rows: 1\n\n    * - Name\n      - #graphs\n      - #nodes\n      - #edges\n      - #features\n      - #classes/#tasks\n{% for cls in torch_geometric.datasets.homo_datasets %}\n    * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %}\n      - {%if torch_geometric.datasets.utils.has_stats(cls) %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default=1) }}{% else %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default='') }}{% endif %}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', default='') }}\n    {% for child in torch_geometric.datasets.utils.get_children(cls) %}\n    * - └─ {{ child }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', child, default=1) }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }}\n    {% endfor %}\n{% endfor %}\n\nHeterogeneous Datasets\n----------------------\n\n.. list-table::\n    :widths: 50 30 10 10\n    :header-rows: 1\n\n    * - Name\n      - #nodes/#edges\n      - #features\n      - #classes/#tasks\n{% for cls in torch_geometric.datasets.hetero_datasets %}\n    * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %}\n      -\n      -\n      -\n    {% for child in torch_geometric.datasets.utils.get_children(cls) %}\n    * - └─ **{{torch_geometric.datasets.utils.get_type(child)}} Type**: {{ child }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes/#edges', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }}\n    {% endfor %}\n{% endfor %}\n\nSynthetic Datasets\n------------------\n\n.. list-table::\n    :widths: 50 10 10 10 10 10\n    :header-rows: 1\n\n    * - Name\n      - #graphs\n      - #nodes\n      - #edges\n      - #features\n      - #classes/#tasks\n{% for cls in torch_geometric.datasets.synthetic_datasets %}\n    * - :class:`~torch_geometric.datasets.{{ cls }}` {% if torch_geometric.datasets.utils.paper_link(cls) %}(`Paper <{{ torch_geometric.datasets.utils.paper_link(cls) }}>`__){% endif %}\n      - {%if torch_geometric.datasets.utils.has_stats(cls) %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default=1) }}{% else %}{{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', default='') }}{% endif %}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', default='') }}\n    {% for child in torch_geometric.datasets.utils.get_children(cls) %}\n    * - └─ {{ child }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#graphs', child, default=1) }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#nodes', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#edges', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#features', child, default='') }}\n      - {{ torch_geometric.datasets.utils.get_stat(cls, '#classes', child, default='') }}{{ torch_geometric.datasets.utils.get_stat(cls, '#tasks', child, default='') }}\n    {% endfor %}\n{% endfor %}\n"
  },
  {
    "path": "docs/source/notes/explain.rst",
    "content": ":orphan:\n\n.. include:: ../tutorial/explain.rst\n"
  },
  {
    "path": "docs/source/notes/graphgym.rst",
    "content": ":orphan:\n\n.. include:: ../advanced/graphgym.rst\n"
  },
  {
    "path": "docs/source/notes/heterogeneous.rst",
    "content": ":orphan:\n\n.. include:: ../tutorial/heterogeneous.rst\n"
  },
  {
    "path": "docs/source/notes/installation.rst",
    "content": ":orphan:\n\n.. meta::\n   :http-equiv=refresh: 0; URL=../install/installation.html\n\nInstallation\n============\n\nThis page has moved to :doc:`/install/installation`.\n"
  },
  {
    "path": "docs/source/notes/introduction.rst",
    "content": ":orphan:\n\n.. include:: ../get_started/introduction.rst\n"
  },
  {
    "path": "docs/source/notes/jit.rst",
    "content": ":orphan:\n\n.. include:: ../advanced/jit.rst\n"
  },
  {
    "path": "docs/source/notes/load_csv.rst",
    "content": ":orphan:\n\n.. include:: ../tutorial/load_csv.rst\n"
  },
  {
    "path": "docs/source/notes/remote.rst",
    "content": ":orphan:\n\n.. include:: ../advanced/remote.rst\n"
  },
  {
    "path": "docs/source/notes/resources.rst",
    "content": ":orphan:\n\n.. include:: ../external/resources.rst\n"
  },
  {
    "path": "docs/source/notes/sparse_tensor.rst",
    "content": ":orphan:\n\n.. include:: ../advanced/sparse_tensor.rst\n"
  },
  {
    "path": "docs/source/tutorial/application.rst",
    "content": "Use-Cases & Applications\n========================\n\n.. nbgallery::\n    :name: rst-gallery\n\n    neighbor_loader\n    point_cloud\n    explain\n    shallow_node_embeddings\n    graph_transformer\n"
  },
  {
    "path": "docs/source/tutorial/compile.rst",
    "content": ":orphan:\n\n.. include:: ../advanced/compile.rst\n"
  },
  {
    "path": "docs/source/tutorial/create_dataset.rst",
    "content": "Creating Graph Datasets\n=======================\n\nAlthough :pyg:`PyG` already contains a lot of useful datasets, you may wish to create your own dataset with self-recorded or non-publicly available data.\n\nImplementing datasets by yourself is straightforward and you may want to take a look at the source code to find out how the various datasets are implemented.\nHowever, we give a brief introduction on what is needed to setup your own dataset.\n\nWe provide two abstract classes for datasets: :class:`torch_geometric.data.Dataset` and :class:`torch_geometric.data.InMemoryDataset`.\n:class:`torch_geometric.data.InMemoryDataset` inherits from :class:`torch_geometric.data.Dataset` and should be used if the whole dataset fits into CPU memory.\n\nFollowing the :obj:`torchvision` convention, each dataset gets passed a root folder which indicates where the dataset should be stored.\nWe split up the root folder into two folders: the :obj:`raw_dir`, where the dataset gets downloaded to, and the :obj:`processed_dir`, where the processed dataset is being saved.\n\nIn addition, each dataset can be passed a :obj:`transform`, a :obj:`pre_transform` and a :obj:`pre_filter` function, which are :obj:`None` by default.\nThe :obj:`transform` function dynamically transforms the data object before accessing (so it is best used for data augmentation).\nThe :obj:`pre_transform` function applies the transformation before saving the data objects to disk (so it is best used for heavy precomputation which needs to be only done once).\nThe :obj:`pre_filter` function can manually filter out data objects before saving.\nUse cases may involve the restriction of data objects being of a specific class.\n\nCreating \"In Memory Datasets\"\n-----------------------------\n\nIn order to create a :class:`torch_geometric.data.InMemoryDataset`, you need to implement four fundamental methods:\n\n* :func:`torch_geometric.data.InMemoryDataset.raw_file_names`: A list of files in the :obj:`raw_dir` which needs to be found in order to skip the download.\n\n* :func:`torch_geometric.data.InMemoryDataset.processed_file_names`: A list of files in the :obj:`processed_dir` which needs to be found in order to skip the processing.\n\n* :func:`torch_geometric.data.InMemoryDataset.download`: Downloads raw data into :obj:`raw_dir`.\n\n* :func:`torch_geometric.data.InMemoryDataset.process`: Processes raw data and saves it into the :obj:`processed_dir`.\n\nYou can find helpful methods to download and extract data in :mod:`torch_geometric.data`.\n\nThe real magic happens in the body of :meth:`~torch_geometric.data.InMemoryDataset.process`.\nHere, we need to read and create a list of :class:`~torch_geometric.data.Data` objects and save it into the :obj:`processed_dir`.\nBecause saving a huge python list is quite slow, we collate the list into one huge :class:`~torch_geometric.data.Data` object via :meth:`torch_geometric.data.InMemoryDataset.collate` before saving.\nThe collated data object concatenates all examples into one big data object and, in addition, returns a :obj:`slices` dictionary to reconstruct single examples from this object.\nFinally, we need to load these two objects in the constructor into the properties :obj:`self.data` and :obj:`self.slices`.\n\n.. note::\n\n    From :pyg:`null` **PyG >= 2.4**, the functionalities of :meth:`torch.save` and :meth:`torch_geometric.data.InMemoryDataset.collate` are unified and implemented behind :meth:`torch_geometric.data.InMemoryDataset.save`.\n    Additionally, :obj:`self.data` and :obj:`self.slices` are implicitly loaded via :meth:`torch_geometric.data.InMemoryDataset.load`.\n\nLet's see this process in a simplified example:\n\n.. code-block:: python\n\n    import torch\n    from torch_geometric.data import InMemoryDataset, download_url\n\n\n    class MyOwnDataset(InMemoryDataset):\n        def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):\n            super().__init__(root, transform, pre_transform, pre_filter)\n            self.load(self.processed_paths[0])\n            # For PyG<2.4:\n            # self.data, self.slices = torch.load(self.processed_paths[0])\n\n        @property\n        def raw_file_names(self):\n            return ['some_file_1', 'some_file_2', ...]\n\n        @property\n        def processed_file_names(self):\n            return ['data.pt']\n\n        def download(self):\n            # Download to `self.raw_dir`.\n            download_url(url, self.raw_dir)\n            ...\n\n        def process(self):\n            # Read data into huge `Data` list.\n            data_list = [...]\n\n            if self.pre_filter is not None:\n                data_list = [data for data in data_list if self.pre_filter(data)]\n\n            if self.pre_transform is not None:\n                data_list = [self.pre_transform(data) for data in data_list]\n\n            self.save(data_list, self.processed_paths[0])\n            # For PyG<2.4:\n            # torch.save(self.collate(data_list), self.processed_paths[0])\n\nCreating \"Larger\" Datasets\n--------------------------\n\nFor creating datasets which do not fit into memory, the :class:`torch_geometric.data.Dataset` can be used, which closely follows the concepts of the :obj:`torchvision` datasets.\nIt expects the following methods to be implemented in addition:\n\n* :func:`torch_geometric.data.Dataset.len`: Returns the number of examples in your dataset.\n\n* :func:`torch_geometric.data.Dataset.get`: Implements the logic to load a single graph.\n\nInternally, :meth:`torch_geometric.data.Dataset.__getitem__` gets data objects from :meth:`torch_geometric.data.Dataset.get` and optionally transforms them according to :obj:`transform`.\n\nLet's see this process in a simplified example:\n\n.. code-block:: python\n\n    import os.path as osp\n\n    import torch\n    from torch_geometric.data import Dataset, download_url\n\n\n    class MyOwnDataset(Dataset):\n        def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):\n            super().__init__(root, transform, pre_transform, pre_filter)\n\n        @property\n        def raw_file_names(self):\n            return ['some_file_1', 'some_file_2', ...]\n\n        @property\n        def processed_file_names(self):\n            return ['data_1.pt', 'data_2.pt', ...]\n\n        def download(self):\n            # Download to `self.raw_dir`.\n            path = download_url(url, self.raw_dir)\n            ...\n\n        def process(self):\n            idx = 0\n            for raw_path in self.raw_paths:\n                # Read data from `raw_path`.\n                data = Data(...)\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))\n                idx += 1\n\n        def len(self):\n            return len(self.processed_file_names)\n\n        def get(self, idx):\n            data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))\n            return data\n\nHere, each graph data object gets saved individually in :meth:`~torch_geometric.data.Dataset.process`, and is manually loaded in :meth:`~torch_geometric.data.Dataset.get`.\n\nFrequently Asked Questions\n--------------------------\n\n#. **How can I skip the execution of** :meth:`download` **and/or** :meth:`process` **?**\n\n    You can skip downloading and/or processing by just not overriding the :meth:`download` and :meth:`process` methods:\n\n    .. code-block:: python\n\n        class MyOwnDataset(Dataset):\n            def __init__(self, transform=None, pre_transform=None):\n                super().__init__(None, transform, pre_transform)\n\n#. **Do I really need to use these dataset interfaces?**\n\n    No! Just as in regular :pytorch:`PyTorch`, you do not have to use datasets, *e.g.*, when you want to create synthetic data on the fly without saving them explicitly to disk.\n    In this case, simply pass a regular python list holding :class:`torch_geometric.data.Data` objects and pass them to :class:`torch_geometric.loader.DataLoader`:\n\n    .. code-block:: python\n\n        from torch_geometric.data import Data\n        from torch_geometric.loader import DataLoader\n\n        data_list = [Data(...), ..., Data(...)]\n        loader = DataLoader(data_list, batch_size=32)\n\nExercises\n---------\n\nConsider the following :class:`~torch_geometric.data.InMemoryDataset` constructed from a list of :obj:`~torch_geometric.data.Data` objects:\n\n.. code-block:: python\n\n    class MyDataset(InMemoryDataset):\n        def __init__(self, root, data_list, transform=None):\n            self.data_list = data_list\n            super().__init__(root, transform)\n            self.load(self.processed_paths[0])\n\n        @property\n        def processed_file_names(self):\n            return 'data.pt'\n\n        def process(self):\n            self.save(self.data_list, self.processed_paths[0])\n\n1. What is the output of :obj:`self.processed_paths[0]`?\n\n2. What does :meth:`~torch_geometric.data.InMemoryDataset.save` do?\n"
  },
  {
    "path": "docs/source/tutorial/create_gnn.rst",
    "content": "Creating Message Passing Networks\n=================================\n\nGeneralizing the convolution operator to irregular domains is typically expressed as a *neighborhood aggregation* or *message passing* scheme.\nWith :math:`\\mathbf{x}^{(k-1)}_i \\in \\mathbb{R}^F` denoting node features of node :math:`i` in layer :math:`(k-1)` and :math:`\\mathbf{e}_{j,i} \\in \\mathbb{R}^D` denoting (optional) edge features from node :math:`j` to node :math:`i`, message passing graph neural networks can be described as\n\n.. math::\n  \\mathbf{x}_i^{(k)} = \\gamma^{(k)} \\left( \\mathbf{x}_i^{(k-1)}, \\bigoplus_{j \\in \\mathcal{N}(i)} \\, \\phi^{(k)}\\left(\\mathbf{x}_i^{(k-1)}, \\mathbf{x}_j^{(k-1)},\\mathbf{e}_{j,i}\\right) \\right),\n\nwhere :math:`\\bigoplus` denotes a differentiable, permutation invariant function, *e.g.*, sum, mean or max, and :math:`\\gamma` and :math:`\\phi` denote differentiable functions such as MLPs (Multi Layer Perceptrons).\n\n.. contents::\n    :local:\n\nThe \"MessagePassing\" Base Class\n-------------------------------\n\n:pyg:`PyG` provides the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation.\nThe user only has to define the functions :math:`\\phi` , *i.e.* :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message`, and :math:`\\gamma` , *i.e.* :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.update`, as well as the aggregation scheme to use, *i.e.* :obj:`aggr=\"add\"`, :obj:`aggr=\"mean\"` or :obj:`aggr=\"max\"`.\n\nThis is done with the help of the following methods:\n\n* :obj:`MessagePassing(aggr=\"add\", flow=\"source_to_target\", node_dim=-2)`: Defines the aggregation scheme to use (:obj:`\"add\"`, :obj:`\"mean\"` or :obj:`\"max\"`) and the flow direction of message passing (either :obj:`\"source_to_target\"` or :obj:`\"target_to_source\"`).\n  Furthermore, the :obj:`node_dim` attribute indicates along which axis to propagate.\n* :obj:`MessagePassing.propagate(edge_index, size=None, **kwargs)`:\n  The initial call to start propagating messages.\n  Takes in the edge indices and all additional data which is needed to construct messages and to update node embeddings.\n  Note that :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate` is not limited to exchanging messages in square adjacency matrices of shape :obj:`[N, N]` only, but can also exchange messages in general sparse assignment matrices, *e.g.*, bipartite graphs, of shape :obj:`[N, M]` by passing :obj:`size=(N, M)` as an additional argument.\n  If set to :obj:`None`, the assignment matrix is assumed to be a square matrix.\n  For bipartite graphs with two independent sets of nodes and indices, and each set holding its own information, this split can be marked by passing the information as a tuple, *e.g.* :obj:`x=(x_N, x_M)`.\n* :obj:`MessagePassing.message(...)`: Constructs messages to node :math:`i` in analogy to :math:`\\phi` for each edge :math:`(j,i) \\in \\mathcal{E}` if :obj:`flow=\"source_to_target\"` and :math:`(i,j) \\in \\mathcal{E}` if :obj:`flow=\"target_to_source\"`.\n  Can take any argument which was initially passed to :meth:`propagate`.\n  In addition, tensors passed to :meth:`propagate` can be mapped to the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name, *e.g.* :obj:`x_i` and :obj:`x_j`.\n  Note that we generally refer to :math:`i` as the central nodes that aggregates information, and refer to :math:`j` as the neighboring nodes, since this is the most common notation.\n* :obj:`MessagePassing.update(aggr_out, ...)`: Updates node embeddings in analogy to :math:`\\gamma` for each node :math:`i \\in \\mathcal{V}`.\n  Takes in the output of aggregation as first argument and any argument which was initially passed to :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`.\n\nLet us verify this by re-implementing two popular GNN variants, the `GCN layer from Kipf and Welling <https://arxiv.org/abs/1609.02907>`_ and the `EdgeConv layer from Wang et al. <https://arxiv.org/abs/1801.07829>`_.\n\nImplementing the GCN Layer\n--------------------------\n\nThe `GCN layer <https://arxiv.org/abs/1609.02907>`_ is mathematically defined as\n\n.. math::\n\n    \\mathbf{x}_i^{(k)} = \\sum_{j \\in \\mathcal{N}(i) \\cup \\{ i \\}} \\frac{1}{\\sqrt{\\deg(i)} \\cdot \\sqrt{\\deg(j)}} \\cdot \\left( \\mathbf{W}^{\\top} \\cdot \\mathbf{x}_j^{(k-1)} \\right) + \\mathbf{b},\n\nwhere neighboring node features are first transformed by a weight matrix :math:`\\mathbf{W}`, normalized by their degree, and finally summed up.\nLastly, we apply the bias vector :math:`\\mathbf{b}` to the aggregated output.\nThis formula can be divided into the following steps:\n\n1. Add self-loops to the adjacency matrix.\n2. Linearly transform node feature matrix.\n3. Compute normalization coefficients.\n4. Normalize node features in :math:`\\phi`.\n5. Sum up neighboring node features (:obj:`\"add\"` aggregation).\n6. Apply a final bias vector.\n\nSteps 1-3 are typically computed before message passing takes place.\nSteps 4-5 can be easily processed using the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` base class.\nThe full layer implementation is shown below:\n\n.. code-block:: python\n\n    import torch\n    from torch.nn import Linear, Parameter\n    from torch_geometric.nn import MessagePassing\n    from torch_geometric.utils import add_self_loops, degree\n\n    class GCNConv(MessagePassing):\n        def __init__(self, in_channels, out_channels):\n            super().__init__(aggr='add')  # \"Add\" aggregation (Step 5).\n            self.lin = Linear(in_channels, out_channels, bias=False)\n            self.bias = Parameter(torch.empty(out_channels))\n\n            self.reset_parameters()\n\n        def reset_parameters(self):\n            self.lin.reset_parameters()\n            self.bias.data.zero_()\n\n        def forward(self, x, edge_index):\n            # x has shape [N, in_channels]\n            # edge_index has shape [2, E]\n\n            # Step 1: Add self-loops to the adjacency matrix.\n            edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))\n\n            # Step 2: Linearly transform node feature matrix.\n            x = self.lin(x)\n\n            # Step 3: Compute normalization.\n            row, col = edge_index\n            deg = degree(col, x.size(0), dtype=x.dtype)\n            deg_inv_sqrt = deg.pow(-0.5)\n            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0\n            norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]\n\n            # Step 4-5: Start propagating messages.\n            out = self.propagate(edge_index, x=x, norm=norm)\n\n            # Step 6: Apply a final bias vector.\n            out = out + self.bias\n\n            return out\n\n        def message(self, x_j, norm):\n            # x_j has shape [E, out_channels]\n\n            # Step 4: Normalize node features.\n            return norm.view(-1, 1) * x_j\n\n:class:`~torch_geometric.nn.conv.GCNConv` inherits from :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` with :obj:`\"add\"` propagation.\nAll the logic of the layer takes place in its :meth:`forward` method.\nHere, we first add self-loops to our edge indices using the :meth:`torch_geometric.utils.add_self_loops` function (step 1), as well as linearly transform node features by calling the :class:`torch.nn.Linear` instance (step 2).\n\nThe normalization coefficients are derived by the node degrees :math:`\\deg(i)` for each node :math:`i` which gets transformed to :math:`1/(\\sqrt{\\deg(i)} \\cdot \\sqrt{\\deg(j)})` for each edge :math:`(j,i) \\in \\mathcal{E}`.\nThe result is saved in the tensor :obj:`norm` of shape :obj:`[num_edges, ]` (step 3).\n\nWe then call :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`, which internally calls :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message`, :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.aggregate` and :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.update`.\nWe pass the node embeddings :obj:`x` and the normalization coefficients :obj:`norm` as additional arguments for message propagation.\n\nIn the :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message` function, we need to normalize the neighboring node features :obj:`x_j` by :obj:`norm`.\nHere, :obj:`x_j` denotes a *lifted* tensor, which contains the source node features of each edge, *i.e.*, the neighbors of each node.\nNode features can be automatically lifted by appending :obj:`_i` or :obj:`_j` to the variable name.\nIn fact, any tensor can be converted this way, as long as they hold source or destination node features.\n\nThat is all that it takes to create a simple message passing layer.\nYou can use this layer as a building block for deep architectures.\nInitializing and calling it is straightforward:\n\n.. code-block:: python\n\n    conv = GCNConv(16, 32)\n    x = conv(x, edge_index)\n\nImplementing the Edge Convolution\n---------------------------------\n\nThe `edge convolutional layer <https://arxiv.org/abs/1801.07829>`_ processes graphs or point clouds and is mathematically defined as\n\n.. math::\n\n    \\mathbf{x}_i^{(k)} = \\max_{j \\in \\mathcal{N}(i)} h_{\\mathbf{\\Theta}} \\left( \\mathbf{x}_i^{(k-1)}, \\mathbf{x}_j^{(k-1)} - \\mathbf{x}_i^{(k-1)} \\right),\n\nwhere :math:`h_{\\mathbf{\\Theta}}` denotes an MLP.\nIn analogy to the GCN layer, we can use the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` class to implement this layer, this time using the :obj:`\"max\"` aggregation:\n\n.. code-block:: python\n\n    import torch\n    from torch.nn import Sequential as Seq, Linear, ReLU\n    from torch_geometric.nn import MessagePassing\n\n    class EdgeConv(MessagePassing):\n        def __init__(self, in_channels, out_channels):\n            super().__init__(aggr='max') #  \"Max\" aggregation.\n            self.mlp = Seq(Linear(2 * in_channels, out_channels),\n                           ReLU(),\n                           Linear(out_channels, out_channels))\n\n        def forward(self, x, edge_index):\n            # x has shape [N, in_channels]\n            # edge_index has shape [2, E]\n\n            return self.propagate(edge_index, x=x)\n\n        def message(self, x_i, x_j):\n            # x_i has shape [E, in_channels]\n            # x_j has shape [E, in_channels]\n\n            tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]\n            return self.mlp(tmp)\n\nInside the :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message` function, we use :obj:`self.mlp` to transform both the target node features :obj:`x_i` and the relative source node features :obj:`x_j - x_i` for each edge :math:`(j,i) \\in \\mathcal{E}`.\n\nThe edge convolution is actually a dynamic convolution, which recomputes the graph for each layer using nearest neighbors in the feature space.\nLuckily, :pyg:`PyG` comes with a GPU accelerated batch-wise k-NN graph generation method named :meth:`torch_geometric.nn.pool.knn_graph`:\n\n.. code-block:: python\n\n    from torch_geometric.nn import knn_graph\n\n    class DynamicEdgeConv(EdgeConv):\n        def __init__(self, in_channels, out_channels, k=6):\n            super().__init__(in_channels, out_channels)\n            self.k = k\n\n        def forward(self, x, batch=None):\n            edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)\n            return super().forward(x, edge_index)\n\nHere, :meth:`~torch_geometric.nn.pool.knn_graph` computes a nearest neighbor graph, which is further used to call the :meth:`forward` method of :class:`~torch_geometric.nn.conv.EdgeConv`.\n\nThis leaves us with a clean interface for initializing and calling this layer:\n\n.. code-block:: python\n\n    conv = DynamicEdgeConv(3, 128, k=6)\n    x = conv(x, batch)\n\nExercises\n---------\n\nImagine we are given the following :obj:`~torch_geometric.data.Data` object:\n\n.. code-block:: python\n\n    import torch\n    from torch_geometric.data import Data\n\n    edge_index = torch.tensor([[0, 1],\n                               [1, 0],\n                               [1, 2],\n                               [2, 1]], dtype=torch.long)\n    x = torch.tensor([[-1], [0], [1]], dtype=torch.float)\n\n    data = Data(x=x, edge_index=edge_index.t().contiguous())\n\nTry to answer the following questions related to :class:`~torch_geometric.nn.conv.GCNConv`:\n\n1. What information does :obj:`row` and :obj:`col` hold?\n\n2. What does :meth:`~torch_geometric.utils.degree` do?\n\n3. Why do we use :obj:`degree(col, ...)` rather than :obj:`degree(row, ...)`?\n\n4. What does :obj:`deg_inv_sqrt[col]` and :obj:`deg_inv_sqrt[row]` do?\n\n5. What information does :obj:`x_j` hold in the :meth:`~torch_geometric.nn.conv.MessagePassing.message` function? If :obj:`self.lin` denotes the identity function, what is the exact content of :obj:`x_j`?\n\n6. Add an :meth:`~torch_geometric.nn.conv.MessagePassing.update` function to :class:`~torch_geometric.nn.conv.GCNConv` that adds transformed central node features to the aggregated output.\n\nTry to answer the following questions related to :class:`~torch_geometric.nn.conv.EdgeConv`:\n\n1. What is :obj:`x_i` and :obj:`x_j - x_i`?\n\n2. What does :obj:`torch.cat([x_i, x_j - x_i], dim=1)` do? Why :obj:`dim = 1`?\n"
  },
  {
    "path": "docs/source/tutorial/dataset.rst",
    "content": "Working with Graph Datasets\n===========================\n\n.. nbgallery::\n    :name: rst-gallery\n\n    create_dataset\n    load_csv\n    dataset_splitting\n"
  },
  {
    "path": "docs/source/tutorial/dataset_splitting.rst",
    "content": "Dataset Splitting\n=================\n\nDataset splitting is a critical step in graph machine learning, where we divide our dataset into subsets for training, validation, and testing.\nIt ensures that our models are evaluated properly, preventing overfitting, and enabling generalization.\nIn this tutorial, we will explore the basics of dataset splitting, focusing on three fundamental tasks: node prediction, link prediction, and graph prediction.\nWe will introduce commonly used techniques, including :class:`~torch_geometric.transforms.RandomNodeSplit` and :class:`~torch_geometric.transforms.RandomLinkSplit` transformations.\nAdditionally, we will also cover how to create custom dataset splits beyond random ones.\n\nNode Prediction\n---------------\n\n.. note::\n\n    In this section, we'll learn how to use :class:`~torch_geometric.transforms.RandomNodeSplit` of :pyg:`PyG` to randomly divide nodes into training, validation, and test sets.\n    A fully working example on dataset :class:`~torch_geometric.datasets.Planetoid` is available in `examples/cora.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/cora.py>`_.\n\nThe :class:`~torch_geometric.transforms.RandomNodeSplit` is initialized to split nodes for both a :pyg:`PyG` :class:`~torch_geometric.data.Data` and :class:`~torch_geometric.data.HeteroData` object.\n\n* :obj:`split` defines the dataset's split type.\n* :obj:`num_splits` defines the number of splits to add.\n* :obj:`num_train_per_class` defines the number of training nodes per class.\n* :obj:`num_val` defines the number of validation nodes after data splitting.\n* :obj:`num_test` defines the number of test nodes after data splitting.\n* :obj:`key` defines the name of the ground-truth labels.\n\n.. code-block:: python\n\n    import torch\n    from torch_geometric.data import Data\n    from torch_geometric.transforms import RandomNodeSplit\n\n    x = torch.randn(8, 32)  # Node features of shape [num_nodes, num_features]\n    y = torch.randint(0, 4, (8, ))  # Node labels of shape [num_nodes]\n    edge_index = torch.tensor([\n        [2, 3, 3, 4, 5, 6, 7],\n        [0, 0, 1, 1, 2, 3, 4]],\n    )\n\n    #   0  1\n    #  / \\/ \\\n    # 2  3  4\n    # |  |  |\n    # 5  6  7\n\n    data = Data(x=x, y=y, edge_index=edge_index)\n    node_transform = RandomNodeSplit(num_val=2, num_test=3)\n    node_splits = node_transform(data)\n\nHere, we initialize a :class:`~torch_geometric.transforms.RandomNodeSplit` transformation to split the graph data by nodes.\nAfter the transformation, :obj:`train_mask`, :obj:`valid_mask` and :obj:`test_mask` will be attached to the graph data.\n\n.. code-block:: python\n\n    node_splits.train_mask\n    >>> tensor([ True, False, False, False, True, True, False, False])\n    node_splits.val_mask\n    >>> tensor([False, False, False, False, False, False, True, True])\n    node_splits.test_mask\n    >>> tensor([False, True, True, True, False, False, False, False])\n\nIn this example, there are 8 nodes, we want to sample 2 nodes for validation, 3 nodes for testing, and the rest for training.\nFinally, we got node :obj:`0, 4, 5` as training set, node :obj:`6, 7` as validation set, and node :obj:`1, 2, 3` as test set.\n\nLink Prediction\n---------------\n\n.. note::\n\n    In this section, we'll learn how to use :class:`~torch_geometric.transforms.RandomLinkSplit` of :pyg:`PyG` to randomly divide edges into training, validation, and test sets.\n    A fully working example on dataset :class:`~torch_geometric.datasets.Planetoid` is available in `examples/link_pred.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/link_pred.py>`_.\n\nThe :class:`~torch_geometric.transforms.RandomLinkSplit` is initialized to split edges for both a :pyg:`PyG` :class:`~torch_geometric.data.Data` and :class:`~torch_geometric.data.HeteroData` object.\n\n* :obj:`num_val` defines the number of validation edges after data splitting.\n* :obj:`num_test` defines the number of test edges after data splitting.\n* :obj:`is_undirected` defines whether the graph is assumed as undirected.\n\n.. code-block:: python\n\n    import torch\n    from torch_geometric.data import Data\n    from torch_geometric.transforms import RandomLinkSplit\n\n    x = torch.randn(8, 32)  # Node features of shape [num_nodes, num_features]\n    y = torch.randint(0, 4, (8, ))  # Node labels of shape [num_nodes]\n    edge_index = torch.tensor([\n        [2, 3, 3, 4, 5, 6, 7],\n        [0, 0, 1, 1, 2, 3, 4]],\n    )\n\n    edge_y = torch.tensor([0, 0, 0, 0, 1, 1, 1])\n    #   0  1\n    #  / \\/ \\\n    # 2  3  4\n    # |  |  |\n    # 5  6  7\n\n    data = Data(x=x, y=y, edge_index=edge_index, edge_y=edge_y)\n    edge_transform = RandomLinkSplit(num_val=0.2, num_test=0.2, key='edge_y',\n                                    is_undirected=False, add_negative_train_samples=False)\n    train_data, val_data, test_data = edge_transform(data)\n\nSimilar to node splitting, we initialize a :class:`~torch_geometric.transforms.RandomLinkSplit` transformation to split the graph data by edges.\nBelow, we can see the splitting results.\n\n.. code-block:: python\n\n    train_data\n    >>> Data(x=[8, 32], edge_index=[2, 5], y=[8], edge_y=[5], edge_y_index=[2, 5])\n    val_data\n    >>> Data(x=[8, 32], edge_index=[2, 5], y=[8], edge_y=[2], edge_y_index=[2, 2])\n    test_data\n    >>> Data(x=[8, 32], edge_index=[2, 6], y=[8], edge_y=[2], edge_y_index=[2, 2])\n\n:obj:`train_data.edge_index` and :obj:`val_data.edge_index` refers to the edges that are used for message passing.\nAs such, during training and validation, we are allowed to propagate information based on the training edges.\nWhile during testing, we can propagate information based on the union of training and validation edges.\nFor evaluation and testing, :obj:`val_data.edge_label_index` and :obj:`test_data.edge_label_index` hold a batch of positive and negative samples that should be used to evaluate and test our model on.\n\nGraph Prediction\n----------------\n\n.. note::\n\n    In this section, we'll learn how to randomly divide graphs into training, validation, and test sets.\n    A fully working example on dataset :class:`~torch_geometric.datasets.PPI` is available in `examples/ppi.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ppi.py>`_.\n\nIn graph prediction task, each graph is an independent sample.\nUsually we need to divide a graph dataset according to a certain ratio.\n:pyg:`PyG` has provided some datasets that already contain corresponding indexes for training, validation and test, such as :class:`~torch_geometric.datasets.PPI`.\n\n.. code-block:: python\n\n    from torch_geometric.datasets import PPI\n\n    path = './data/PPI'\n    train_dataset = PPI(path, split='train')\n    val_dataset = PPI(path, split='val')\n    test_dataset = PPI(path, split='test')\n\nIn addition, we can also use :obj:`scikit-learn` or :obj:`numpy` to randomly divide :pyg:`PyG` dataset.\n\nCreating Custom Splits\n----------------------\n\nIf random splitting doesn't suit our specific use case, then we can create custom node splits.\nThis requirement generally occurs in real business scenarios.\nFor example, there are large-scale heterogeneous graphs in e-commerce scenarios, and nodes can be used to represent users, products, merchants, etc.\nWe may divide new and old users to evaluate the performance of the model on new users.\nTherefore, we'll not post specific examples here for reference.\n"
  },
  {
    "path": "docs/source/tutorial/distributed.rst",
    "content": ".. _distributed_tutorials:\n\nDistributed Training\n====================\n\n.. nbgallery::\n    :name: rst-gallery\n\n    multi_gpu_vanilla\n    multi_node_multi_gpu_vanilla\n    distributed_pyg\n"
  },
  {
    "path": "docs/source/tutorial/distributed_pyg.rst",
    "content": "Distributed Training in PyG\n===========================\n\n.. warning::\n    ``torch_geometric.distributed`` has been deprecated and will no longer be maintained.\n    For distributed training with cuGraph, refer to `cuGraph examples <https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples>`_.\n\n.. figure:: ../_figures/intel_kumo.png\n   :width: 400px\n\n.. note::\n    We are thrilled to announce the first **in-house distributed training solution** for :pyg:`PyG` via :class:`torch_geometric.distributed`, available from version 2.5 onwards.\n    Developers and researchers can now take full advantage of distributed training on large-scale datasets which cannot be fully loaded in memory of one machine at the same time.\n    This implementation doesn't require any additional packages to be installed on top of the default :pyg:`PyG` stack.\n\nIn real life applications, graphs often consists of billions of nodes that cannot fit into a single system memory.\nThis is when distributed training of Graph Neural Networks comes in handy.\nBy allocating a number of partitions of the large graph into a cluster of CPUs, one can deploy synchronized model training on the whole dataset at once by making use of :pytorch:`PyTorch's` `Distributed Data Parallel (DDP) <https://pytorch.org/docs/stable/notes/ddp.html>`_ capabilities.\nThis architecture seamlessly distributes training of Graph Neural Networks across multiple nodes via `Remote Procedure Calls (RPCs) <https://pytorch.org/docs/stable/rpc.html>`_ for efficient sampling and retrieval of non-local features with traditional DDP for model training.\nThis new technique in :pyg:`PyG` was produced by engineers from `Intel <https://intel.com>`_ and `Kumo AI <https://kumo.ai/>`_.\n\n\nKey Advantages\n--------------\n\n#. **Balanced graph partitioning** via METIS ensures minimal communication overhead when sampling subgraphs across compute nodes.\n#. Utilizing **DDP for model training** in conjunction with **RPC for remote sampling and feature fetching routines** (with TCP/IP protocol and `gloo <https://github.com/facebookincubator/gloo>`_ communication backend) allows for data parallelism with distinct data partitions at each node.\n#. The implementation via custom :class:`~torch_geometric.data.GraphStore` and :class:`~torch_geometric.data.FeatureStore` APIs provides a flexible and tailored interface for distributing large graph structure information and feature storage.\n#. **Distributed neighbor sampling** is capable of sampling in both local and remote partitions through RPC communication channels.\n   All advanced functionality of single-node sampling are also applicable for distributed training, *e.g.*, heterogeneous sampling, link-level sampling, temporal sampling, *etc*.\n#. **Distributed data loaders** offer a high-level abstraction for managing sampler processes, ensuring simplicity and seamless integration with standard :pyg:`PyG`  data loaders.\n#. Incorporating the Python `asyncio <https://docs.python.org/3/library/asyncio.html>`_ library for asynchronous processing on top of :pytorch:`PyTorch`-based RPCs further enhances the system's responsiveness and overall performance.\n\nArchitecture Components\n-----------------------\n\n.. note::\n    The purpose of this tutorial is to guide you through the most important steps of deploying distributed training applications in :pyg:`PyG`.\n    For code examples, please refer to `examples/distributed/pyg <https://github.com/pyg-team/pytorch_geometric/tree/master/examples/distributed/pyg>`_.\n\nOverall, :class:`torch_geometric.distributed` is divided into the following components:\n\n* :class:`~torch_geometric.distributed.Partitoner` partitions the graph into multiple parts, such that each node only needs to load its local data in memory.\n* :class:`~torch_geometric.distributed.LocalGraphStore` and :class:`~torch_geometric.distributed.LocalFeatureStore` store the graph topology and features per partition, respectively.\n  In addition, they maintain a mapping between local and global IDs for efficient assignment of nodes and feature lookup.\n* :class:`~torch_geometric.distributed.DistNeighborSampler`  implements the distributed sampling algorithm, which includes local+remote sampling and the final merge between local/remote sampling results based on :pytorch:`PyTorch's` RPC mechanisms.\n* :class:`~torch_geometric.distributed.DistNeighborLoader` manages the distributed neighbor sampling and feature fetching processes via multiple RPC workers.\n  Finally, it takes care to form sampled nodes, edges, and their features into the classic :pyg:`PyG` data format.\n\n.. figure:: ../_figures/dist_proc.png\n   :align: center\n   :width: 100%\n\n   Schematic breakdown of the main components of :class:`torch_geometric.distributed`.\n\nGraph Partitioning\n~~~~~~~~~~~~~~~~~~\n\nThe first step for distributed training is to split the graph into multiple smaller portions,  which can then be loaded locally into nodes of the cluster.\nPartitioning is built on top of :pyg:`null` :obj:`pyg-lib`'s `implementation <https://pyg-lib.readthedocs.io/en/latest/modules/partition.html#pyg_lib.partition.metis>`_ of the METIS algorithm, suitable to perform graph partitioning efficiently, even on large-scale graphs.\nNote that METIS requires undirected, homogeneous graphs as input.\n:class:`~torch_geometric.distributed.Partitoner` performs necessary processing steps to partition heterogeneous data objects with correct distribution and indexing.\n\nBy default, METIS tries to balance the number of nodes of each type in each partition while minimizing the number of edges between partitions.\nThis ensures that the resulting partitions provide maximal local access of neighbors, enabling samplers to perform local computations without the need for communication between different compute nodes.\nThrough this partitioning approach, every node receives a distinct assignment, while \"halo nodes\" (1-hop neighbors that fall into a different partition) are replicated.\nHalo nodes ensure that neighbor sampling for a single node in a single layer stays purely local.\n\n.. figure:: ../_figures/dist_part.png\n   :align: center\n   :width: 100%\n\n   Graph partitioning with halo nodes.\n\nIn our distributed training example, we prepared the `partition_graph.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/distributed/pyg/partition_graph.py>`_ script to demonstrate how to apply partitioning on a selected subset of both homogeneous and heterogeneous graphs.\nThe :class:`~torch_geometric.distributed.Partitioner` can also preserve node features, edge features, and any temporal attributes at the level of nodes and edges.\nLater on, each node in the cluster then owns a single partition of this graph.\n\n.. warning::\n    Partitioning via METIS is non-deterministic and as such may differ between iterations.\n    However, all compute nodes should access the same partition data.\n    Therefore, generate the partitions on one node and copy the data to all members of the cluster, or place the folder into a shared location.\n\nThe resulting structure of partitioning for a two-part split on the homogeneous :obj:`ogbn-products` is shown below:\n\n.. code-block:: none\n\n    partitions\n    └─ obgn-products\n       ├─ ogbn-products-partitions\n       │  ├─ part_0\n       │  ├─ part_1\n       │  ├─ META.json\n       │  ├─ node_map.pt\n       │  └─ edge_map.pt\n       ├─ ogbn-products-label\n       │  └─ label.pt\n       ├─ ogbn-products-test-partitions\n       │  ├─ partition0.pt\n       │  └─ partition1.pt\n       └─ ogbn-products-train-partitions\n          ├─ partition0.pt\n          └─ partition1.pt\n\nDistributed Data Storage\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nTo maintain distributed data partitions, we utilize instantiations of :pyg:`PyG's` :class:`~torch_geometric.data.GraphStore` and :class:`~torch_geometric.data.FeatureStore` remote interfaces.\nTogether with an integrated API for sending and receiving RPC requests, they provide a powerful tool for inter-connected distributed data storage.\nBoth stores can be filled with data in a number of ways, *e.g.*, from :class:`~torch_geometric.data.Data` and :class:`~torch_geometric.data.HeteroData` objects or initialized directly from generated partition files.\n\n:class:`~torch_geometric.distributed.LocalGraphStore` is a class designed to act as a **container for graph topology information**.\nIt holds the edge indices that define relationships between nodes in a graph.\nIt offers methods that provide mapping information for nodes and edges to individual partitions and support both homogeneous and heterogeneous data formats.\n\n**Key Features:**\n\n* It only stores information about local graph connections and its halo nodes within a partition.\n* Remote connectivity: The affiliation information of individual nodes and edges to partitions (both local and global) can be retrieved through node and edge \"partition books\", *i.e.* mappings of partition IDs to global node/edge IDs.\n* It maintains global identifiers for nodes and edges, allowing for consistent mapping across partitions.\n\n:class:`~torch_geometric.distributed.LocalFeatureStore` is a class that serves as both a **node-level and edge-level feature storage**.\nIt provides efficient :obj:`put` and :obj:`get` routines for attribute retrieval for both local and remote node/edge IDs.\nThe :class:`~torch_geometric.distributed.LocalFeatureStore` is responsible for retrieving and updating features across different partitions and machines during the training process.\n\n**Key Features:**\n\n* It provides functionalities for storing, retrieving, and distributing node and edge features.\n  Within the managed partition of a machine, node and edge features are stored locally.\n* Remote feature lookup: It implements mechanisms for looking up features in both local and remote nodes during distributed training processes through RPC requests.\n  The class is designed to work seamlessly in distributed training scenarios, allowing for efficient feature handling across partitions.\n* It maintains global identifiers for nodes and edges, allowing for consistent mapping across partitions.\n\nBelow is an example of how :class:`~torch_geometric.distributed.LocalFeatureStore` is used internally to retrieve both local+remote features:\n\n.. code-block:: python\n\n    import torch\n    from torch_geometric.distributed import LocalFeatureStore\n    from torch_geometric.distributed.event_loop import to_asyncio_future\n\n    feature_store = LocalFeatureStore(...)\n\n    async def get_node_features():\n        # Create a `LocalFeatureStore` instance:\n\n        # Retrieve node features for specific node IDs:\n        node_id = torch.tensor([1])\n        future = feature_store.lookup_features(node_id)\n\n        return await to_asyncio_future(future)\n\nDistributed Neighbor Sampling\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n:class:`~torch_geometric.distributed.DistNeighborSampler` is a class designed for efficient distributed training of Graph Neural Networks.\nIt addresses the challenges of sampling neighbors in a distributed environment, whereby graph data is partitioned across multiple machines or devices.\nThe sampler ensures that GNNs can effectively learn from large-scale graphs, maintaining scalability and performance.\n\n**Asynchronous Neighbor Sampling and Feature Collection:**\n\nDistributed neighbor sampling is implemented using asynchronous :class:`torch.distributed.rpc` calls.\nIt allows machines to independently sample neighbors without strict synchronization.\nEach machine autonomously selects neighbors from its local graph partition, without waiting for others to complete their sampling processes.\nThis approach enhances parallelism, as machines can progress asynchronously, and leads to faster training.\nIn addition to asynchronous sampling, distributed neighbor sampling also provides asynchronous feature collection.\n\n**Customizable Sampling Strategies:**\n\nUsers can customize neighbor sampling strategies based on their specific requirements.\nThe :class:`~torch_geometric.distributed.DistNeighborSampler` class provides full flexibility in defining sampling techniques, such as:\n\n* Node sampling vs. edge sampling\n* Homogeneous vs. heterogeneous sampling\n* Temporal sampling vs. static sampling\n\n**Distributed Neighbor Sampling Workflow:**\n\nA batch of seed nodes follows three main steps before it is made available for the model's :meth:`forward` pass by the data loader:\n\n#. **Distributed node sampling:** While the underlying principles of neighbor sampling holds for the distributed case as well, the implementation slightly diverges from single-machine sampling.\n   In distributed training, seed nodes can belong to different partitions, leading to simultaneous sampling on multiple machines for a single batch.\n   Consequently, synchronization of sampling results across machines is necessary to obtain seed nodes for the subsequent layer, requiring modifications to the basic algorithm.\n   For nodes within a local partition, the sampling occurs on the local machine.\n   Conversely, for nodes associated with a remote partition, the neighbor sampling is conducted on the machine responsible for storing the respective partition.\n   Sampling then happens layer-wise, where sampled nodes act as seed nodes in follow-up layers.\n#. **Distributed feature lookup:** Each partition stores an array of features of nodes and edges that are within that partition.\n   Consequently, if the output of a sampler on a specific machine includes sampled nodes or edges which do not pertain in its partition, the machine initiates an RPC request to a remote server which these nodes (or edges) belong to.\n#. **Data conversion:** Based on the sampler output and the acquired node (or edge) features, a :pyg:`PyG` :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object is created.\n   This object forms a batch used in subsequent computational operations of the model.\n\n.. figure:: ../_figures/dist_sampling.png\n   :align: center\n   :width: 450px\n\n   Local and remote neighbor sampling.\n\nDistributed Data Loading\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nDistributed data loaders such as :class:`~torch_geometric.distributed.DistNeighborLoader` and :class:`~torch_geometric.distributed.DistLinkNeighborLoader` provide a simple API for the sampling engine described above because they entirely wrap initialization and cleanup of sampler processes internally.\nNotably, the distributed data loaders inherit from the standard :pyg:`PyG` single-node :class:`~torch_geometric.loader.NodeLoader` and :class:`~torch_geometric.loader.LinkLoader` loaders, making their application inside training scripts nearly identically.\n\nBatch generation is slightly different from the single-node case in that the step of (local+remote) feature fetching happens within the sampler, rather than encapsulated into two separate steps (sampling->feature fetching).\nThis allows limiting the amount of RPCs.\nDue to the asynchronous processing between all sampler sub-processes, the samplers then return their output to a :class:`torch.multiprocessing.Queue`.\n\nSetting up Communication using DDP & RPC\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nIn this distributed training implementation two :class:`torch.distributed` communication technologies are used:\n\n* :class:`torch.distributed.rpc` for remote sampling calls and distributed feature retrieval\n* :class:`torch.distributed.ddp` for data parallel model training\n\nOur solution opts for :class:`torch.distributed.rpc` over alternatives such as gRPC because :pytorch:`PyTorch` RPC inherently comprehends tensor-type data.\nUnlike other RPC methods, which require the serialization or digitization of JSON or other user data into tensor types, using this method helps avoid additional serialization and digitization overhead.\n\nThe DDP group is initialzied in a standard way in the main training script:\n\n.. code-block:: python\n\n    torch.distributed.init_process_group(\n        backend='gloo',\n        rank=current_ctx.rank,\n        world_size=current_ctx.world_size,\n        init_method=f'tcp://{master_addr}:{ddp_port}',\n    )\n\n.. note::\n    For CPU-based sampling we recommended the `gloo <https://github.com/facebookincubator/gloo>`_ communication backend.\n\nRPC group initialization is more complicated because it happens in each sampler subprocess, which is achieved via the :meth:`~torch.utils.data.DataLoader.worker_init_fn` of the data loader, which is called by :pytorch:`PyTorch` directly at the initialization step of worker processes.\nThis function first defines a distributed context for each worker and assigns it a group and rank, subsequently initializes its own distributed neighbor sampler, and finally registers a new member in the RPC group.\nThis RPC connection remains open as long as the subprocess exists.\nAdditionally, we opted for the `atexit <https://docs.python.org/3/library/atexit.html>`_ module to register additional cleanup behaviors that are triggered when the process is terminated.\n\nResults and Performance\n-----------------------\n\nWe collected the benchmarking results on :pytorch:`PyTorch` 2.1 using the system configuration at the bottom of this blog.\nThe below table shows the scaling performance on the :obj:`ogbn-products` dataset of a :class:`~torch_geometric.nn.models.GraphSAGE` model under different partition configurations (1/2/4/8/16).\n\n.. list-table::\n   :widths: 15 15 15 15\n   :header-rows: 1\n\n   * - #Partitions\n     - :obj:`batch_size=1024`\n     - :obj:`batch_size=4096`\n     - :obj:`batch_size=8192`\n   * - 1\n     - 98s\n     - 47s\n     - 38s\n   * - 2\n     - 45s\n     - 30s\n     - 24s\n   * - 4\n     - 38s\n     - 21s\n     - 16s\n   * - 8\n     - 29s\n     - 14s\n     - 10s\n   * - 16\n     - 22s\n     - 13s\n     - 9s\n\n* **Hardware:** 2x Intel(R) Xeon(R) Platinum 8360Y CPU @ 2.40GHz, 36 cores, HT On, Turbo On, NUMA 2, Integrated Accelerators Available [used]: DLB 0 [0], DSA 0 [0], IAA 0 [0], QAT 0 [0], Total Memory 256GB (16x16GB DDR4 3200 MT/s [3200 MT/s]), BIOS SE5C620.86B.01.01.0003.2104260124, microcode 0xd000389, 2x Ethernet Controller X710 for 10GbE SFP+, 1x MT28908 Family [ConnectX-6], 1x 894.3G INTEL SSDSC2KG96, Rocky Linux 8.8 (Green Obsidian), 4.18.0-477.21.1.el8_8.x86_64\n* **Software:** :python:`Python` 3.9, :pytorch:`PyTorch` 2.1, :pyg:`PyG` 2.5, :pyg:`null` :obj:`pyg-lib` 0.4.0\n"
  },
  {
    "path": "docs/source/tutorial/explain.rst",
    "content": "Explaining Graph Neural Networks\n================================\n\nInterpreting GNN models is crucial for many use cases.\n:pyg:`PyG` (2.3 and beyond) provides the :class:`torch_geometric.explain` package for first-class GNN explainability support that currently includes\n\n#. a flexible interface to generate a variety of explanations via the :class:`~torch_geometric.explain.Explainer` class,\n\n#. several underlying explanation algorithms including, *e.g.*, :class:`~torch_geometric.explain.algorithm.GNNExplainer`,  :class:`~torch_geometric.explain.algorithm.PGExplainer` and :class:`~torch_geometric.explain.algorithm.CaptumExplainer`,\n\n#. support to visualize explanations via the :class:`~torch_geometric.explain.Explanation` or the :class:`~torch_geometric.explain.HeteroExplanation` class,\n\n#. and metrics to evaluate explanations via the :class:`~torch_geometric.explain.metric` package.\n\n.. warning::\n\n   The explanation APIs discussed here may change in the future as we continuously work to improve their ease-of-use and generalizability.\n\nExplainer Interface\n-------------------\n\nThe :class:`torch_geometric.explain.Explainer` class is designed to handle all explainability parameters (see the :class:`~torch_geometric.explain.config.ExplainerConfig` class for more details):\n\n#. which algorithm from the :class:`torch_geometric.explain.algorithm` module to use (*e.g.*, :class:`~torch_geometric.explain.algorithm.GNNExplainer`)\n\n#. the type of explanation to compute, *i.e.* :obj:`explanation_type=\"phenomenon\"` to explain the underlying phenomenon of a dataset, and :obj:`explanation_type=\"model\"` to explain the prediction of a GNN model (see the `\"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks\" <https://arxiv.org/abs/2206.09677>`_ paper for more details).\n\n#. the different type of masks for node and edges (*e.g.*, :obj:`mask=\"object\"` or :obj:`mask=\"attributes\"`)\n\n#. any postprocessing of the masks (*e.g.*, :obj:`threshold_type=\"topk\"` or :obj:`threshold_type=\"hard\"`)\n\nThis class allows the user to easily compare different explainability methods and to easily switch between different types of masks, while making sure the high-level framework stays the same.\nThe :class:`~torch_geometric.explain.Explainer` generates an :class:`~torch_geometric.explain.Explanation` or :class:`~torch_geometric.explain.HeteroExplanation` object which contains the final information about which nodes, edges and features are crucial to explain a GNN model.\n\n.. note::\n\n   You can read more about the :class:`torch_geometric.explain` package in this `blog post <https://medium.com/@pytorch_geometric/graph-machine-learning-explainability-with-pyg-ff13cffc23c2>`__.\n\nExamples\n--------\n\nIn what follows, we discuss a few use-cases with corresponding code examples.\n\nExplaining node classification on a homogeneous graph\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nAssume we have a GNN :obj:`model` that does node classification on a homogeneous graph.\nWe can use the :class:`torch_geometric.explain.algorithm.GNNExplainer` algorithm to generate an :class:`~torch_geometric.explain.Explanation`.\nWe configure the :class:`~torch_geometric.explain.Explainer` to use both a :obj:`node_mask_type` and an :obj:`edge_mask_type` such that the final :class:`~torch_geometric.explain.Explanation` object contains (1) a :obj:`node_mask` (indicating which nodes and features are crucial for prediction), and (2) an :obj:`edge_mask` (indicating which edges are crucial for prediction).\n\n.. code-block:: python\n\n    from torch_geometric.data import Data\n    from torch_geometric.explain import Explainer, GNNExplainer\n\n    data = Data(...)  # A homogeneous graph data object.\n\n    explainer = Explainer(\n        model=model,\n        algorithm=GNNExplainer(epochs=200),\n        explanation_type='model',\n        node_mask_type='attributes',\n        edge_mask_type='object',\n        model_config=dict(\n            mode='multiclass_classification',\n            task_level='node',\n            return_type='log_probs',  # Model returns log probabilities.\n        ),\n    )\n\n    # Generate explanation for the node at index `10`:\n    explanation = explainer(data.x, data.edge_index, index=10)\n    print(explanation.edge_mask)\n    print(explanation.node_mask)\n\nFinally, we can visualize both feature importance and the crucial subgraph of the explanation:\n\n.. code-block:: python\n\n    explanation.visualize_feature_importance(top_k=10)\n\n    explanation.visualize_graph()\n\nTo evaluate the explanation from the :class:`~torch_geometric.explain.algorithm.GNNExplainer`, we can utilize the :class:`torch_geometric.explain.metric` module.\nFor example, to compute the :meth:`~torch_geometric.explain.metric.unfaithfulness` of an explanation, run:\n\n.. code-block:: python\n\n    from torch_geometric.explain import unfaithfulness\n\n    metric = unfaithfulness(explainer, explanation)\n    print(metric)\n\nExplaining node classification on a heterogeneous graph\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nAssume we have a heterogeneous GNN :obj:`model` that does node classification on a heterogeneous graph.\nWe can use the :class:`IntegratedGradient` attribution method from :captum:`null` `Captum <https://captum.ai/docs/extension/integrated_gradients>`__ via the :class:`torch_geometric.explain.algorithm.CaptumExplainer` algorithm to generate a :class:`~torch_geometric.explain.HeteroExplanation`.\n\n.. note::\n    :class:`~torch_geometric.explain.algorithm.CaptumExplainer` is a wrapper around the :captum:`null` `Captum <https://captum.ai>`__ library with support for most of attribution methods to explain *any* homogeneous or heterogeneous :pyg:`PyG` model.\n\nWe configure the :class:`~torch_geometric.explain.Explainer` to use both a :obj:`node_mask_type` and an :obj:`edge_mask_type` such that the final :class:`~torch_geometric.explain.HeteroExplanation` object contains (1) a :obj:`node_mask` for *each* node type (indicating which nodes and features for each node type are crucial for prediction), and (2) an :obj:`edge_mask` for *each* edge type (indicating which edges for each edge type are crucial for prediction).\n\n.. code-block:: python\n\n    from torch_geometric.data import HeteroData\n    from torch_geometric.explain import Explainer, CaptumExplainer\n\n    hetero_data = HeteroData(...)  # A heterogeneous graph data object.\n\n    explainer = Explainer(\n        model,  # It is assumed that model outputs a single tensor.\n        algorithm=CaptumExplainer('IntegratedGradients'),\n        explanation_type='model',\n        node_mask_type='attributes',\n        edge_mask_type='object',\n        model_config = dict(\n            mode='multiclass_classification',\n            task_level=task_level,\n            return_type='probs',  # Model returns probabilities.\n        ),\n    )\n\n    # Generate batch-wise heterogeneous explanations for\n    # the nodes at index `1` and `3`:\n    hetero_explanation = explainer(\n        hetero_data.x_dict,\n        hetero_data.edge_index_dict,\n        index=torch.tensor([1, 3]),\n    )\n    print(hetero_explanation.edge_mask_dict)\n    print(hetero_explanation.node_mask_dict)\n\nExplaining graph regression on a homogeneous graph\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nAssume we have a GNN :obj:`model` that does graph regression on a homogeneous graph.\nWe can use the :class:`torch_geometric.explain.algorithm.PGExplainer` algorithm to generate an :class:`~torch_geometric.explain.Explanation`.\nWe configure the :class:`~torch_geometric.explain.Explainer` to use an :obj:`edge_mask_type` such that the final :class:`~torch_geometric.explain.Explanation` object contains an :obj:`edge_mask` (indicating which edges are crucial for prediction).\nImportantly, passing a :obj:`node_mask_type` to the :class:`~torch_geometric.explain.Explainer` will throw an error since :class:`~torch_geometric.explain.algorithm.PGExplainer` cannot explain the importance of nodes:\n\n.. code-block:: python\n\n    from torch_geometric.data import Data\n    from torch_geometric.explain import Explainer, PGExplainer\n\n    dataset = ...\n    loader = DataLoader(dataset, batch_size=1, shuffle=True)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=PGExplainer(epochs=30, lr=0.003),\n        explanation_type='phenomenon',\n        edge_mask_type='object',\n        model_config=dict(\n            mode='regression',\n            task_level='graph',\n            return_type='raw',\n        ),\n        # Include only the top 10 most important edges:\n        threshold_config=dict(threshold_type='topk', value=10),\n    )\n\n    # PGExplainer needs to be trained separately since it is a parametric\n    # explainer i.e it uses a neural network to generate explanations:\n    for epoch in range(30):\n        for batch in loader:\n            loss = explainer.algorithm.train(\n                epoch, model, batch.x, batch.edge_index, target=batch.target)\n\n    # Generate the explanation for a particular graph:\n    explanation = explainer(dataset[0].x, dataset[0].edge_index)\n    print(explanation.edge_mask)\n\nSince this feature is still undergoing heavy development, please feel free to reach out to the :pyg:`PyG` core team either on :github:`null` `GitHub <https://github.com/pyg-team/pytorch_geometric/discussions>`_ or :slack:`null` `Slack <https://data.pyg.org/slack.html>`_ if you have any questions, comments or concerns.\n"
  },
  {
    "path": "docs/source/tutorial/gnn_design.rst",
    "content": "Design of Graph Neural Networks\n===============================\n\n.. nbgallery::\n    :name: rst-gallery\n\n    create_gnn\n    heterogeneous\n"
  },
  {
    "path": "docs/source/tutorial/graph_transformer.rst",
    "content": "Graph Transformer\n=================\n\n`Transformer <https://arxiv.org/abs/1706.03762>`_ is an effictive architecture in `natural language processing <https://arxiv.org/abs/1810.04805>`_ and `computer vision <https://arxiv.org/abs/2010.11929>`_.\nRecently, there have been some applications(`Grover <https://arxiv.org/abs/2007.02835>`_, `GraphGPS <https://arxiv.org/abs/2205.12454>`_, etc) that combine transformers on graphs.\nIn this tutorial, we will present how to build a graph transformer model via :pyg:`PyG`. See `our webinar <https://youtu.be/wAYryx3GjLw?si=2vB7imfenP5tUvqd>`_ for in-depth learning on this topic.\n\n.. note::\n    Click `here <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_gps.py>`_ to download the full example code\n\nTransformers on Graphs\n------------------------------\n\nCompared to Graph Transformers, MPNNs have several drawbacks: (1) WL test: 1-order MPNNs have limited expressivity; (2) Over-smoothing: the features tend to\nconverge to the same value while increasing the number of GNN layers; (3) Over-squashing: Losing information when trying to aggregate messages from many neighbors into a single vector;\n(4) Cannot capture long-range dependencies.\n\nFeeding the whole graph into the Transformer also brings several pros and cons\n\n**Pros**\n\n* Computation graph structure is decoupled from the input graph structure.\n* Long-range connections can be handled as all nodes are connected to each other.\n\n**Cons**\n\n* Loss of inductive bias that enables GNNs to work so well on graphs with pronounced locality. Particularly in graphs where edges represent relatedness/closeness.\n* Language input is squential, but graphs are permutation invariant to node ordering.\n* Square computational complexity :math:`O(N^2)` in the number of nodes whereas message passing GNNs are linear in the number of edges :math:`O(E)`. Graphs are often sparse :math:`N \\approx E`.\n\nAttention\n+++++++++\n\n.. math::\n    Q = XW_Q, K = XW_K, V = XW_V\n.. math::\n    Attention(Q, K, V) = softmax(\\frac{QK^T}{\\sqrt{d_k}})V\n\nIn Transformer, attention can be multi-head, which consists of multiple attention weights.\n\nPositional and Structural Encodings\n+++++++++++++++++++++++++++++++++++\n\nWe organized PE/SE into 3 categories based on their locality: (1) Local, (2) Global, (3) Relative.\nPositional encodings (PE) provides an idea of the position in space of a given node within the graph. When two nodes are close to each other within a graph or subgraph, their PE should also be close.\nStructure encodings (SE) provides an embedding of the structure of graphs or subgraphs to help increasing the expressivity and the generalizability of GNNs.\nWhen two nodes share similar subgraphs, or when two graphs are similar, their SE should also be close.\n\n.. list-table::\n   :widths: 10 20 20\n   :header-rows: 1\n\n   * - Encoding type\n     - Positional encodings (PE)\n     - Structure encodings (SE)\n   * - Local (node)\n     - (1)Distance to cluster center; (2)Sum of non-diagonal elements in m-step random walk.\n     - (1)Node degree; (2)Random walk diagonals; (3) Enumerate substructures(triangles, rings).\n   * - Global (node)\n     - (1)Eigenvectors of A/L or distance matrix; (2)Distance to graph's centroid; (3)Unique ID for each node.\n     - (1)Eigenvalues of A/L; (2) Graph diameter, girth, degree, etc.\n   * - Relative (edge)\n     - (1)Pair-wise distance from: Heat Kernels, Random Walks, Graph geodesic, etc; (2)Gradient of eigenvectors\n     - (1)Gradient of any Local SE; (2)Gradient of sub-structure enumeration\n\nGPS Layer and GraphGPS Model\n--------------------------------------\n\nFirstly, we introduce the GPS layer, which is combined with local MPNN and global Transformer, and then followed by 2-layer MLP and skip-connecttions.\nLocal MPNN can provide locality bias that is difficult or expensive to achieve in Transformer.\nIn addition, features of edges can be updated and encoded into the node features(`GatedGCN <https://arxiv.org/abs/1711.07553>`_, `GINE <https://arxiv.org/abs/1905.12265>`_).\nTransformer can utilize positional and structural encodings. As we don't need to consider edge features, We can use the existing linear Transformer architecture to reduce the time complexity from :math:`O(N^2)` to :math:`O(N + E)`, like `Performer <https://arxiv.org/abs/2009.14794>`_ and `BigBird <https://arxiv.org/abs/2007.14062>`_.\n\n.. warning::\n    `BigBird <https://arxiv.org/abs/2007.14062>`_ currently is not supported, will be added in the future.\n\n.. figure:: ../_figures/graphgps_layer.png\n    :align: center\n    :width: 100%\n\nThe update function of each layer is described by the equations below.\n\nLocal MPNN\n++++++++++\n\n.. math::\n    \\hat{X}_M^{l + 1}, E^{l + 1} = MPNN_e^l(X^l, E^l, A)\n.. math::\n    X_M^{l + 1} = BatchNorm(Dropout(\\hat{X}_M^{l + 1}) + X^l)\n\n.. code-block:: python\n\n    h = self.conv(x, edge_index, **kwargs)\n    h = F.dropout(h, p=self.dropout, training=self.training)\n    h = h + x\n    if self.norm1 is not None:\n        if self.norm_with_batch:\n            h = self.norm1(h, batch=batch)\n        else:\n            h = self.norm1(h)\n    hs.append(h)\n\nGlobal Attention\n++++++++++++++++\n\n.. math::\n    \\hat{X}_T^{l + 1} = GlobalAttn^l(X^l)\n.. math::\n    X_T^{l + 1} = BatchNorm(Dropout(\\hat{X}_T^{l + 1}) + X^l)\n\n.. code-block:: python\n\n    h, mask = to_dense_batch(x, batch)\n\n    if isinstance(self.attn, torch.nn.MultiheadAttention):\n        h, _ = self.attn(h, h, h, key_padding_mask=~mask,\n                        need_weights=False)\n    elif isinstance(self.attn, PerformerAttention):\n        h = self.attn(h, mask=mask)\n\n    h = h[mask]\n    h = F.dropout(h, p=self.dropout, training=self.training)\n    h = h + x  # Residual connection.\n    if self.norm2 is not None:\n        if self.norm_with_batch:\n            h = self.norm2(h, batch=batch)\n        else:\n            h = self.norm2(h)\n    hs.append(h)\n\nCombine local and global outputs\n++++++++++++++++++++++++++++++++\n\n.. math::\n    X^{l + 1} = MLP^l(X_M^{l + 1} + X_T^{l + 1})\n\n.. code-block:: python\n\n    out = sum(hs)\n\n    out = out + self.mlp(out)\n    if self.norm3 is not None:\n        if self.norm_with_batch:\n            out = self.norm3(out, batch=batch)\n        else:\n            out = self.norm3(out)\n\nNext, we introduce GraphGPS architecture. The difference between `GraphGPS <https://arxiv.org/abs/2205.12454>`_ and `GraphTrans <https://arxiv.org/abs/2201.08821>`_ is the organization of MPNN and transformer.\nIn GraphTrans, a few layers of MPNNs are comprised before the Transformer, which may be limited by problems of over-smoothing, over-squashing and low expressivity against the WL test.\nThese layers could irreparably fail to keep some information in the early stage. The design of GraphGPS is a stacking of MPNN + transformer hybrid, which resolves\nthe local expressivity bottlenecks by allowing information to spread across the graph via full-connectivity.\n\nTrain GraphGPS on graph-structured data\n--------------------------------------------------\n\nIn this part, we'll show how to train a :class:`~torch_geometric.nn.GPSConv` GNN model on the :class:`~torch_geometric.datasets.ZINC` dataset.\n\nLoad dataset\n++++++++++++\n\n.. code-block:: python\n\n    transform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')\n    train_dataset = ZINC(path, subset=True, split='train', pre_transform=transform)\n    val_dataset = ZINC(path, subset=True, split='val', pre_transform=transform)\n    test_dataset = ZINC(path, subset=True, split='test', pre_transform=transform)\n\n    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n    val_loader = DataLoader(val_dataset, batch_size=64)\n    test_loader = DataLoader(test_dataset, batch_size=64)\n\n\nDefine model\n++++++++++++\n\n.. code-block:: python\n\n    class RedrawProjection:\n        def __init__(self, model: torch.nn.Module,\n                    redraw_interval: Optional[int] = None):\n            self.model = model\n            self.redraw_interval = redraw_interval\n            self.num_last_redraw = 0\n\n        def redraw_projections(self):\n            if not self.model.training or self.redraw_interval is None:\n                return\n            if self.num_last_redraw >= self.redraw_interval:\n                fast_attentions = [\n                    module for module in self.model.modules()\n                    if isinstance(module, PerformerAttention)\n                ]\n                for fast_attention in fast_attentions:\n                    fast_attention.redraw_projection_matrix()\n                self.num_last_redraw = 0\n                return\n            self.num_last_redraw += 1\n\n    class GPS(torch.nn.Module):\n        def __init__(self, channels: int, pe_dim: int, num_layers: int,\n                    attn_type: str, attn_kwargs: Dict[str, Any]):\n            super().__init__()\n\n            self.node_emb = Embedding(28, channels - pe_dim)\n            self.pe_lin = Linear(20, pe_dim)\n            self.pe_norm = BatchNorm1d(20)\n            self.edge_emb = Embedding(4, channels)\n\n            self.convs = ModuleList()\n            for _ in range(num_layers):\n                nn = Sequential(\n                    Linear(channels, channels),\n                    ReLU(),\n                    Linear(channels, channels),\n                )\n                conv = GPSConv(channels, GINEConv(nn), heads=4,\n                            attn_type=attn_type, attn_kwargs=attn_kwargs)\n                self.convs.append(conv)\n\n            self.mlp = Sequential(\n                Linear(channels, channels // 2),\n                ReLU(),\n                Linear(channels // 2, channels // 4),\n                ReLU(),\n                Linear(channels // 4, 1),\n            )\n            self.redraw_projection = RedrawProjection(\n                self.convs,\n                redraw_interval=1000 if attn_type == 'performer' else None)\n\n        def forward(self, x, pe, edge_index, edge_attr, batch):\n            x_pe = self.pe_norm(pe)\n            x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1)\n            edge_attr = self.edge_emb(edge_attr)\n\n            for conv in self.convs:\n                x = conv(x, edge_index, batch, edge_attr=edge_attr)\n            x = global_add_pool(x, batch)\n            return self.mlp(x)\n\n\n\nTrain and evaluate\n+++++++++++++++++++\n\n.. code-block:: python\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    attn_kwargs = {'dropout': 0.5}\n    model = GPS(channels=64, pe_dim=8, num_layers=10, attn_type=args.attn_type,\n                attn_kwargs=attn_kwargs).to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)\n    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,\n                                min_lr=0.00001)\n\n\n    def train():\n        model.train()\n\n        total_loss = 0\n        for data in train_loader:\n            data = data.to(device)\n            optimizer.zero_grad()\n            model.redraw_projection.redraw_projections()\n            out = model(data.x, data.pe, data.edge_index, data.edge_attr,\n                        data.batch)\n            loss = (out.squeeze() - data.y).abs().mean()\n            loss.backward()\n            total_loss += loss.item() * data.num_graphs\n            optimizer.step()\n        return total_loss / len(train_loader.dataset)\n\n\n    @torch.no_grad()\n    def test(loader):\n        model.eval()\n\n        total_error = 0\n        for data in loader:\n            data = data.to(device)\n            out = model(data.x, data.pe, data.edge_index, data.edge_attr,\n                        data.batch)\n            total_error += (out.squeeze() - data.y).abs().sum().item()\n        return total_error / len(loader.dataset)\n\n\n    for epoch in range(1, 101):\n        loss = train()\n        val_mae = test(val_loader)\n        test_mae = test(test_loader)\n        scheduler.step(val_mae)\n        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '\n            f'Test: {test_mae:.4f}')\n\n.. code-block:: text\n\n    Epoch: 01, Loss: 0.7216, Val: 0.5316, Test: 0.5454\n    Epoch: 02, Loss: 0.5519, Val: 0.5895, Test: 0.6288\n    Epoch: 03, Loss: 0.5009, Val: 0.5029, Test: 0.4924\n    Epoch: 04, Loss: 0.4751, Val: 0.4801, Test: 0.4786\n    Epoch: 05, Loss: 0.4363, Val: 0.4438, Test: 0.4352\n    Epoch: 06, Loss: 0.4276, Val: 0.4931, Test: 0.4994\n    Epoch: 07, Loss: 0.3956, Val: 0.3502, Test: 0.3439\n    Epoch: 08, Loss: 0.4021, Val: 0.3143, Test: 0.3296\n    Epoch: 09, Loss: 0.3761, Val: 0.4012, Test: 0.3858\n    Epoch: 10, Loss: 0.3739, Val: 0.3343, Test: 0.3032\n    Epoch: 11, Loss: 0.3532, Val: 0.3679, Test: 0.3334\n    Epoch: 12, Loss: 0.3683, Val: 0.3094, Test: 0.2754\n    Epoch: 13, Loss: 0.3457, Val: 0.4007, Test: 0.4023\n    Epoch: 14, Loss: 0.3460, Val: 0.3986, Test: 0.3589\n    Epoch: 15, Loss: 0.3369, Val: 0.3478, Test: 0.3124\n    Epoch: 16, Loss: 0.3222, Val: 0.3043, Test: 0.2651\n    Epoch: 17, Loss: 0.3190, Val: 0.4496, Test: 0.4070\n    Epoch: 18, Loss: 0.3317, Val: 0.3803, Test: 0.3450\n    Epoch: 19, Loss: 0.3179, Val: 0.2671, Test: 0.2408\n    Epoch: 20, Loss: 0.3143, Val: 0.4168, Test: 0.3901\n    Epoch: 21, Loss: 0.3238, Val: 0.3183, Test: 0.2926\n    Epoch: 22, Loss: 0.3132, Val: 0.9534, Test: 1.0879\n    Epoch: 23, Loss: 0.3088, Val: 0.3705, Test: 0.3360\n    Epoch: 24, Loss: 0.3032, Val: 0.3051, Test: 0.2692\n    Epoch: 25, Loss: 0.2968, Val: 0.2829, Test: 0.2571\n    Epoch: 26, Loss: 0.2915, Val: 0.3145, Test: 0.2820\n    Epoch: 27, Loss: 0.2871, Val: 0.3127, Test: 0.2965\n    Epoch: 28, Loss: 0.2953, Val: 0.4415, Test: 0.4144\n    Epoch: 29, Loss: 0.2916, Val: 0.3118, Test: 0.2733\n    Epoch: 30, Loss: 0.3074, Val: 0.4497, Test: 0.4418\n"
  },
  {
    "path": "docs/source/tutorial/heterogeneous.rst",
    "content": "Heterogeneous Graph Learning\n============================\n\nA large set of real-world datasets are stored as heterogeneous graphs, motivating the introduction of specialized functionality for them in :pyg:`PyG`.\nFor example, most graphs in the area of recommendation, such as social graphs, are heterogeneous, as they store information about different types of entities and their different types of relations.\nThis tutorial introduces how heterogeneous graphs are mapped to :pyg:`PyG` and how they can be used as input to Graph Neural Network models.\n\nHeterogeneous graphs come with different types of information attached to nodes and edges.\nThus, a single node or edge feature tensor cannot hold all node or edge features of the whole graph, due to differences in type and dimensionality.\nInstead, a set of types need to be specified for nodes and edges, respectively, each having its own data tensors.\nAs a consequence of the different data structure, the message passing formulation changes accordingly, allowing the computation of message and update function conditioned on node or edge type.\n\nExample Graph\n-------------\n\nAs a guiding example, we take a look at the heterogeneous `ogbn-mag <https://ogb.stanford.edu/docs/nodeprop>`__ network from the :ogb:`null` `dataset suite <https://ogb.stanford.edu>`_:\n\n.. image:: ../_figures/hg_example.svg\n  :align: center\n  :width: 500px\n\nThe given heterogeneous graph has 1,939,743 nodes, split between the four node types **author**, **paper**, **institution** and **field of study**.\nIt further has 21,111,007 edges, which also are of one of four types:\n\n* **writes**: An author *writes* a specific paper\n* **affiliated with**: An author is *affiliated with* a specific institution\n* **cites**: A paper *cites* another paper\n* **has topic**: A paper *has a topic* of a specific field of study\n\nThe task for this graph is to infer the venue of each paper (conference or journal) given the information stored in the graph.\n\nCreating Heterogeneous Graphs\n-----------------------------\n\nFirst, we can create a data object of type :class:`torch_geometric.data.HeteroData`, for which we define node feature tensors, edge index tensors and edge feature tensors individually for each type:\n\n.. code-block:: python\n\n   from torch_geometric.data import HeteroData\n\n   data = HeteroData()\n\n   data['paper'].x = ... # [num_papers, num_features_paper]\n   data['author'].x = ... # [num_authors, num_features_author]\n   data['institution'].x = ... # [num_institutions, num_features_institution]\n   data['field_of_study'].x = ... # [num_field, num_features_field]\n\n   data['paper', 'cites', 'paper'].edge_index = ... # [2, num_edges_cites]\n   data['author', 'writes', 'paper'].edge_index = ... # [2, num_edges_writes]\n   data['author', 'affiliated_with', 'institution'].edge_index = ... # [2, num_edges_affiliated]\n   data['paper', 'has_topic', 'field_of_study'].edge_index = ... # [2, num_edges_topic]\n\n   data['paper', 'cites', 'paper'].edge_attr = ... # [num_edges_cites, num_features_cites]\n   data['author', 'writes', 'paper'].edge_attr = ... # [num_edges_writes, num_features_writes]\n   data['author', 'affiliated_with', 'institution'].edge_attr = ... # [num_edges_affiliated, num_features_affiliated]\n   data['paper', 'has_topic', 'field_of_study'].edge_attr = ... # [num_edges_topic, num_features_topic]\n\nNode or edge tensors will be automatically created upon first access and indexed by string keys.\nNode types are identified by a single string while edge types are identified by using a triplet :obj:`(source_node_type, edge_type, destination_node_type)` of strings: the edge type identifier and the two node types between which the edge type can exist.\nAs such, the data object allows different feature dimensionalities for each type.\n\nDictionaries containing the heterogeneous information grouped by attribute names rather than by node or edge type can directly be accessed via :obj:`data.{attribute_name}_dict` and serve as input to GNN models later:\n\n.. code-block:: python\n\n    model = HeteroGNN(...)\n\n    output = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict)\n\nIf the dataset exists in the `list of Pytorch Geometric datasets <https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html>`_, it can directly be imported and used.\nIn particular, it will be downloaded to :obj:`root` and processed automatically.\n\n.. code-block:: python\n\n    from torch_geometric.datasets import OGB_MAG\n\n    dataset = OGB_MAG(root='./data', preprocess='metapath2vec')\n    data = dataset[0]\n\nThe :obj:`data` object can be printed for verification.\n\n.. code-block:: text\n\n    HeteroData(\n      paper={\n        x=[736389, 128],\n        y=[736389],\n        train_mask=[736389],\n        val_mask=[736389],\n        test_mask=[736389]\n      },\n      author={ x=[1134649, 128] },\n      institution={ x=[8740, 128] },\n      field_of_study={ x=[59965, 128] },\n      (author, affiliated_with, institution)={ edge_index=[2, 1043998] },\n      (author, writes, paper)={ edge_index=[2, 7145660] },\n      (paper, cites, paper)={ edge_index=[2, 5416271] },\n      (paper, has_topic, field_of_study)={ edge_index=[2, 7505078] }\n    )\n\n.. note::\n\n   The original `ogbn-mag <https://ogb.stanford.edu/docs/nodeprop>`__ network does only provide features for \"paper\" nodes.\n   In :class:`~torch_geometric.datasets.OGB_MAG`, we provide the option to download a processed version of it in which structural features (obtained from either :obj:`\"metapath2vec\"` or :obj:`\"TransE\"`) are added to featureless nodes, as it is commonly done in the top ranked submissions to the `OGB leaderboards <https://ogb.stanford.edu/docs/leader_nodeprop>`_.\n\nUtility Functions\n~~~~~~~~~~~~~~~~~\n\nThe :class:`torch_geometric.data.HeteroData` class provides a number of useful utility functions to modify and analyze the given graph.\n\nFor example, single node or edge stores can be individually indexed:\n\n.. code-block:: python\n\n    paper_node_data = data['paper']\n    cites_edge_data = data['paper', 'cites', 'paper']\n\nIn case the edge type can be uniquely identified by only the pair of source and destination node types or the edge type, the following operations work as well:\n\n.. code-block:: python\n\n    cites_edge_data = data['paper', 'paper']\n    cites_edge_data = data['cites']\n\nWe can add new node types or tensors and remove them:\n\n.. code-block:: python\n\n    data['paper'].year = ...    # Setting a new paper attribute\n    del data['field_of_study']  # Deleting 'field_of_study' node type\n    del data['has_topic']       # Deleting 'has_topic' edge type\n\nWe can access the meta-data of the :obj:`data` object, holding information of all present node and edge types:\n\n.. code-block:: python\n\n    node_types, edge_types = data.metadata()\n    print(node_types)\n    ['paper', 'author', 'institution']\n    print(edge_types)\n    [('paper', 'cites', 'paper'),\n    ('author', 'writes', 'paper'),\n    ('author', 'affiliated_with', 'institution')]\n\nThe :obj:`data` object can be transferred between devices as usual:\n\n.. code-block:: python\n\n    data = data.to('cuda:0')\n    data = data.cpu()\n\nWe further have access to additional helper functions to analyze the given graph\n\n.. code-block:: python\n\n    data.has_isolated_nodes()\n    data.has_self_loops()\n    data.is_undirected()\n\nand can convert it to a homogeneous \"typed\" graph via :meth:`~torch_geometric.data.HeteroData.to_homogeneous` which is able to maintain features in case their dimensionalities match across different types:\n\n.. code-block:: python\n\n    homogeneous_data = data.to_homogeneous()\n    print(homogeneous_data)\n    Data(x=[1879778, 128], edge_index=[2, 13605929], edge_type=[13605929])\n\nHere, :obj:`homogeneous_data.edge_type` represents an edge-level vector that holds the edge type of each edge as an integer.\n\nHeterogeneous Graph Transformations\n-----------------------------------\n\nMost `transformations <https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html>`_ for preprocessing regular graphs work as well on the heterogeneous graph :obj:`data` object.\n\n.. code-block:: python\n\n    import torch_geometric.transforms as T\n\n    data = T.ToUndirected()(data)\n    data = T.AddSelfLoops()(data)\n    data = T.NormalizeFeatures()(data)\n\nHere, :meth:`~torch_geometric.transforms.ToUndirected` transforms a directed graph into (the :pyg:`PyG` representation of) an undirected graph, by adding reverse edges for all edges in the graph.\nThus, future message passing is performed in both direction of all edges.\nThe function may add reverse edge types to the heterogeneous graph, if necessary.\n\nFor all nodes of type :obj:`'node_type'` and all existing edge types of the form :obj:`('node_type', 'edge_type', 'node_type')`, the function :meth:`~torch_geometric.transforms.AddSelfLoops` will add self-loop edges.\nAs a result, each node might receive one or more (one per appropriate edge type) messages from itself during message passing.\n\nThe transform :meth:`~torch_geometric.transforms.NormalizeFeatures` works like in the homogeneous case, and normalizes all specified features (of all types) to sum up to one.\n\nCreating Heterogeneous GNNs\n---------------------------\n\nStandard Message Passing GNNs (MP-GNNs) can not trivially be applied to heterogeneous graph data, as node and edge features from different types can not be processed by the same functions due to differences in feature type.\nA natural way to circumvent this is to implement message and update functions individually for each edge type.\nDuring runtime, the MP-GNN algorithm would need to iterate over edge type dictionaries during message computation and over node type dictionaries during node updates.\n\nTo avoid unnecessary runtime overheads and to make the creation of heterogeneous MP-GNNs as simple as possible, Pytorch Geometric provides three ways for the user to create models on heterogeneous graph data:\n\n#. Automatically convert a homogeneous GNN model to a heterogeneous GNN model by making use of :meth:`torch_geometric.nn.to_hetero` or :meth:`torch_geometric.nn.to_hetero_with_bases`\n#. Define individual functions for different types using :pyg:`PyG's` wrapper :class:`torch_geometric.nn.conv.HeteroConv` for heterogeneous convolution\n#. Deploy existing (or write your own) heterogeneous GNN operators\n\nIn the following, each option is introduced in detail.\n\nAutomatically Converting GNN Models\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nPytorch Geometric allows to automatically convert any :pyg:`PyG` GNN model to a model for heterogeneous input graphs, using the built in functions :meth:`torch_geometric.nn.to_hetero` or :meth:`torch_geometric.nn.to_hetero_with_bases`.\nThe following `example <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/to_hetero_mag.py>`__ shows how to apply it:\n\n.. code-block:: python\n\n    import torch_geometric.transforms as T\n    from torch_geometric.datasets import OGB_MAG\n    from torch_geometric.nn import SAGEConv, to_hetero\n\n\n    dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected())\n    data = dataset[0]\n\n    class GNN(torch.nn.Module):\n        def __init__(self, hidden_channels, out_channels):\n            super().__init__()\n            self.conv1 = SAGEConv((-1, -1), hidden_channels)\n            self.conv2 = SAGEConv((-1, -1), out_channels)\n\n        def forward(self, x, edge_index):\n            x = self.conv1(x, edge_index).relu()\n            x = self.conv2(x, edge_index)\n            return x\n\n\n    model = GNN(hidden_channels=64, out_channels=dataset.num_classes)\n    model = to_hetero(model, data.metadata(), aggr='sum')\n\nThe process takes an existing GNN model and duplicates the message functions to work on each edge type individually, as detailed in the following figure.\n\n.. image:: ../_figures/to_hetero.svg\n   :align: center\n   :width: 90%\n\nAs a result, the model now expects dictionaries with node and edge types as keys as input arguments, rather than single tensors utilized in homogeneous graphs.\nNote that we pass in a tuple of :obj:`in_channels` to :class:`~torch_geometric.nn.conv.SAGEConv` in order to allow for message passing in bipartite graphs.\n\n.. _lazyinit:\n\n.. note::\n   Since the number of input features and thus the size of tensors varies between different types, :pyg:`PyG` can make use of **lazy initialization** to initialize parameters in heterogeneous GNNs (as denoted by :obj:`-1` as the :obj:`in_channels` argument).\n   This allows us to avoid calculating and keeping track of all tensor sizes of the computation graph.\n   Lazy initialization is supported for all existing :pyg:`PyG` operators.\n   We can initialize the model's parameters by calling it once:\n\n   .. code-block:: python\n\n        with torch.no_grad():  # Initialize lazy modules.\n            out = model(data.x_dict, data.edge_index_dict)\n\nBoth :meth:`~torch_geometric.nn.to_hetero` and :meth:`~torch_geometric.nn.to_hetero_with_bases` are very flexible with respect to the homogeneous architectures that can be automatically converted to heterogeneous ones, *e.g.*, applying skip-connections, jumping knowledge or other techniques are supported out-of-the-box.\nFor example, this is all it takes to implement a heterogeneous graph attention network with learnable skip-connections:\n\n.. code-block:: python\n\n    from torch_geometric.nn import GATConv, Linear, to_hetero\n\n    class GAT(torch.nn.Module):\n        def __init__(self, hidden_channels, out_channels):\n            super().__init__()\n            self.conv1 = GATConv((-1, -1), hidden_channels, add_self_loops=False)\n            self.lin1 = Linear(-1, hidden_channels)\n            self.conv2 = GATConv((-1, -1), out_channels, add_self_loops=False)\n            self.lin2 = Linear(-1, out_channels)\n\n        def forward(self, x, edge_index):\n            x = self.conv1(x, edge_index) + self.lin1(x)\n            x = x.relu()\n            x = self.conv2(x, edge_index) + self.lin2(x)\n            return x\n\n\n    model = GAT(hidden_channels=64, out_channels=dataset.num_classes)\n    model = to_hetero(model, data.metadata(), aggr='sum')\n\nNote that we disable the creation of self loops via the :obj:`add_self_loops=False` argument.\nThis is done because the concept of self-loops is not well-defined in bipartite graphs (message passing for an edge type with distinct source and destination node types), and we would mistakenly add the edges :obj:`[(0, 0), (1, 1), ...]` to the bipartite graph.\nTo preserve central node information, we thus utilize a learnable skip-connection via :obj:`conv(x, edge_index) + lin(x)` instead, which will perform attention-based message passing from source to destination node features, and its output is then summed up to the existing destination node features.\n\nAfterwards, the created model can be trained as usual:\n\n.. _trainfunc:\n\n.. code-block:: python\n\n    def train():\n        model.train()\n        optimizer.zero_grad()\n        out = model(data.x_dict, data.edge_index_dict)\n        mask = data['paper'].train_mask\n        loss = F.cross_entropy(out['paper'][mask], data['paper'].y[mask])\n        loss.backward()\n        optimizer.step()\n        return float(loss)\n\nUsing the Heterogeneous Convolution Wrapper\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe heterogeneous convolution wrapper :class:`torch_geometric.nn.conv.HeteroConv` allows to define custom heterogeneous message and update functions to build arbitrary MP-GNNs for heterogeneous graphs from scratch.\nWhile the automatic converter :meth:`~torch_geometric.nn.to_hetero` uses the same operator for all edge types, the wrapper allows to define different operators for different edge types.\nHere, :class:`~torch_geometric.nn.conv.HeteroConv` takes a dictionary of submodules as input, one for each edge type in the graph data.\nThe following `example <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hetero_conv_dblp.py>`__ shows how to apply it.\n\n.. code-block:: python\n\n    import torch_geometric.transforms as T\n    from torch_geometric.datasets import OGB_MAG\n    from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear\n\n\n    dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected())\n    data = dataset[0]\n\n    class HeteroGNN(torch.nn.Module):\n        def __init__(self, hidden_channels, out_channels, num_layers):\n            super().__init__()\n\n            self.convs = torch.nn.ModuleList()\n            for _ in range(num_layers):\n                conv = HeteroConv({\n                    ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),\n                    ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),\n                    ('paper', 'rev_writes', 'author'): GATConv((-1, -1), hidden_channels, add_self_loops=False),\n                }, aggr='sum')\n                self.convs.append(conv)\n\n            self.lin = Linear(hidden_channels, out_channels)\n\n        def forward(self, x_dict, edge_index_dict):\n            for conv in self.convs:\n                x_dict = conv(x_dict, edge_index_dict)\n                x_dict = {key: x.relu() for key, x in x_dict.items()}\n            return self.lin(x_dict['author'])\n\n    model = HeteroGNN(hidden_channels=64, out_channels=dataset.num_classes,\n                      num_layers=2)\n\nWe can initialize the model by calling it once (see :ref:`here<lazyinit>` for more details about lazy initialization)\n\n.. code-block:: python\n\n    with torch.no_grad():  # Initialize lazy modules.\n         out = model(data.x_dict, data.edge_index_dict)\n\nand run the standard training procedure as outlined :ref:`here<trainfunc>`.\n\nDeploy Existing Heterogeneous Operators\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n:pyg:`PyG` provides operators (*e.g.*, :class:`torch_geometric.nn.conv.HGTConv`), which are specifically designed for heterogeneous graphs.\nThese operators can be directly used to build heterogeneous GNN models as can be seen in the following `example <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/hgt_dblp.py>`__:\n\n.. code-block:: python\n\n    import torch_geometric.transforms as T\n    from torch_geometric.datasets import OGB_MAG\n    from torch_geometric.nn import HGTConv, Linear\n\n\n    dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected())\n    data = dataset[0]\n\n    class HGT(torch.nn.Module):\n        def __init__(self, hidden_channels, out_channels, num_heads, num_layers):\n            super().__init__()\n\n            self.lin_dict = torch.nn.ModuleDict()\n            for node_type in data.node_types:\n                self.lin_dict[node_type] = Linear(-1, hidden_channels)\n\n            self.convs = torch.nn.ModuleList()\n            for _ in range(num_layers):\n                conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),\n                               num_heads, group='sum')\n                self.convs.append(conv)\n\n            self.lin = Linear(hidden_channels, out_channels)\n\n        def forward(self, x_dict, edge_index_dict):\n            for node_type, x in x_dict.items():\n                x_dict[node_type] = self.lin_dict[node_type](x).relu_()\n\n            for conv in self.convs:\n                x_dict = conv(x_dict, edge_index_dict)\n\n            return self.lin(x_dict['author'])\n\n    model = HGT(hidden_channels=64, out_channels=dataset.num_classes,\n                num_heads=2, num_layers=2)\n\nWe can initialize the model by calling it once (see :ref:`here<lazyinit>` for more details about lazy initialization).\n\n.. code-block:: python\n\n    with torch.no_grad():  # Initialize lazy modules.\n         out = model(data.x_dict, data.edge_index_dict)\n\nand run the standard training procedure as outlined :ref:`here<trainfunc>`.\n\nHeterogeneous Graph Samplers\n----------------------------\n\n:pyg:`PyG` provides various functionalities for sampling heterogeneous graphs, *i.e.* in the standard :class:`torch_geometric.loader.NeighborLoader` class  or in dedicated heterogeneous graph samplers such as :class:`torch_geometric.loader.HGTLoader`.\nThis is especially useful for efficient representation learning on large heterogeneous graphs, where processing the full number of neighbors is too computationally expensive.\nHeterogeneous graph support for other samplers such as :class:`torch_geometric.loader.ClusterLoader` or :class:`torch_geometric.loader.GraphSAINTLoader` will be added soon.\nOverall, all heterogeneous graph loaders will produce a :class:`~torch_geometric.data.HeteroData` object as output, holding a subset of the original data, and mainly differ in the way their sampling procedures works.\nAs such, only minimal code changes are required to convert the training procedure from :ref:`full-batch training<trainfunc>` to mini-batch training.\n\nPerforming neighbor sampling using :class:`~torch_geometric.loader.NeighborLoader` works as outlined in the following `example <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/to_hetero_mag.py>`__:\n\n.. code-block:: python\n\n    import torch_geometric.transforms as T\n    from torch_geometric.datasets import OGB_MAG\n    from torch_geometric.loader import NeighborLoader\n\n    transform = T.ToUndirected()  # Add reverse edge types.\n    data = OGB_MAG(root='./data', preprocess='metapath2vec', transform=transform)[0]\n\n    train_loader = NeighborLoader(\n        data,\n        # Sample 15 neighbors for each node and each edge type for 2 iterations:\n        num_neighbors=[15] * 2,\n        # Use a batch size of 128 for sampling training nodes of type \"paper\":\n        batch_size=128,\n        input_nodes=('paper', data['paper'].train_mask),\n    )\n\n    batch = next(iter(train_loader))\n\n\nNotably, :class:`~torch_geometric.loader.NeighborLoader` works for both homogeneous *and* heterogeneous graphs.\nWhen operating in heterogeneous graphs, more fine-grained control over the amount of sampled neighbors of individual edge types is possible, but not necessary, *e.g.*:\n\n.. code-block:: python\n\n    num_neighbors = {key: [15] * 2 for key in data.edge_types}\n\nUsing the :obj:`input_nodes` argument, we further specify the type and indices of nodes from which we want to sample local neighborhoods, *i.e.* all the \"paper\" nodes marked as training nodes according to :obj:`data['paper'].train_mask`.\n\nPrinting :obj:`batch` then yields the following output:\n\n.. code-block:: text\n\n    HeteroData(\n      paper={\n        x=[20799, 256],\n        y=[20799],\n        train_mask=[20799],\n        val_mask=[20799],\n        test_mask=[20799],\n        batch_size=128\n      },\n      author={ x=[4419, 128] },\n      institution={ x=[302, 128] },\n      field_of_study={ x=[2605, 128] },\n      (author, affiliated_with, institution)={ edge_index=[2, 0] },\n      (author, writes, paper)={ edge_index=[2, 5927] },\n      (paper, cites, paper)={ edge_index=[2, 11829] },\n      (paper, has_topic, field_of_study)={ edge_index=[2, 10573] },\n      (institution, rev_affiliated_with, author)={ edge_index=[2, 829] },\n      (paper, rev_writes, author)={ edge_index=[2, 5512] },\n      (field_of_study, rev_has_topic, paper)={ edge_index=[2, 10499] }\n    )\n\nAs such, :obj:`batch` holds a total of 28,187 nodes involved for computing the embeddings of 128 \"paper\" nodes.\nSampled nodes are always sorted based on the order in which they were sampled.\nThus, the first :obj:`batch['paper'].batch_size` nodes represent the set of original mini-batch nodes, making it easy to obtain the final output embeddings via slicing.\n\nTraining our heterogeneous GNN model in mini-batch mode is then similar to training it in full-batch mode, except that we now iterate over the mini-batches produced by :obj:`train_loader` and optimize model parameters based on individual mini-batches:\n\n.. code-block:: python\n\n    def train():\n        model.train()\n\n        total_examples = total_loss = 0\n        for batch in train_loader:\n            optimizer.zero_grad()\n            batch = batch.to('cuda:0')\n            batch_size = batch['paper'].batch_size\n            out = model(batch.x_dict, batch.edge_index_dict)\n            loss = F.cross_entropy(out['paper'][:batch_size],\n                                   batch['paper'].y[:batch_size])\n            loss.backward()\n            optimizer.step()\n\n            total_examples += batch_size\n            total_loss += float(loss) * batch_size\n\n        return total_loss / total_examples\n\nImportantly, we only make use of the first 128 \"paper\" nodes during loss computation.\nWe do so by slicing both \"paper\" labels :obj:`batch['paper'].y` and \"paper\" output predictions :obj:`out['paper']` based on :obj:`batch['paper'].batch_size`, representing the labels and final output predictions of original mini-batch nodes, respectively.\n"
  },
  {
    "path": "docs/source/tutorial/load_csv.rst",
    "content": "Loading Graphs from CSV\n=======================\n\nIn this example, we will show how to load a set of :obj:`*.csv` files as input and construct a **heterogeneous graph** from it, which can be used as input to a `heterogeneous graph model <heterogeneous.html>`__.\nThis tutorial is also available as an executable `example script <https://github.com/pyg-team/pytorch_geometric/tree/master/examples/hetero/load_csv.py>`_ in the :obj:`examples/hetero` directory.\n\nWe are going to use the `MovieLens dataset <https://grouplens.org/datasets/movielens/>`_ collected by the GroupLens research group.\nThis toy dataset describes 5-star rating and tagging activity from MovieLens.\nThe dataset contains approximately 100k ratings across more than 9k movies from more than 600 users.\nWe are going to use this dataset to generate two node types holding data for **movies** and **users**, respectively, and one edge type connecting **users and movies**, representing the relation of how a user has rated a specific movie.\n\nFirst, we download the dataset to an arbitrary folder (in this case, the current directory):\n\n.. code-block:: python\n\n    from torch_geometric.data import download_url, extract_zip\n\n    url = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'\n    extract_zip(download_url(url, '.'), '.')\n\n    movie_path = './ml-latest-small/movies.csv'\n    rating_path = './ml-latest-small/ratings.csv'\n\nBefore we create the heterogeneous graph, let's take a look at the data.\n\n.. code-block:: python\n\n    import pandas as pd\n\n    print(pd.read_csv(movie_path).head())\n    print(pd.read_csv(rating_path).head())\n\n.. list-table:: Head of :obj:`movies.csv`\n    :widths: 5 40 60\n    :header-rows: 1\n\n    * - movieId\n      - title\n      - genres\n    * - 1\n      - Toy Story (1995)\n      - Adventure|Animation|Children|Comedy|Fantasy\n    * - 2\n      - Jumanji (1995)\n      - Adventure|Children|Fantasy\n    * - 3\n      - Grumpier Old Men (1995)\n      - Comedy|Romance\n    * - 4\n      - Waiting to Exhale (1995)\n      - Comedy|Drama|Romance\n    * - 5\n      - Father of the Bride Part II (1995)\n      - Comedy\n\nWe see that the :obj:`movies.csv` file provides three columns: :obj:`movieId` assigns a unique identifier to each movie, while the :obj:`title` and :obj:`genres` columns represent title and genres of the given movie.\nWe can make use of those two columns to define a feature representation that can be easily interpreted by machine learning models.\n\n.. list-table:: Head of :obj:`ratings.csv`\n    :widths: 5 5 10 30\n    :header-rows: 1\n\n    * - userId\n      - movieId\n      - rating\n      - timestamp\n    * - 1\n      - 1\n      - 4.0\n      - 964982703\n    * - 1\n      - 3\n      - 4.0\n      - 964981247\n    * - 1\n      - 6\n      - 4.0\n      - 964982224\n    * - 1\n      - 47\n      - 5.0\n      - 964983815\n    * - 1\n      - 50\n      - 5.0\n      - 964982931\n\nThe :obj:`ratings.csv` data connects users (as given by :obj:`userId`) and movies (as given by :obj:`movieId`), and defines how a given user has rated a specific movie (:obj:`rating`).\nDue to simplicity, we do not make use of the additional :obj:`timestamp` information.\n\nFor representing this data in the :pyg:`PyG` data format, we first define a method :meth:`load_node_csv` that reads in a :obj:`*.csv` file and returns a node-level feature representation :obj:`x` of shape :obj:`[num_nodes, num_features]`:\n\n.. code-block:: python\n\n    import torch\n\n    def load_node_csv(path, index_col, encoders=None, **kwargs):\n        df = pd.read_csv(path, index_col=index_col, **kwargs)\n        mapping = {index: i for i, index in enumerate(df.index.unique())}\n\n        x = None\n        if encoders is not None:\n            xs = [encoder(df[col]) for col, encoder in encoders.items()]\n            x = torch.cat(xs, dim=-1)\n\n        return x, mapping\n\nHere, :meth:`load_node_csv` reads the :obj:`*.csv` file from :obj:`path`, and creates a dictionary :obj:`mapping` that maps its index column to a consecutive value in the range :obj:`{ 0, ..., num_rows - 1 }`.\nThis is needed as we want our final data representation to be as compact as possible, *e.g.*, the representation of a movie in the first row should be accessible via :obj:`x[0]`.\n\nWe further utilize the concept of encoders, which define how the values of specific columns should be encoded into a numerical feature representation.\nFor example, we can define a sentence encoder that encodes raw column strings into low-dimensional embeddings.\nFor this, we make use of the excellent `sentence-transformers <https://www.sbert.net/>`_ library which provides a large number of state-of-the-art pretrained NLP embedding models:\n\n.. code-block:: bash\n\n    pip install sentence-transformers\n\n.. code-block:: python\n\n    class SequenceEncoder:\n        def __init__(self, model_name='all-MiniLM-L6-v2', device=None):\n            self.device = device\n            self.model = SentenceTransformer(model_name, device=device)\n\n        @torch.no_grad()\n        def __call__(self, df):\n            x = self.model.encode(df.values, show_progress_bar=True,\n                                  convert_to_tensor=True, device=self.device)\n            return x.cpu()\n\nThe :class:`SequenceEncoder` class loads a pre-trained NLP model as given by :obj:`model_name`, and uses it to encode a list of strings into a :pytorch:`PyTorch` tensor of shape :obj:`[num_strings, embedding_dim]`.\nWe can use this :class:`SequenceEncoder` to encode the :obj:`title` of the :obj:`movies.csv` file.\n\nIn a similar fashion, we can create another encoder that converts the genres of movies, *e.g.*, :obj:`Adventure|Children|Fantasy`, into categorical labels.\nFor this, we first need to find all existing genres present in the data, create a feature representation :obj:`x` of shape :obj:`[num_movies, num_genres]`, and assign a :obj:`1` to :obj:`x[i, j]` in case the genre :obj:`j` is present in movie :obj:`i`:\n\n.. code-block:: python\n\n    class GenresEncoder:\n        def __init__(self, sep='|'):\n            self.sep = sep\n\n        def __call__(self, df):\n            genres = set(g for col in df.values for g in col.split(self.sep))\n            mapping = {genre: i for i, genre in enumerate(genres)}\n\n            x = torch.zeros(len(df), len(mapping))\n            for i, col in enumerate(df.values):\n                for genre in col.split(self.sep):\n                    x[i, mapping[genre]] = 1\n            return x\n\nWith this, we can obtain our final representation of movies via:\n\n.. code-block:: python\n\n    movie_x, movie_mapping = load_node_csv(\n        movie_path, index_col='movieId', encoders={\n            'title': SequenceEncoder(),\n            'genres': GenresEncoder()\n        })\n\nSimilarly, we can utilize :meth:`load_node_csv` for obtaining a user mapping from :obj:`userId` to consecutive values as well.\nHowever, there is no additional feature information for users present in this dataset.\nAs such, we do not define any encoders:\n\n.. code-block:: python\n\n    _, user_mapping = load_node_csv(rating_path, index_col='userId')\n\nWith this, we are ready to initialize our :class:`~torch_geometric.data.HeteroData` object and pass two node types into it:\n\n.. code-block:: python\n\n    from torch_geometric.data import HeteroData\n\n    data = HeteroData()\n\n    data['user'].num_nodes = len(user_mapping)  # Users do not have any features.\n    data['movie'].x = movie_x\n\n    print(data)\n    HeteroData(\n      user={ num_nodes=610 },\n      movie={ x[9742, 404] }\n    )\n\nAs users do not have any node-level information, we solely define its number of nodes.\nAs a result, we likely need to learn distinct user embeddings via :class:`torch.nn.Embedding` in an end-to-end fashion during training of a heterogeneous graph model.\n\nNext, we take a look at connecting users with movies as defined by their ratings.\nFor this, we define a method :meth:`load_edge_csv` that returns the final :obj:`edge_index` representation of shape :obj:`[2, num_ratings]` from :obj:`ratings.csv`, as well as any additional features present in the raw :obj:`*.csv` file:\n\n.. code-block:: python\n\n    def load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping,\n                      encoders=None, **kwargs):\n        df = pd.read_csv(path, **kwargs)\n\n        src = [src_mapping[index] for index in df[src_index_col]]\n        dst = [dst_mapping[index] for index in df[dst_index_col]]\n        edge_index = torch.tensor([src, dst])\n\n        edge_attr = None\n        if encoders is not None:\n            edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]\n            edge_attr = torch.cat(edge_attrs, dim=-1)\n\n        return edge_index, edge_attr\n\nHere, :obj:`src_index_col` and :obj:`dst_index_col` define the index columns of source and destination nodes, respectively.\nWe further make use of the node-level mappings :obj:`src_mapping` and :obj:`dst_mapping` to ensure that raw indices are mapped to the correct consecutive indices in our final representation.\nFor every edge defined in the file, it looks up the forward indices in :obj:`src_mapping` and :obj:`dst_mapping`, and moves the data appropriately.\n\nSimilarly to :meth:`load_node_csv`, encoders are used to return additional edge-level feature information.\nFor example, for loading the ratings from the :obj:`rating` column in :obj:`ratings.csv`, we can define an :class:`IdentityEncoder` that simply converts a list of floating-point values into a :pytorch:`PyTorch` tensor:\n\n.. code-block:: python\n\n    class IdentityEncoder:\n        def __init__(self, dtype=None):\n            self.dtype = dtype\n\n        def __call__(self, df):\n            return torch.from_numpy(df.values).view(-1, 1).to(self.dtype)\n\nWith this, we are ready to finalize our :class:`~torch_geometric.data.HeteroData` object:\n\n.. code-block:: python\n\n    edge_index, edge_label = load_edge_csv(\n        rating_path,\n        src_index_col='userId',\n        src_mapping=user_mapping,\n        dst_index_col='movieId',\n        dst_mapping=movie_mapping,\n        encoders={'rating': IdentityEncoder(dtype=torch.long)},\n    )\n\n    data['user', 'rates', 'movie'].edge_index = edge_index\n    data['user', 'rates', 'movie'].edge_label = edge_label\n\n    print(data)\n    HeteroData(\n      user={ num_nodes=610 },\n      movie={ x=[9742, 404] },\n      (user, rates, movie)={\n        edge_index=[2, 100836],\n        edge_label=[100836, 1]\n      }\n    )\n\nThis :class:`~torch_geometric.data.HeteroData` object is the native format of heterogeneous graphs in :pyg:`PyG` and can be used as input for `heterogeneous graph models <heterogeneous.html>`__.\n\n.. note::\n\n    Click `here <https://github.com/pyg-team/pytorch_geometric/tree/master/examples/hetero/load_csv.py>`_ to see the final example script.\n"
  },
  {
    "path": "docs/source/tutorial/multi_gpu_vanilla.rst",
    "content": "Multi-GPU Training in Pure PyTorch\n==================================\n\n.. note::\n    For multi-GPU training with cuGraph, refer to `cuGraph examples <https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples>`_.\n\n\nFor many large scale, real-world datasets, it may be necessary to scale-up training across multiple GPUs.\nThis tutorial goes over how to set up a multi-GPU training  pipeline in :pyg:`PyG` with :pytorch:`PyTorch` via :class:`torch.nn.parallel.DistributedDataParallel`, without the need for any other third-party libraries (such as :lightning:`PyTorch Lightning`).\nNote that this approach is based on data-parallelism.\nThis means that each GPU runs an identical copy of the model; you might want to look into `PyTorch FSDP <https://arxiv.org/abs/2304.11277>`_ if you want to scale your model across devices.\nData-parallelism allows you to increase the batch size of your model by aggregating gradients across GPUs and then sharing the same optimizer step within every model replica.\nThis `DDP+MNIST-tutorial <https://github.com/PrincetonUniversity/multi_gpu_training/tree/main/02_pytorch_ddp#overall-idea-of-distributed-data-parallel>`_  by the Princeton University has some nice illustrations of the process.\n\nSpecifically this tutorial shows how to train a :class:`~torch_geometric.nn.models.GraphSAGE` GNN model on the :class:`~torch_geometric.datasets.Reddit` dataset.\nFor this, we will use :class:`torch.nn.parallel.DistributedDataParallel` to scale-up training across all available GPUs.\nWe will do this by spawning multiple processes from our :python:`Python` code which will all execute the same function.\nPer process, we set up our model instance and feed data through it by utilizing the :class:`~torch_geometric.loader.NeighborLoader`.\nGradients are synchronized by wrapping the model in :class:`torch.nn.parallel.DistributedDataParallel` (as described in its `official tutorial <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`_), which in turn relies on :obj:`torch.distributed`-IPC-facilities.\n\n.. note::\n    The complete script of this tutorial can be found at `examples/multi_gpu/distributed_sampling.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling.py>`_.\n\nDefining a Spawnable Runner\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nTo create our training script, we use the :pytorch:`PyTorch`-provided wrapper of the vanilla :python:`Python` :class:`multiprocessing` module.\nHere, the :obj:`world_size` corresponds to the number of GPUs we will be using at once.\n:meth:`torch.multiprocessing.spawn` will take care of spawning :obj:`world_size` processes.\nEach process will load the same script as a module and subsequently execute the :meth:`run`-function:\n\n.. code-block:: python\n\n    from torch_geometric.datasets import Reddit\n    import torch.multiprocessing as mp\n\n    def run(rank: int, world_size: int, dataset: Reddit):\n        pass\n\n    if __name__ == '__main__':\n        dataset = Reddit('./data/Reddit')\n        world_size = torch.cuda.device_count()\n        mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True)\n\nNote that we initialize the dataset *before* spawning any processes.\nWith this, we only initialize the dataset once, and any data inside it will be automatically moved to shared memory via :obj:`torch.multiprocessing` such that processes do not need to create their own replica of the data.\nIn addition, note how the :meth:`run` function accepts :obj:`rank` as its first argument.\nThis argument is not explicitly provided by us.\nIt corresponds to the process ID (starting at :obj:`0`) injected by :pytorch:`PyTorch`.\nLater we will use this to select a unique GPU for every :obj:`rank`.\n\nWith this, we can start to implement our spawnable runner function.\nThe first step is to initialize a process group with :obj:`torch.distributed`.\nTo this point, processes are not aware of each other and we set a hardcoded server-address for rendezvous using the :obj:`nccl` protocol.\nMore details can be found in the `\"Writing Distributed Applications with PyTorch\" <https://pytorch.org/tutorials/intermediate/dist_tuto.html>`_ tutorial:\n\n.. code-block:: python\n\n    import os\n    import torch.distributed as dist\n    import torch\n\n    def run(rank: int, world_size: int, dataset: Reddit):\n        os.environ['MASTER_ADDR'] = 'localhost'\n        os.environ['MASTER_PORT'] = '12345'\n        dist.init_process_group('nccl', rank=rank, world_size=world_size)\n\nNext, we split training indices into :obj:`world_size` many chunks for each GPU, and initialize the :class:`~torch_geometric.loader.NeighborLoader` class to only operate on its specific chunk of the training set:\n\n.. code-block:: python\n\n    from torch_geometric.loader import NeighborLoader\n\n    def run(rank: int, world_size: int, dataset: Reddit):\n        ...\n\n        data = dataset[0]\n\n        train_index = data.train_mask.nonzero().view(-1)\n        train_index = train_index.split(train_index.size(0) // world_size)[rank]\n\n        train_loader = NeighborLoader(\n            data,\n            input_nodes=train_index,\n            num_neighbors=[25, 10],\n            batch_size=1024,\n            num_workers=4,\n            shuffle=True,\n        )\n\nNote that our :meth:`run` function is called for each rank, which means that each rank holds a separate :class:`~torch_geometric.loader.NeighborLoader` instance.\n\nSimilarly, we create a :class:`~torch_geometric.loader.NeighborLoader` instance for evaluation.\nFor simplicity, we only do this on rank :obj:`0` such that computation of metrics does not need to communicate across different processes.\nWe recommend taking a look at the `torchmetrics <https://torchmetrics.readthedocs.io/en/stable/>`_ package for distributed computation of metrics.\n\n.. code-block:: python\n\n    def run(rank: int, world_size: int, dataset: Reddit):\n        ...\n\n        if rank == 0:\n            val_index = data.val_mask.nonzero().view(-1)\n            val_loader = NeighborLoader(\n                data,\n                input_nodes=val_index,\n                num_neighbors=[25, 10],\n                batch_size=1024,\n                num_workers=4,\n                shuffle=False,\n            )\n\nNow that we have our data loaders defined, we initialize our :class:`~torch_geometric.nn.GraphSAGE` model and wrap it inside :class:`torch.nn.parallel.DistributedDataParallel`.\nWe also move the model to its exclusive GPU using the :obj:`rank` as a shortcut for the full device identifier.\nThe wrapper on our model manages communication between each rank and synchronizes gradients across all ranks before updating the model parameters across all ranks:\n\n.. code-block:: python\n\n    from torch.nn.parallel import DistributedDataParallel\n    from torch_geometric.nn import GraphSAGE\n\n    def run(rank: int, world_size: int, dataset: Reddit):\n        ...\n\n        torch.manual_seed(12345)\n        model = GraphSAGE(\n            in_channels=dataset.num_features,\n            hidden_channels=256,\n            num_layers=2,\n            out_channels=dataset.num_classes,\n        ).to(rank)\n        model = DistributedDataParallel(model, device_ids=[rank])\n\nFinally, we can set up our optimizer and define our training loop, which follows a similar flow as usual single GPU training loops - the actual magic of gradient and model weight synchronization across different processes will happen behind the scenes within :class:`~torch.nn.parallel.DistributedDataParallel`:\n\n.. code-block:: python\n\n    import torch.nn.functional as F\n\n    def run(rank: int, world_size: int, dataset: Reddit):\n        ...\n\n        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n        for epoch in range(1, 11):\n            model.train()\n            for batch in train_loader:\n                batch = batch.to(rank)\n                optimizer.zero_grad()\n                out = model(batch.x, batch.edge_index)[:batch.batch_size]\n                loss = F.cross_entropy(out, batch.y[:batch.batch_size])\n                loss.backward()\n                optimizer.step()\n\nAfter each training epoch, we evaluate and report validation metrics.\nAs previously mentioned, we do this on a single GPU only.\nTo synchronize all processes and to ensure that the model weights have been updated, we need to call :meth:`torch.distributed.barrier`:\n\n.. code-block:: python\n\n            dist.barrier()\n\n            if rank == 0:\n                print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')\n\n            if rank == 0:\n                model.eval()\n                count = correct = 0\n                with torch.no_grad():\n                    for batch in val_loader:\n                        batch = batch.to(rank)\n                        out = model(batch.x, batch.edge_index)[:batch.batch_size]\n                        pred = out.argmax(dim=-1)\n                        correct += (pred == batch.y[:batch.batch_size]).sum()\n                        count += batch.batch_size\n                print(f'Validation Accuracy: {correct/count:.4f}')\n\n            dist.barrier()\n\nAfter finishing training, we can clean up processes and destroy the process group via:\n\n.. code-block:: python\n\n        dist.destroy_process_group()\n\nAnd that's it.\nPutting it all together gives a working multi-GPU example that follows a training flow that is similar to single GPU training.\nYou can run the shown tutorial by yourself by looking at `examples/multi_gpu/distributed_sampling.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling.py>`_.\n"
  },
  {
    "path": "docs/source/tutorial/multi_node_multi_gpu_vanilla.rst",
    "content": "Multi-Node Training using SLURM\n===============================\n\n.. note::\n    For multi-GPU training with cuGraph, refer to `cuGraph examples <https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples>`_.\n\n\nThis tutorial introduces a skeleton on how to perform distributed training on multiple GPUs over multiple nodes using the `SLURM workload manager <https://slurm.schedmd.com/>`_ available at many supercomputing centers.\nThe code is based on `our tutorial on single-node multi-GPU training <multi_gpu_vanilla.html>`_.\nPlease go there first to understand the basics if you are unfamiliar with the concepts of distributed training in :pytorch:`PyTorch`.\n\n.. note::\n    The complete script of this tutorial can be found at `examples/multi_gpu/distributed_sampling_multinode.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling_multinode.py>`_.\n    You can find the example :obj:`*.sbatch` file `next to it <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling_multinode.sbatch>`_ and tune it to your needs.\n\nA submission script to manage startup\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nAs we are now running on multiple nodes, we can no longer use our :obj:`__main__` entrypoint and start processes from there.\nThis is where the workload manager comes in as it allows us to prepare a special :obj:`*.sbatch` file.\nThis file is a standard bash script with instructions on how to setup the processes and your environment.\n\nOur example starts with the usual shebang :obj:`#!/bin/bash` and special comments instructing which resources the SLURM system should reserve for our training run.\nConfiguration of the specifics usually depends on your site (and your usage limits!).\nThe following is a minimal example which works with a quite unrestricted configuration available to us:\n\n.. code-block:: bash\n\n    #!/bin/bash\n    #SBATCH --job-name=pyg-multinode-tutorial # identifier for the job listings\n    #SBATCH --output=pyg-multinode.log        # outputfile\n    #SBATCH --partition=gpucloud              # ADJUST this to your system\n    #SBATCH -N 2                              # number of nodes you want to use\n    #SBATCH --ntasks=4                        # number of processes to be run\n    #SBATCH --gpus-per-task=1                 # every process wants one GPU!\n    #SBATCH --gpu-bind=none                   # NCCL can't deal with task-binding...\n\nThis example will create two processes each on two nodes with each process having a single GPU reserved.\n\nIn the following part, we have to set up some environment variables for :obj:`torch.distributed` to properly do the rendezvous procedure.\nIn theory you could also set those inside the :python:`Python` process:\n\n.. code-block:: bash\n\n    export MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))\n    export MASTER_ADDR=$(scontrol show hostnames \"$SLURM_JOB_NODELIST\" | head -n 1)\n    echo \"MASTER_ADDR:MASTER_PORT=\"${MASTER_ADDR}:${MASTER_PORT}\n\nIf you do not want to let your script randomly open a port and listen for incoming connections, you can also use a file on your shared filesystem.\n\nNow the only thing left to add is the execution of the training script:\n\n.. code-block:: console\n\n    srun python distributed_sampling_multinode.py\n\nNote how the :obj:`python` call is prefixed with the :obj:`srun` command and thus :obj:`--ntasks` replicas will be started.\n\nFinally, to submit the :obj:`*.sbatch` file itself into the work queue, use the :obj:`sbatch` utility in your shell:\n\n.. code-block:: console\n\n    sbatch distributed_sampling_multinode.sbatch\n\nUsing a cluster configured with pyxis-containers\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nIf your cluster supports the :obj:`pyxis` plugin developed by NVIDIA, you can use a ready-to-use :pyg:`PyG` container that is updated each month with the latest from NVIDIA and :pyg:`PyG`, see `here <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg>`_ for more information.\nThe container sets up all necessary environment variables from which you can now directly run the example using :obj:`srun` from your command prompt:\n\n.. code-block:: console\n\n    srun --partition=<partitionname> -N<num_nodes> --ntasks=<number of GPUS in total> --gpus-per-task=1 --gpu-bind=none --container-name=pyg-test --container-image=<image_url> --container-mounts='.:/workspace' python3 distributed_sampling_multinode.py\n\nNote that :obj:`--container-mounts='.:/workspace'` makes the current folder (which should include the example code) available in the default startup folder :obj:`workspace` of the container.\n\nIf you want to eventually customize packages in the container without having access to :obj:`docker` (very likely on a public HPC), you can create your own image by following `this tutorial <https://doku.lrz.de/9-creating-and-reusing-a-custom-enroot-container-image-10746637.html>`_.\n\nModifying the training script\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nAs SLURM now takes care of creating multiple :python:`Python` processes and we can not share any data (each process will have the full dataset loaded!), our :obj:`__main__` section now has to query the environment for the process setup generated by SLURM or the :obj:`pyxis` container:\n\n.. code-block:: python\n\n    # Get the world size from the WORLD_SIZE variable or directly from SLURM:\n    world_size = int(os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS')))\n    # Likewise for RANK and LOCAL_RANK:\n    rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID')))\n    local_rank = int(os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID')))\n    run(world_size, rank, local_rank)\n\nThe :meth:`torch.distributed.init_process_group` function will now pick up the :obj:`MASTER_ADDR` from the environment:\n\n.. code-block:: python\n\n    def run(world_size: int, rank: int, local_rank: int):\n        dist.init_process_group('nccl', world_size=world_size, rank=rank)\n\nWe also have to replace the usage of :obj:`rank` depending on whether we want to use it for node-local purposes like selecting a GPU or global tasks such as data splitting\n\n.. code-block:: python\n\n    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)\n    train_idx = train_idx.split(train_idx.size(0) // world_size)[rank]\n\nwhile we need to assign the model to a node-local GPU and thus use :obj:`local_rank`:\n\n.. code-block:: python\n\n    model = SAGE(dataset.num_features, 256, dataset.num_classes).to(local_rank)\n    model = DistributedDataParallel(model, device_ids=[local_rank])\n"
  },
  {
    "path": "docs/source/tutorial/neighbor_loader.rst",
    "content": "Scaling GNNs via Neighbor Sampling\n==================================\n\nOne of the challenges of Graph Neural Networks is to scale them to large graphs, *e.g.*, in industrial and social applications.\nTraditional deep neural networks are known to scale well to large amounts of data by decomposing the training loss into individual samples (called a *mini-batch*) and approximating exact gradients stochastically.\nIn contrast, applying stochastic mini-batch training in GNNs is challenging since the embedding of a given node depends recursively on all its neighbor’s embeddings, leading to high inter-dependency between nodes that grows exponentially with respect to the number of layers.\nThis phenomenon is often referred to as *neighbor explosion*.\nAs a simple workaround, GNNs are typically executed in a full-batch fashion (see `here <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/gcn.py>`_ for an example), where the GNN has access to all hidden node representations in all its layers.\nHowever, this is not feasible in large-scale graphs due to memory limitations and slow convergence.\n\nScalability techniques are indispensable for applying GNNs to large-scale graphs in order to alleviate the neighbor explosion problem induced by mini-batch training, *i.e.* **node-wise**, **layer-wise** or **subgraph-wise** sampling techniques, or to **decouple propagations from predictions**.\nIn this tutorial, we take a closer look at the most common node-wise sampling approach, originally introduced in the `\"Inductive Representation Learning on Large Graphs\" <https://arxiv.org/abs/1706.02216>`_ paper.\n\nNeighbor Sampling\n-----------------\n\n:pyg:`PyG` implements neighbor sampling via its :class:`torch_geometric.loader.NeighborLoader` class.\nNeighbor sampling works by recursively sampling a fixed number of at most :math:`k` neighbors for a node :math:`v \\in \\mathcal{V}`, *i.e.* :math:`\\tilde{\\mathcal{N}}(v) \\subset \\mathcal{N}(v)` with :math:`|\\tilde{\\mathcal{N}}| \\le k`, leading to an overall bounded :math:`L`-hop neighborhood size of :math:`\\mathcal{O}(k^L)`.\nThat is, starting from a set of seed nodes :math:`\\mathcal{B} \\subset \\mathcal{V}`, we sample at most :math:`k` neighbors for every node in :math:`v \\in \\mathcal{B}`, and then proceed to sample neighbors for every sampled node in the previous hop, and so on.\nThe resulting graph structure holds a **directed** :math:`L`-hop subgraph around every node in :math:`v \\in \\mathcal{B}`, for which it is guaranteed that every node has at least one path of at most length :math:`L` to at least one of the seed nodes in :math:`\\mathcal{B}`.\nAs such, a message passing GNN with :math:`L` layers will incorporate the full set of sampled nodes in its computation graph.\n\n.. figure:: ../_static/thumbnails/neighbor_loader.png\n  :align: center\n  :width: 40%\n\n|\n\nIt is important to note that neighbor sampling can only mitigate the neighbor explosion problem to some extend since the overall neighborhood size still increases exponentially with the number of layers.\nAs a result, sampling for more than two or three iterations is generally not feasible.\n\nOften times, the number of sampled hops and the number of message passing layers is kept in sync.\nSpecifically, it is very wasteful to sample for more hops than there exist message passing layers since the GNN will never be able to incorporate the features of the nodes sampled in later hops into the final node representation of its seed nodes.\nHowever, it is nonetheless possible to utilize deeper GNNs, but one needs to be careful to convert the sampled subgraph into a bidirectional variant to ensure correct message passing flow.\n:pyg:`PyG` provides support for this via an additional argument in :class:`~torch_geometric.loader.NeighborLoader`, while other mini-batch techniques are designed for this use-case out-of-the-box, *e.g.*, :class:`~torch_geometric.loader.ClusterLoader`, :class:`~torch_geometric.loader.GraphSAINTSampler` and :class:`~torch_geometric.loader.ShaDowKHopSampler`.\n\nBasic Usage\n-----------\n\n.. note::\n\n    In this section of the tutorial, we will learn how to utilize the :class:`~torch_geometric.nn.models.Node2Vec` class of :pyg:`PyG` to train GNNs on single graphs in a mini-batch fashion.\n    A fully working example on large-scale real-world data is available in `examples/reddit.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py>`_.\n\nThe :class:`~torch_geometric.loader.NeighborLoader` is initialized from a :pyg:`PyG` :class:`~torch_geometric.data.Data` or :class:`~torch_geometric.data.HeteroData` object and defines how sampling should be performed:\n\n* :obj:`input_nodes` defines the set of seed nodes from which we want to start sampling from.\n* :obj:`num_neighbors` defines the number of neighbors to sample for each node in each hop.\n* :obj:`batch_size` defines the size of seed nodes we want to consider at once.\n* :obj:`replace` defines whether to sample with or without replacement.\n* :obj:`shuffle` defines whether seed nodes should be shuffled at every epoch.\n\n.. code-block:: python\n\n    import torch\n    from torch_geometric.data import Data\n    from torch_geometric.loader import NeighborLoader\n\n    x = torch.randn(8, 32)  # Node features of shape [num_nodes, num_features]\n    y = torch.randint(0, 4, (8, ))  # Node labels of shape [num_nodes]\n    edge_index = torch.tensor([\n        [2, 3, 3, 4, 5, 6, 7],\n        [0, 0, 1, 1, 2, 3, 4]],\n    )\n\n    #   0  1\n    #  / \\/ \\\n    # 2  3  4\n    # |  |  |\n    # 5  6  7\n\n    data = Data(x=x, y=y, edge_index=edge_index)\n\n    loader = NeighborLoader(\n        data,\n        input_nodes=torch.tensor([0, 1]),\n        num_neighbors=[2, 1],\n        batch_size=1,\n        replace=False,\n        shuffle=False,\n    )\n\nHere, we initialize the :class:`~torch_geometric.loader.NeigborLoader` to sample subgraphs for the first two nodes, where we want to sample 2 neighbors in the first hop, and 1 neighbor in the second hop.\nOur :obj:`batch_size` is set to :obj:`1`, such that :obj:`input_nodes` will be split into chunks of size :obj:`1`.\n\nIn the execution of :class:`~torch_geometric.loader.NeighborLoader`, we expect that the seed node :obj:`0` samples nodes :obj:`2` and :obj:`3` in the first hop. In the second hop, node :obj:`2` samples node :obj:`5`, and node :obj:`3` samples node :obj:`6`.\nLet's confirm by looking at the output of the :obj:`loader`:\n\n.. code-block:: python\n\n    batch = next(iter(loader))\n\n    batch.edge_index\n    >>> tensor([[1, 2, 3, 4],\n                [0, 0, 1, 2]])\n\n     batch.n_id\n     >>> tensor([0, 2, 3, 5, 6])\n\n     batch.batch_size\n     >>> 1\n\nThe :class:`~torch_geometric.loader.NeighborLoader` will return a :class:`~torch_geometric.data.Data` object, which contains the following attributes:\n\n* :obj:`batch.edge_index` contain the edge indices of the subgraph.\n* :obj:`batch.n_id` contains the original node indices of all the sampled nodes.\n* :obj:`batch.batch_size` contains the number of seed nodes/the batch size.\n\nIn addition, node and edge features will be filtered to only contain the features of sampled nodes/edges, respectively.\n\nImportantly, :obj:`batch.edge_index` contains the sampled subgraph with relabeled node indices, such that its indices range from :obj:`0` to :obj:`batch.num_nodes - 1`.\nIf you want to reconstruct the original node indices of :obj:`batch.edge_index`, do:\n\n.. code-block:: python\n\n    batch.n_id[batch.edge_index]\n    >>> tensor([[2, 3, 5, 6],\n                [0, 0, 2, 3]])\n\nFurthermore, while :class:`~torch_geometric.loader.NeighborLoader` starts sampling *from* seed nodes, the resulting subgraph will hold edges that point *to* the seed nodes.\nThis aligns well with the default :pyg:`PyG` message passing flow from source to destination nodes.\n\nLastly, nodes in the output of :class:`~torch_geometric.loader.NeighborLoader` are guaranteed to be sorted.\nIn particular, the first :obj:`batch_size` sampled nodes will exactly match with the seed nodes that were used for sampling:\n\n.. code-block:: python\n\n    batch.n_id[:batch.batch_size]\n    >>> tensor([0])\n\nAfterwards, we can use :class:`~torch_geometric.loader.NeighborLoader` as a data loading routine to train GNNs on large-scale graphs in mini-batch fashion.\nFor this, let's create a simple two-layer :class:`~torch_geometric.nn.models.GraphSAGE` model:\n\n.. code-block:: python\n\n    from torch_geometric.nn import GraphSAGE\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n    model = GraphSAGE(\n        in_channels=32,\n        hidden_channels=64,\n        out_channels=4,\n        num_layers=2\n    ).to(device)\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\nWe can now combine the :obj:`loader` and :obj:`model` to define our training routine:\n\n.. code-block:: python\n\n    import torch.nn.functional as F\n\n    for batch in loader:\n        optimizer.zero_grad()\n        batch = batch.to(device)\n        out = model(batch.x, batch.edge_index)\n\n        # NOTE Only consider predictions and labels of seed nodes:\n        y = batch.y[:batch.batch_size]\n        out = out[:batch.batch_size]\n\n        loss = F.cross_entropy(out, y)\n        loss.backward()\n        optimizer.step()\n\nThe training loop follows a similar design to any other :pytorch:`PyTorch` training loop.\nThe only important difference is that by default the model will output a matrix of shape :obj:`[batch.num_nodes, *]`, while we are only interested in the predictions of the seed nodes.\nAs such, we can use efficient slicing both on the node predictions and the ground-truth information :obj:`batch.y` to only obtain predictions and ground-truth information of actual seed nodes.\nThis ensures that we are only making use of the first :obj:`batch_size` many nodes for loss and metric computation.\n\nHierarchical Extension\n----------------------\n\nA drawback of :class:`~torch_geometric.loader.Neighborloader` is that it computes a representations for *all* sampled nodes at *all* depths of the network.\nHowever, nodes sampled in later hops no longer contribute to the node representations of seed nodes in later GNN layers, thus performing useless computation.\n:class:`~torch_geometric.loader.NeighborLoader` will be marginally slower since we are computing node embeddings for nodes we no longer need.\nThis is a trade-off we make to obtain a clean, modular and experimental-friendly GNN design, which does not tie the definition of the model to its utilized data loader routine.\nThe `Hierarchical Neighborhood Sampling <../advanced/hgam.html>`__ tutorial shows how to eliminate this overhead and speed up training and inference in mini-batch GNNs further.\n\nAdvanced Options\n----------------\n\n:class:`~torch_geometric.loader.NeighborLoader` provides many more features for advanced usage.\nIn particular,\n\n* :class:`~torch_geometric.loader.NeighborLoader` supports both sampling on homogeneous and heterogeneous graphs out-of-the-box.\n  For sampling on heterogeneous graphs, simply initialize it with a :class:`~torch_geometric.data.HeteroData` object.\n  Sampling on heterogeneous graphs via :class:`~torch_geometric.loader.NeighborLoader` allows for fine-granular control of sampling parameters, *e.g.*, it allows to specify the number of neighbors to sample for each edge type individually.\n  Take a look at the `Heterogeneous Graph Learning <../advanced/heterogeneous.html>`__ tutorial for additional information.\n\n* By default, :class:`~torch_geometric.loader.NeighborLoader` fuses sampled nodes across different seed nodes into a single subgraph.\n  This way, shared neighbors of seed nodes will not be duplicated in the resulting subgraph and hence save memory.\n  You can disable this behavior by passing the :obj:`disjoint=True` option to the :class:`~torch_geometric.loader.NeighborLoader`.\n\n* By default, the subgraphs returned from :class:`~torch_geometric.loader.NeighborLoader` will be **directed**, which restricts its use to GNNs with equal depth to the number of sampling hops.\n  If you want to utilize deeper GNNs, specify the :obj:`subgraph_type` option.\n  If set to :obj:`\"bidirectional\"`, sampled edges are converted to bidirectional edges.\n  If set to :obj:`\"induced\"`, the returned subgraph will contain the induced subgraph of all sampled nodes.\n\n* :class:`~torch_geometric.loader.NeighborLoader` is designed to perform sampling from individual seed nodes.\n  As such, it is not directly applicable in a link prediction scenario.\n  For this use-cases, we developed the :class:`~torch_geometric.loader.LinkNeighborLoader`, which expects a set of input edges, and will return subgraphs that were created via neighbor sampling from both source and destination nodes.\n"
  },
  {
    "path": "docs/source/tutorial/point_cloud.rst",
    "content": "Point Cloud Processing\n======================\n\nThis tutorial explains how to leverage Graph Neural Networks (GNNs) for operating and training on point cloud data.\nAlthough point clouds do not come with a graph structure by default, we can utilize :pyg:`PyG` transformations to make them applicable for the full suite of GNNs available in :pyg:`PyG`.\nThe key idea is to create a synthetic graph from point clouds, from which we can learn meaningful local geometric structures via a GNN's message passing scheme.\nThese point representations can then be used to, *e.g.*, perform point cloud classification or segmentation.\n\n3D Point Cloud Datasets\n-----------------------\n\n:pyg:`PyG` provides several point cloud datasets, such as the :class:`~torch_geometric.datasets.PCPNetDataset`, :class:`~torch_geometric.datasets.S3DIS` and :class:`~torch_geometric.datasets.ShapeNet` datasets.\nTo get started, we also provide the :class:`~torch_geometric.datasets.GeometricShapes` dataset, which is a toy dataset that contains various geometric shapes such cubes, spheres or pyramids.\nNotably, the :class:`~torch_geometric.datasets.GeometricShapes` dataset contains meshes instead of point clouds by default, represented via :obj:`pos` and :obj:`face` attributes, which hold the information of vertices and their triangular connectivity, respectively:\n\n.. code-block:: python\n\n    from torch_geometric.datasets import GeometricShapes\n\n    dataset = GeometricShapes(root='data/GeometricShapes')\n    print(dataset)\n    >>> GeometricShapes(40)\n\n    data = dataset[0]\n    print(data)\n    >>> Data(pos=[32, 3], face=[3, 30], y=[1])\n\nWhen visualizing the first mesh in the dataset, we can see that it represents a circle:\n\n.. figure:: ../_figures/point_cloud1.png\n  :align: center\n  :width: 40%\n\n|\n\nSince we are interested in point clouds, we can transform our meshes into points via the usage of :class:`torch_geometric.transforms`.\nIn particular, :pyg:`PyG` provides the :class:`~torch_geometric.transforms.SamplePoints` transformation, which will uniformly sample a fixed number of points on the mesh faces according to their face area.\n\nWe can add this transformation to the dataset by simply setting it via :obj:`dataset.transform = SamplePoints(num=...)`.\nEach time an example is accessed from the dataset, the transformation procedure will get called, converting our mesh into a point cloud.\nNote that sampling points is stochastic, and so you will receive a new point cloud upon every access:\n\n.. code-block:: python\n\n    import torch_geometric.transforms as T\n\n    dataset.transform = T.SamplePoints(num=256)\n\n    data = dataset[0]\n    print(data)\n    >>> Data(pos=[256, 3], y=[1])\n\nNote that we now have :obj:`256` points in our example, and the triangular connectivity stored in :obj:`face` has been removed.\nVisualizing the points now shows that we have correctly sampled points on the surface of the initial mesh:\n\n.. figure:: ../_figures/point_cloud2.png\n  :align: center\n  :width: 40%\n\n|\n\nFinally, let's convert our point cloud into a graph.\nSince we are interested in learning local geometric structures, we want to construct a graph in such a way that nearby points are connected.\nTypically, this is either done via :math:`k`-nearest neighbor search or via ball queries (which connect all points that are within a certain radius to the query point).\n:pyg:`PyG` provides utilities for such graph generation via the :class:`~torch_geometric.transforms.KNNGraph` and :class:`~torch_geometric.transforms.RadiusGraph` transformations, respectively.\n\n.. code-block:: python\n\n    from torch_geometric.transforms import SamplePoints, KNNGraph\n\n    dataset.transform = T.Compose([SamplePoints(num=256), KNNGraph(k=6)])\n\n    data = dataset[0]\n    print(data)\n    >>> Data(pos=[256, 3], edge_index=[2, 1536], y=[1])\n\nYou can see that the :obj:`data` object now also contains an :obj:`edge_index` representation, holding :obj:`1536` edges in total, 6 edges for every of the 256 points.\nWe can confirm that our graph looks good via the following visualization:\n\n.. figure:: ../_figures/point_cloud3.png\n  :align: center\n  :width: 40%\n\n|\n\nPointNet++ Implementation\n-------------------------\n\n`PointNet++ <https://arxiv.org/abs/1706.02413>`_ is a pioneering work that proposes a Graph Neural Network architecture for point cloud classification and segmentation.\nPointNet++ processes point clouds iteratively by following a simple grouping, neighborhood aggregation and downsampling scheme:\n\n.. figure:: ../_figures/point_cloud4.png\n  :align: center\n  :width: 100%\n\n|\n\n1. The **grouping phase** constructs a graph :math:`k`-nearest neighbor search or via ball queries as described above.\n\n2. The **neighborhood aggregation** phase executes a GNN layer that, for each point, aggregates information from its direct neighbors (given by the graph constructed in the previous phase).\n   This allows PointNet++ to capture local context at different scales.\n\n3. The **downsampling phase** implements a pooling scheme suitable for point clouds with potentially different sizes.\n   Due to simplicity, we will ignore this phase for now.\n   We recommend to take a look at `examples/pointnet2_classification.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/pointnet2_classification.py>`_ on guidance to how to implement this step.\n\nNeighborhood Aggregation\n~~~~~~~~~~~~~~~~~~~~~~~~\n\nThe PointNet++ layer follows a simple neural message passing scheme defined via\n\n.. math::\n\n    \\mathbf{h}^{(\\ell + 1)}_i = \\max_{j \\in \\mathcal{N}(i)} \\textrm{MLP} \\left( \\mathbf{h}_j^{(\\ell)}, \\mathbf{p}_j - \\mathbf{p}_i \\right)\n\nwhere\n\n    * :math:`\\mathbf{h}_i^{(\\ell)} \\in \\mathbb{R}^d` denotes the hidden features of point :math:`i` in layer :math:`\\ell`, and\n    * :math:`\\mathbf{p}_i \\in \\mathbf{R}^3$` denotes the position of point :math:`i`.\n\nWe can make use of the :class:`~torch_geometric.nn.conv.MessagePassing` interface in :pyg:`PyG` to implement this layer from scratch.\nThe :class:`~torch_geometric.nn.conv.MessagePassing` interface helps us in **creating message passing graph neural networks** by automatically taking care of message propagation.\nHere, we only need to define its :meth:`~torch_geometric.nn.conv.MessagePassing.message` function and which aggregation scheme we want to use, *e.g.*, :obj:`aggr=\"max\"` (see `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/create_gnn.html>`_ for the accompanying tutorial):\n\n.. code-block:: python\n\n    from torch import Tensor\n    from torch.nn import Sequential, Linear, ReLU\n\n    from torch_geometric.nn import MessagePassing\n\n\n    class PointNetLayer(MessagePassing):\n        def __init__(self, in_channels: int, out_channels: int):\n            # Message passing with \"max\" aggregation.\n            super().__init__(aggr='max')\n\n            # Initialization of the MLP:\n            # Here, the number of input features correspond to the hidden\n            # node dimensionality plus point dimensionality (=3).\n            self.mlp = Sequential(\n                Linear(in_channels + 3, out_channels),\n                ReLU(),\n                Linear(out_channels, out_channels),\n            )\n\n        def forward(self,\n            h: Tensor,\n            pos: Tensor,\n            edge_index: Tensor,\n        ) -> Tensor:\n            # Start propagating messages.\n            return self.propagate(edge_index, h=h, pos=pos)\n\n        def message(self,\n            h_j: Tensor,\n            pos_j: Tensor,\n            pos_i: Tensor,\n        ) -> Tensor:\n            # h_j: The features of neighbors as shape [num_edges, in_channels]\n            # pos_j: The position of neighbors as shape [num_edges, 3]\n            # pos_i: The central node position as shape [num_edges, 3]\n\n            edge_feat = torch.cat([h_j, pos_j - pos_i], dim=-1)\n            return self.mlp(edge_feat)\n\nAs one can see, implementing the PointNet++ layer is quite straightforward in :pyg:`PyG`.\nIn the :meth:`__init__` function, we first define that we want to apply **max aggregation**, and afterwards initialize an MLP that takes care of transforming node features of neighbors and the spatial relation between source and destination nodes to a (trainable) message.\n\nIn the :meth:`forward` function, we can start **propagating messages** based on :obj:`edge_index`, and pass in everything needed in order to create messages.\nIn the :meth:`message` function, we can now access neighbor and central node information via :obj:`*_j` and :obj:`*_i` suffixes, respectively, and return a message for each edge.\n\nNetwork Architecture\n~~~~~~~~~~~~~~~~~~~~\n\nWe can make use of above :class:`PointNetLayer` to define our network architecture (or use its equivalent :class:`torch_geometric.nn.conv.PointNetConv` directly integrated in :pyg:`PyG`).\nWith this, our overall :class:`PointNet` architecture looks as follows:\n\n.. code-block:: python\n\n    from torch_geometric.nn import global_max_pool\n\n\n    class PointNet(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n\n            self.conv1 = PointNetLayer(3, 32)\n            self.conv2 = PointNetLayer(32, 32)\n            self.classifier = Linear(32, dataset.num_classes)\n\n        def forward(self,\n            pos: Tensor,\n            edge_index: Tensor,\n            batch: Tensor,\n        ) -> Tensor:\n\n            # Perform two-layers of message passing:\n            h = self.conv1(h=pos, pos=pos, edge_index=edge_index)\n            h = h.relu()\n            h = self.conv2(h=h, pos=pos, edge_index=edge_index)\n            h = h.relu()\n\n            # Global Pooling:\n            h = global_max_pool(h, batch)  # [num_examples, hidden_channels]\n\n            # Classifier:\n            return self.classifier(h)\n\n\n    model = PointNet()\n\nIf we inspect the model, we can see the everything is initialized correctly:\n\n.. code-block:: python\n\n    print(model)\n    >>> PointNet(\n    ...   (conv1): PointNetLayer()\n    ...   (conv2): PointNetLayer()\n    ...   (classifier): Linear(in_features=32, out_features=40, bias=True)\n    ... )\n\nHere, we create our network architecture by inheriting from :class:`torch.nn.Module` and initialize **two** :class:`PointNetLayer` **modules** and a **final linear classifier** in its constructor.\n\nIn the :meth:`forward` method, we apply two graph-based convolutional operators and enhance them by ReLU non-linearities.\nThe first operator takes in 3 input features (the positions of nodes) and maps them to 32 output features.\nAfter that, each point holds information about its 2-hop neighborhood, and should already be able to distinguish between simple local shapes.\n\nNext, we apply a global graph readout function, *i.e.*, :meth:`~torch_geometric.nn.pool.global_max_pool`, which takes the maximum value along the node dimension for each example.\nIn order to map the different nodes to their corresponding examples, we use the :obj:`batch` vector which will be automatically created for use when using the mini-batch :class:`torch_geometric.loader.DataLoader`.\nLast, we apply a linear classifier to map the global 32 features per point cloud to one of the 40 classes.\n\nTraining Procedure\n~~~~~~~~~~~~~~~~~~\n\nWe are now ready to write two simple procedures to train and test our model on the training and test datasets, respectively.\nIf you are not new to :pytorch:`PyTorch`, this scheme should appear familiar to you.\nOtherwise, the :pytorch:`PyTorch` documentation provide a `good introduction on how to train a neural network in PyTorch <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#define-a-loss-function-and-optimizer>`_:\n\n.. code-block:: python\n\n    from torch_geometric.loader import DataLoader\n\n    train_dataset = GeometricShapes(root='data/GeometricShapes', train=True)\n    train_dataset.transform = T.Compose([SamplePoints(num=256), KNNGraph(k=6)])\n    test_dataset = GeometricShapes(root='data/GeometricShapes', train=False)\n    test_dataset.transform = T.Compose([SamplePoints(num=256), KNNGraph(k=6)])\n\n    train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)\n    test_loader = DataLoader(test_dataset, batch_size=10)\n\n    model = PointNet()\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n    criterion = torch.nn.CrossEntropyLoss()\n\n    def train():\n        model.train()\n\n        total_loss = 0\n        for data in train_loader:\n            optimizer.zero_grad()\n            logits = model(data.pos, data.edge_index, data.batch)\n            loss = criterion(logits, data.y)\n            loss.backward()\n            optimizer.step()\n            total_loss += float(loss) * data.num_graphs\n\n        return total_loss / len(train_loader.dataset)\n\n\n    @torch.no_grad()\n    def test():\n        model.eval()\n\n        total_correct = 0\n        for data in test_loader:\n            logits = model(data.pos, data.edge_index, data.batch)\n            pred = logits.argmax(dim=-1)\n            total_correct += int((pred == data.y).sum())\n\n        return total_correct / len(test_loader.dataset)\n\n    for epoch in range(1, 51):\n        loss = train()\n        test_acc = test()\n        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Acc: {test_acc:.4f}')\n\nUsing this setup, you should get around **75%-80% test set accuracy**, even when training only on a single example per class.\n"
  },
  {
    "path": "docs/source/tutorial/shallow_node_embeddings.rst",
    "content": "Shallow Node Embeddings\n=======================\n\nIn this tutorial, we will take a closer look at how to learn *shallow node embeddings* in an unsupervised fashion via :pyg:`PyG`.\n\nIntroduction\n------------\n\nThe key difference between *shallow* node embeddings (*e.g.,* :class:`~torch_geometric.nn.models.Node2Vec`) and *deep* node embeddings (*e.g.,* GNNs) is the choice of the encoder :math:`\\textrm{ENC}(v, \\mathcal{G}) = \\mathbf{z}_v \\in \\mathbb{R}^d`.\nSpecifically, shallow node embedding techniques rely on embedding nodes into low-dimensional vectorial representations :math:`\\mathbf{z}_v` via a *shallow embedding lookup table* such that the likelihood of preserving neighborhoods is maximized, *i.e.* nearby nodes should receive similar embeddings while distant nodes should receive distinct embedding.\nThese techniques generalize the famous `SkipGram <https://arxiv.org/abs/1310.4546>`_ model for obtaining low-dimensional word embeddings, in which sequences of words are now interpreted as sequences of nodes, *e.g.*, given via randomly-generated walks:\n\n.. figure:: ../_figures/shallow_node_embeddings.png\n  :align: center\n  :width: 100%\n\n|\n\nSpecifically, given a *random walk* :math:`\\mathcal{W} = (v_{\\pi(1)}, \\ldots, v_{\\pi_(k)})` of length :math:`k` starting at node :math:`v \\in \\mathcal{V}`, the objective is to maximize the likelihood of observing node :math:`v_{\\pi(i)}` given node :math:`v`.\nThis objective can be efficiently trained via stochastic gradient descent in a contrastive learning scenario\n\n.. math::\n    \\mathcal{L} = \\sum_{w \\in \\mathcal{W}} - \\log \\left(\\sigma(\\mathbf{z}_v^{\\top} \\mathbf{z}_w) \\right) + \\sum_{w \\sim \\mathcal{V} \\setminus \\mathcal{W}} - \\log \\left( 1 - \\sigma(\\mathbf{z}_v^{\\top} \\mathbf{z}_w) \\right),\n\nin which non-existent walks (so called *negative examples*) are sampled and trained jointly, and :math:`\\sigma` denotes the :math:`\\textrm{sigmoid}` function.\nNoteworthy, the dot-product :math:`\\mathbf{z}_v^{\\top} \\mathbf{z}_w` between the embeddings is usually used to measure similarity, but other similarity measures are applicable as well.\n\nImportantly, shallow node embeddings are trained in an unsupervised fashion, and can eventually be used as input for a given down-stream task, *e.g.*, in node-level tasks :math:`\\mathbf{z}_v` can directly be used as input to a final classifier.\nFor edge-level tasks, edge-level representations can be obtained via averaging :math:`\\frac{1}{2} (\\mathbf{z}_v + \\mathbf{z}_w)` or via the Hadamard product :math:`\\mathbf{z}_v \\odot \\mathbf{z}_w`.\n\nDespite the simplicity of node embedding techniques, they are also subject to certain shortcomings.\nIn particular, they fail to incorporate rich feature information attached to nodes and edges, and cannot be trivially applied to unseen\ngraphs as learnable parameters are fixed to the nodes of a particular graph (making this approach transductive by nature and hard-to-scale due to the :math:`\\mathcal{O}(|\\mathcal{V}| \\cdot d)` parameter complexity).\nHowever, it is still a commonly used technique to preserve structural graph information into fixed-size vectors, and is often times also used to generate inputs to GNNs for further processing in case the initial set of node features is not rich.\n\nNode2Vec\n--------\n\n.. note::\n\n    In this section of the tutorial, we will learn node embeddings for **homogenous graphs** using the :class:`~torch_geometric.nn.models.Node2Vec` module of :pyg:`PyG`.\n    The code is available in `examples/node2vec.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/node2vec.py>`_ and as a `Google Colab tutorial notebook <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial11/Tutorial11.ipynb>`_.\n\n:class:`~torch_geometric.nn.models.Node2Vec` is a method for learning shallow node embeddings, which allows for flexible\ncontrol of random walk procedures based on breadth-first or depth-first samplers.\nIn particular, its parameter :obj:`p` dictates the likelihood of immediately revisiting a node in the walk, while its parameter :obj:`q` interpolates between breadth-first and depth-first strategies.\n\nTo begin the example, let us load in the needed packages and the data that we will be working with:\n\n.. code-block:: python\n\n    from torch_geometric.nn import Node2Vec\n\n    data = Planetoid('./data/Planetoid', name='Cora')[0]\n\nWe are now ready to initialize our :class:`~torch_geometric.nn.module.Node2Vec` module:\n\n.. code-block:: python\n\n    import torch\n    from torch_geometric.nn import Node2Vec\n\n    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n\n    model = Node2Vec(\n        data.edge_index,\n        embedding_dim=128,\n        walks_per_node=10,\n        walk_length=20,\n        context_size=10,\n        p=1.0,\n        q=1.0,\n        num_negative_samples=1,\n    ).to(device)\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n:class:`~torch_geometric.nn.models.Node2Vec` takes the graph structure :obj:`edge_index` as input (but none of its feature information), the :obj:`embedding_dim` of the shallow embeddings, and additional parameters to control the random walk and negative sampling procedures.\nIn particular, :obj:`walks_per_node` and :obj:`walk_length` specify the number of walks to perform for each node and their length, respectively.\nThe :obj:`context_size` then denotes how many nodes in the walk are actually used for gradient optimization, *i.e* :class:`~torch_geometric.nn.models.Node2Vec` slides over each sampled walk and splits them into windows of size :obj:`context_size`.\nAs previously mentioned, :obj:`p` and :obj:`q` denote how random walks are generated.\nFinally, :obj:`num_negative_samples` specifies how many negative walks we want to generate for each positive walk.\n\nAfter initializing, we can go ahead and train our :class:`~torch_geometric.nn.models.Node2Vec` model right away.\nWe start this by creating a data loader that will generate positive and negative random walks for us:\n\n.. code-block:: python\n\n    loader = model.loader(batch_size=128, shuffle=True, num_workers=4)\n\nTo generate random walks, we can simply iterate over the data loader, *e.g.*:\n\n.. code-block:: python\n\n    pos_rw, neg_rw = next(iter(loader))\n\nHere, :obj:`pos_rw` will contain the node indices of positive random walks and :obj:`neg_rw` will contain the node indices of negative walks.\nIn particular, :obj:`pos_rw` is a two-dimensional matrix of shape :obj:`[batch_size * walks_per_node * (2 + walk_length - context_size), context_size]`, and :obj:`neg_rw` is a two-dimensional matrix of shape :obj:`[num_negative_samples * pos_rw.size(0), context_size]`.\n\nUsing this :obj:`loader` and the built-in constrastive :meth:`~torch_geometric.nn.models.Node2Vec.loss` function, we can define our :meth:`train` function as follows:\n\n.. code-block:: python\n\n    def train():\n        model.train()\n        total_loss = 0\n        for pos_rw, neg_rw in loader:\n            optimizer.zero_grad()\n            loss = model.loss(pos_rw.to(device), neg_rw.to(device))\n            loss.backward()\n            optimizer.step()\n            total_loss += loss.item()\n        return total_loss / len(loader)\n\nAfter finishing training, we can obtain the final node embeddings from the model as follows:\n\n.. code-block:: python\n\n    z = model()  # Full node-level embeddings.\n    z = model(torch.tensor([0, 1, 2]))  # Embeddings of first three nodes.\n\n\nMetaPath2Vec\n------------\n\n.. note::\n\n   In this section of the tutorial, we will learn node embeddings for **heterogenous graphs** using the :class:`~torch_geometric.nn.models.MetaPath2Vec` module of :pyg:`PyG`.\n   The code is available as `examples/hetero/metapath2vec.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/hetero/metapath2vec.py>`_ and as a `Google Colab tutorial notebook <https://colab.research.google.com/github/AntonioLonga/PytorchGeometricTutorial/blob/main/Tutorial11/Tutorial11.ipynb>`_.\n\n\nAn extension of :class:`~torch_geometric.nn.models.Node2Vec` to *heterogeneous graphs* is the :class:`~torch_geometric.nn.models.MetaPath2Vec` model.\n:class:`~torch_geometric.nn.models.MetaPath2Vec` works similar to :class:`~torch_geometric.nn.models.Node2Vec` but expects a dictionary of edge indices as input (holding the :obj:`edge_index` for each edge type in the graph), and samples random walks based on a given :obj:`metapath` formulation, *e.g.*,\n\n.. code-block:: python\n\n    metapath = [\n        ('author', 'writes', 'paper'),\n        ('paper', 'published_in', 'venue'),\n        ('venue', 'publishes', 'paper'),\n        ('paper', 'written_by', 'author'),\n    ]\n\ndenotes that random walk sampling is performed from author nodes to paper nodes to venue nodes back to paper nodes and author nodes.\nOtherwise, initialization and training of the model stays the same as in the :class:`~torch_geometric.nn.models.Node2Vec` case.\n"
  },
  {
    "path": "examples/README.md",
    "content": "# Examples\n\nThis folder contains a plethora of examples covering different GNN use-cases.\nThis readme highlights some key examples.\n\n> [!NOTE]\n> We recommend the [NVIDIA PyG Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pyg/tags) for best results and easiest setup with NVIDIA GPUs. See the [cuGraph installation guide](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html#accelerating-pyg-with-nvidia-cugraph-gnn) for details.\n\nA great and simple example to start with is [`gcn.py`](./gcn.py), showing a user how to train a [`GCN`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GCN.html) model for node-level prediction on small-scale homogeneous data.\n\nFor a simple GNN based link prediction example, see [`link_pred.py`](./link_pred.py).\n\nFor an improved GNN based link prediction approach using Attract-Repel embeddings that can significantly boost accuracy (up to 23% improvement in AUC), see [`ar_link_pred.py`](./ar_link_pred.py). This approach is based on [Pseudo-Euclidean Attract-Repel Embeddings for Undirected Graphs](https://arxiv.org/abs/2106.09671).\n\nTo see an example for doing link prediction with an advanced Graph Transformer called [`LPFormer`](https://arxiv.org/abs/2310.11009), see \\[`lpformer.py`\\].\n\nFor examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see the `ogbn_*.py` examples:\n\n- [`ogbn_train.py`](./ogbn_train.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges or the medium scale `ogbn-products` dataset, ~62M edges.\n  - Uses SGFormer (a kind of GraphTransformer) by default.\n  - [SGFormer Paper](https://arxiv.org/pdf/2306.10759)\n  - [Polynormer](https://arxiv.org/pdf/2403.01232)\n  - [Kumo.ai x NVIDIA x Stanford Graph Transformer Webinar](https://www.youtube.com/watch?v=wAYryx3GjLw)\n- [`ogbn_proteins_deepgcn.py`](./ogbn_proteins_deepgcn.py) is an example to showcase how to train deep GNNs on the `ogbn-proteins` dataset.\n- [`ogbn_train_perforatedai.py`](https://github.com/PerforatedAI/PerforatedAI-Examples/tree/master/otherExamples/torch_geometric/OGBNProducts) shows how to optimize the `ogbn_train.py` workflow using [Perforated AI](https://github.com/PerforatedAI/PerforatedAI-API). Perforated AI provides a PyTorch add-on which increases network accuracy by empowering each artificial neuron with artificial dendrites.\n\nFor an example on [Relational Deep Learning](https://arxiv.org/abs/2312.04615) with the [RelBench datasets](https://relbench.stanford.edu/), see [`rdl.py`](./rdl.py).\n\nFor examples on using `torch.compile`, see the examples under [`examples/compile`](./compile).\n\nFor examples on scaling PyG up via multi-GPUs, see the examples under [`examples/multi_gpu`](./multi_gpu).\n\nFor examples on working with heterogeneous data, see the examples under [`examples/hetero`](./hetero).\n\nFor examples on co-training LLMs with GNNs, see the examples under [`examples/llm`](./llm).\n\n- [Stanford GNN+LLM Talk](https://www.nvidia.com/en-us/on-demand/session/other25-nv-0003/)\n\nWe recommend looking into [PyTorch documentation](https://docs.pytorch.org/tutorials/beginner/dist_overview.html) for examples on setting up model parralel GNNs.\n\n### Scale to Trillions of Edges with cuGraph\n\n[cuGraph](https://github.com/rapidsai/cugraph) is a collection of packages focused on GPU-accelerated graph analytics including support for property graphs and scaling up to thousands of GPUs. cuGraph supports the creation and manipulation of graphs followed by the execution of scalable fast graph algorithms. It is part of the [RAPIDS](https://rapids.ai) accelerated data science framework.\n\n[cuGraph GNN](https://github.com/rapidsai/cugraph-gnn) is a collection of GPU-accelerated plugins that support PyTorch and PyG natively through the _cuGraph-PyG_ and _WholeGraph_ subprojects. cuGraph GNN is built on top of cuGraph, leveraging its low-level [pylibcugraph](https://github.com/rapidsai/cugraph/python/pylibcugraph) API and C++ primitives for sampling and other GNN operations ([libcugraph](https://github.com/rapidai/cugraph/python/libcugraph)). It also includes the `libwholegraph` and `pylibwholegraph` libraries for high-performance distributed edgelist and embedding storage. Users have the option of working with these lower-level libraries directly, or through the higher-level API in cuGraph-PyG that directly implements the `GraphStore`, `FeatureStore`, `NodeLoader`, and `LinkLoader` interfaces.\n\nComplete documentation on RAPIDS graph packages, including `cugraph`, `cugraph-pyg`, `pylibwholegraph`, and `pylibcugraph` is available on the [RAPIDS docs pages](https://docs.rapids.ai/api/cugraph/nightly/graph_support).\n\nSee [`rapidsai/cugraph-gnn/tree/branch-25.12/python/cugraph-pyg/cugraph_pyg/examples` on GitHub](https://github.com/rapidsai/cugraph-gnn/tree/branch-25.12/python/cugraph-pyg/cugraph_pyg/examples) for fully scalable PyG example workflows.\n"
  },
  {
    "path": "examples/agnn.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import AGNNConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\ndata = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lin1 = torch.nn.Linear(dataset.num_features, 16)\n        self.prop1 = AGNNConv(requires_grad=False)\n        self.prop2 = AGNNConv(requires_grad=True)\n        self.lin2 = torch.nn.Linear(16, dataset.num_classes)\n\n    def forward(self):\n        x = F.dropout(data.x, training=self.training)\n        x = F.relu(self.lin1(x))\n        x = self.prop1(x, data.edge_index)\n        x = self.prop2(x, data.edge_index)\n        x = F.dropout(x, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel, data = Net().to(device), data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()\n    optimizer.step()\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    out, accs = model(), []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        pred = out[mask].argmax(1)\n        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n        accs.append(acc)\n    return accs\n\n\nbest_val_acc = test_acc = 0\nfor epoch in range(1, 201):\n    train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '\n          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/ar_link_pred.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GCNConv\nfrom torch_geometric.utils import negative_sampling, train_test_split_edges\n\n\nclass GCNEncoder(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, hidden_channels)\n        self.conv2 = GCNConv(hidden_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        return self.conv2(x, edge_index)\n\n\nclass LinkPredictor(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels):\n        super().__init__()\n        self.lin1 = torch.nn.Linear(in_channels * 2, hidden_channels)\n        self.lin2 = torch.nn.Linear(hidden_channels, 1)\n\n    def forward(self, z_i, z_j):\n        x = torch.cat([z_i, z_j], dim=1)\n        x = self.lin1(x).relu()\n        x = self.lin2(x)\n        return x.view(-1)\n\n\nclass ARLinkPredictor(torch.nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        # Split dimensions between attract and repel\n        self.attract_dim = in_channels // 2\n        self.repel_dim = in_channels - self.attract_dim\n\n    def forward(self, z_i, z_j):\n        # Split into attract and repel parts\n        z_i_attr = z_i[:, :self.attract_dim]\n        z_i_repel = z_i[:, self.attract_dim:]\n\n        z_j_attr = z_j[:, :self.attract_dim]\n        z_j_repel = z_j[:, self.attract_dim:]\n\n        # Calculate AR score\n        attract_score = (z_i_attr * z_j_attr).sum(dim=1)\n        repel_score = (z_i_repel * z_j_repel).sum(dim=1)\n\n        return attract_score - repel_score\n\n\ndef train(encoder, predictor, data, optimizer):\n    encoder.train()\n    predictor.train()\n\n    # Forward pass and calculate loss\n    optimizer.zero_grad()\n    z = encoder(data.x, data.train_pos_edge_index)\n\n    # Positive edges\n    pos_out = predictor(z[data.train_pos_edge_index[0]],\n                        z[data.train_pos_edge_index[1]])\n\n    # Sample and predict on negative edges\n    neg_edge_index = negative_sampling(\n        edge_index=data.train_pos_edge_index,\n        num_nodes=data.num_nodes,\n        num_neg_samples=data.train_pos_edge_index.size(1),\n    )\n    neg_out = predictor(z[neg_edge_index[0]], z[neg_edge_index[1]])\n\n    # Calculate loss\n    pos_loss = F.binary_cross_entropy_with_logits(pos_out,\n                                                  torch.ones_like(pos_out))\n    neg_loss = F.binary_cross_entropy_with_logits(neg_out,\n                                                  torch.zeros_like(neg_out))\n    loss = pos_loss + neg_loss\n\n    loss.backward()\n    optimizer.step()\n\n    return loss.item()\n\n\n@torch.no_grad()\ndef test(encoder, predictor, data):\n    encoder.eval()\n    predictor.eval()\n\n    z = encoder(data.x, data.train_pos_edge_index)\n\n    pos_val_out = predictor(z[data.val_pos_edge_index[0]],\n                            z[data.val_pos_edge_index[1]])\n    neg_val_out = predictor(z[data.val_neg_edge_index[0]],\n                            z[data.val_neg_edge_index[1]])\n\n    pos_test_out = predictor(z[data.test_pos_edge_index[0]],\n                             z[data.test_pos_edge_index[1]])\n    neg_test_out = predictor(z[data.test_neg_edge_index[0]],\n                             z[data.test_neg_edge_index[1]])\n\n    val_auc = compute_auc(pos_val_out, neg_val_out)\n    test_auc = compute_auc(pos_test_out, neg_test_out)\n\n    return val_auc, test_auc\n\n\ndef compute_auc(pos_out, neg_out):\n    pos_out = torch.sigmoid(pos_out).cpu().numpy()\n    neg_out = torch.sigmoid(neg_out).cpu().numpy()\n\n    # Simple AUC calculation\n    from sklearn.metrics import roc_auc_score\n    y_true = torch.cat(\n        [torch.ones(pos_out.shape[0]),\n         torch.zeros(neg_out.shape[0])])\n    y_score = torch.cat([torch.tensor(pos_out), torch.tensor(neg_out)])\n\n    return roc_auc_score(y_true, y_score)\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--dataset', type=str, default='Cora',\n                        choices=['Cora', 'CiteSeer', 'PubMed'])\n    parser.add_argument('--hidden_channels', type=int, default=128)\n    parser.add_argument('--out_channels', type=int, default=64)\n    parser.add_argument('--epochs', type=int, default=200)\n    parser.add_argument('--use_ar', action='store_true',\n                        help='Use Attract-Repel embeddings')\n    parser.add_argument('--lr', type=float, default=0.01)\n    args = parser.parse_args()\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n    # Load dataset\n    transform = T.Compose([\n        T.NormalizeFeatures(),\n        T.ToDevice(device),\n    ])\n\n    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',\n                    args.dataset)\n    dataset = Planetoid(path, args.dataset, transform=transform)\n    data = dataset[0]\n\n    # Process data for link prediction\n    data = train_test_split_edges(data)\n\n    # Initialize encoder\n    encoder = GCNEncoder(\n        in_channels=dataset.num_features,\n        hidden_channels=args.hidden_channels,\n        out_channels=args.out_channels,\n    ).to(device)\n\n    # Choose predictor based on args\n    if args.use_ar:\n        predictor = ARLinkPredictor(in_channels=args.out_channels).to(device)\n        print(f\"Running link prediction on {args.dataset}\"\n              f\"with Attract-Repel embeddings\")\n    else:\n        predictor = LinkPredictor(\n            in_channels=args.out_channels,\n            hidden_channels=args.hidden_channels).to(device)\n        print(f\"Running link prediction on {args.dataset}\"\n              f\"with Traditional embeddings\")\n\n    optimizer = torch.optim.Adam(\n        list(encoder.parameters()) + list(predictor.parameters()), lr=args.lr)\n\n    best_val_auc = 0\n    final_test_auc = 0\n\n    for epoch in range(1, args.epochs + 1):\n        loss = train(encoder, predictor, data, optimizer)\n        val_auc, test_auc = test(encoder, predictor, data)\n\n        if val_auc > best_val_auc:\n            best_val_auc = val_auc\n            final_test_auc = test_auc\n\n        if epoch % 10 == 0:\n            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '\n                  f'Val AUC: {val_auc:.4f}, '\n                  f'Test AUC: {test_auc:.4f}')\n\n    print(f'Final results - Val AUC: {best_val_auc:.4f}, '\n          f'Test AUC: {final_test_auc:.4f}')\n\n    # Calculate R-fraction if using AR\n    if args.use_ar:\n        with torch.no_grad():\n            z = encoder(data.x, data.train_pos_edge_index)\n            attr_dim = args.out_channels // 2\n\n            z_attr = z[:, :attr_dim]\n            z_repel = z[:, attr_dim:]\n\n            attract_norm_squared = torch.sum(z_attr**2)\n            repel_norm_squared = torch.sum(z_repel**2)\n\n            r_fraction = repel_norm_squared / (attract_norm_squared +\n                                               repel_norm_squared)\n            print(f\"R-fraction: {r_fraction.item():.4f}\")\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/argva_node_clustering.py",
    "content": "import os.path as osp\n\nimport matplotlib.pyplot as plt\nimport torch\nfrom sklearn.cluster import KMeans\nfrom sklearn.manifold import TSNE\nfrom sklearn.metrics.cluster import (\n    completeness_score,\n    homogeneity_score,\n    v_measure_score,\n)\nfrom torch.nn import Linear\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import ARGVA, GCNConv\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\ntransform = T.Compose([\n    T.ToDevice(device),\n    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,\n                      split_labels=True, add_negative_train_samples=False),\n])\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, name='Cora', transform=transform)\ntrain_data, val_data, test_data = dataset[0]\n\n\nclass Encoder(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, hidden_channels)\n        self.conv_mu = GCNConv(hidden_channels, out_channels)\n        self.conv_logstd = GCNConv(hidden_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)\n\n\nclass Discriminator(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        self.lin1 = Linear(in_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, hidden_channels)\n        self.lin3 = Linear(hidden_channels, out_channels)\n\n    def forward(self, x):\n        x = self.lin1(x).relu()\n        x = self.lin2(x).relu()\n        return self.lin3(x)\n\n\nencoder = Encoder(train_data.num_features, hidden_channels=32, out_channels=32)\ndiscriminator = Discriminator(in_channels=32, hidden_channels=64,\n                              out_channels=32)\nmodel = ARGVA(encoder, discriminator).to(device)\n\nencoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=0.005)\ndiscriminator_optimizer = torch.optim.Adam(discriminator.parameters(),\n                                           lr=0.001)\n\n\ndef train():\n    model.train()\n    encoder_optimizer.zero_grad()\n    z = model.encode(train_data.x, train_data.edge_index)\n\n    # We optimize the discriminator more frequently than the encoder.\n    for _ in range(5):\n        discriminator_optimizer.zero_grad()\n        discriminator_loss = model.discriminator_loss(z)\n        discriminator_loss.backward()\n        discriminator_optimizer.step()\n\n    loss = model.recon_loss(z, train_data.pos_edge_label_index)\n    loss = loss + model.reg_loss(z)\n    loss = loss + (1 / train_data.num_nodes) * model.kl_loss()\n    loss.backward()\n    encoder_optimizer.step()\n    return float(loss.detach())\n\n\n@torch.no_grad()\ndef test(data):\n    model.eval()\n    z = model.encode(data.x, data.edge_index)\n\n    # Cluster embedded values using k-means.\n    kmeans_input = z.cpu().numpy()\n    kmeans = KMeans(n_clusters=7, random_state=0,\n                    n_init='auto').fit(kmeans_input)\n    pred = kmeans.predict(kmeans_input)\n\n    labels = data.y.cpu().numpy()\n    completeness = completeness_score(labels, pred)\n    hm = homogeneity_score(labels, pred)\n    nmi = v_measure_score(labels, pred)\n\n    auc, ap = model.test(z, data.pos_edge_label_index,\n                         data.neg_edge_label_index)\n\n    return auc, ap, completeness, hm, nmi\n\n\nfor epoch in range(1, 151):\n    loss = train()\n    auc, ap, completeness, hm, nmi = test(test_data)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.3f}, AUC: {auc:.3f}, '\n          f'AP: {ap:.3f}, Completeness: {completeness:.3f}, '\n          f'Homogeneity: {hm:.3f}, NMI: {nmi:.3f}')\n\n\n@torch.no_grad()\ndef plot_points(data, colors):\n    model.eval()\n    z = model.encode(data.x, data.edge_index)\n    z = TSNE(n_components=2).fit_transform(z.cpu().numpy())\n    y = data.y.cpu().numpy()\n\n    plt.figure(figsize=(8, 8))\n    for i in range(dataset.num_classes):\n        plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])\n    plt.axis('off')\n    plt.show()\n\n\ncolors = [\n    '#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700'\n]\nplot_points(test_data, colors)\n"
  },
  {
    "path": "examples/arma.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import ARMAConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\ndata = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n\n        self.conv1 = ARMAConv(in_channels, hidden_channels, num_stacks=3,\n                              num_layers=2, shared_weights=True, dropout=0.25)\n\n        self.conv2 = ARMAConv(hidden_channels, out_channels, num_stacks=3,\n                              num_layers=2, shared_weights=True, dropout=0.25,\n                              act=lambda x: x)\n\n    def forward(self, x, edge_index):\n        x = F.dropout(x, training=self.training)\n        x = F.relu(self.conv1(x, edge_index))\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nmodel, data = Net(dataset.num_features, 16,\n                  dataset.num_classes).to(device), data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n\n\ndef test():\n    model.eval()\n    out, accs = model(data.x, data.edge_index), []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        pred = out[mask].argmax(1)\n        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n        accs.append(acc)\n    return accs\n\n\nbest_val_acc = test_acc = 0\nfor epoch in range(1, 401):\n    train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '\n          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/attentive_fp.py",
    "content": "import os.path as osp\nfrom math import sqrt\n\nimport torch\nimport torch.nn.functional as F\nfrom rdkit import Chem\n\nfrom torch_geometric.datasets import MoleculeNet\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn.models import AttentiveFP\n\n\nclass GenFeatures:\n    def __init__(self):\n        self.symbols = [\n            'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br',\n            'Te', 'I', 'At', 'other'\n        ]\n\n        self.hybridizations = [\n            Chem.rdchem.HybridizationType.SP,\n            Chem.rdchem.HybridizationType.SP2,\n            Chem.rdchem.HybridizationType.SP3,\n            Chem.rdchem.HybridizationType.SP3D,\n            Chem.rdchem.HybridizationType.SP3D2,\n            'other',\n        ]\n\n        self.stereos = [\n            Chem.rdchem.BondStereo.STEREONONE,\n            Chem.rdchem.BondStereo.STEREOANY,\n            Chem.rdchem.BondStereo.STEREOZ,\n            Chem.rdchem.BondStereo.STEREOE,\n        ]\n\n    def __call__(self, data):\n        # Generate AttentiveFP features according to Table 1.\n        mol = Chem.MolFromSmiles(data.smiles)\n\n        xs = []\n        for atom in mol.GetAtoms():\n            symbol = [0.] * len(self.symbols)\n            symbol[self.symbols.index(atom.GetSymbol())] = 1.\n            degree = [0.] * 6\n            degree[atom.GetDegree()] = 1.\n            formal_charge = atom.GetFormalCharge()\n            radical_electrons = atom.GetNumRadicalElectrons()\n            hybridization = [0.] * len(self.hybridizations)\n            hybridization[self.hybridizations.index(\n                atom.GetHybridization())] = 1.\n            aromaticity = 1. if atom.GetIsAromatic() else 0.\n            hydrogens = [0.] * 5\n            hydrogens[atom.GetTotalNumHs()] = 1.\n            chirality = 1. if atom.HasProp('_ChiralityPossible') else 0.\n            chirality_type = [0.] * 2\n            if atom.HasProp('_CIPCode'):\n                chirality_type[['R', 'S'].index(atom.GetProp('_CIPCode'))] = 1.\n\n            x = torch.tensor(symbol + degree + [formal_charge] +\n                             [radical_electrons] + hybridization +\n                             [aromaticity] + hydrogens + [chirality] +\n                             chirality_type)\n            xs.append(x)\n\n        data.x = torch.stack(xs, dim=0)\n\n        edge_indices = []\n        edge_attrs = []\n        for bond in mol.GetBonds():\n            edge_indices += [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]]\n            edge_indices += [[bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]]\n\n            bond_type = bond.GetBondType()\n            single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.\n            double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.\n            triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.\n            aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.\n            conjugation = 1. if bond.GetIsConjugated() else 0.\n            ring = 1. if bond.IsInRing() else 0.\n            stereo = [0.] * 4\n            stereo[self.stereos.index(bond.GetStereo())] = 1.\n\n            edge_attr = torch.tensor(\n                [single, double, triple, aromatic, conjugation, ring] + stereo)\n\n            edge_attrs += [edge_attr, edge_attr]\n\n        if len(edge_attrs) == 0:\n            data.edge_index = torch.zeros((2, 0), dtype=torch.long)\n            data.edge_attr = torch.zeros((0, 10), dtype=torch.float)\n        else:\n            data.edge_index = torch.tensor(edge_indices).t().contiguous()\n            data.edge_attr = torch.stack(edge_attrs, dim=0)\n\n        return data\n\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'AFP_Mol')\ndataset = MoleculeNet(path, name='ESOL', pre_transform=GenFeatures()).shuffle()\n\nN = len(dataset) // 10\nval_dataset = dataset[:N]\ntest_dataset = dataset[N:2 * N]\ntrain_dataset = dataset[2 * N:]\n\ntrain_loader = DataLoader(train_dataset, batch_size=200, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=200)\ntest_loader = DataLoader(test_dataset, batch_size=200)\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = AttentiveFP(in_channels=39, hidden_channels=200, out_channels=1,\n                    edge_dim=10, num_layers=2, num_timesteps=2,\n                    dropout=0.2).to(device)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=10**-2.5,\n                             weight_decay=10**-5)\n\n\ndef train():\n    total_loss = total_examples = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.edge_index, data.edge_attr, data.batch)\n        loss = F.mse_loss(out, data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss) * data.num_graphs\n        total_examples += data.num_graphs\n    return sqrt(total_loss / total_examples)\n\n\n@torch.no_grad()\ndef test(loader):\n    mse = []\n    for data in loader:\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.edge_attr, data.batch)\n        mse.append(F.mse_loss(out, data.y, reduction='none').cpu())\n    return float(torch.cat(mse, dim=0).mean().sqrt())\n\n\nfor epoch in range(1, 201):\n    train_rmse = train()\n    val_rmse = test(val_loader)\n    test_rmse = test(test_loader)\n    print(f'Epoch: {epoch:03d}, Loss: {train_rmse:.4f} Val: {val_rmse:.4f} '\n          f'Test: {test_rmse:.4f}')\n"
  },
  {
    "path": "examples/autoencoder.py",
    "content": "import argparse\nimport os.path as osp\nimport time\n\nimport torch\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GAE, VGAE, GCNConv\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--variational', action='store_true')\nparser.add_argument('--linear', action='store_true')\nparser.add_argument('--dataset', type=str, default='Cora',\n                    choices=['Cora', 'CiteSeer', 'PubMed'])\nparser.add_argument('--epochs', type=int, default=400)\nargs = parser.parse_args()\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\ntransform = T.Compose([\n    T.NormalizeFeatures(),\n    T.ToDevice(device),\n    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,\n                      split_labels=True, add_negative_train_samples=False),\n])\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, args.dataset, transform=transform)\ntrain_data, val_data, test_data = dataset[0]\n\n\nclass GCNEncoder(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, 2 * out_channels)\n        self.conv2 = GCNConv(2 * out_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        return self.conv2(x, edge_index)\n\n\nclass VariationalGCNEncoder(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, 2 * out_channels)\n        self.conv_mu = GCNConv(2 * out_channels, out_channels)\n        self.conv_logstd = GCNConv(2 * out_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)\n\n\nclass LinearEncoder(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv = GCNConv(in_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        return self.conv(x, edge_index)\n\n\nclass VariationalLinearEncoder(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv_mu = GCNConv(in_channels, out_channels)\n        self.conv_logstd = GCNConv(in_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)\n\n\nin_channels, out_channels = dataset.num_features, 16\n\nif not args.variational and not args.linear:\n    model = GAE(GCNEncoder(in_channels, out_channels))\nelif not args.variational and args.linear:\n    model = GAE(LinearEncoder(in_channels, out_channels))\nelif args.variational and not args.linear:\n    model = VGAE(VariationalGCNEncoder(in_channels, out_channels))\nelif args.variational and args.linear:\n    model = VGAE(VariationalLinearEncoder(in_channels, out_channels))\n\nmodel = model.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    z = model.encode(train_data.x, train_data.edge_index)\n    loss = model.recon_loss(z, train_data.pos_edge_label_index)\n    if args.variational:\n        loss = loss + (1 / train_data.num_nodes) * model.kl_loss()\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test(data):\n    model.eval()\n    z = model.encode(data.x, data.edge_index)\n    return model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)\n\n\ntimes = []\nfor epoch in range(1, args.epochs + 1):\n    start = time.time()\n    loss = train()\n    auc, ap = test(test_data)\n    print(f'Epoch: {epoch:03d}, AUC: {auc:.4f}, AP: {ap:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/cluster_gcn_ppi.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.metrics import f1_score\n\nfrom torch_geometric.data import Batch\nfrom torch_geometric.datasets import PPI\nfrom torch_geometric.loader import ClusterData, ClusterLoader, DataLoader\nfrom torch_geometric.nn import BatchNorm, SAGEConv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI')\ntrain_dataset = PPI(path, split='train')\nval_dataset = PPI(path, split='val')\ntest_dataset = PPI(path, split='test')\n\ntrain_data = Batch.from_data_list(train_dataset)\ncluster_data = ClusterData(train_data, num_parts=50, recursive=False,\n                           save_dir=train_dataset.processed_dir)\ntrain_loader = ClusterLoader(cluster_data, batch_size=1, shuffle=True,\n                             num_workers=12)\n\nval_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)\ntest_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):\n        super().__init__()\n        self.convs = torch.nn.ModuleList()\n        self.batch_norms = torch.nn.ModuleList()\n        self.convs.append(SAGEConv(in_channels, hidden_channels))\n        self.batch_norms.append(BatchNorm(hidden_channels))\n        for _ in range(num_layers - 2):\n            self.convs.append(SAGEConv(hidden_channels, hidden_channels))\n            self.batch_norms.append(BatchNorm(hidden_channels))\n        self.convs.append(SAGEConv(hidden_channels, out_channels))\n\n    def forward(self, x, edge_index):\n        for conv, batch_norm in zip(self.convs[:-1], self.batch_norms):\n            x = conv(x, edge_index)\n            x = batch_norm(x)\n            x = F.relu(x)\n            x = F.dropout(x, p=0.2, training=self.training)\n        return self.convs[-1](x, edge_index)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(in_channels=train_dataset.num_features, hidden_channels=1024,\n            out_channels=train_dataset.num_classes, num_layers=6).to(device)\nloss_op = torch.nn.BCEWithLogitsLoss()\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        loss = loss_op(model(data.x, data.edge_index), data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item() * data.num_nodes\n    return total_loss / train_data.num_nodes\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    ys, preds = [], []\n    for data in loader:\n        ys.append(data.y)\n        out = model(data.x.to(device), data.edge_index.to(device))\n        preds.append((out > 0).float().cpu())\n\n    y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()\n    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0\n\n\ntimes = []\nfor epoch in range(1, 201):\n    start = time.time()\n    loss = train()\n    val_f1 = test(val_loader)\n    test_f1 = test(test_loader)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, '\n          f'Test: {test_f1:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/cluster_gcn_reddit.py",
    "content": "import time\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import ModuleList\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import Reddit\nfrom torch_geometric.loader import ClusterData, ClusterLoader, NeighborLoader\nfrom torch_geometric.nn import SAGEConv\n\ndataset = Reddit('../data/Reddit')\ndata = dataset[0]\n\ncluster_data = ClusterData(data, num_parts=1500, recursive=False,\n                           save_dir=dataset.processed_dir)\ntrain_loader = ClusterLoader(cluster_data, batch_size=20, shuffle=True,\n                             num_workers=12)\n\nsubgraph_loader = NeighborLoader(data, num_neighbors=[-1], batch_size=1024,\n                                 shuffle=False, num_workers=12)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.convs = ModuleList(\n            [SAGEConv(in_channels, 128),\n             SAGEConv(128, out_channels)])\n\n    def forward(self, x, edge_index):\n        for i, conv in enumerate(self.convs):\n            x = conv(x, edge_index)\n            if i != len(self.convs) - 1:\n                x = F.relu(x)\n                x = F.dropout(x, p=0.5, training=self.training)\n        return F.log_softmax(x, dim=-1)\n\n    def inference(self, x_all):\n        pbar = tqdm(total=x_all.size(0) * len(self.convs))\n        pbar.set_description('Evaluating')\n\n        # Compute representations of nodes layer by layer, using *all*\n        # available edges. This leads to faster computation in contrast to\n        # immediately computing the final representations of each batch.\n        for i, conv in enumerate(self.convs):\n            xs = []\n            for batch in subgraph_loader:\n                edge_index = batch.edge_index.to(device)\n                x = x_all[batch.n_id].to(device)\n                x_target = x[:batch.batch_size]\n                x = conv((x, x_target), edge_index)\n                if i != len(self.convs) - 1:\n                    x = F.relu(x)\n                xs.append(x.cpu())\n\n                pbar.update(batch.batch_size)\n\n            x_all = torch.cat(xs, dim=0)\n\n        pbar.close()\n\n        return x_all\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(dataset.num_features, dataset.num_classes).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n\n\ndef train():\n    model.train()\n\n    total_loss = total_nodes = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n        out = model(batch.x, batch.edge_index)\n        loss = F.nll_loss(out[batch.train_mask], batch.y[batch.train_mask])\n        loss.backward()\n        optimizer.step()\n\n        nodes = batch.train_mask.sum().item()\n        total_loss += loss.item() * nodes\n        total_nodes += nodes\n\n    return total_loss / total_nodes\n\n\n@torch.no_grad()\ndef test():  # Inference should be performed on the full graph.\n    model.eval()\n\n    out = model.inference(data.x)\n    y_pred = out.argmax(dim=-1)\n\n    accs = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        correct = y_pred[mask].eq(data.y[mask]).sum().item()\n        accs.append(correct / mask.sum().item())\n    return accs\n\n\ntimes = []\nfor epoch in range(1, 31):\n    start = time.time()\n    loss = train()\n    if epoch % 5 == 0:\n        train_acc, val_acc, test_acc = test()\n        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n              f'Val: {val_acc:.4f}, test: {test_acc:.4f}')\n    else:\n        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/colors_topk_pool.py",
    "content": "import copy\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GINConv, TopKPooling, global_add_pool\nfrom torch_geometric.utils import scatter\n\n\nclass HandleNodeAttention:\n    def __call__(self, data):\n        data = copy.copy(data)\n        data.attn = torch.softmax(data.x[:, 0], dim=0)\n        data.x = data.x[:, 1:]\n        return data\n\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'COLORS-3')\ndataset = TUDataset(path, 'COLORS-3', use_node_attr=True,\n                    transform=HandleNodeAttention())\n\ntrain_loader = DataLoader(dataset[:500], batch_size=60, shuffle=True)\nval_loader = DataLoader(dataset[500:3000], batch_size=60)\ntest_loader = DataLoader(dataset[3000:], batch_size=60)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n\n        self.conv1 = GINConv(Seq(Lin(in_channels, 64), ReLU(), Lin(64, 64)))\n        self.pool1 = TopKPooling(in_channels, min_score=0.05)\n        self.conv2 = GINConv(Seq(Lin(64, 64), ReLU(), Lin(64, 64)))\n\n        self.lin = torch.nn.Linear(64, 1)\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n\n        out = F.relu(self.conv1(x, edge_index))\n\n        out, edge_index, _, batch, perm, score = self.pool1(\n            out, edge_index, None, batch, attn=x)\n        ratio = out.size(0) / x.size(0)\n\n        out = F.relu(self.conv2(out, edge_index))\n        out = global_add_pool(out, batch)\n        out = self.lin(out).view(-1)\n\n        attn_loss = F.kl_div(torch.log(score + 1e-14), data.attn[perm],\n                             reduction='none')\n        attn_loss = scatter(attn_loss, batch, reduce='mean')\n\n        return out, attn_loss, ratio\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(dataset.num_features).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n# Initialize to optimal attention weights:\n# model.pool1.weight.data = torch.tensor([0., 1., 0., 0.]).view(1,4).to(device)\n\n\ndef train(epoch):\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out, attn_loss, _ = model(data)\n        loss = ((out - data.y).pow(2) + 100 * attn_loss).mean()\n        loss.backward()\n        total_loss += loss.item() * data.num_graphs\n        optimizer.step()\n\n    return total_loss / len(train_loader.dataset)\n\n\ndef test(loader):\n    model.eval()\n\n    corrects, total_ratio = [], 0\n    for data in loader:\n        data = data.to(device)\n        out, _, ratio = model(data)\n        pred = out.round().to(torch.long)\n        corrects.append(pred.eq(data.y.to(torch.long)))\n        total_ratio += ratio\n    return torch.cat(corrects, dim=0), total_ratio / len(loader)\n\n\nfor epoch in range(1, 301):\n    loss = train(epoch)\n    train_correct, train_ratio = test(train_loader)\n    val_correct, val_ratio = test(val_loader)\n    test_correct, test_ratio = test(test_loader)\n\n    train_acc = train_correct.sum().item() / train_correct.size(0)\n    val_acc = val_correct.sum().item() / val_correct.size(0)\n\n    test_acc1 = test_correct[:2500].sum().item() / 2500\n    test_acc2 = test_correct[2500:5000].sum().item() / 2500\n    test_acc3 = test_correct[5000:].sum().item() / 2500\n\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.3f}, '\n          f'Val: {val_acc:.3f}, Test Orig: {test_acc1:.3f}, '\n          f'Test Large: {test_acc2:.3f}, Test LargeC: {test_acc3:.3f}, '\n          f'Train/Val/Test Ratio='\n          f'{train_ratio:.3f}/{val_ratio:.3f}/{test_ratio:.3f}')\n"
  },
  {
    "path": "examples/compile/gcn.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GCNConv\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\npath = osp.dirname(osp.realpath(__file__))\npath = osp.join(path, '..', '..', 'data', 'Planetoid')\ndataset = Planetoid(\n    path, name='Cora', transform=T.Compose([\n        T.NormalizeFeatures(),\n        T.GCNNorm(),\n    ]))\ndata = dataset[0].to(device)\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        # Pre-process normalization to avoid CPU communication/graph breaks:\n        self.conv1 = GCNConv(in_channels, hidden_channels, normalize=False)\n        self.conv2 = GCNConv(hidden_channels, out_channels, normalize=False)\n\n    def forward(self, x, edge_index, edge_weight):\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.conv1(x, edge_index, edge_weight).relu()\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.conv2(x, edge_index)\n        return x\n\n\nmodel = GCN(\n    in_channels=dataset.num_features,\n    hidden_channels=16,\n    out_channels=dataset.num_classes,\n).to(device)\n\n# Compile the model into an optimized version:\nmodel = torch.compile(model, dynamic=False)\n\noptimizer = torch.optim.Adam([\n    dict(params=model.conv1.parameters(), weight_decay=5e-4),\n    dict(params=model.conv2.parameters(), weight_decay=0)\n], lr=0.01)  # Only perform weight-decay on first convolution.\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index, data.edge_weight)\n    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss.detach())\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.x, data.edge_index, data.edge_weight).argmax(dim=-1)\n\n    accs = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\ntimes = []\nfor epoch in range(1, 201):\n    start = time.time()\n    loss = train()\n    train_acc, val_acc, test_acc = test()\n    times.append(time.time() - start)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\nprint(f'Median time per epoch: {torch.tensor(times).median():.4f}s')\n"
  },
  {
    "path": "examples/compile/gin.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import MLP, GINConv, global_add_pool\n\nif not torch_geometric.typing.WITH_PT21:\n    quit('Dynamic shape compilation requires PyTorch >= 2.1.0')\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    # MPS is currently slower than CPU due to missing int64 min/max ops\n    device = torch.device('cpu')\nelse:\n    device = torch.device('cpu')\n\npath = osp.dirname(osp.realpath(__file__))\npath = osp.join(path, '..', '..', 'data', 'TU')\ndataset = TUDataset(path, name='MUTAG').shuffle()\n\ntrain_loader = DataLoader(dataset[:0.9], batch_size=128, shuffle=True)\ntest_loader = DataLoader(dataset[0.9:], batch_size=128)\n\n\nclass GIN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            mlp = MLP([in_channels, hidden_channels, hidden_channels])\n            self.convs.append(GINConv(nn=mlp, train_eps=False))\n            in_channels = hidden_channels\n\n        self.mlp = MLP([hidden_channels, hidden_channels, out_channels],\n                       norm=None, dropout=0.5)\n\n    def forward(self, x, edge_index, batch, batch_size):\n        for conv in self.convs:\n            x = conv(x, edge_index).relu()\n        # Pass the batch size to avoid CPU communication/graph breaks:\n        x = global_add_pool(x, batch, size=batch_size)\n        return self.mlp(x)\n\n\nmodel = GIN(\n    in_channels=dataset.num_features,\n    hidden_channels=32,\n    out_channels=dataset.num_classes,\n    num_layers=5,\n).to(device)\n\n# Compile the model into an optimized version:\nmodel = torch.compile(model, dynamic=True)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.edge_index, data.batch, data.batch_size)\n        loss = F.cross_entropy(out, data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss.detach()) * data.num_graphs\n    return total_loss / len(train_loader.dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_correct = 0\n    for data in loader:\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.batch, data.batch_size)\n        pred = out.argmax(dim=-1)\n        total_correct += int((pred == data.y).sum())\n    return total_correct / len(loader.dataset)\n\n\ntimes = []\nfor epoch in range(1, 101):\n    start = time.time()\n    loss = train()\n    train_acc = test(train_loader)\n    test_acc = test(test_loader)\n    times.append(time.time() - start)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Test: {test_acc:.4f}')\nprint(f'Median time per epoch: {torch.tensor(times).median():.4f}s')\n"
  },
  {
    "path": "examples/contrib/README.md",
    "content": "# Examples for External Contributions\n\nThis directory contains examples demonstrating functionality included in the `torch_geometric.contrib` package.\nThe `contrib` package of PyG is a staging area for early-stage, experimental code.\nModules included here might be moved to the main library in the future.\n\n| Example                                                                            | Description                                                                                 |\n| ---------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- |\n| [`rbcd_attack.py`](./rbcd_attack.py)                                               | An example of the RBCD (Resource-based Critical Data) attack                                |\n| [`rbcd_attack_poisoning.py`](./rbcd_attack_poisoning.py)                           | An example of the RBCD (Resource-Based Critical Data) attack with data poisoning strategies |\n| [`pgm_explainer_node_classification.py`](./pgm_explainer_node_classification.py)   | An example of the PGM (Probabilistic Graphical Model) explainer for node classification     |\n| [`pgm_explainer_graph_classification.py`](./pgm_explainer_graph_classification.py) | An example of the PGM (Probabilistic Graphical Model) explainer for graph classification    |\n"
  },
  {
    "path": "examples/contrib/pgm_explainer_graph_classification.py",
    "content": "\"\"\"This is an example of using the PGM explainer algorithm on a graph\nclassification task.\n\"\"\"\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear, ReLU, Sequential\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.contrib.explain import PGMExplainer\nfrom torch_geometric.datasets import MNISTSuperpixels\nfrom torch_geometric.explain import Explainer\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import (\n    NNConv,\n    global_mean_pool,\n    graclus,\n    max_pool,\n    max_pool_x,\n)\nfrom torch_geometric.utils import normalized_cut\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST')\ntransform = T.Cartesian(cat=False)\ntrain_dataset = MNISTSuperpixels(path, True, transform=transform)\ntest_dataset = MNISTSuperpixels(path, False, transform=transform)\ntrain_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)\ntest_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)\nd = train_dataset\n\n\ndef normalized_cut_2d(edge_index, pos):\n    row, col = edge_index\n    edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)\n    return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        nn1 = Sequential(\n            Linear(2, 25),\n            ReLU(),\n            Linear(25, d.num_features * 32),\n        )\n        self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean')\n\n        nn2 = Sequential(\n            Linear(2, 25),\n            ReLU(),\n            Linear(25, 32 * 64),\n        )\n        self.conv2 = NNConv(32, 64, nn2, aggr='mean')\n\n        self.fc1 = torch.nn.Linear(64, 128)\n        self.fc2 = torch.nn.Linear(128, d.num_classes)\n\n    def forward(self, x, edge_index, **kwargs):\n        data = kwargs.get('data')\n        data = data.detach().clone()\n        x = F.elu(self.conv1(x, edge_index, data.edge_attr))\n        weight = normalized_cut_2d(edge_index, data.pos)\n        cluster = graclus(edge_index, weight, x.size(0))\n        data.edge_attr = None\n        data.x = x\n        data.edge_index = edge_index\n        data = max_pool(cluster, data, transform=transform)\n\n        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))\n        weight = normalized_cut_2d(data.edge_index, data.pos)\n        cluster = graclus(data.edge_index, weight, data.x.size(0))\n        x, batch = max_pool_x(cluster, data.x, data.batch)\n\n        x = global_mean_pool(x, batch)\n        x = F.elu(self.fc1(x))\n        x = F.dropout(x, training=self.training)\n        return F.log_softmax(self.fc2(x), dim=1)\n\n\ndef train(model, dataloader):\n    model.train()\n\n    for data in dataloader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        F.nll_loss(model(data.x, data), data.y).backward()\n        optimizer.step()\n\n\nif __name__ == \"__main__\":\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    print(f'current device: {device}')\n    model = Net().to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n    for _ in range(2):\n        train(model, train_loader)\n\n    explainer = Explainer(\n        model=model, algorithm=PGMExplainer(perturb_feature_list=[0],\n                                            perturbation_mode=\"mean\"),\n        explanation_type='phenomenon', node_mask_type=\"object\",\n        model_config=dict(mode=\"multiclass_classification\", task_level=\"graph\",\n                          return_type=\"raw\"))\n    i = 0\n\n    for explain_dataset in test_loader:\n        explain_dataset.to(device)\n        explanation = explainer(x=explain_dataset.x,\n                                edge_index=explain_dataset.edge_index,\n                                target=explain_dataset.y,\n                                edge_attr=explain_dataset.edge_attr,\n                                data=explain_dataset)\n        for k in explanation.available_explanations:\n            print(explanation[k])\n        i += 1\n        if i > 2:\n            break\n"
  },
  {
    "path": "examples/contrib/pgm_explainer_node_classification.py",
    "content": "\"\"\"This is an example of using the PGM explainer algorithm on a node\nclassification task.\n\"\"\"\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.contrib.explain import PGMExplainer\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.explain import Explainer, ModelConfig\nfrom torch_geometric.nn import GCNConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ntransform = T.Compose([T.GCNNorm(), T.NormalizeFeatures()])\ndataset = Planetoid(path, dataset, transform=transform)\ndata = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GCNConv(dataset.num_features, 16, normalize=False)\n        self.conv2 = GCNConv(16, dataset.num_classes, normalize=False)\n\n    def forward(self, x, edge_index, edge_weight):\n        x = F.relu(self.conv1(x, edge_index, edge_weight))\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index, edge_weight)\n        return F.log_softmax(x, dim=1)\n\n\nif __name__ == \"__main__\":\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    model = Net().to(device)\n    data = data.to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01,\n                                 weight_decay=5e-4)\n    x, edge_index, edge_weight, target = \\\n        data.x, data.edge_index, data.edge_weight, data.y\n\n    model.train()\n    for _ in range(1, 500):\n        optimizer.zero_grad()\n        log_logits = model(x, edge_index, edge_weight)\n        loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])\n        loss.backward()\n        optimizer.step()\n\n    model.eval()\n    log_logits = model(x, edge_index, edge_weight)\n    predicted_target = log_logits.argmax(dim=1)\n\n    explainer = Explainer(\n        model=model, algorithm=PGMExplainer(), node_mask_type='attributes',\n        explanation_type='phenomenon',\n        model_config=ModelConfig(mode='multiclass_classification',\n                                 task_level='node', return_type='raw'))\n    node_idx = 100\n    explanation = explainer(x=data.x, edge_index=edge_index, index=node_idx,\n                            target=predicted_target, edge_weight=edge_weight)\n    print(f'Significance of relevant neighbors: {explanation.pgm_stats}')\n"
  },
  {
    "path": "examples/contrib/rbcd_attack.py",
    "content": "import copy\nimport os.path as osp\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.optim import Adam\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.contrib.nn import GRBCDAttack, PRBCDAttack\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GATConv, GCNConv\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.utils import softmax\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid')\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        self.norm = gcn_norm\n        self.conv1 = GCNConv(in_channels, hidden_channels, normalize=False)\n        self.conv2 = GCNConv(hidden_channels, out_channels, normalize=False)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        self.conv2.reset_parameters()\n\n    def forward(self, x, edge_index, edge_weight=None, **kwargs):\n        # Normalize edge indices only once:\n        if not kwargs.get('skip_norm', False):\n            edge_index, edge_weight = self.norm(\n                edge_index,\n                edge_weight,\n                num_nodes=x.size(0),\n                add_self_loops=True,\n            )\n\n        x = self.conv1(x, edge_index, edge_weight).relu()\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index, edge_weight)\n        return x\n\n\nclass WeightedGATConv(GATConv):\n    \"\"\"Extended GAT to allow for weighted edges.\"\"\"\n    def edge_update(self, alpha_j: Tensor, alpha_i: Optional[Tensor],\n                    edge_attr: Optional[Tensor], index: Tensor,\n                    ptr: Optional[Tensor], size_i: Optional[int]) -> Tensor:\n        # Given edge-level attention coefficients for source and target nodes,\n        # we simply need to sum them up to \"emulate\" concatenation:\n        alpha = alpha_j if alpha_i is None else alpha_j + alpha_i\n        alpha = F.leaky_relu(alpha, self.negative_slope)\n\n        if edge_attr is not None:\n            assert edge_attr.dim() == 1, 'Only scalar edge weights supported'\n            edge_attr = edge_attr.view(-1, 1)\n            # `alpha` unchanged if edge_attr == 1 and -Inf if edge_attr == 0;\n            # We choose log to counteract underflow in subsequent exp/softmax\n            alpha = alpha + torch.log2(edge_attr)\n\n        alpha = softmax(alpha, index, ptr, size_i)\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n        return alpha\n\n\nclass GAT(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        # Initialize edge weights of self-loops with 1:\n        self.conv1 = WeightedGATConv(in_channels, hidden_channels,\n                                     fill_value=1.)\n        self.conv2 = WeightedGATConv(hidden_channels, out_channels,\n                                     fill_value=1.)\n\n    def reset_parameters(self):\n        self.conv1.reset_parameters()\n        self.conv2.reset_parameters()\n\n    def forward(self, x, edge_index, edge_weight=None):\n        x = self.conv1(x, edge_index, edge_weight).relu()\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index, edge_weight)\n        return x\n\n\ndef train(model, data, epochs=200, lr=0.01, weight_decay=5e-4):\n    model.train()\n    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)\n    for _ in range(epochs):\n        optimizer.zero_grad()\n        pred = model(data.x, data.edge_index, data.edge_weight)\n        loss = F.cross_entropy(pred[data.train_mask], data.y[data.train_mask])\n        loss.backward()\n        optimizer.step()\n\n\ndef accuracy(pred, y, mask):\n    return (pred.argmax(-1)[mask] == y[mask]).float().mean()\n\n\n@torch.no_grad()\ndef test(model, data):\n    model.eval()\n    pred = model(data.x, data.edge_index, data.edge_weight)\n    return float(accuracy(pred, data.y, data.test_mask))\n\n\n# The metric in PRBCD is assumed to be best if lower (like a loss).\ndef metric(*args, **kwargs):\n    return -accuracy(*args, **kwargs)\n\n\nif __name__ == '__main__':\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n    dataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures())\n    data = dataset[0].to(device)\n\n    gcn = GCN(dataset.num_features, 16, dataset.num_classes).to(device)\n    gat = GAT(dataset.num_features, 16, dataset.num_classes).to(device)\n\n    train(gcn, data)\n    gcn.eval()\n\n    train(gat, data)\n    gat.eval()\n\n    node_idx = 42\n    local_budget = 2  # Degree of (training) node 42 is 2.\n\n    # Perturb 5% of edges:\n    global_budget = int(0.05 * data.edge_index.size(1) / 2)\n\n    print('------------- GAT: Local Evasion -------------')\n    # Note: GRBCD is faster than PRBCD for small budgets but not as consistent\n\n    grbcd = GRBCDAttack(gat, block_size=250_000)\n    # The learning rate is one of the most important parameters for PRBCD and a\n    # good heuristic is to choose it s.t. the budget is exhausted within a few\n    # steps. Moreover, a high learning rate mitigates the impact of the\n    # relaxation gap ({0, 1} -> [0, 1]) of the edge weights. See poisoning\n    #  example for a debug plot.\n    prbcd = PRBCDAttack(gat, block_size=250_000, metric=metric, lr=2_000)\n\n    clean_acc = test(gat, data)\n    print(f'Clean accuracy: {clean_acc:.3f}')\n\n    # GRBCD: Attack a single node:\n    pert_edge_index, perts = grbcd.attack(\n        data.x,\n        data.edge_index,\n        data.y,\n        budget=local_budget,\n        idx_attack=[node_idx],\n    )\n\n    clean_margin = -PRBCDAttack._probability_margin_loss(\n        gat(data.x, data.edge_index), data.y, [node_idx])\n    pert_margin = -PRBCDAttack._probability_margin_loss(\n        gat(data.x, pert_edge_index), data.y, [node_idx])\n    print(f'GRBCD: Confidence margin of target to best non-target dropped '\n          f'from {clean_margin:.3f} to {pert_margin:.3f}')\n    adv_edges = ', '.join(str((u, v)) for u, v in perts.T.tolist())\n    print(f'Adv. edges: {adv_edges}')\n\n    # PRBCD: Attack single node:\n    pert_edge_index, perts = prbcd.attack(\n        data.x,\n        data.edge_index,\n        data.y,\n        budget=local_budget,\n        idx_attack=[node_idx],\n    )\n    clean_margin = -PRBCDAttack._probability_margin_loss(\n        gat(data.x, data.edge_index), data.y, [node_idx])\n    pert_margin = -PRBCDAttack._probability_margin_loss(\n        gat(data.x, pert_edge_index), data.y, [node_idx])\n    print(f'PRBCD: Confidence margin of target to best non-target dropped '\n          f'from {clean_margin:.3f} to {pert_margin:.3f}')\n    adv_edges = ', '.join(str((u, v)) for u, v in perts.T.tolist())\n    print(f'Adv. edges: {adv_edges}\\n')\n\n    print('------------- GCN: Global Evasion -------------')\n\n    grbcd = GRBCDAttack(gcn, block_size=250_000)\n    prbcd = PRBCDAttack(gcn, block_size=250_000, metric=metric, lr=2_000)\n\n    clean_acc = test(gcn, data)\n\n    # GRBCD: Attack test set:\n    pert_edge_index, perts = grbcd.attack(\n        data.x,\n        data.edge_index,\n        data.y,\n        budget=global_budget,\n        idx_attack=data.test_mask,\n    )\n\n    pert_data = copy.copy(data)\n    pert_data.edge_index = pert_edge_index\n    pert_acc = test(gcn, pert_data)\n    print(f'GRBCD: Accuracy dropped from {clean_acc:.3f} to {pert_acc:.3f}')\n\n    # PRBCD: Attack test set:\n    pert_edge_index, perts = prbcd.attack(\n        data.x,\n        data.edge_index,\n        data.y,\n        budget=global_budget,\n        idx_attack=data.test_mask,\n    )\n\n    pert_data = copy.copy(data)\n    pert_data.edge_index = pert_edge_index\n    pert_acc = test(gcn, pert_data)\n    print(f'PRBCD: Accuracy dropped from {clean_acc:.3f} to {pert_acc:.3f}')\n"
  },
  {
    "path": "examples/contrib/rbcd_attack_poisoning.py",
    "content": "import copy\nimport os.path as osp\nimport sys\nfrom typing import Optional, Tuple\n\nimport matplotlib.pyplot as plt\nimport torch\nimport torch.nn.functional as F\nfrom rbcd_attack import GCN, metric, test, train\nfrom torch import Tensor\nfrom torch.optim import Adam\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.contrib.nn import PRBCDAttack\nfrom torch_geometric.datasets import Planetoid\n\ntry:\n    import higher\nexcept ImportError:\n    sys.exit('Install `higher` via `pip install higher` for poisoning example')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid')\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n# IMPORTANT: Edge weights are being ignored later and most adjacency matrix\n# preprocessing should be part of the model (part of backpropagation):\ndataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures())\ndata = dataset[0].to(device)\n\ngcn = GCN(dataset.num_features, 16, dataset.num_classes).to(device)\ntrain(gcn, data)\n\nprint('------------- GCN: Global Poisoning -------------')\n\nclean_acc = test(gcn, data)\nprint(f'Clean accuracy: {clean_acc:.3f}')\n\nn_epochs = 50\nlr = 0.04\nweight_decay = 5e-4\n\n\nclass PoisoningPRBCDAttack(PRBCDAttack):\n    def _forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor,\n                 **kwargs) -> Tensor:\n        \"\"\"Forward model.\"\"\"\n        self.model.reset_parameters()\n\n        with torch.enable_grad():\n            ped = copy.copy(data)\n            ped.x, ped.edge_index, ped.edge_weight = x, edge_index, edge_weight\n            train(self.model, ped, n_epochs, lr, weight_decay)\n\n        self.model.eval()\n        return self.model(x, edge_index, edge_weight)\n\n    def _forward_and_gradient(self, x: Tensor, labels: Tensor,\n                              idx_attack: Optional[Tensor] = None,\n                              **kwargs) -> Tuple[Tensor, Tensor]:\n        \"\"\"Forward and update edge weights.\"\"\"\n        self.block_edge_weight.requires_grad = True\n\n        self.model.reset_parameters()\n\n        self.model.train()\n        opt = Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)\n\n        with higher.innerloop_ctx(self.model, opt) as (fmodel, diffopt):\n            edge_index, edge_weight = self._get_modified_adj(\n                self.edge_index, self.edge_weight, self.block_edge_index,\n                self.block_edge_weight)\n\n            # Normalize only once (only relevant if model normalizes adj)\n            if hasattr(fmodel, 'norm'):\n                edge_index, edge_weight = fmodel.norm(\n                    edge_index,\n                    edge_weight,\n                    num_nodes=x.size(0),\n                    add_self_loops=True,\n                )\n\n            for _ in range(n_epochs):\n                pred = fmodel.forward(x, edge_index, edge_weight,\n                                      skip_norm=True)\n                loss = F.cross_entropy(pred[data.train_mask],\n                                       data.y[data.train_mask])\n                diffopt.step(loss)\n\n            pred = fmodel(x, edge_index, edge_weight)\n            loss = self.loss(pred, labels, idx_attack)\n\n            gradient = torch.autograd.grad(loss, self.block_edge_weight)[0]\n\n        # Clip gradient for stability:\n        clip_norm = 0.5\n        grad_len_sq = gradient.square().sum()\n        if grad_len_sq > clip_norm:\n            gradient *= clip_norm / grad_len_sq.sqrt()\n\n        self.model.eval()\n\n        return loss, gradient\n\n\nprbcd = PoisoningPRBCDAttack(gcn, block_size=250_000, metric=metric, lr=100)\n\n# PRBCD: Attack test set:\nglobal_budget = int(0.05 * data.edge_index.size(1) / 2)  # Perturb 5% of edges\n\npert_edge_index, perts = prbcd.attack(\n    data.x,\n    data.edge_index,\n    data.y,\n    budget=global_budget,\n    idx_attack=data.test_mask,\n)\n\ngcn.reset_parameters()\npert_data = copy.copy(data)\npert_data.edge_index = pert_edge_index\ntrain(gcn, pert_data)\npert_acc = test(gcn, pert_data)\n# Note that the values here a bit more noisy than in the evasion case:\nprint(f'PRBCD: Accuracy dropped from {clean_acc:.3f} to {pert_acc:.3f}')\n\nfig, ax1 = plt.subplots()\nplt.title('Global Poisoning GCN')\ncolor = 'tab:red'\nax1.plot(prbcd.attack_statistics['loss'], color=color, label='Loss')\nax1.tick_params(axis='y', labelcolor=color)\nax1.set_ylabel('Loss')\nax1.set_xlabel('Steps')\n\n# It is best practice choosing the learning rate s.t. the budget is exhausted:\nax2 = ax1.twinx()\ncolor = 'tab:blue'\nax2.plot(prbcd.attack_statistics['prob_mass_after_update'], color=color,\n         linestyle='--', label='Before projection')\nax2.plot(prbcd.attack_statistics['prob_mass_after_projection'], color=color,\n         label='After projection')\nax2.tick_params(axis='y', labelcolor=color)\nax2.set_ylabel('Used budget')\nplt.legend()\nfig.show()\n"
  },
  {
    "path": "examples/cora.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import SplineConv\nfrom torch_geometric.typing import WITH_SPLINE\n\nif not WITH_SPLINE:\n    quit(\"This example requires 'pyg-lib>=0.6.0'\")\n\ndataset = 'Cora'\ntransform = T.Compose([\n    T.RandomNodeSplit(num_val=500, num_test=500),\n    T.TargetIndegree(),\n])\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset, transform=transform)\ndata = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = SplineConv(dataset.num_features, 16, dim=1, kernel_size=2)\n        self.conv2 = SplineConv(16, dataset.num_classes, dim=1, kernel_size=2)\n\n    def forward(self):\n        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr\n        x = F.dropout(x, training=self.training)\n        x = F.elu(self.conv1(x, edge_index, edge_attr))\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index, edge_attr)\n        return F.log_softmax(x, dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel, data = Net().to(device), data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-3)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()\n    optimizer.step()\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    log_probs, accs = model(), []\n    for _, mask in data('train_mask', 'test_mask'):\n        pred = log_probs[mask].max(1)[1]\n        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n        accs.append(acc)\n    return accs\n\n\nfor epoch in range(1, 201):\n    train()\n    train_acc, test_acc = test()\n    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/correct_and_smooth.py",
    "content": "import os.path as osp\n\nimport torch\nfrom ogb.nodeproppred import Evaluator, PygNodePropPredDataset\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.nn import MLP, CorrectAndSmooth\nfrom torch_geometric.typing import WITH_TORCH_SPARSE\n\nif not WITH_TORCH_SPARSE:\n    quit(\"This example requires 'torch-sparse'\")\n\nroot = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB')\ndataset = PygNodePropPredDataset('ogbn-products', root,\n                                 transform=T.ToSparseTensor())\nevaluator = Evaluator(name='ogbn-products')\nsplit_idx = dataset.get_idx_split()\ndata = dataset[0]\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = MLP([dataset.num_features, 200, 200, dataset.num_classes], dropout=0.5,\n            norm=\"batch_norm\", act_first=True).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\ncriterion = torch.nn.CrossEntropyLoss()\n\nx, y = data.x.to(device), data.y.to(device)\ntrain_idx = split_idx['train'].to(device)\nval_idx = split_idx['valid'].to(device)\ntest_idx = split_idx['test'].to(device)\nx_train, y_train = x[train_idx], y[train_idx]\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(x_train)\n    loss = criterion(out, y_train.view(-1))\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test(out=None):\n    model.eval()\n    out = model(x) if out is None else out\n    pred = out.argmax(dim=-1, keepdim=True)\n    train_acc = evaluator.eval({\n        'y_true': y[train_idx],\n        'y_pred': pred[train_idx]\n    })['acc']\n    val_acc = evaluator.eval({\n        'y_true': y[val_idx],\n        'y_pred': pred[val_idx]\n    })['acc']\n    test_acc = evaluator.eval({\n        'y_true': y[test_idx],\n        'y_pred': pred[test_idx]\n    })['acc']\n    return train_acc, val_acc, test_acc, out\n\n\nbest_val_acc = 0\nfor epoch in range(1, 301):\n    loss = train()\n    train_acc, val_acc, test_acc, out = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        y_soft = out.softmax(dim=-1)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '\n          f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n\nadj_t = data.adj_t.to(device)\ndeg = adj_t.sum(dim=1).to(torch.float)\ndeg_inv_sqrt = deg.pow_(-0.5)\ndeg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0\nDAD = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1)\nDA = deg_inv_sqrt.view(-1, 1) * deg_inv_sqrt.view(-1, 1) * adj_t\n\npost = CorrectAndSmooth(num_correction_layers=50, correction_alpha=1.0,\n                        num_smoothing_layers=50, smoothing_alpha=0.8,\n                        autoscale=False, scale=20.)\n\nprint('Correct and smooth...')\ny_soft = post.correct(y_soft, y_train, train_idx, DAD)\ny_soft = post.smooth(y_soft, y_train, train_idx, DA)\nprint('Done!')\ntrain_acc, val_acc, test_acc, _ = test(y_soft)\nprint(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/cpp/CMakeLists.txt",
    "content": "cmake_minimum_required(VERSION 3.10)\nproject(hello-world)\n\n# The first thing do is to tell cmake to find the TorchScatter\n# and TorchSparse libraries. The package pulls in all the necessary\n# torch libraries, so there is no need to add `find_package(Torch)`.\nfind_package(TorchScatter REQUIRED)\nfind_package(TorchSparse REQUIRED)\n\nfind_package(Python3 COMPONENTS Development)\n\nadd_executable(hello-world main.cpp)\n\n# We now need to link the TorchScatter and TorchSparse libraries\n# to our executable. We can do that by using the\n# TorchScatter::TorchScatter and TorchSparse::TorchSparse targets,\n# which also adds all the necessary torch dependencies.\ntarget_compile_features(hello-world PUBLIC cxx_range_for)\ntarget_link_libraries(hello-world TorchScatter::TorchScatter)\ntarget_link_libraries(hello-world TorchSparse::TorchSparse)\ntarget_link_libraries(hello-world ${CUDA_cusparse_LIBRARY})\nset_property(TARGET hello-world PROPERTY CXX_STANDARD 14)\n"
  },
  {
    "path": "examples/cpp/README.md",
    "content": "# PyG in C++\n\nThis is a minimal example of getting PyG to work in C++ with CMake.\n\nIn order to successfully compile this example, make sure you have both the C++ APIs of [`TorchScatter`](https://github.com/rusty1s/pytorch_scatter#c-api) and [`TorchSparse`](https://github.com/rusty1s/pytorch_sparse/#c-api) installed.\n\nFor this, we need to add `TorchLib` to the `-DCMAKE_PREFIX_PATH` (run `import torch; print(torch.utils.cmake_prefix_path)` to obtain it).\nThen, *e.g.*, to install `TorchScatter`, run:\n\n```\ngit clone https://github.com/rusty1s/pytorch_scatter.git\ncd pytorch_scatter\nmkdir build && cd build\ncmake -DWITH_CUDA=on -DCMAKE_PREFIX_PATH=\"...\" ..\nmake\n(sudo) make install\n```\n\nOnce both dependencies are sorted, we can start the CMake fun:\n\n1. Run `save_model.py` to create and save a PyG GNN model.\n1. Create a `build` directory inside the current one.\n1. From within the `build` directory, run the following commands:\n   - `cmake -DCMAKE_PREFIX_PATH=\"<PATH_TO_LIBTORCH>;<PATH_TO_TORCHSCATTER>;<PATH_TO_TORCH_SPARSE>\" ..`\n   - `cmake --build .`\n\nThat's it!\nYou should now have a `hello-world` executable in your `build` folder.\nRun it via:\n\n```\n./hello-world ../model.pt\n```\n"
  },
  {
    "path": "examples/cpp/main.cpp",
    "content": "#include <torch/script.h>\n#include <torchscatter/scatter.h>\n#include <torchsparse/sparse.h>\n\n#include <iostream>\n\nint main(int argc, const char *argv[]) {\n  if (argc != 2) {\n    std::cerr << \"usage: hello-world <path-to-exported-script-module>\\n\";\n    return -1;\n  }\n\n  torch::jit::script::Module model;\n  try {\n    model = torch::jit::load(argv[1]);\n  } catch (const c10::Error &e) {\n    std::cerr << \"error loading the model\\n\";\n    return -1;\n  }\n\n  auto x = torch::randn({5, 32});\n  auto edge_index = torch::tensor({\n      {0, 1, 1, 2, 2, 3, 3, 4},\n      {1, 0, 2, 1, 3, 2, 4, 3},\n  });\n\n  std::vector<torch::jit::IValue> inputs;\n  inputs.push_back(x);\n  inputs.push_back(edge_index);\n\n  auto out = model.forward(inputs).toTensor();\n  std::cout << \"output tensor shape: \" << out.sizes() << std::endl;\n}\n"
  },
  {
    "path": "examples/cpp/save_model.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import BatchNorm1d, Linear, ReLU, Sequential\n\nfrom torch_geometric.nn import GINConv, global_mean_pool\n\n\nclass GIN(torch.nn.Module):\n    def __init__(self, in_channels: int, hidden_channels: int,\n                 out_channels: int, num_layers: int):\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            mlp = Sequential(\n                Linear(in_channels, hidden_channels),\n                BatchNorm1d(hidden_channels),\n                ReLU(),\n                Linear(hidden_channels, hidden_channels),\n            )\n            self.convs.append(GINConv(mlp))\n            in_channels = hidden_channels\n\n        self.lin1 = Linear(hidden_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, out_channels)\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Optional[Tensor] = None,\n    ) -> Tensor:\n\n        for conv in self.convs:\n            x = conv(x, edge_index).relu()\n\n        x = global_mean_pool(x, batch)\n\n        x = self.lin1(x).relu()\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n\n        return x\n\n\nmodel = GIN(32, 64, 16, num_layers=3)\nmodel = torch.jit.script(model)\nmodel.save('model.pt')\n"
  },
  {
    "path": "examples/datapipe.py",
    "content": "# In this example, you will find data loading implementations using PyTorch\n# DataPipes (https://pytorch.org/data/) across various tasks:\n\n# (1) molecular graph data loading pipe\n# (2) mesh/point cloud data loading pipe\n\n# In particular, we make use of PyG's built-in DataPipes, e.g., for batching\n# multiple PyG data objects together or for converting SMILES strings into\n# molecular graph representations. We also showcase how to write your own\n# DataPipe (i.e. for loading and parsing mesh data into PyG data objects).\n\nimport argparse\nimport csv\nimport os.path as osp\nimport time\nfrom itertools import chain, tee\n\nimport torch\nfrom torch.utils.data import IterDataPipe\nfrom torch.utils.data.datapipes.iter import (\n    FileLister,\n    FileOpener,\n    IterableWrapper,\n)\n\nfrom torch_geometric.data import Data, download_url, extract_zip\n\n\ndef molecule_datapipe() -> IterDataPipe:\n    # Download HIV dataset from MoleculeNet:\n    url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets'\n    root_dir = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')\n    path = download_url(f'{url}/HIV.csv', root_dir)\n\n    datapipe = FileOpener([path], mode=\"rt\")\n    # Convert CSV rows into dictionaries, skipping the header row\n    datapipe = datapipe.map(lambda file: (\n        dict(zip([\"smiles\", \"activity\", \"HIV_active\"], row))\n        for i, row in enumerate(csv.reader(file[1])) if i > 0 and row))\n\n    datapipe = IterableWrapper(chain.from_iterable(datapipe))\n    datapipe = datapipe.parse_smiles(target_key='HIV_active')\n    datapipe, = tee(datapipe, 1)\n    return IterableWrapper(datapipe)\n\n\n@torch.utils.data.functional_datapipe('read_mesh')\nclass MeshOpener(IterDataPipe):\n    # A custom DataPipe to load and parse mesh data into PyG data objects.\n    def __init__(self, dp: IterDataPipe) -> None:\n        try:\n            import meshio  # noqa: F401\n            import torch_cluster  # noqa: F401\n        except ImportError as e:\n            raise ImportError(\n                \"To run this example, please install required packages:\\n\"\n                \"pip install meshio torch-cluster\") from e\n\n        super().__init__()\n        self.dp = dp\n\n    def __iter__(self):\n        import meshio\n\n        for path in self.dp:\n            category = osp.basename(path).split('_')[0]\n            try:\n                mesh = meshio.read(path)\n            except UnicodeDecodeError:\n                # Failed to read the file because it is not in the expected OFF\n                # format.\n                continue\n\n            pos = torch.from_numpy(mesh.points).to(torch.float)\n            face = torch.from_numpy(mesh.cells[0].data).t().contiguous()\n\n            yield Data(pos=pos, face=face, category=category)\n\n\ndef mesh_datapipe() -> IterDataPipe:\n    # Download ModelNet10 dataset from Princeton:\n    url = 'http://vision.princeton.edu/projects/2014/3DShapeNets'\n    root_dir = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data')\n    path = download_url(f'{url}/ModelNet10.zip', root_dir)\n    root_dir = osp.join(root_dir, 'ModelNet10')\n    if not osp.exists(root_dir):\n        extract_zip(path, root_dir)\n\n    def is_train(path: str) -> bool:\n        return 'train' in path\n\n    datapipe = FileLister([root_dir], masks='*.off', recursive=True)\n    datapipe = datapipe.filter(is_train)\n    datapipe = datapipe.read_mesh()\n    datapipe, = tee(datapipe, 1)\n    datapipe = IterableWrapper(datapipe)\n    datapipe = datapipe.sample_points(1024)  # Use PyG transforms from here.\n    datapipe = datapipe.knn_graph(k=8)\n    return datapipe\n\n\nDATAPIPES = {\n    'molecule': molecule_datapipe,\n    'mesh': mesh_datapipe,\n}\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--task', default='molecule', choices=DATAPIPES.keys())\n\n    args = parser.parse_args()\n\n    datapipe = DATAPIPES[args.task]()\n\n    print('Example output:')\n    print(next(iter(datapipe)))\n\n    # Shuffling + Batching support:\n    datapipe = datapipe.shuffle()\n    datapipe = datapipe.batch_graphs(batch_size=32)\n\n    # The first epoch will take longer than the remaining ones...\n    print('Iterating over all data...')\n    t = time.perf_counter()\n    for _ in datapipe:\n        pass\n    print(f'Done! [{time.perf_counter() - t:.2f}s]')\n\n    print('Iterating over all data a second time...')\n    t = time.perf_counter()\n    for _ in datapipe:\n        pass\n    print(f'Done! [{time.perf_counter() - t:.2f}s]')\n"
  },
  {
    "path": "examples/dgcnn_classification.py",
    "content": "import argparse\nimport os.path as osp\nimport random\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import MedShapeNet, ModelNet\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import MLP, DynamicEdgeConv, global_max_pool\n\nparser = argparse.ArgumentParser(\n    formatter_class=argparse.ArgumentDefaultsHelpFormatter, )\nparser.add_argument(\n    '--dataset',\n    type=str,\n    default='modelnet10',\n    choices=['modelnet10', 'modelnet40', 'medshapenet'],\n    help='Dataset name.',\n)\nparser.add_argument(\n    '--dataset_dir',\n    type=str,\n    default='./data',\n    help='Root directory of dataset.',\n)\nparser.add_argument('--batch_size', type=int, default=32)\nparser.add_argument('--num_workers', type=int, default=6)\nparser.add_argument('--epochs', type=int, default=201)\n\nargs = parser.parse_args()\n\nnum_epochs = args.epochs\nnum_workers = args.num_workers\nbatch_size = args.batch_size\nroot = osp.join(args.dataset_dir, args.dataset)\n\nprint('The root is: ', root)\n\npre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024)\n\nprint('The Dataset is: ', args.dataset)\nif args.dataset == 'modelnet40':\n    print('Loading training data')\n    train_dataset = ModelNet(root, '40', True, transform, pre_transform)\n    print('Loading test data')\n    test_dataset = ModelNet(root, '40', False, transform, pre_transform)\nelif args.dataset == 'medshapenet':\n    print('Loading dataset')\n    dataset = MedShapeNet(root=root, size=50, pre_transform=pre_transform,\n                          transform=transform, force_reload=False)\n\n    random.seed(42)\n\n    train_indices = []\n    test_indices = []\n    for label in range(dataset.num_classes):\n        by_class = [\n            i for i, data in enumerate(dataset) if int(data.y) == label\n        ]\n        random.shuffle(by_class)\n\n        split_point = int(0.7 * len(by_class))\n        train_indices.extend(by_class[:split_point])\n        test_indices.extend(by_class[split_point:])\n\n    train_dataset = dataset[train_indices]\n    test_dataset = dataset[test_indices]\n\nelif args.dataset == 'modelnet10':\n    print('Loading training data')\n    train_dataset = ModelNet(root, '10', True, transform, pre_transform)\n    print('Loading test data')\n    test_dataset = ModelNet(root, '10', False, transform, pre_transform)\n\nelse:\n    raise ValueError(\n        f\"Unknown dataset name '{args.dataset}'. \"\n        f\"Available options: 'modelnet10', 'modelnet40', 'medshapenet'.\")\n\ntrain_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,\n                          num_workers=num_workers)\ntest_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,\n                         num_workers=num_workers)\n\nprint('Running model')\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, out_channels, k=20, aggr='max'):\n        super().__init__()\n\n        self.conv1 = DynamicEdgeConv(MLP([2 * 3, 64, 64, 64]), k, aggr)\n        self.conv2 = DynamicEdgeConv(MLP([2 * 64, 128]), k, aggr)\n        self.lin1 = Linear(128 + 64, 1024)\n\n        self.mlp = MLP([1024, 512, 256, out_channels], dropout=0.5, norm=None)\n\n    def forward(self, data):\n        pos, batch = data.pos, data.batch\n        x1 = self.conv1(pos, batch)\n        x2 = self.conv2(x1, batch)\n        out = self.lin1(torch.cat([x1, x2], dim=1))\n        out = global_max_pool(out, batch)\n        out = self.mlp(out)\n        return F.log_softmax(out, dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(train_dataset.num_classes, k=20).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\nscheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        total_loss += loss.item() * data.num_graphs\n        optimizer.step()\n    return total_loss / len(train_dataset)\n\n\ndef test(loader):\n    model.eval()\n\n    correct = 0\n    for data in loader:\n        data = data.to(device)\n        with torch.no_grad():\n            pred = model(data).max(dim=1)[1]\n        correct += pred.eq(data.y).sum().item()\n    return correct / len(loader.dataset)\n\n\nfor epoch in range(1, num_epochs):\n    loss = train()\n    test_acc = test(test_loader)\n    print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Test: {test_acc:.4f}')\n    scheduler.step()\n"
  },
  {
    "path": "examples/dgcnn_segmentation.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torchmetrics.functional import jaccard_index\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import ShapeNet\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import MLP, DynamicEdgeConv\nfrom torch_geometric.utils import scatter\n\ncategory = 'Airplane'  # Pass in `None` to train on all categories.\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')\ntransform = T.Compose([\n    T.RandomJitter(0.01),\n    T.RandomRotate(15, axis=0),\n    T.RandomRotate(15, axis=1),\n    T.RandomRotate(15, axis=2)\n])\npre_transform = T.NormalizeScale()\ntrain_dataset = ShapeNet(path, category, split='trainval', transform=transform,\n                         pre_transform=pre_transform)\ntest_dataset = ShapeNet(path, category, split='test',\n                        pre_transform=pre_transform)\ntrain_loader = DataLoader(train_dataset, batch_size=10, shuffle=True,\n                          num_workers=6)\ntest_loader = DataLoader(test_dataset, batch_size=10, shuffle=False,\n                         num_workers=6)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, out_channels, k=30, aggr='max'):\n        super().__init__()\n\n        self.conv1 = DynamicEdgeConv(MLP([2 * 6, 64, 64]), k, aggr)\n        self.conv2 = DynamicEdgeConv(MLP([2 * 64, 64, 64]), k, aggr)\n        self.conv3 = DynamicEdgeConv(MLP([2 * 64, 64, 64]), k, aggr)\n\n        self.mlp = MLP([3 * 64, 1024, 256, 128, out_channels], dropout=0.5,\n                       norm=None)\n\n    def forward(self, data):\n        x, pos, batch = data.x, data.pos, data.batch\n        x0 = torch.cat([x, pos], dim=-1)\n        x1 = self.conv1(x0, batch)\n        x2 = self.conv2(x1, batch)\n        x3 = self.conv3(x2, batch)\n        out = self.mlp(torch.cat([x1, x2, x3], dim=1))\n        return F.log_softmax(out, dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(train_dataset.num_classes, k=30).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\nscheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.8)\n\n\ndef train():\n    model.train()\n\n    total_loss = correct_nodes = total_nodes = 0\n    for i, data in enumerate(train_loader):\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item()\n        correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()\n        total_nodes += data.num_nodes\n\n        if (i + 1) % 10 == 0:\n            print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '\n                  f'Train Acc: {correct_nodes / total_nodes:.4f}')\n            total_loss = correct_nodes = total_nodes = 0\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    ious, categories = [], []\n    y_map = torch.empty(loader.dataset.num_classes, device=device).long()\n    for data in loader:\n        data = data.to(device)\n        outs = model(data)\n\n        sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()\n        for out, y, category in zip(outs.split(sizes), data.y.split(sizes),\n                                    data.category.tolist()):\n            category = list(ShapeNet.seg_classes.keys())[category]\n            part = ShapeNet.seg_classes[category]\n            part = torch.tensor(part, device=device)\n\n            y_map[part] = torch.arange(part.size(0), device=device)\n\n            iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y],\n                                num_classes=part.size(0), absent_score=1.0)\n            ious.append(iou)\n\n        categories.append(data.category)\n\n    iou = torch.tensor(ious, device=device)\n    category = torch.cat(categories, dim=0)\n\n    mean_iou = scatter(iou, category, reduce='mean')  # Per-category IoU.\n    return float(mean_iou.mean())  # Global IoU.\n\n\nfor epoch in range(1, 31):\n    train()\n    iou = test(test_loader)\n    print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}')\n"
  },
  {
    "path": "examples/dir_gnn.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import WikipediaNetwork\nfrom torch_geometric.nn import DirGNNConv, GCNConv, SAGEConv\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, default='chameleon')\nparser.add_argument('--hidden_channels', type=int, default=128)\nparser.add_argument('--lr', type=float, default=0.01)\nparser.add_argument('--epochs', type=int, default=1000)\nparser.add_argument('--alpha', type=float, default=1)\nparser.add_argument('--conv', type=str, default='gcn')\nargs = parser.parse_args()\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Wikipedia')\ndataset = WikipediaNetwork(\n    root=path,\n    name=args.dataset,\n    transform=T.NormalizeFeatures(),\n)\n\ndata = dataset[0].to(device)\ndata.train_mask = data.train_mask[:, 0]\ndata.val_mask = data.val_mask[:, 0]\ndata.test_mask = data.test_mask[:, 0]\n\nif args.conv == 'gcn':\n    Conv = GCNConv\nelif args.conv == 'sage':\n    Conv = SAGEConv\nelse:\n    raise NotImplementedError\n\n\nclass DirGNN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, alpha):\n        super().__init__()\n        self.conv1 = Conv(in_channels, hidden_channels)\n        self.conv1 = DirGNNConv(self.conv1, alpha, root_weight=False)\n\n        self.conv2 = Conv(hidden_channels, out_channels)\n        self.conv2 = DirGNNConv(self.conv2, alpha, root_weight=False)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n        return x\n\n\nmodel = DirGNN(\n    dataset.num_features,\n    args.hidden_channels,\n    dataset.num_classes,\n    alpha=args.alpha,\n).to(device)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.x, data.edge_index).argmax(dim=-1)\n\n    accs = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\nbest_val_acc = final_test_acc = 0\nfor epoch in range(1, args.epochs + 1):\n    loss = train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n\n    print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/distributed/README.md",
    "content": "# Examples for Distributed Graph Learning\n\nThis directory contains examples for distributed graph learning.\nThe examples are organized into two subdirectories:\n\n1. [`graphlearn_for_pytorch`](./graphlearn_for_pytorch): Distributed training via the external [GraphLearn-for-PyTorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch) package.\n1. [`kuzu`](./kuzu): Remote backend via the [Kùzu](https://kuzudb.com/) graph database.\n"
  },
  {
    "path": "examples/distributed/graphlearn_for_pytorch/README.md",
    "content": "# Using GraphLearn-for-PyTorch (GLT) for Distributed Training with PyG\n\n**[GraphLearn-for-PyTorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch)** is a graph learning library for PyTorch that makes distributed GNN training easy and efficient.\nGLT leverages GPUs to accelerate graph sampling and utilizes UVA and GPU caches to reduce the data conversion and transferring costs during graph sampling and model training.\nMost of the APIs of GLT are compatible with PyG, so PyG users only need to modify a few lines of their PyG code to train their model with GLT.\n\n## Requirements\n\n- `python >= 3.6`\n- `torch >= 1.12`\n- `graphlearn-torch`\n\n## Distributed (Multi-Node) Example\n\nThis example shows how to leverage [GraphLearn-for-PyTorch (GLT)](https://github.com/alibaba/graphlearn-for-pytorch) to train PyG models in a distributed scenario with GPUs. The dataset in this example is `ogbn-products` from the [Open Graph Benchmark](https://ogb.stanford.edu/), but you can also train on `ogbn-papers100M` with only minor modifications.\n\nTo run this example, you can run the example as described below or directly make use of our [`launch.py`](launch.py) script.\nThe training results will be generated and saved in `dist_sage_sup.txt`.\n\n### Running the Example\n\n#### Step 1: Prepare and partition the data\n\nHere, we use `ogbn-products` and partition it into two partitions:\n\n```bash\npython partition_ogbn_dataset.py --dataset=ogbn-products --root_dir=../../../data/ogbn-products --num_partitions=2\n```\n\n#### Step 2: Run the example in each training node\n\nFor example, running the example in two nodes each with two GPUs:\n\n```bash\n# Node 0:\nCUDA_VISIBLE_DEVICES=0,1 python dist_train_sage_supervised.py \\\n  --num_nodes=2 --node_rank=0 --master_addr=localhost \\\n  --dataset=ogbn-products --dataset_root_dir=../../../data/ogbn-products \\\n  --in_channel=100 --out_channel=47\n\n# Node 1:\nCUDA_VISIBLE_DEVICES=2,3 python dist_train_sage_supervised.py \\\n  --num_nodes=2 --node_rank=1 --master_addr=localhost \\\n  --dataset=ogbn-products --dataset_root_dir=../../../data/ogbn-products \\\n  --in_channel=100 --out_channel=47\n```\n\n**Notes:**\n\n1. You should change the `master_addr` to the IP of `node#0`.\n1. Since there is randomness during data partitioning, please ensure all nodes are using the same partitioned data when running `dist_train_sage_supervised.py`.\n\n### Using the `launch.py` Script\n\n#### Step 1: Setup a distributed file system\n\n**Note**: You may skip this step if you already set up folder(s) synchronized across machines.\n\nTo perform distributed sampling, files and codes need to be accessed across multiple machines.\nA distributed file system (*i.e.*, [NFS](https://wiki.archlinux.org/index.php/NFS), [SSHFS](https://www.digitalocean.com/community/tutorials/how-to-use-sshfs-to-mount-remote-file-systems-over-ssh), [Ceph](https://docs.ceph.com/en/latest/install), ...) exempts you from synchnonizing files such as partition information.\n\n#### Step 2: Prepare and partition the data\n\nIn distributed training (under the worker mode), each node in the cluster holds a partition of the graph.\nThus, before the training starts, we partition the `ogbn-products` dataset into multiple partitions, each of which corresponds to a specific training worker.\n\nThe partitioning occurs in three steps:\n\n1. Run the partition algorithm to assign nodes to partitions.\n1. Construct the partitioned graph structure based on the node assignment.\n1. Split the node features and edge features into partitions.\n\nGLT supports caching graph topology and frequently accessed features in GPU to accelerate GPU sampling and feature collection.\nFor feature caching, we adopt a pre-sampling-based approach to determine the hotness of nodes, and cache features for nodes with higher hotness while loading the graph.\nThe uncached features are stored in pinned memory for efficient access via UVA.\n\nFor further information about partitioning, please refer to the [official tutorial](https://github.com/alibaba/graphlearn-for-pytorch/blob/main/docs/tutorial/dist.md).\n\nHere, we use `ogbn-products` and partition it into two partitions:\n\n```bash\npython partition_ogbn_dataset.py --dataset=ogbn-products --root_dir=../../../data/ogbn-products --num_partitions=2\n```\n\n#### Step 3: Set up the configure file\n\nAn example configuration file in given via [`dist_train_sage_sup_config.yml`](dist_train_sage_sup_config.yml).\n\n#### Step 4: Launch the distributed training\n\n```bash\npip install paramiko\npip install click\napt install tmux\npython launch.py --config=dist_train_sage_sup_config.yml --master_addr=0.0.0.0 --master_port=11234\n```\n\nHere, `master_addr` is for the master RPC address, and `master_port` is for PyTorch's process group initialization across training processes.\nNote that you should change the `master_addr` to the IP of `node#0`.\n"
  },
  {
    "path": "examples/distributed/graphlearn_for_pytorch/dist_train_sage_sup_config.yml",
    "content": "# IP addresses for all nodes.\n# Note: The first 3 params are expected to form usernames@nodes:ports.\nnodes:\n  - 0.0.0.0\n  - 1.1.1.1\n\n# SSH ports for each node:\nports: [22, 22]\n\n# Username for remote IPs:\nusernames:\n  - your_username_for_node_0\n  - your_username_for_node_1\n\n# Path to Python with GLT environment for each node:\npython_bins:\n  - /path/to/python\n  - /path/to/python\n\n# The dataset name, e.g., ogbn-products, ogbn-papers100M.\n# Note: make sure the name of dataset_root_dir is the same as the dataset name.\ndataset: ogbn-products\n\n# `in_channel` and `out_channel` of the dataset, e.g.,:\n# - ogbn-products: in_channel=100, out_channel=47\n# - ogbn-papers100M: in_channel=128, out_channel=172\nin_channel: 100\nout_channel: 47\n\n# Path to the pytorch_geometric directory:\ndst_paths:\n  - /path/to/pytorch_geometric\n  - /path/to/pytorch_geometric\n\n# Setup visible CUDA devices for each node:\nvisible_devices:\n  - 0,1,2,3\n  - 0,1,2,3\n"
  },
  {
    "path": "examples/distributed/graphlearn_for_pytorch/dist_train_sage_supervised.py",
    "content": "import argparse\nimport os.path as osp\nimport time\n\nimport graphlearn_torch as glt\nimport torch\nimport torch.distributed\nimport torch.nn.functional as F\nfrom ogb.nodeproppred import Evaluator\nfrom torch import Tensor\nfrom torch.nn.parallel import DistributedDataParallel\n\nfrom torch_geometric.io import fs\nfrom torch_geometric.nn import GraphSAGE\n\n\n@torch.no_grad()\ndef test(model, test_loader, dataset_name):\n    evaluator = Evaluator(name=dataset_name)\n    model.eval()\n    xs = []\n    y_true = []\n    for i, batch in enumerate(test_loader):\n        if i == 0:\n            device = batch.x.device\n        x = model(batch.x, batch.edge_index)[:batch.batch_size]\n        xs.append(x.cpu())\n        y_true.append(batch.y[:batch.batch_size].cpu())\n\n    xs = [t.to(device) for t in xs]\n    y_true = [t.to(device) for t in y_true]\n    y_pred = torch.cat(xs, dim=0).argmax(dim=-1, keepdim=True)\n    y_true = torch.cat(y_true, dim=0).unsqueeze(-1)\n    test_acc = evaluator.eval({\n        'y_true': y_true,\n        'y_pred': y_pred,\n    })['acc']\n    return test_acc\n\n\ndef run_training_proc(\n    local_proc_rank: int,\n    num_nodes: int,\n    node_rank: int,\n    num_training_procs_per_node: int,\n    dataset_name: str,\n    in_channels: int,\n    out_channels: int,\n    dataset: glt.distributed.DistDataset,\n    train_idx: Tensor,\n    test_idx: Tensor,\n    epochs: int,\n    batch_size: int,\n    master_addr: str,\n    training_pg_master_port: int,\n    train_loader_master_port: int,\n    test_loader_master_port: int,\n):\n    # Initialize graphlearn_torch distributed worker group context:\n    glt.distributed.init_worker_group(\n        world_size=num_nodes * num_training_procs_per_node,\n        rank=node_rank * num_training_procs_per_node + local_proc_rank,\n        group_name='distributed-sage-supervised-trainer')\n\n    current_ctx = glt.distributed.get_context()\n    current_device = torch.device(local_proc_rank % torch.cuda.device_count())\n\n    # Initialize training process group of PyTorch:\n    torch.distributed.init_process_group(\n        backend='nccl',  # or choose 'gloo' if 'nccl' is not supported.\n        rank=current_ctx.rank,\n        world_size=current_ctx.world_size,\n        init_method=f'tcp://{master_addr}:{training_pg_master_port}',\n    )\n\n    # Create distributed neighbor loader for training.\n    # We replace PyG's NeighborLoader with GLT's DistNeighborLoader.\n    # GLT parameters for sampling are quite similar to PyG.\n    # We only need to configure additional network and device parameters:\n    train_idx = train_idx.split(\n        train_idx.size(0) // num_training_procs_per_node)[local_proc_rank]\n    train_loader = glt.distributed.DistNeighborLoader(\n        data=dataset,\n        num_neighbors=[15, 10, 5],\n        input_nodes=train_idx,\n        batch_size=batch_size,\n        shuffle=True,\n        collect_features=True,\n        to_device=current_device,\n        worker_options=glt.distributed.MpDistSamplingWorkerOptions(\n            num_workers=1,\n            worker_devices=[current_device],\n            worker_concurrency=4,\n            master_addr=master_addr,\n            master_port=train_loader_master_port,\n            channel_size='1GB',\n            pin_memory=True,\n        ),\n    )\n\n    # Create distributed neighbor loader for testing.\n    test_idx = test_idx.split(test_idx.size(0) //\n                              num_training_procs_per_node)[local_proc_rank]\n    test_loader = glt.distributed.DistNeighborLoader(\n        data=dataset,\n        num_neighbors=[15, 10, 5],\n        input_nodes=test_idx,\n        batch_size=batch_size,\n        shuffle=False,\n        collect_features=True,\n        to_device=current_device,\n        worker_options=glt.distributed.MpDistSamplingWorkerOptions(\n            num_workers=2,\n            worker_devices=[\n                torch.device('cuda', i % torch.cuda.device_count())\n                for i in range(2)\n            ],\n            worker_concurrency=4,\n            master_addr=master_addr,\n            master_port=test_loader_master_port,\n            channel_size='2GB',\n            pin_memory=True,\n        ),\n    )\n\n    # Define the model and optimizer.\n    torch.cuda.set_device(current_device)\n    model = GraphSAGE(\n        in_channels=in_channels,\n        hidden_channels=256,\n        num_layers=3,\n        out_channels=out_channels,\n    ).to(current_device)\n    model = DistributedDataParallel(model, device_ids=[current_device.index])\n\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n    # Train and test:\n    f = open('dist_sage_sup.txt', 'a+')\n    for epoch in range(0, epochs):\n        model.train()\n        start = time.time()\n        for batch in train_loader:\n            optimizer.zero_grad()\n            out = model(batch.x, batch.edge_index)[:batch.batch_size]\n            loss = F.cross_entropy(out, batch.y[:batch.batch_size].long())\n            loss.backward()\n            optimizer.step()\n        f.write(f'-- [Trainer {current_ctx.rank}] Epoch: {epoch:03d}, '\n                f'Loss: {loss:.4f}, Epoch Time: {time.time() - start}\\n')\n\n        torch.cuda.synchronize()\n        torch.distributed.barrier()\n\n        if epoch == 0 or epoch > (epochs // 2):\n            test_acc = test(model, test_loader, dataset_name)\n            f.write(f'-- [Trainer {current_ctx.rank}] '\n                    f'Test Acc: {test_acc:.4f}\\n')\n            torch.cuda.synchronize()\n            torch.distributed.barrier()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        '--dataset',\n        type=str,\n        default='ogbn-products',\n        help='The name of the dataset',\n    )\n    parser.add_argument(\n        '--in_channel',\n        type=int,\n        default=100,\n        help='Number of input features of the dataset',\n    )\n    parser.add_argument(\n        '--out_channel',\n        type=int,\n        default=47,\n        help='Number of classes of the dataset',\n    )\n    parser.add_argument(\n        '--num_dataset_partitions',\n        type=int,\n        default=2,\n        help='The number of partitions',\n    )\n    parser.add_argument(\n        '--dataset_root_dir',\n        type=str,\n        default='../../../data/products',\n        help='The root directory (relative path) of the partitioned dataset',\n    )\n    parser.add_argument(\n        '--num_nodes',\n        type=int,\n        default=2,\n        help='Number of distributed nodes',\n    )\n    parser.add_argument(\n        '--node_rank',\n        type=int,\n        default=0,\n        help='The current node rank',\n    )\n    parser.add_argument(\n        '--num_training_procs',\n        type=int,\n        default=2,\n        help='The number of training processes per node',\n    )\n    parser.add_argument(\n        '--epochs',\n        type=int,\n        default=10,\n        help='The number of training epochs',\n    )\n    parser.add_argument(\n        '--batch_size',\n        type=int,\n        default=512,\n        help='The batch size for the training and testing data loaders',\n    )\n    parser.add_argument(\n        '--master_addr',\n        type=str,\n        default='localhost',\n        help='The master address for RPC initialization',\n    )\n    parser.add_argument(\n        '--training_pg_master_port',\n        type=int,\n        default=11111,\n        help=\"The port used for PyTorch's process group initialization\",\n    )\n    parser.add_argument(\n        '--train_loader_master_port',\n        type=int,\n        default=11112,\n        help='The port used for RPC initialization for training',\n    )\n    parser.add_argument(\n        '--test_loader_master_port',\n        type=int,\n        default=11113,\n        help='The port used for RPC initialization for testing',\n    )\n    args = parser.parse_args()\n\n    # Record configuration information for debugging\n    f = open('dist_sage_sup.txt', 'a+')\n    f.write('--- Distributed training example of supervised SAGE ---\\n')\n    f.write(f'* dataset: {args.dataset}\\n')\n    f.write(f'* dataset root dir: {args.dataset_root_dir}\\n')\n    f.write(f'* number of dataset partitions: {args.num_dataset_partitions}\\n')\n    f.write(f'* total nodes: {args.num_nodes}\\n')\n    f.write(f'* node rank: {args.node_rank}\\n')\n    f.write(f'* number of training processes per node: '\n            f'{args.num_training_procs}\\n')\n    f.write(f'* epochs: {args.epochs}\\n')\n    f.write(f'* batch size: {args.batch_size}\\n')\n    f.write(f'* master addr: {args.master_addr}\\n')\n    f.write(f'* training process group master port: '\n            f'{args.training_pg_master_port}\\n')\n    f.write(f'* training loader master port: '\n            f'{args.train_loader_master_port}\\n')\n    f.write(f'* testing loader master port: {args.test_loader_master_port}\\n')\n\n    f.write('--- Loading data partition ...\\n')\n    root_dir = osp.join(osp.dirname(osp.realpath(__file__)),\n                        args.dataset_root_dir)\n    data_pidx = args.node_rank % args.num_dataset_partitions\n    dataset = glt.distributed.DistDataset()\n\n    label_file = osp.join(root_dir, f'{args.dataset}-label', 'label.pt')\n    dataset.load(\n        root_dir=osp.join(root_dir, f'{args.dataset}-partitions'),\n        partition_idx=data_pidx,\n        graph_mode='ZERO_COPY',\n        whole_node_label_file=label_file,\n    )\n    train_file = osp.join(root_dir, f'{args.dataset}-train-partitions',\n                          f'partition{data_pidx}.pt')\n    train_idx = fs.torch_load(train_file)\n    test_file = osp.join(root_dir, f'{args.dataset}-test-partitions',\n                         f'partition{data_pidx}.pt')\n    test_idx = fs.torch_load(test_file)\n    train_idx.share_memory_()\n    test_idx.share_memory_()\n\n    f.write('--- Launching training processes ...\\n')\n    torch.multiprocessing.spawn(\n        run_training_proc,\n        args=(\n            args.num_nodes,\n            args.node_rank,\n            args.num_training_procs,\n            args.dataset,\n            args.in_channel,\n            args.out_channel,\n            dataset,\n            train_idx,\n            test_idx,\n            args.epochs,\n            args.batch_size,\n            args.master_addr,\n            args.training_pg_master_port,\n            args.train_loader_master_port,\n            args.test_loader_master_port,\n        ),\n        nprocs=args.num_training_procs,\n        join=True,\n    )\n"
  },
  {
    "path": "examples/distributed/graphlearn_for_pytorch/launch.py",
    "content": "import argparse\n\nimport click\nimport paramiko\nimport yaml\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        '--config',\n        type=str,\n        default='dist_train_sage_sup_config.yml',\n        help='The path to the configuration file',\n    )\n    parser.add_argument(\n        '--epochs',\n        type=int,\n        default=10,\n        help='The number of training epochs',\n    )\n    parser.add_argument(\n        '--batch_size',\n        type=int,\n        default=512,\n        help='The batch size for the training and testing data loaders',\n    )\n    parser.add_argument(\n        '--master_addr',\n        type=str,\n        default='0.0.0.0',\n        help='Master IP address for synchronization across all training nodes',\n    )\n    parser.add_argument(\n        '--master_port',\n        type=str,\n        default='11345',\n        help='The port for synchronization across all training nodes',\n    )\n    args = parser.parse_args()\n\n    config = open(args.config)\n    config = yaml.safe_load(config)\n    dataset = config['dataset']\n    ip_list = config['nodes']\n    port_list = config['ports']\n    username_list = config['usernames']\n    dst_path_list = config['dst_paths']\n    node_ranks = list(range(len(ip_list)))\n    num_nodes = len(node_ranks)\n    visible_devices = config['visible_devices']\n    python_bins = config['python_bins']\n    num_cores = len(str(visible_devices[0]).split(','))\n    in_channel = str(config['in_channel'])\n    out_channel = str(config['out_channel'])\n\n    dataset_path = '../../../data/'\n    passwd_dict = {}\n    for username, ip in zip(username_list, ip_list):\n        passwd_dict[ip + username] = click.prompt(\n            f'Password for {username}@{ip}', hide_input=True)\n    for username, ip, port, dst, noderk, device, pythonbin in zip(\n            username_list,\n            ip_list,\n            port_list,\n            dst_path_list,\n            node_ranks,\n            visible_devices,\n            python_bins,\n    ):\n        trans = paramiko.Transport((ip, port))\n        trans.connect(username=username, password=passwd_dict[ip + username])\n        ssh = paramiko.SSHClient()\n        ssh._transport = trans\n\n        to_dist_dir = 'cd ' + dst + \\\n            '/examples/distributed/graphlearn_for_pytorch/ '\n        exec_example = \"tmux new -d 'CUDA_VISIBLE_DEVICES=\" + str(device) + \\\n            \" \" + pythonbin + \" dist_train_sage_supervised.py --dataset=\" + \\\n            dataset + \" --dataset_root_dir=\" + dataset_path + dataset + \\\n            \" --in_channel=\" + in_channel + \" --out_channel=\" + out_channel + \\\n            \" --node_rank=\" + str(noderk) + \" --num_dataset_partitions=\" + \\\n            str(num_nodes) + \" --num_nodes=\" + str(num_nodes) + \\\n            \" --num_training_procs=\" + str(num_cores) + \" --master_addr=\" + \\\n            args.master_addr + \" --training_pg_master_port=\" + \\\n            args.master_port + \" --train_loader_master_port=\" + \\\n            str(int(args.master_port) + 1) + \" --test_loader_master_port=\" + \\\n            str(int(args.master_port) + 2) + \" --batch_size=\" + \\\n            str(args.batch_size) + \" --epochs=\" + str(args.epochs)\n\n        print(to_dist_dir + ' && ' + exec_example + \" '\")\n        stdin, stdout, stderr = ssh.exec_command(\n            to_dist_dir + ' && ' + exec_example + \" '\", bufsize=1)\n        print(stdout.read().decode())\n        print(stderr.read().decode())\n        ssh.close()\n"
  },
  {
    "path": "examples/distributed/graphlearn_for_pytorch/partition_ogbn_dataset.py",
    "content": "import argparse\nimport ast\nimport os.path as osp\n\nimport graphlearn_torch as glt\nimport torch\nfrom ogb.nodeproppred import PygNodePropPredDataset\n\n\ndef partition_dataset(\n    ogbn_dataset: str,\n    root_dir: str,\n    num_partitions: int,\n    num_nbrs: glt.NumNeighbors,\n    chunk_size: int,\n    cache_ratio: float,\n):\n    ###########################################################################\n    # In distributed training (under the worker mode), each node in the cluster\n    # holds a partition of the graph. Thus before the training starts, we\n    # partition the dataset into multiple partitions, each of which corresponds\n    # to a specific training worker.\n    # The partitioning occurs in three steps:\n    #   1. Run a partition algorithm to assign nodes to partitions.\n    #   2. Construct partition graph structure based on the node assignment.\n    #   3. Split the node features and edge features based on the partition\n    # result.\n    ###########################################################################\n\n    print(f'-- Loading {ogbn_dataset} ...')\n    dataset = PygNodePropPredDataset(ogbn_dataset, root_dir)\n    data = dataset[0]\n    print(f'* node count: {data.num_nodes}')\n    print(f'* edge count: {data.num_edges}')\n    split_idx = dataset.get_idx_split()\n\n    print('-- Saving label ...')\n    label_dir = osp.join(root_dir, f'{ogbn_dataset}-label')\n    glt.utils.ensure_dir(label_dir)\n    torch.save(data.y.squeeze(), osp.join(label_dir, 'label.pt'))\n\n    print('-- Partitioning training idx ...')\n    train_idx = split_idx['train']\n    train_idx = train_idx.split(train_idx.size(0) // num_partitions)\n    train_idx_partitions_dir = osp.join(\n        root_dir,\n        f'{ogbn_dataset}-train-partitions',\n    )\n    glt.utils.ensure_dir(train_idx_partitions_dir)\n    for pidx in range(num_partitions):\n        torch.save(\n            train_idx[pidx],\n            osp.join(train_idx_partitions_dir, f'partition{pidx}.pt'),\n        )\n\n    print('-- Partitioning test idx ...')\n    test_idx = split_idx['test']\n    test_idx = test_idx.split(test_idx.size(0) // num_partitions)\n    test_idx_partitions_dir = osp.join(\n        root_dir,\n        f'{ogbn_dataset}-test-partitions',\n    )\n    glt.utils.ensure_dir(test_idx_partitions_dir)\n    for pidx in range(num_partitions):\n        torch.save(\n            test_idx[pidx],\n            osp.join(test_idx_partitions_dir, f'partition{pidx}.pt'),\n        )\n\n    print('-- Initializing graph ...')\n    csr_topo = glt.data.Topology(edge_index=data.edge_index,\n                                 input_layout='COO')\n    graph = glt.data.Graph(csr_topo, mode='ZERO_COPY')\n\n    print('-- Sampling hotness ...')\n    glt_sampler = glt.sampler.NeighborSampler(graph, num_nbrs)\n    node_probs = []\n    for pidx in range(num_partitions):\n        seeds = train_idx[pidx]\n        prob = glt_sampler.sample_prob(seeds, data.num_nodes)\n        node_probs.append(prob.cpu())\n\n    print('-- Partitioning graph and features ...')\n    partitions_dir = osp.join(root_dir, f'{ogbn_dataset}-partitions')\n    freq_partitioner = glt.partition.FrequencyPartitioner(\n        output_dir=partitions_dir,\n        num_parts=num_partitions,\n        num_nodes=data.num_nodes,\n        edge_index=data.edge_index,\n        probs=node_probs,\n        node_feat=data.x,\n        chunk_size=chunk_size,\n        cache_ratio=cache_ratio,\n    )\n    freq_partitioner.partition()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        '--dataset',\n        type=str,\n        default='ogbn-products',\n        help='The name of the dataset',\n    )\n    parser.add_argument(\n        '--num_partitions',\n        type=int,\n        default=2,\n        help='The Number of partitions',\n    )\n    parser.add_argument(\n        '--root_dir',\n        type=str,\n        default='../../../data/ogbn-products',\n        help='The root directory (relative path) of the partitioned dataset',\n    )\n    parser.add_argument(\n        '--num_nbrs',\n        type=ast.literal_eval,\n        default='[15,10,5]',\n        help='The number of neighbors to sample hotness for feature caching',\n    )\n    parser.add_argument(\n        '--chunk_size',\n        type=int,\n        default=10000,\n        help='The chunk size for feature partitioning',\n    )\n    parser.add_argument(\n        '--cache_ratio',\n        type=float,\n        default=0.2,\n        help='The proportion to cache features per partition',\n    )\n    args = parser.parse_args()\n\n    partition_dataset(\n        ogbn_dataset=args.dataset,\n        root_dir=osp.join(osp.dirname(osp.realpath(__file__)), args.root_dir),\n        num_partitions=args.num_partitions,\n        num_nbrs=args.num_nbrs,\n        chunk_size=args.chunk_size,\n        cache_ratio=args.cache_ratio,\n    )\n"
  },
  {
    "path": "examples/distributed/kuzu/README.md",
    "content": "# Using Kùzu as a Remote Backend for PyG\n\n[Kùzu](https://kuzudb.com/) is an in-process property graph database management system built for query speed and scalability.\nIt provides an integration with PyG via the [remote backend interface](https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html) of PyG.\nThe Python API of Kùzu outputs a [`torch_geometric.data.FeatureStore`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.FeatureStore.html) and a [`torch_geometric.data.GraphStore`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.GraphStore.html) that can be plugged directly into existing familiar PyG interfaces such as [`NeighborLoader`](https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/loader/neighbor_loader.html) and enables training GNNs directly on graphs stored in Kùzu.\nThis is particularly useful if you would like to train graphs that don't fit on your CPU's memory.\n\n## Installation\n\nYou can install Kùzu as follows:\n\n```bash\npip install kuzu\n```\n\n## Usage\n\nThe API and design documentation of Kùzu can be found at [https://kuzudb.com/docs/](https://kuzudb.com/docs/).\n\n## Examples\n\nWe provide the following examples to showcase the usage of Kùzu remote backend within PyG:\n\n### PubMed\n\n<a target=\"_blank\" href=\"https://colab.research.google.com/drive/12fOSqPm1HQTz_m9caRW7E_92vaeD9xq6\">\n  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n</a>\n\nThe PubMed example is hosted on [Google Colab](https://colab.research.google.com/drive/12fOSqPm1HQTz_m9caRW7E_92vaeD9xq6).\nIn this example, we work on a small dataset for demonstrative purposes.\nThe [PubMed](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.datasets.Planetoid.html) dataset consists of 19,717 papers as nodes and 88,648 citation relationships between them.\n\n### `papers_100M`\n\nThis example shows how to use the remote backend feature of Kùzu to work with a large graph of papers and citations on a single machine.\nThe data used in this example is `ogbn-papers100M` from the [Open Graph Benchmark](https://ogb.stanford.edu/).\nThe dataset contains approximately 111 million nodes and 1.6 billion edges.\n"
  },
  {
    "path": "examples/distributed/kuzu/papers_100M/README.md",
    "content": "# `papers_100M` Example\n\nThis example shows how to use the remote backend feature of [Kùzu](https://kuzudb.com) to work with a large graph of papers and citations on a single machine.\nThe data used in this example is `ogbn-papers100M` from the [Open Graph Benchmark](https://ogb.stanford.edu/).\nThe dataset contains approximately 100 million nodes and 1.6 billion edges.\n\n## Prepare the data\n\n1. Download the dataset from [`http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip`](http://snap.stanford.edu/ogb/data/nodeproppred/papers100M-bin.zip) and put the `*.zip` file into this directory.\n1. Run `python prepare_data.py`.\n   The script will automatically extract the data and convert it to the format that Kùzu can read.\n   A Kùzu database instance is then created under `papers_100M` and the data is loaded into the it.\n\n## Train a Model\n\nAfterwards, run `python train.py` to train a three-layer [`GraphSAGE`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphSAGE.html) model on this dataset.\n"
  },
  {
    "path": "examples/distributed/kuzu/papers_100M/prepare_data.py",
    "content": "from multiprocessing import cpu_count\nfrom os import path\nfrom zipfile import ZipFile\n\nimport kuzu\nimport numpy as np\nfrom tqdm import tqdm\n\nwith ZipFile(\"papers100M-bin.zip\", 'r') as papers100M_zip:\n    print('Extracting papers100M-bin.zip...')\n    papers100M_zip.extractall()\n\nwith ZipFile(\"papers100M-bin/raw/data.npz\", 'r') as data_zip:\n    print('Extracting data.npz...')\n    data_zip.extractall()\n\nwith ZipFile(\"papers100M-bin/raw/node-label.npz\", 'r') as node_label_zip:\n    print('Extracting node-label.npz...')\n    node_label_zip.extractall()\n\nprint(\"Converting edge_index to CSV...\")\nedge_index = np.load('edge_index.npy', mmap_mode='r')\ncsvfile = open('edge_index.csv', 'w')\ncsvfile.write('src,dst\\n')\nfor i in tqdm(range(edge_index.shape[1])):\n    csvfile.write(str(edge_index[0, i]) + ',' + str(edge_index[1, i]) + '\\n')\ncsvfile.close()\n\nprint(\"Generating IDs for nodes...\")\nnode_year = np.load('node_year.npy', mmap_mode='r')\nlength = node_year.shape[0]\nids = np.arange(length)\nnp.save('ids.npy', ids)\n\nids_path = path.abspath(path.join('.', 'ids.npy'))\nedge_index_path = path.abspath(path.join('.', 'edge_index.csv'))\nnode_label_path = path.abspath(path.join('.', 'node_label.npy'))\nnode_feature_path = path.abspath(path.join('.', 'node_feat.npy'))\nnode_year_path = path.abspath(path.join('.', 'node_year.npy'))\n\nprint(\"Creating Kùzu database...\")\ndb = kuzu.Database('papers100M')\nconn = kuzu.Connection(db, num_threads=cpu_count())\nprint(\"Creating Kùzu tables...\")\nconn.execute(\n    \"CREATE NODE TABLE paper(id INT64, x FLOAT[128], year INT64, y FLOAT, \"\n    \"PRIMARY KEY (id));\")\nconn.execute(\"CREATE REL TABLE cites(FROM paper TO paper, MANY_MANY);\")\nprint(\"Copying nodes to Kùzu tables...\")\nconn.execute('COPY paper FROM (\"%s\",  \"%s\",  \"%s\", \"%s\") BY COLUMN;' %\n             (ids_path, node_feature_path, node_year_path, node_label_path))\nprint(\"Copying edges to Kùzu tables...\")\nconn.execute('COPY cites FROM \"%s\";' % (edge_index_path))\nprint(\"All done!\")\n"
  },
  {
    "path": "examples/distributed/kuzu/papers_100M/train.py",
    "content": "import multiprocessing as mp\nimport os.path as osp\n\nimport kuzu\nimport pandas as pd\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import MLP, BatchNorm, SAGEConv\n\nNUM_EPOCHS = 1\nLOADER_BATCH_SIZE = 1024\n\nprint('Batch size:', LOADER_BATCH_SIZE)\nprint('Number of epochs:', NUM_EPOCHS)\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nprint('Using device:', device)\n\n# Load the train set:\ntrain_path = osp.join('.', 'papers100M-bin', 'split', 'time', 'train.csv.gz')\ntrain_df = pd.read_csv(\n    osp.abspath(train_path),\n    compression='gzip',\n    header=None,\n)\ninput_nodes = torch.tensor(train_df[0].values, dtype=torch.long)\n\n########################################################################\n# The below code sets up the remote backend of Kùzu for PyG.\n# Please refer to: https://kuzudb.com/docs/client-apis/python-api/overview.html\n# for how to use the Python API of Kùzu.\n########################################################################\n\n# The buffer pool size of Kùzu is set to 40GB. You can change it to a smaller\n# value if you have less memory.\nKUZU_BM_SIZE = 40 * 1024**3\n\n# Create Kùzu database:\ndb = kuzu.Database(osp.abspath(osp.join('.', 'papers100M')), KUZU_BM_SIZE)\n\n# Get remote backend for PyG:\nfeature_store, graph_store = db.get_torch_geometric_remote_backend(\n    mp.cpu_count())\n\n# Plug the graph store and feature store into the `NeighborLoader`.\n# Note that `filter_per_worker` is set to `False`. This is because the Kùzu\n# database is already using multi-threading to scan the features in parallel\n# and the database object is not fork-safe.\nloader = NeighborLoader(\n    data=(feature_store, graph_store),\n    num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},\n    batch_size=LOADER_BATCH_SIZE,\n    input_nodes=('paper', input_nodes),\n    num_workers=4,\n    filter_per_worker=False,\n)\n\n\nclass GraphSAGE(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,\n                 dropout=0.2):\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList()\n        self.norms = torch.nn.ModuleList()\n\n        self.convs.append(SAGEConv(in_channels, hidden_channels))\n        self.norms.append(BatchNorm(hidden_channels))\n        for _ in range(1, num_layers):\n            self.convs.append(SAGEConv(hidden_channels, hidden_channels))\n            self.norms.append(BatchNorm(hidden_channels))\n\n        self.mlp = MLP(\n            in_channels=in_channels + num_layers * hidden_channels,\n            hidden_channels=2 * out_channels,\n            out_channels=out_channels,\n            num_layers=2,\n            norm='batch_norm',\n            act='leaky_relu',\n        )\n\n        self.dropout = dropout\n\n    def forward(self, x, edge_index):\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        xs = [x]\n        for conv, norm in zip(self.convs, self.norms):\n            x = conv(x, edge_index)\n            x = norm(x)\n            x = x.relu()\n            x = F.dropout(x, p=self.dropout, training=self.training)\n            xs.append(x)\n        return self.mlp(torch.cat(xs, dim=-1))\n\n\nmodel = GraphSAGE(in_channels=128, hidden_channels=1024, out_channels=172,\n                  num_layers=3).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\nfor epoch in range(1, NUM_EPOCHS + 1):\n    total_loss = total_examples = 0\n    for batch in tqdm(loader):\n        batch = batch.to(device)\n        batch_size = batch['paper'].batch_size\n\n        optimizer.zero_grad()\n        out = model(\n            batch['paper'].x,\n            batch['paper', 'cites', 'paper'].edge_index,\n        )[:batch_size]\n        y = batch['paper'].y[:batch_size].long().view(-1)\n        loss = F.cross_entropy(out, y)\n\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss) * y.numel()\n        total_examples += y.numel()\n\n    print(f'Epoch: {epoch:02d}, Loss: {total_loss / total_examples:.4f}')\n"
  },
  {
    "path": "examples/distributed/pyg/README.md",
    "content": "# Distributed Training with PyG\n\n> **Deprecated:** `torch_geometric.distributed` is deprecated.\n> Please refer to [NVIDIA cuGraph-GNN](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html#accelerating-pyg-with-nvidia-cugraph-gnn) for scalable distributed GNN training with NVIDIA GPUs.\n"
  },
  {
    "path": "examples/dna.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.model_selection import StratifiedKFold\n\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import DNAConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset)\ndata = dataset[0]\ndata.train_mask = data.val_mask = data.test_mask = None\n\n\ndef gen_uniform_20_20_60_split(data):\n    skf = StratifiedKFold(5, shuffle=True, random_state=55)\n    idx = [torch.from_numpy(i) for _, i in skf.split(data.y, data.y)]\n    data.train_idx = idx[0].to(torch.long)\n    data.val_idx = idx[1].to(torch.long)\n    data.test_idx = torch.cat(idx[2:], dim=0).to(torch.long)\n    return data\n\n\ndata = gen_uniform_20_20_60_split(data)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,\n                 heads=1, groups=1):\n        super().__init__()\n        self.hidden_channels = hidden_channels\n        self.lin1 = torch.nn.Linear(in_channels, hidden_channels)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            self.convs.append(\n                DNAConv(hidden_channels, heads, groups, dropout=0.8))\n        self.lin2 = torch.nn.Linear(hidden_channels, out_channels)\n\n    def reset_parameters(self):\n        self.lin1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, x, edge_index):\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x_all = x.view(-1, 1, self.hidden_channels)\n        for conv in self.convs:\n            x = F.relu(conv(x_all, edge_index))\n            x = x.view(-1, 1, self.hidden_channels)\n            x_all = torch.cat([x_all, x], dim=1)\n        x = x_all[:, -1]\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return torch.log_softmax(x, dim=1)\n\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nmodel = Net(in_channels=dataset.num_features, hidden_channels=128,\n            out_channels=dataset.num_classes, num_layers=5, heads=8, groups=16)\nmodel, data = model.to(device), data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.0005)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.nll_loss(out[data.train_idx], data.y[data.train_idx])\n    loss.backward()\n    optimizer.step()\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    out, accs = model(data.x, data.edge_index), []\n    for _, idx in data('train_idx', 'val_idx', 'test_idx'):\n        pred = out[idx].argmax(1)\n        acc = pred.eq(data.y[idx]).sum().item() / idx.numel()\n        accs.append(acc)\n    return accs\n\n\nbest_val_acc = test_acc = 0\nfor epoch in range(1, 201):\n    train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '\n          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/egc.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom ogb.graphproppred import Evaluator\nfrom ogb.graphproppred import PygGraphPropPredDataset as OGBG\nfrom ogb.graphproppred.mol_encoder import AtomEncoder\nfrom torch.nn import BatchNorm1d, Linear, ReLU, Sequential\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import EGConv, global_mean_pool\nfrom torch_geometric.typing import WITH_TORCH_SPARSE\n\nif not WITH_TORCH_SPARSE:\n    quit(\"This example requires 'torch-sparse'\")\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--use_multi_aggregators', action='store_true',\n                    help='Switch between EGC-S and EGC-M')\nargs = parser.parse_args()\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB')\ndataset = OGBG('ogbg-molhiv', path, pre_transform=T.ToSparseTensor())\nevaluator = Evaluator('ogbg-molhiv')\n\nsplit_idx = dataset.get_idx_split()\ntrain_dataset = dataset[split_idx['train']]\nval_dataset = dataset[split_idx['valid']]\ntest_dataset = dataset[split_idx['test']]\n\ntrain_loader = DataLoader(train_dataset, batch_size=32, num_workers=4,\n                          shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=256)\ntest_loader = DataLoader(test_dataset, batch_size=256)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, hidden_channels, num_layers, num_heads, num_bases):\n        super().__init__()\n        if args.use_multi_aggregators:\n            aggregators = ['sum', 'mean', 'max']\n        else:\n            aggregators = ['symnorm']\n\n        self.encoder = AtomEncoder(hidden_channels)\n\n        self.convs = torch.nn.ModuleList()\n        self.norms = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            self.convs.append(\n                EGConv(hidden_channels, hidden_channels, aggregators,\n                       num_heads, num_bases))\n            self.norms.append(BatchNorm1d(hidden_channels))\n\n        self.mlp = Sequential(\n            Linear(hidden_channels, hidden_channels // 2, bias=False),\n            BatchNorm1d(hidden_channels // 2),\n            ReLU(inplace=True),\n            Linear(hidden_channels // 2, hidden_channels // 4, bias=False),\n            BatchNorm1d(hidden_channels // 4),\n            ReLU(inplace=True),\n            Linear(hidden_channels // 4, 1),\n        )\n\n    def forward(self, x, adj_t, batch):\n        adj_t = adj_t.set_value(None)  # EGConv works without any edge features\n\n        x = self.encoder(x)\n\n        for conv, norm in zip(self.convs, self.norms):\n            h = conv(x, adj_t)\n            h = norm(h)\n            h = h.relu_()\n            x = x + h\n\n        x = global_mean_pool(x, batch)\n\n        return self.mlp(x)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(hidden_channels=236, num_layers=4, num_heads=4,\n            num_bases=4).to(device)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\nscheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20,\n                              min_lr=1e-5)\n\n\ndef train():\n    model.train()\n\n    total_loss = total_examples = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n\n        out = model(data.x, data.adj_t, data.batch)\n        loss = F.binary_cross_entropy_with_logits(out, data.y.to(torch.float))\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss) * data.num_graphs\n        total_examples += data.num_graphs\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef evaluate(loader):\n    model.eval()\n\n    y_pred, y_true = [], []\n    for data in loader:\n        data = data.to(device)\n        pred = model(data.x, data.adj_t, data.batch)\n        y_pred.append(pred.cpu())\n        y_true.append(data.y.cpu())\n\n    y_true = torch.cat(y_true, dim=0)\n    y_pred = torch.cat(y_pred, dim=0)\n    return evaluator.eval({'y_true': y_true, 'y_pred': y_pred})['rocauc']\n\n\nfor epoch in range(1, 31):\n    loss = train()\n    val_rocauc = evaluate(val_loader)\n    test_rocauc = evaluate(test_loader)\n    scheduler.step(val_rocauc)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_rocauc:.4f}, '\n          f'Test: {test_rocauc:.4f}')\n"
  },
  {
    "path": "examples/equilibrium_median.py",
    "content": "r\"\"\"Replicates the experiment from `\"Deep Graph Infomax\"\n<https://arxiv.org/abs/1809.10341>`_ to try and teach `EquilibriumAggregation`\nto learn to take the median of a set of numbers.\n\nThis example converges slowly to being able to predict the\nmedian similar to what is observed in the paper.\n\"\"\"\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.nn import EquilibriumAggregation\n\ninput_size = 100\nsteps = 10000000\nembedding_size = 10\neval_each = 1000\n\nmodel = EquilibriumAggregation(1, 10, [256, 256], 1)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\nnorm = torch.distributions.normal.Normal(0.5, 0.4)\ngamma = torch.distributions.gamma.Gamma(0.2, 0.5)\nuniform = torch.distributions.uniform.Uniform(0, 1)\ntotal_loss = 0\nn_loss = 0\n\nfor i in range(1, steps + 1):\n    optimizer.zero_grad()\n    dist = np.random.choice([norm, gamma, uniform])\n    x = dist.sample((input_size, 1))\n    y = model(x)\n    loss = (y - x.median()).norm(2) / input_size\n    loss.backward()\n    optimizer.step()\n    total_loss += loss\n    n_loss += 1\n    if i % eval_each == 0:\n        print(f\"Epoch: {i}, Loss {total_loss / n_loss:.6f}\")\n"
  },
  {
    "path": "examples/explain/README.md",
    "content": "# Examples for Generating Explanations of Graph Neural Networks\n\nThis directory contains examples demonstrating the use of the `torch_geometric.explain` package.\nThe `explain` package of PyG provides a set of tools to explain the predictions of a GNN model or to explain the underlying phenomenon of a dataset.\n\n| Example                                                                | Description                                             |\n| ---------------------------------------------------------------------- | ------------------------------------------------------- |\n| [`gnn_explainer.py`](./gnn_explainer.py)                               | `GNNExplainer` for node classification                  |\n| [`gnn_explainer_link_pred.py`](./gnn_explainer_link_pred.py)           | `GNNExplainer` for link prediction                      |\n| [`gnn_explainer_ba_shapes.py`](./gnn_explainer_ba_shapes.py)           | `GNNExplainer` applied on the `BAShapes` dataset        |\n| [`captum_explainer.py`](./captum_explainer.py)                         | Captum-based explainer for node classification          |\n| [`captum_explainer_hetero_link.py`](./captum_explainer_hetero_link.py) | Captum-based explainer for heterogenous link prediction |\n| [`graphmask_explainer.py`](./graphmask_explainer.py)                   | `GraphMaskExplainer` for node classification            |\n"
  },
  {
    "path": "examples/explain/captum_explainer.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.explain import CaptumExplainer, Explainer\nfrom torch_geometric.nn import GCNConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, dataset)\ndata = dataset[0]\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GCNConv(dataset.num_features, 16)\n        self.conv2 = GCNConv(16, dataset.num_classes)\n\n    def forward(self, x, edge_index):\n        x = F.relu(self.conv1(x, edge_index))\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = GCN().to(device)\ndata = data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\nfor _ in range(1, 201):\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n\nexplainer = Explainer(\n    model=model,\n    algorithm=CaptumExplainer('IntegratedGradients'),\n    explanation_type='model',\n    model_config=dict(\n        mode='multiclass_classification',\n        task_level='node',\n        return_type='log_probs',\n    ),\n    node_mask_type='attributes',\n    edge_mask_type='object',\n    threshold_config=dict(\n        threshold_type='topk',\n        value=200,\n    ),\n)\n\nnode_index = 10\nexplanation = explainer(data.x, data.edge_index, index=node_index)\nprint(f'Generated explanations in {explanation.available_explanations}')\n\npath = 'feature_importance.png'\nexplanation.visualize_feature_importance(path, top_k=10)\nprint(f\"Feature importance plot has been saved to '{path}'\")\n\npath = 'subgraph.pdf'\nexplanation.visualize_graph(path)\nprint(f\"Subgraph plot has been saved to '{path}'\")\n"
  },
  {
    "path": "examples/explain/captum_explainer_hetero_link.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import MovieLens\nfrom torch_geometric.explain import CaptumExplainer, Explainer\nfrom torch_geometric.nn import SAGEConv, to_hetero\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens')\ndataset = MovieLens(path, model_name='all-MiniLM-L6-v2')\ndata = dataset[0].to(device)\n\n# Add user node features for message passing:\ndata['user'].x = torch.eye(data['user'].num_nodes, device=device)\ndel data['user'].num_nodes\n\n# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing:\ndata = T.ToUndirected()(data)\ndata['user', 'movie'].edge_label = data['user',\n                                        'movie'].edge_label.to(torch.float)\ndel data['movie', 'rev_rates', 'user'].edge_label  # Remove \"reverse\" label.\n\n# Perform a link-level split into training, validation, and test edges:\ndata, _, _ = T.RandomLinkSplit(\n    num_val=0.1,\n    num_test=0.1,\n    neg_sampling_ratio=0.0,\n    edge_types=[('user', 'rates', 'movie')],\n    rev_edge_types=[('movie', 'rev_rates', 'user')],\n)(data)\n\n\nclass GNNEncoder(torch.nn.Module):\n    def __init__(self, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = SAGEConv((-1, -1), hidden_channels)\n        self.conv2 = SAGEConv((-1, -1), out_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n        return x\n\n\nclass EdgeDecoder(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.lin1 = Linear(2 * hidden_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, 1)\n\n    def forward(self, z_dict, edge_label_index):\n        row, col = edge_label_index\n        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)\n\n        z = self.lin1(z).relu()\n        z = self.lin2(z)\n        return z.view(-1)\n\n\nclass Model(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.encoder = GNNEncoder(hidden_channels, hidden_channels)\n        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')\n        self.decoder = EdgeDecoder(hidden_channels)\n\n    def forward(self, x_dict, edge_index_dict, edge_label_index):\n        z_dict = self.encoder(x_dict, edge_index_dict)\n        return self.decoder(z_dict, edge_label_index)\n\n\nmodel = Model(hidden_channels=32).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\nfor _ in range(1, 10):\n    model.train()\n    optimizer.zero_grad()\n    pred = model(\n        data.x_dict,\n        data.edge_index_dict,\n        data['user', 'movie'].edge_label_index,\n    )\n    loss = F.mse_loss(pred, data['user', 'movie'].edge_label)\n    loss.backward()\n    optimizer.step()\n\nexplainer = Explainer(\n    model=model,\n    algorithm=CaptumExplainer('IntegratedGradients'),\n    explanation_type='model',\n    model_config=dict(\n        mode='regression',\n        task_level='edge',\n        return_type='raw',\n    ),\n    node_mask_type='attributes',\n    edge_mask_type='object',\n    threshold_config=dict(\n        threshold_type='topk',\n        value=200,\n    ),\n)\n\nindex = torch.tensor([2, 10])  # Explain edge labels with index 2 and 10.\nexplanation = explainer(\n    data.x_dict,\n    data.edge_index_dict,\n    index=index,\n    edge_label_index=data['user', 'movie'].edge_label_index,\n)\nprint(f'Generated explanations in {explanation.available_explanations}')\n\npath = 'feature_importance.png'\nexplanation.visualize_feature_importance(path, top_k=10)\nprint(f\"Feature importance plot has been saved to '{path}'\")\n"
  },
  {
    "path": "examples/explain/gnn_explainer.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.explain import Explainer, GNNExplainer\nfrom torch_geometric.nn import GCNConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, dataset)\ndata = dataset[0]\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GCNConv(dataset.num_features, 16)\n        self.conv2 = GCNConv(16, dataset.num_classes)\n\n    def forward(self, x, edge_index):\n        x = F.relu(self.conv1(x, edge_index))\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = GCN().to(device)\ndata = data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\nfor _ in range(1, 201):\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n\nexplainer = Explainer(\n    model=model,\n    algorithm=GNNExplainer(epochs=200),\n    explanation_type='model',\n    node_mask_type='attributes',\n    edge_mask_type='object',\n    model_config=dict(\n        mode='multiclass_classification',\n        task_level='node',\n        return_type='log_probs',\n    ),\n)\nnode_index = 10\nexplanation = explainer(data.x, data.edge_index, index=node_index)\nprint(f'Generated explanations in {explanation.available_explanations}')\n\npath = 'feature_importance.png'\nexplanation.visualize_feature_importance(path, top_k=10)\nprint(f\"Feature importance plot has been saved to '{path}'\")\n\npath = 'subgraph.pdf'\nexplanation.visualize_graph(path)\nprint(f\"Subgraph visualization plot has been saved to '{path}'\")\n"
  },
  {
    "path": "examples/explain/gnn_explainer_ba_shapes.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom sklearn.metrics import roc_auc_score\nfrom sklearn.model_selection import train_test_split\nfrom tqdm import tqdm\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import ExplainerDataset\nfrom torch_geometric.datasets.graph_generator import BAGraph\nfrom torch_geometric.explain import Explainer, GNNExplainer\nfrom torch_geometric.nn import GCN\nfrom torch_geometric.utils import k_hop_subgraph\n\ndataset = ExplainerDataset(\n    graph_generator=BAGraph(num_nodes=300, num_edges=5),\n    motif_generator='house',\n    num_motifs=80,\n    transform=T.Constant(),\n)\ndata = dataset[0]\n\nidx = torch.arange(data.num_nodes)\ntrain_idx, test_idx = train_test_split(idx, train_size=0.8, stratify=data.y)\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\ndata = data.to(device)\nmodel = GCN(data.num_node_features, hidden_channels=20, num_layers=3,\n            out_channels=dataset.num_classes).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.005)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.cross_entropy(out[train_idx], data.y[train_idx])\n    torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.x, data.edge_index).argmax(dim=-1)\n\n    train_correct = int((pred[train_idx] == data.y[train_idx]).sum())\n    train_acc = train_correct / train_idx.size(0)\n\n    test_correct = int((pred[test_idx] == data.y[test_idx]).sum())\n    test_acc = test_correct / test_idx.size(0)\n\n    return train_acc, test_acc\n\n\npbar = tqdm(range(1, 2001))\nfor epoch in pbar:\n    loss = train()\n    if epoch == 1 or epoch % 200 == 0:\n        train_acc, test_acc = test()\n        pbar.set_description(f'Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n                             f'Test: {test_acc:.4f}')\npbar.close()\nmodel.eval()\n\nfor explanation_type in ['phenomenon', 'model']:\n    explainer = Explainer(\n        model=model,\n        algorithm=GNNExplainer(epochs=300),\n        explanation_type=explanation_type,\n        node_mask_type='attributes',\n        edge_mask_type='object',\n        model_config=dict(\n            mode='multiclass_classification',\n            task_level='node',\n            return_type='raw',\n        ),\n    )\n\n    # Explanation ROC AUC over all test nodes:\n    targets, preds = [], []\n    node_indices = range(400, data.num_nodes, 5)\n    for node_index in tqdm(node_indices, leave=False, desc='Train Explainer'):\n        target = data.y if explanation_type == 'phenomenon' else None\n        explanation = explainer(data.x, data.edge_index, index=node_index,\n                                target=target)\n\n        _, _, _, hard_edge_mask = k_hop_subgraph(node_index, num_hops=3,\n                                                 edge_index=data.edge_index)\n\n        targets.append(data.edge_mask[hard_edge_mask].cpu())\n        preds.append(explanation.edge_mask[hard_edge_mask].cpu())\n\n    auc = roc_auc_score(torch.cat(targets), torch.cat(preds))\n    print(f'Mean ROC AUC (explanation type {explanation_type:10}): {auc:.4f}')\n"
  },
  {
    "path": "examples/explain/gnn_explainer_link_pred.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.metrics import roc_auc_score\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.explain import Explainer, GNNExplainer, ModelConfig\nfrom torch_geometric.nn import GCNConv\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ntransform = T.Compose([\n    T.NormalizeFeatures(),\n    T.ToDevice(device),\n    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True),\n])\ndataset = Planetoid(path, dataset, transform=transform)\ntrain_data, val_data, test_data = dataset[0]\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, hidden_channels)\n        self.conv2 = GCNConv(hidden_channels, out_channels)\n\n    def encode(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n        return x\n\n    def decode(self, z, edge_label_index):\n        src, dst = edge_label_index\n        return (z[src] * z[dst]).sum(dim=-1)\n\n    def forward(self, x, edge_index, edge_label_index):\n        z = model.encode(x, edge_index)\n        return model.decode(z, edge_label_index).view(-1)\n\n\nmodel = GCN(dataset.num_features, 128, 64).to(device)\noptimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n\n    out = model(train_data.x, train_data.edge_index,\n                train_data.edge_label_index)\n    loss = F.binary_cross_entropy_with_logits(out, train_data.edge_label)\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test(data):\n    model.eval()\n    out = model(data.x, data.edge_index, data.edge_label_index).sigmoid()\n    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())\n\n\nfor epoch in range(1, 201):\n    loss = train()\n    if epoch % 20 == 0:\n        val_auc = test(val_data)\n        test_auc = test(test_data)\n        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '\n              f'Test: {test_auc:.4f}')\n\nmodel_config = ModelConfig(\n    mode='binary_classification',\n    task_level='edge',\n    return_type='raw',\n)\n\n# Explain model output for a single edge:\nedge_label_index = val_data.edge_label_index[:, 0]\n\nexplainer = Explainer(\n    model=model,\n    explanation_type='model',\n    algorithm=GNNExplainer(epochs=200),\n    node_mask_type='attributes',\n    edge_mask_type='object',\n    model_config=model_config,\n)\nexplanation = explainer(\n    x=train_data.x,\n    edge_index=train_data.edge_index,\n    edge_label_index=edge_label_index,\n)\nprint(f'Generated model explanations in {explanation.available_explanations}')\n\n# Explain a selected target (phenomenon) for a single edge:\nedge_label_index = val_data.edge_label_index[:, 0]\ntarget = val_data.edge_label[0].unsqueeze(dim=0).long()\n\nexplainer = Explainer(\n    model=model,\n    explanation_type='phenomenon',\n    algorithm=GNNExplainer(epochs=200),\n    node_mask_type='attributes',\n    edge_mask_type='object',\n    model_config=model_config,\n)\nexplanation = explainer(\n    x=train_data.x,\n    edge_index=train_data.edge_index,\n    target=target,\n    edge_label_index=edge_label_index,\n)\navailable_explanations = explanation.available_explanations\nprint(f'Generated phenomenon explanations in {available_explanations}')\n"
  },
  {
    "path": "examples/explain/graphmask_explainer.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.explain import Explainer, GraphMaskExplainer\nfrom torch_geometric.nn import GATConv, GCNConv\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid')\ndataset = Planetoid(path, name='Cora')\ndata = dataset[0].to(device)\n\n# GCN Node Classification =====================================================\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GCNConv(dataset.num_features, 16)\n        self.conv2 = GCNConv(16, dataset.num_classes)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\nmodel = GCN().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\nfor _ in range(1, 201):\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n\nexplainer = Explainer(\n    model=model,\n    algorithm=GraphMaskExplainer(2, epochs=5),\n    explanation_type='model',\n    node_mask_type='attributes',\n    edge_mask_type='object',\n    model_config=dict(\n        mode='multiclass_classification',\n        task_level='node',\n        return_type='log_probs',\n    ),\n)\n\nnode_index = 10\nexplanation = explainer(data.x, data.edge_index, index=node_index)\nprint(f'Generated explanations in {explanation.available_explanations}')\n\n# GAT Node Classification =====================================================\n\n\nclass GAT(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GATConv(dataset.num_features, 8, heads=8)\n        self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=False)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\nmodel = GAT().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\nfor _ in range(1, 201):\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n\nexplainer = Explainer(\n    model=model,\n    algorithm=GraphMaskExplainer(2, epochs=5),\n    explanation_type='model',\n    node_mask_type='attributes',\n    edge_mask_type='object',\n    model_config=dict(\n        mode='multiclass_classification',\n        task_level='node',\n        return_type='log_probs',\n    ),\n)\n\nnode_index = torch.tensor([10, 20])\nexplanation = explainer(data.x, data.edge_index, index=node_index)\nprint(f'Generated explanations in {explanation.available_explanations}')\n"
  },
  {
    "path": "examples/faust.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import FAUST\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import SplineConv\nfrom torch_geometric.typing import WITH_SPLINE\n\nif not WITH_SPLINE:\n    quit(\"This example requires 'pyg-lib>=0.6.0'\")\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FAUST')\npre_transform = T.Compose([T.FaceToEdge(), T.Constant(value=1)])\ntrain_dataset = FAUST(path, True, T.Cartesian(), pre_transform)\ntest_dataset = FAUST(path, False, T.Cartesian(), pre_transform)\ntrain_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)\ntest_loader = DataLoader(test_dataset, batch_size=1)\nd = train_dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = SplineConv(1, 32, dim=3, kernel_size=5, aggr='add')\n        self.conv2 = SplineConv(32, 64, dim=3, kernel_size=5, aggr='add')\n        self.conv3 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')\n        self.conv4 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')\n        self.conv5 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')\n        self.conv6 = SplineConv(64, 64, dim=3, kernel_size=5, aggr='add')\n        self.lin1 = torch.nn.Linear(64, 256)\n        self.lin2 = torch.nn.Linear(256, d.num_nodes)\n\n    def forward(self, data):\n        x, edge_index, pseudo = data.x, data.edge_index, data.edge_attr\n        x = F.elu(self.conv1(x, edge_index, pseudo))\n        x = F.elu(self.conv2(x, edge_index, pseudo))\n        x = F.elu(self.conv3(x, edge_index, pseudo))\n        x = F.elu(self.conv4(x, edge_index, pseudo))\n        x = F.elu(self.conv5(x, edge_index, pseudo))\n        x = F.elu(self.conv6(x, edge_index, pseudo))\n        x = F.elu(self.lin1(x))\n        x = F.dropout(x, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net().to(device)\ntarget = torch.arange(d.num_nodes, dtype=torch.long, device=device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train(epoch):\n    model.train()\n\n    if epoch == 61:\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = 0.001\n\n    for data in train_loader:\n        optimizer.zero_grad()\n        F.nll_loss(model(data.to(device)), target).backward()\n        optimizer.step()\n\n\ndef test():\n    model.eval()\n    correct = 0\n\n    for data in test_loader:\n        pred = model(data.to(device)).max(1)[1]\n        correct += pred.eq(target).sum().item()\n    return correct / (len(test_dataset) * d.num_nodes)\n\n\nfor epoch in range(1, 101):\n    train(epoch)\n    test_acc = test()\n    print(f'Epoch: {epoch:03d}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/film.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.metrics import f1_score\nfrom torch.nn import BatchNorm1d\n\nfrom torch_geometric.datasets import PPI\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import FiLMConv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI')\ntrain_dataset = PPI(path, split='train')\nval_dataset = PPI(path, split='val')\ntest_dataset = PPI(path, split='test')\ntrain_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)\ntest_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,\n                 dropout=0.0):\n        super().__init__()\n        self.dropout = dropout\n\n        self.convs = torch.nn.ModuleList()\n        self.convs.append(FiLMConv(in_channels, hidden_channels))\n        for _ in range(num_layers - 2):\n            self.convs.append(FiLMConv(hidden_channels, hidden_channels))\n        self.convs.append(FiLMConv(hidden_channels, out_channels, act=None))\n\n        self.norms = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.norms.append(BatchNorm1d(hidden_channels))\n\n    def forward(self, x, edge_index):\n        for conv, norm in zip(self.convs[:-1], self.norms):\n            x = norm(conv(x, edge_index))\n            x = F.dropout(x, p=self.dropout, training=self.training)\n        x = self.convs[-1](x, edge_index)\n        return x\n\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nmodel = Net(in_channels=train_dataset.num_features, hidden_channels=320,\n            out_channels=train_dataset.num_classes, num_layers=4,\n            dropout=0.1).to(device)\ncriterion = torch.nn.BCEWithLogitsLoss()\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        loss = criterion(model(data.x, data.edge_index), data.y)\n        total_loss += loss.item() * data.num_graphs\n        loss.backward()\n        optimizer.step()\n    return total_loss / len(train_loader.dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    ys, preds = [], []\n    for data in loader:\n        ys.append(data.y)\n        out = model(data.x.to(device), data.edge_index.to(device))\n        preds.append((out > 0).float().cpu())\n\n    y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()\n    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0\n\n\nfor epoch in range(1, 501):\n    loss = train()\n    val_f1 = test(val_loader)\n    test_f1 = test(test_loader)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, '\n          f'Test: {test_f1:.4f}')\n"
  },
  {
    "path": "examples/gat.py",
    "content": "import argparse\nimport os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.logging import init_wandb, log\nfrom torch_geometric.nn import GATConv\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, default='Cora')\nparser.add_argument('--hidden_channels', type=int, default=8)\nparser.add_argument('--heads', type=int, default=8)\nparser.add_argument('--lr', type=float, default=0.005)\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--wandb', action='store_true', help='Track experiment')\nargs = parser.parse_args()\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\n\ninit_wandb(name=f'GAT-{args.dataset}', heads=args.heads, epochs=args.epochs,\n           hidden_channels=args.hidden_channels, lr=args.lr, device=device)\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures())\ndata = dataset[0].to(device)\n\n\nclass GAT(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, heads):\n        super().__init__()\n        self.conv1 = GATConv(in_channels, hidden_channels, heads, dropout=0.6)\n        # On the Pubmed dataset, use `heads` output heads in `conv2`.\n        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1,\n                             concat=False, dropout=0.6)\n\n    def forward(self, x, edge_index):\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = F.elu(self.conv1(x, edge_index))\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = self.conv2(x, edge_index)\n        return x\n\n\nmodel = GAT(dataset.num_features, args.hidden_channels, dataset.num_classes,\n            args.heads).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss.detach())\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.x, data.edge_index).argmax(dim=-1)\n\n    accs = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\ntimes = []\nbest_val_acc = final_test_acc = 0\nfor epoch in range(1, args.epochs + 1):\n    start = time.time()\n    loss = train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/gcn.py",
    "content": "import argparse\nimport os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.logging import init_wandb, log\nfrom torch_geometric.nn import GCNConv\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, default='Cora')\nparser.add_argument('--hidden_channels', type=int, default=16)\nparser.add_argument('--lr', type=float, default=0.01)\nparser.add_argument('--epochs', type=int, default=200)\nparser.add_argument('--use_gdc', action='store_true', help='Use GDC')\nparser.add_argument('--wandb', action='store_true', help='Track experiment')\nargs = parser.parse_args()\n\ndevice = torch_geometric.device('auto')\n\ninit_wandb(\n    name=f'GCN-{args.dataset}',\n    lr=args.lr,\n    epochs=args.epochs,\n    hidden_channels=args.hidden_channels,\n    device=device,\n)\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures())\ndata = dataset[0].to(device)\n\nif args.use_gdc:\n    transform = T.GDC(\n        self_loop_weight=1,\n        normalization_in='sym',\n        normalization_out='col',\n        diffusion_kwargs=dict(method='ppr', alpha=0.05),\n        sparsification_kwargs=dict(method='topk', k=128, dim=0),\n        exact=True,\n    )\n    data = transform(data)\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, hidden_channels,\n                             normalize=not args.use_gdc)\n        self.conv2 = GCNConv(hidden_channels, out_channels,\n                             normalize=not args.use_gdc)\n\n    def forward(self, x, edge_index, edge_weight=None):\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.conv1(x, edge_index, edge_weight).relu()\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.conv2(x, edge_index, edge_weight)\n        return x\n\n\nmodel = GCN(\n    in_channels=dataset.num_features,\n    hidden_channels=args.hidden_channels,\n    out_channels=dataset.num_classes,\n).to(device)\n\noptimizer = torch.optim.Adam([\n    dict(params=model.conv1.parameters(), weight_decay=5e-4),\n    dict(params=model.conv2.parameters(), weight_decay=0)\n], lr=args.lr)  # Only perform weight-decay on first convolution.\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index, data.edge_attr)\n    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss.detach())\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.x, data.edge_index, data.edge_attr).argmax(dim=-1)\n\n    accs = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\nbest_val_acc = test_acc = 0\ntimes = []\nfor epoch in range(1, args.epochs + 1):\n    start = time.time()\n    loss = train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)\n    times.append(time.time() - start)\nprint(f'Median time per epoch: {torch.tensor(times).median():.4f}s')\n"
  },
  {
    "path": "examples/gcn2_cora.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GCN2Conv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ntransform = T.Compose([T.NormalizeFeatures(), T.GCNNorm(), T.ToSparseTensor()])\ndataset = Planetoid(path, dataset, transform=transform)\ndata = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, hidden_channels, num_layers, alpha, theta,\n                 shared_weights=True, dropout=0.0):\n        super().__init__()\n\n        self.lins = torch.nn.ModuleList()\n        self.lins.append(Linear(dataset.num_features, hidden_channels))\n        self.lins.append(Linear(hidden_channels, dataset.num_classes))\n\n        self.convs = torch.nn.ModuleList()\n        for layer in range(num_layers):\n            self.convs.append(\n                GCN2Conv(hidden_channels, alpha, theta, layer + 1,\n                         shared_weights, normalize=False))\n\n        self.dropout = dropout\n\n    def forward(self, x, adj_t):\n        x = F.dropout(x, self.dropout, training=self.training)\n        x = x_0 = self.lins[0](x).relu()\n\n        for conv in self.convs:\n            x = F.dropout(x, self.dropout, training=self.training)\n            x = conv(x, x_0, adj_t)\n            x = x.relu()\n\n        x = F.dropout(x, self.dropout, training=self.training)\n        x = self.lins[1](x)\n\n        return x.log_softmax(dim=-1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(hidden_channels=64, num_layers=64, alpha=0.1, theta=0.5,\n            shared_weights=True, dropout=0.6).to(device)\ndata = data.to(device)\noptimizer = torch.optim.Adam([\n    dict(params=model.convs.parameters(), weight_decay=0.01),\n    dict(params=model.lins.parameters(), weight_decay=5e-4)\n], lr=0.01)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.adj_t)\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred, accs = model(data.x, data.adj_t).argmax(dim=-1), []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\nbest_val_acc = test_acc = 0\ntimes = []\nfor epoch in range(1, 1001):\n    start = time.time()\n    loss = train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:04d}, Loss: {loss:.4f} Train: {train_acc:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {tmp_test_acc:.4f}, '\n          f'Final Test: {test_acc:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/gcn2_ppi.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.metrics import f1_score\nfrom torch.nn import Linear\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import PPI\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GCN2Conv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'GCN2_PPI')\npre_transform = T.Compose([T.GCNNorm(), T.ToSparseTensor()])\ntrain_dataset = PPI(path, split='train', pre_transform=pre_transform)\nval_dataset = PPI(path, split='val', pre_transform=pre_transform)\ntest_dataset = PPI(path, split='test', pre_transform=pre_transform)\ntrain_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)\ntest_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, hidden_channels, num_layers, alpha, theta,\n                 shared_weights=True, dropout=0.0):\n        super().__init__()\n\n        self.lins = torch.nn.ModuleList()\n        self.lins.append(Linear(train_dataset.num_features, hidden_channels))\n        self.lins.append(Linear(hidden_channels, train_dataset.num_classes))\n\n        self.convs = torch.nn.ModuleList()\n        for layer in range(num_layers):\n            self.convs.append(\n                GCN2Conv(hidden_channels, alpha, theta, layer + 1,\n                         shared_weights, normalize=False))\n\n        self.dropout = dropout\n\n    def forward(self, x, adj_t):\n        x = F.dropout(x, self.dropout, training=self.training)\n        x = x_0 = self.lins[0](x).relu()\n\n        for conv in self.convs:\n            h = F.dropout(x, self.dropout, training=self.training)\n            h = conv(h, x_0, adj_t)\n            x = h + x\n            x = x.relu()\n\n        x = F.dropout(x, self.dropout, training=self.training)\n        x = self.lins[1](x)\n\n        return x\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(hidden_channels=2048, num_layers=9, alpha=0.5, theta=1.0,\n            shared_weights=False, dropout=0.2).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\ncriterion = torch.nn.BCEWithLogitsLoss()\n\n\ndef train():\n    model.train()\n\n    total_loss = total_examples = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        loss = criterion(model(data.x, data.adj_t), data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item() * data.num_nodes\n        total_examples += data.num_nodes\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    ys, preds = [], []\n    for data in loader:\n        ys.append(data.y)\n        out = model(data.x.to(device), data.adj_t.to(device))\n        preds.append((out > 0).float().cpu())\n\n    y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()\n    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0\n\n\ntimes = []\nfor epoch in range(1, 2001):\n    start = time.time()\n    loss = train()\n    val_f1 = test(val_loader)\n    test_f1 = test(test_loader)\n    print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, '\n          f'Test: {test_f1:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/geniepath.py",
    "content": "import argparse\nimport os.path as osp\nimport time\n\nimport torch\nfrom sklearn.metrics import f1_score\n\nfrom torch_geometric.datasets import PPI\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GATConv\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--model', type=str, default='GeniePathLazy')\nargs = parser.parse_args()\nassert args.model in ['GeniePath', 'GeniePathLazy']\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'PPI')\ntrain_dataset = PPI(path, split='train')\nval_dataset = PPI(path, split='val')\ntest_dataset = PPI(path, split='test')\ntrain_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)\ntest_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)\n\ndim = 256\nlstm_hidden = 256\nlayer_num = 4\n\n\nclass Breadth(torch.nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super().__init__()\n        self.gatconv = GATConv(in_dim, out_dim, heads=1)\n\n    def forward(self, x, edge_index):\n        x = torch.tanh(self.gatconv(x, edge_index))\n        return x\n\n\nclass Depth(torch.nn.Module):\n    def __init__(self, in_dim, hidden):\n        super().__init__()\n        self.lstm = torch.nn.LSTM(in_dim, hidden, 1, bias=False)\n\n    def forward(self, x, h, c):\n        x, (h, c) = self.lstm(x, (h, c))\n        return x, (h, c)\n\n\nclass GeniePathLayer(torch.nn.Module):\n    def __init__(self, in_dim):\n        super().__init__()\n        self.breadth_func = Breadth(in_dim, dim)\n        self.depth_func = Depth(dim, lstm_hidden)\n\n    def forward(self, x, edge_index, h, c):\n        x = self.breadth_func(x, edge_index)\n        x = x[None, :]\n        x, (h, c) = self.depth_func(x, h, c)\n        x = x[0]\n        return x, (h, c)\n\n\nclass GeniePath(torch.nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super().__init__()\n        self.lin1 = torch.nn.Linear(in_dim, dim)\n        self.gplayers = torch.nn.ModuleList(\n            [GeniePathLayer(dim) for i in range(layer_num)])\n        self.lin2 = torch.nn.Linear(dim, out_dim)\n\n    def forward(self, x, edge_index):\n        x = self.lin1(x)\n        h = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)\n        c = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)\n        for i, _ in enumerate(self.gplayers):\n            x, (h, c) = self.gplayers[i](x, edge_index, h, c)\n        x = self.lin2(x)\n        return x\n\n\nclass GeniePathLazy(torch.nn.Module):\n    def __init__(self, in_dim, out_dim):\n        super().__init__()\n        self.lin1 = torch.nn.Linear(in_dim, dim)\n        self.breadths = torch.nn.ModuleList(\n            [Breadth(dim, dim) for i in range(layer_num)])\n        self.depths = torch.nn.ModuleList(\n            [Depth(dim * 2, lstm_hidden) for i in range(layer_num)])\n        self.lin2 = torch.nn.Linear(dim, out_dim)\n\n    def forward(self, x, edge_index):\n        x = self.lin1(x)\n        h = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)\n        c = torch.zeros(1, x.shape[0], lstm_hidden, device=x.device)\n        h_tmps = []\n        for i, _ in enumerate(self.breadths):\n            h_tmps.append(self.breadths[i](x, edge_index))\n        x = x[None, :]\n        for i, _ in enumerate(self.depths):\n            in_cat = torch.cat((h_tmps[i][None, :], x), -1)\n            x, (h, c) = self.depths[i](in_cat, h, c)\n        x = self.lin2(x[0])\n        return x\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nkwargs = {'GeniePath': GeniePath, 'GeniePathLazy': GeniePathLazy}\nmodel = kwargs[args.model](train_dataset.num_features,\n                           train_dataset.num_classes).to(device)\nloss_op = torch.nn.BCEWithLogitsLoss()\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        num_graphs = data.num_graphs\n        data.batch = None\n        data = data.to(device)\n        optimizer.zero_grad()\n        loss = loss_op(model(data.x, data.edge_index), data.y)\n        total_loss += loss.item() * num_graphs\n        loss.backward()\n        optimizer.step()\n    return total_loss / len(train_loader.dataset)\n\n\ndef test(loader):\n    model.eval()\n\n    ys, preds = [], []\n    for data in loader:\n        ys.append(data.y)\n        with torch.no_grad():\n            out = model(data.x.to(device), data.edge_index.to(device))\n        preds.append((out > 0).float().cpu())\n\n    y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()\n    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0\n\n\ntimes = []\nfor epoch in range(1, 101):\n    start = time.time()\n    loss = train()\n    val_f1 = test(val_loader)\n    test_f1 = test(test_loader)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, '\n          f'Test: {test_f1:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/glnn.py",
    "content": "# Implementation of:\n# Graph-less Neural Networks: Teaching Old MLPs New Tricks via Distillation\n\nimport argparse\nimport os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GCN, MLP\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--lamb', type=float, default=0.0,\n                    help='Balances loss from hard labels and teacher outputs')\nargs = parser.parse_args()\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures())\ndata = dataset[0].to(device)\n\ngnn = GCN(dataset.num_node_features, hidden_channels=16,\n          out_channels=dataset.num_classes, num_layers=2).to(device)\nmlp = MLP([dataset.num_node_features, 64, dataset.num_classes], dropout=0.5,\n          norm=None).to(device)\n\ngnn_optimizer = torch.optim.Adam(gnn.parameters(), lr=0.01, weight_decay=5e-4)\nmlp_optimizer = torch.optim.Adam(mlp.parameters(), lr=0.01, weight_decay=5e-4)\n\n\ndef train_teacher():\n    gnn.train()\n    gnn_optimizer.zero_grad()\n    out = gnn(data.x, data.edge_index)\n    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    gnn_optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test_teacher():\n    gnn.eval()\n    pred = gnn(data.x, data.edge_index).argmax(dim=-1)\n    accs = []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\ntimes = []\nprint('Training Teacher GNN:')\nfor epoch in range(1, 201):\n    start = time.time()\n    loss = train_teacher()\n    if epoch % 20 == 0:\n        train_acc, val_acc, test_acc = test_teacher()\n        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n              f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n    times.append(time.time() - start)\n    start = time.time()\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n\nwith torch.no_grad():  # Obtain soft labels from the GNN:\n    y_soft = gnn(data.x, data.edge_index).log_softmax(dim=-1)\n\n\ndef train_student():\n    mlp.train()\n    mlp_optimizer.zero_grad()\n    out = mlp(data.x)\n    loss1 = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n    loss2 = F.kl_div(out.log_softmax(dim=-1), y_soft, reduction='batchmean',\n                     log_target=True)\n    loss = args.lamb * loss1 + (1 - args.lamb) * loss2\n    loss.backward()\n    mlp_optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test_student():\n    mlp.eval()\n    pred = mlp(data.x).argmax(dim=-1)\n    accs = []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\ntimes = []\nprint('Training Student MLP:')\nfor epoch in range(1, 501):\n    start = time.time()\n    loss = train_student()\n    if epoch % 20 == 0:\n        train_acc, val_acc, test_acc = test_student()\n        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n              f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n    times.append(time.time() - start)\n    start = time.time()\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/gpse.py",
    "content": "import argparse\nimport os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import ZINC\nfrom torch_geometric.graphgym.models.encoder import AtomEncoder\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.logging import log\nfrom torch_geometric.nn import (\n    GPSE,\n    MLP,\n    GCNConv,\n    GINConv,\n    GPSENodeEncoder,\n    Linear,\n    global_mean_pool,\n)\nfrom torch_geometric.nn.models.gpse import precompute_GPSE\nfrom torch_geometric.transforms import AddGPSE\n\n\ndef load_ZINC(args):\n    \"\"\"Load the ZINC dataset, and generate GPSE encodings for the graphs if\n    args.gpse is not None.\n    \"\"\"\n    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',\n                    'ZINC_subset')\n    gpse_model = GPSE.from_pretrained(\n        name=args.gpse,\n        root=osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',\n                      'GPSE_pretrained')) if args.gpse else None\n\n    if args.gpse and args.as_transform:\n        # WARNING: Using a pre_transform will save the encodings to disk,\n        # meaning any future runs will use the saved encodings. This is useful\n        # for speeding up computation, but may not be desirable, e.g. when\n        # experimenting with different pre-trained GPSE models. Alternatively,\n        # AddGPSE can be used as a regular transform, which will compute the\n        # encodings on-the-fly, but this will slow down the data loading\n        # process.\n        train_dataset = ZINC(\n            path, subset=True, split='train',\n            pre_transform=AddGPSE(gpse_model, use_vn=True,\n                                  rand_type='NormalSE'))\n        test_dataset = ZINC(\n            path, subset=True, split='val',\n            pre_transform=AddGPSE(gpse_model, use_vn=True,\n                                  rand_type='NormalSE'))\n    else:\n        train_dataset = ZINC(path, subset=True, split='train')\n        test_dataset = ZINC(path, subset=True, split='val')\n\n        if args.gpse:\n            precompute_GPSE(gpse_model, train_dataset)\n            precompute_GPSE(gpse_model, test_dataset)\n\n    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)\n    test_loader = DataLoader(test_dataset, batch_size=256)\n\n    return train_loader, test_loader\n\n\nclass IdentityNodeEncoder(torch.nn.Module):\n    def __init__(self, emb_dim):\n        super().__init__()\n\n    def forward(self, batch):\n        return batch\n\n\nclass LinearNodeEncoder(torch.nn.Module):\n    def __init__(self, emb_dim, emb_pe_out, bias=True):\n        super().__init__()\n        self.encoder = Linear(emb_dim - emb_pe_out, emb_dim, bias=bias)\n\n    def forward(self, batch):\n        batch.x = self.encoder(batch.x)\n\n        return batch\n\n\nclass TypeDictNodeEncoder(torch.nn.Module):\n    def __init__(self, emb_dim, num_types=28):\n        super().__init__()\n\n        if num_types < 1:\n            raise ValueError(f\"Invalid 'node_encoder_num_types': {num_types}\")\n\n        self.encoder = torch.nn.Embedding(num_embeddings=num_types,\n                                          embedding_dim=emb_dim)\n\n    def forward(self, batch):\n        # Encode just the first dimension if more exist\n        batch.x = self.encoder(batch.x[:, 0])\n\n        return batch\n\n\nclass GNNStackStage(torch.nn.Module):\n    \"\"\"Simple Staging mechanism that stacks an arbitrary number of GNN layers\n    with skip connections and L2 normalization.\n\n    Args:\n        dim_in (int): Input dimension\n        dim_out (int): Output dimension\n        num_layers (int): Number of GNN layers\n        conv_type (str): Type of graph convolution in GNN\n        stage_type (str): Type of skip connections. Options: 'skipsum' or\n        'skipconcat', any other value means no skip connections.\n        l2norm (bool): Whether to apply L2 normalization to outputs\n    \"\"\"\n    def __init__(self, dim_in, dim_out, num_layers, conv_type='gcn',\n                 stage_type='skipsum', l2norm=True):\n        super().__init__()\n        self.num_layers = num_layers\n        self.stage_type = stage_type\n        self.l2norm = l2norm\n        conv_dict = {'gcn': GCNConv, 'gin': GINConv}\n\n        for i in range(num_layers):\n            if stage_type == 'skipconcat':\n                d_in = dim_in if i == 0 else dim_in + i * dim_out\n            else:\n                d_in = dim_in if i == 0 else dim_out\n            layer = conv_dict[conv_type](d_in, dim_out)\n            self.add_module(f'layer{i}', layer)\n\n    def forward(self, batch):\n        for i, layer in enumerate(self.children()):\n            x = batch.x\n            batch.x = layer(batch.x, batch.edge_index)\n            if self.stage_type == 'skipsum':\n                batch.x = x + batch.x\n            elif self.stage_type == 'skipconcat' and \\\n                    i < self.num_layers - 1:\n                batch.x = torch.cat([x, batch.x], dim=1)\n        if self.l2norm:\n            batch.x = F.normalize(batch.x, p=2, dim=-1)\n        return batch\n\n\nclass GPSEPlusGNN(torch.nn.Module):\n    \"\"\"A GPSE encoder paired with a GNN module. Consists of:\n    - encoder1: An optional encoder that is used to encode raw node features,\n        common practice for biochemistry datasets. ZINC uses\n        :class:`TypeDictNodeEncoder`, while ogbg-mol* datasets typically use\n        :class:`~torch_geometric.graphgym.models.encoder.AtomEncoder`. If\n        'none', an :class`IdentityNodeEncoder` is passed that returns the\n        inputs as-is.\n    - encoder2: GPSE encoder that adds precomputed GPSE encodings in the\n        dataset to node features if :obj:`gpse` is :obj:`True`. Otherwise is\n        replaced by a linear layer that maps the :obj:`encoder1` outputs to\n        the correct dimension.\n    - premp: 2-layer MLP before message-passing.\n    - gnn: Stacked <num_layers> message-passing layers of :obj:`conv_type`.\n    - postmp: 1-layer MLP after message-passing to map GNN node states to a\n        single output (for ZINC regression task). For classification tasks,\n        :obj:`num_classes` outputs with softmax activation would be required.\n\n    Args:\n        dim_emb (int): Dimension of embedding outputs. Equals dimension of\n            :obj:`encoder1` outputs (dim_emb - dim_pe_out) and\n            :class:`~torch_geometric.nn.GPSENodeEncoder` outputs (dim_pe_out).\n        dim_conv (int): Dimension of GNN message-passing layers.\n        conv_type (str): Type of graph convolution in GNN.\n        num_layers (int): Number of GNN layers.\n        dim_pe_in (int): Original dimension of posenc_GPSE, i.e. the\n            precomputed GPSE encodings.\n        dim_pe_out (int): Desired dimension of GPSE-derived node features,\n            mapped from the original GPSE encodings via GPSENodeEncoder.\n        encoder (str): Encoding applied to raw node features.\n        gpse (bool): Whether to use GPSE encodings.\n    \"\"\"\n    def __init__(self, dim_emb, dim_conv, conv_type, num_layers, dim_pe_in,\n                 dim_pe_out, encoder='none', gpse=True):\n        super().__init__()\n        encoder_dict = {\n            'none': IdentityNodeEncoder,\n            'Atom': AtomEncoder,\n            'TypeDict': TypeDictNodeEncoder\n        }\n\n        self.encoder1 = encoder_dict[encoder](dim_emb - dim_pe_out)\n        self.encoder2 = GPSENodeEncoder(\n            dim_emb, dim_pe_in, dim_pe_out, expand_x=False) if gpse else (\n                LinearNodeEncoder(dim_emb, dim_pe_out, bias=True))\n        self.premp = MLP([dim_emb, dim_emb, dim_conv])\n        self.gnn = GNNStackStage(dim_conv, dim_conv, num_layers, conv_type)\n        self.postmp = MLP([dim_conv, 1])\n\n    def forward(self, batch):\n        batch = self.encoder1(batch)\n        batch.x = self.encoder2(batch.x, batch.pestat_GPSE)\n        batch.x = self.premp(batch.x)\n        batch = self.gnn(batch)\n        batch = global_mean_pool(batch.x, batch.batch)\n        batch = F.dropout(batch, p=0.5, training=self.training)\n        batch = self.postmp(batch)\n        return batch\n\n\ndef train(loader):\n    model.train()\n\n    total_loss = 0\n    for data in loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data)\n\n        pred = out.squeeze(-1) if out.ndim > 1 else out\n        true = data.y.squeeze(-1) if data.y.ndim > 1 else data.y\n\n        loss = F.mse_loss(pred, true)\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss) * data.num_graphs\n    return total_loss / len(train_loader.dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_loss = 0\n    for data in loader:\n        data = data.to(device)\n        out = model(data)\n\n        pred = out.squeeze(-1) if out.ndim > 1 else out\n        true = data.y.squeeze(-1) if data.y.ndim > 1 else data.y\n\n        loss = F.mse_loss(pred, true)\n        total_loss += float(loss) * data.num_graphs\n    return total_loss / len(loader.dataset)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser(description='GPSE Example')\n\n    parser.add_argument(\n        '--gpse', type=str, default=None, const='molpcba', nargs='?',\n        choices=['molpcba', 'zinc', 'pcqm4mv2', 'geom',\n                 'chembl'], help='which model weights to use '\n        '(default: %(default)s)')\n    parser.add_argument(\n        '--as_transform', action='store_true',\n        help='Whether to apply GPSE as a pre_transform to the '\n        'dataset or not')\n\n    args = parser.parse_args()\n    train_loader, test_loader = load_ZINC(args)\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    model = GPSEPlusGNN(dim_emb=64, dim_conv=128, conv_type='gcn',\n                        num_layers=8, dim_pe_in=512, dim_pe_out=32,\n                        encoder='TypeDict', gpse=args.gpse).to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,\n                                 weight_decay=5e-4)\n\n    num_epochs = 100\n    times = []\n    for epoch in range(1, num_epochs + 1):\n        start = time.time()\n        loss = train(train_loader)\n        train_acc = test(train_loader)\n        test_acc = test(test_loader)\n        log(Epoch=epoch, Loss=loss, Train=train_acc, Test=test_acc)\n        times.append(time.time() - start)\n    print(f'Median time per epoch: {torch.tensor(times).median():.4f}s')\n"
  },
  {
    "path": "examples/graph_gps.py",
    "content": "import argparse\nimport os.path as osp\nfrom typing import Any, Dict, Optional\n\nimport torch\nfrom torch.nn import (\n    BatchNorm1d,\n    Embedding,\n    Linear,\n    ModuleList,\n    ReLU,\n    Sequential,\n)\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import ZINC\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GINEConv, GPSConv, global_add_pool\nfrom torch_geometric.nn.attention import PerformerAttention\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC-PE')\ntransform = T.AddRandomWalkPE(walk_length=20, attr_name='pe')\ntrain_dataset = ZINC(path, subset=True, split='train', pre_transform=transform)\nval_dataset = ZINC(path, subset=True, split='val', pre_transform=transform)\ntest_dataset = ZINC(path, subset=True, split='test', pre_transform=transform)\n\ntrain_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=64)\ntest_loader = DataLoader(test_dataset, batch_size=64)\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    '--attn_type', default='multihead',\n    help=\"Global attention type such as 'multihead' or 'performer'.\")\nargs = parser.parse_args()\n\n\nclass GPS(torch.nn.Module):\n    def __init__(self, channels: int, pe_dim: int, num_layers: int,\n                 attn_type: str, attn_kwargs: Dict[str, Any]):\n        super().__init__()\n\n        self.node_emb = Embedding(28, channels - pe_dim)\n        self.pe_lin = Linear(20, pe_dim)\n        self.pe_norm = BatchNorm1d(20)\n        self.edge_emb = Embedding(4, channels)\n\n        self.convs = ModuleList()\n        for _ in range(num_layers):\n            nn = Sequential(\n                Linear(channels, channels),\n                ReLU(),\n                Linear(channels, channels),\n            )\n            conv = GPSConv(channels, GINEConv(nn), heads=4,\n                           attn_type=attn_type, attn_kwargs=attn_kwargs)\n            self.convs.append(conv)\n\n        self.mlp = Sequential(\n            Linear(channels, channels // 2),\n            ReLU(),\n            Linear(channels // 2, channels // 4),\n            ReLU(),\n            Linear(channels // 4, 1),\n        )\n        self.redraw_projection = RedrawProjection(\n            self.convs,\n            redraw_interval=1000 if attn_type == 'performer' else None)\n\n    def forward(self, x, pe, edge_index, edge_attr, batch):\n        x_pe = self.pe_norm(pe)\n        x = torch.cat((self.node_emb(x.squeeze(-1)), self.pe_lin(x_pe)), 1)\n        edge_attr = self.edge_emb(edge_attr)\n\n        for conv in self.convs:\n            x = conv(x, edge_index, batch, edge_attr=edge_attr)\n        x = global_add_pool(x, batch)\n        return self.mlp(x)\n\n\nclass RedrawProjection:\n    def __init__(self, model: torch.nn.Module,\n                 redraw_interval: Optional[int] = None):\n        self.model = model\n        self.redraw_interval = redraw_interval\n        self.num_last_redraw = 0\n\n    def redraw_projections(self):\n        if not self.model.training or self.redraw_interval is None:\n            return\n        if self.num_last_redraw >= self.redraw_interval:\n            fast_attentions = [\n                module for module in self.model.modules()\n                if isinstance(module, PerformerAttention)\n            ]\n            for fast_attention in fast_attentions:\n                fast_attention.redraw_projection_matrix()\n            self.num_last_redraw = 0\n            return\n        self.num_last_redraw += 1\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nattn_kwargs = {'dropout': 0.5}\nmodel = GPS(channels=64, pe_dim=8, num_layers=10, attn_type=args.attn_type,\n            attn_kwargs=attn_kwargs).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)\nscheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,\n                              min_lr=0.00001)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        model.redraw_projection.redraw_projections()\n        out = model(data.x, data.pe, data.edge_index, data.edge_attr,\n                    data.batch)\n        loss = (out.squeeze() - data.y).abs().mean()\n        loss.backward()\n        total_loss += loss.item() * data.num_graphs\n        optimizer.step()\n    return total_loss / len(train_loader.dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_error = 0\n    for data in loader:\n        data = data.to(device)\n        out = model(data.x, data.pe, data.edge_index, data.edge_attr,\n                    data.batch)\n        total_error += (out.squeeze() - data.y).abs().sum().item()\n    return total_error / len(loader.dataset)\n\n\nfor epoch in range(1, 101):\n    loss = train()\n    val_mae = test(val_loader)\n    test_mae = test(test_loader)\n    scheduler.step(val_mae)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '\n          f'Test: {test_mae:.4f}')\n"
  },
  {
    "path": "examples/graph_sage_unsup.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.linear_model import LogisticRegression\n\nimport torch_geometric\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.loader import LinkNeighborLoader\nfrom torch_geometric.nn import GraphSAGE\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\ndata = dataset[0]\n\ntrain_loader = LinkNeighborLoader(\n    data,\n    batch_size=256,\n    shuffle=True,\n    neg_sampling_ratio=1.0,\n    num_neighbors=[10, 10],\n)\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\ndata = data.to(device, 'x', 'edge_index')\n\nmodel = GraphSAGE(\n    data.num_node_features,\n    hidden_channels=64,\n    num_layers=2,\n).to(device)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n        h = model(batch.x, batch.edge_index)\n        h_src = h[batch.edge_label_index[0]]\n        h_dst = h[batch.edge_label_index[1]]\n        pred = (h_src * h_dst).sum(dim=-1)\n        loss = F.binary_cross_entropy_with_logits(pred, batch.edge_label)\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss) * pred.size(0)\n\n    return total_loss / data.num_nodes\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    out = model(data.x, data.edge_index).cpu()\n\n    clf = LogisticRegression()\n    clf.fit(out[data.train_mask], data.y[data.train_mask])\n\n    val_acc = clf.score(out[data.val_mask], data.y[data.val_mask])\n    test_acc = clf.score(out[data.test_mask], data.y[data.test_mask])\n\n    return val_acc, test_acc\n\n\ntimes = []\nfor epoch in range(1, 51):\n    start = time.time()\n    loss = train()\n    val_acc, test_acc = test()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/graph_sage_unsup_ppi.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nimport tqdm\nfrom sklearn.linear_model import SGDClassifier\nfrom sklearn.metrics import f1_score\nfrom sklearn.multioutput import MultiOutputClassifier\n\nimport torch_geometric\nfrom torch_geometric.data import Batch\nfrom torch_geometric.datasets import PPI\nfrom torch_geometric.loader import DataLoader, LinkNeighborLoader\nfrom torch_geometric.nn import GraphSAGE\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI')\ntrain_dataset = PPI(path, split='train')\nval_dataset = PPI(path, split='val')\ntest_dataset = PPI(path, split='test')\n\n# Group all training graphs into a single graph to perform sampling:\ntrain_data = Batch.from_data_list(train_dataset)\nloader = LinkNeighborLoader(train_data, batch_size=2048, shuffle=True,\n                            neg_sampling_ratio=1.0, num_neighbors=[10, 10],\n                            num_workers=6, persistent_workers=True)\n\n# Evaluation loaders (one datapoint corresponds to a graph)\ntrain_loader = DataLoader(train_dataset, batch_size=2)\nval_loader = DataLoader(val_dataset, batch_size=2)\ntest_loader = DataLoader(test_dataset, batch_size=2)\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\nmodel = GraphSAGE(\n    in_channels=train_dataset.num_features,\n    hidden_channels=64,\n    num_layers=2,\n    out_channels=64,\n).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n\n\ndef train():\n    model.train()\n\n    total_loss = total_examples = 0\n    for data in tqdm.tqdm(loader):\n        data = data.to(device)\n        optimizer.zero_grad()\n        h = model(data.x, data.edge_index)\n\n        h_src = h[data.edge_label_index[0]]\n        h_dst = h[data.edge_label_index[1]]\n        link_pred = (h_src * h_dst).sum(dim=-1)  # Inner product.\n\n        loss = F.binary_cross_entropy_with_logits(link_pred, data.edge_label)\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss) * link_pred.numel()\n        total_examples += link_pred.numel()\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef encode(loader):\n    model.eval()\n\n    xs, ys = [], []\n    for data in loader:\n        data = data.to(device)\n        xs.append(model(data.x, data.edge_index).cpu())\n        ys.append(data.y.cpu())\n    return torch.cat(xs, dim=0), torch.cat(ys, dim=0)\n\n\n@torch.no_grad()\ndef test():\n    # Train classifier on training set:\n    x, y = encode(train_loader)\n\n    clf = MultiOutputClassifier(SGDClassifier(loss='log_loss', penalty='l2'))\n    clf.fit(x, y)\n\n    train_f1 = f1_score(y, clf.predict(x), average='micro')\n\n    # Evaluate on validation set:\n    x, y = encode(val_loader)\n    val_f1 = f1_score(y, clf.predict(x), average='micro')\n\n    # Evaluate on test set:\n    x, y = encode(test_loader)\n    test_f1 = f1_score(y, clf.predict(x), average='micro')\n\n    return train_f1, val_f1, test_f1\n\n\ntimes = []\nfor epoch in range(1, 6):\n    start = time.time()\n    loss = train()\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')\n    train_f1, val_f1, test_f1 = test()\n    print(f'Train F1: {train_f1:.4f}, Val F1: {val_f1:.4f}, '\n          f'Test F1: {test_f1:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/graph_saint.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Flickr\nfrom torch_geometric.loader import GraphSAINTRandomWalkSampler\nfrom torch_geometric.nn import GraphConv\nfrom torch_geometric.typing import WITH_TORCH_SPARSE\nfrom torch_geometric.utils import degree\n\nif not WITH_TORCH_SPARSE:\n    quit(\"This example requires 'torch-sparse'\")\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Flickr')\ndataset = Flickr(path)\ndata = dataset[0]\nrow, col = data.edge_index\ndata.edge_weight = 1. / degree(col, data.num_nodes)[col]  # Norm by in-degree.\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--use_normalization', action='store_true')\nargs = parser.parse_args()\n\nloader = GraphSAINTRandomWalkSampler(data, batch_size=6000, walk_length=2,\n                                     num_steps=5, sample_coverage=100,\n                                     save_dir=dataset.processed_dir,\n                                     num_workers=4)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        in_channels = dataset.num_node_features\n        out_channels = dataset.num_classes\n        self.conv1 = GraphConv(in_channels, hidden_channels)\n        self.conv2 = GraphConv(hidden_channels, hidden_channels)\n        self.conv3 = GraphConv(hidden_channels, hidden_channels)\n        self.lin = torch.nn.Linear(3 * hidden_channels, out_channels)\n\n    def set_aggr(self, aggr):\n        self.conv1.aggr = aggr\n        self.conv2.aggr = aggr\n        self.conv3.aggr = aggr\n\n    def forward(self, x0, edge_index, edge_weight=None):\n        x1 = F.relu(self.conv1(x0, edge_index, edge_weight))\n        x1 = F.dropout(x1, p=0.2, training=self.training)\n        x2 = F.relu(self.conv2(x1, edge_index, edge_weight))\n        x2 = F.dropout(x2, p=0.2, training=self.training)\n        x3 = F.relu(self.conv3(x2, edge_index, edge_weight))\n        x3 = F.dropout(x3, p=0.2, training=self.training)\n        x = torch.cat([x1, x2, x3], dim=-1)\n        x = self.lin(x)\n        return x.log_softmax(dim=-1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(hidden_channels=256).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train():\n    model.train()\n    model.set_aggr('add' if args.use_normalization else 'mean')\n\n    total_loss = total_examples = 0\n    for data in loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n\n        if args.use_normalization:\n            edge_weight = data.edge_norm * data.edge_weight\n            out = model(data.x, data.edge_index, edge_weight)\n            loss = F.nll_loss(out, data.y, reduction='none')\n            loss = (loss * data.node_norm)[data.train_mask].sum()\n        else:\n            out = model(data.x, data.edge_index)\n            loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item() * data.num_nodes\n        total_examples += data.num_nodes\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    model.set_aggr('mean')\n\n    out = model(data.x.to(device), data.edge_index.to(device))\n    pred = out.argmax(dim=-1)\n    correct = pred.eq(data.y.to(device))\n\n    accs = []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        accs.append(correct[mask].sum().item() / mask.sum().item())\n    return accs\n\n\nfor epoch in range(1, 51):\n    loss = train()\n    accs = test()\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {accs[0]:.4f}, '\n          f'Val: {accs[1]:.4f}, Test: {accs[2]:.4f}')\n"
  },
  {
    "path": "examples/graph_unet.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GraphUNet\nfrom torch_geometric.utils import dropout_edge\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset)\ndata = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        pool_ratios = [2000 / data.num_nodes, 0.5]\n        self.unet = GraphUNet(dataset.num_features, 32, dataset.num_classes,\n                              depth=3, pool_ratios=pool_ratios)\n\n    def forward(self):\n        edge_index, _ = dropout_edge(data.edge_index, p=0.2,\n                                     force_undirected=True,\n                                     training=self.training)\n        x = F.dropout(data.x, p=0.92, training=self.training)\n\n        x = self.unet(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\nmodel, data = Net().to(device), data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.001)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()\n    optimizer.step()\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    out, accs = model(), []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        pred = out[mask].argmax(1)\n        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n        accs.append(acc)\n    return accs\n\n\nbest_val_acc = test_acc = 0\nfor epoch in range(1, 201):\n    train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '\n          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/hetero/README.md",
    "content": "# Examples for Heterogeneous Data\n\n| Example                                                | Description                                                                                                                            |\n| ------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------- |\n| [`hetero_conv_dblp.py`](./hetero_conv_dblp.py)         | Shows how to use the `HeteroConv(...)` wrapper; Trains it for node classification on the `DBLP` dataset.                               |\n| [`to_hetero_mag.py`](./to_hetero_mag.py)               | Shows how to use the `to_hetero(...)` functionality; Trains it for node classification on the `ogb-mag` dataset.                       |\n| [`hetero_link_pred.py`](./hetero_link_pred.py)         | Shows how to use the `to_hetero(...)` functionality; Trains it for link prediction on the `MovieLens` dataset.                         |\n| [`hgt_dblp.py`](./hgt_dblp.py)                         | Trains a Heterogeneous Graph Transformer (HGT) model for node classification on the `DBLP` dataset.                                    |\n| [`hierarchical_sage.py`](./hierarchical_sage.py)       | Shows how to perform hierarchical sampling; Trains a heterogeneous `GraphSAGE` model for node classification on the `ogb-mag` dataset. |\n| [`load_csv.py`](./load_csv.py)                         | Shows how to create heterogeneous graphs from raw `*.csv` data.                                                                        |\n| [`metapath2vec.py`](./metapath2vec.py)                 | Train an unsupervised `MetaPath2Vec` model; Tests embeddings for node classification on the `AMiner` dataset.                          |\n| [`temporal_link_pred.py`](./temporal_link_pred.py)     | Trains a heterogeneous `GraphSAGE` model for temporal link prediction on the `MovieLens` dataset.                                      |\n| [`bipartite_sage.py`](./bipartite_sage.py)             | Trains a GNN via metapaths for link prediction on the `MovieLens` dataset.                                                             |\n| [`bipartite_sage_unsup.py`](./bipartite_sage_unsup.py) | Trains a GNN via metapaths for link prediction on the large-scale `TaoBao` dataset.                                                    |\n| [`dmgi_unsup.py`](./dmgi_unsup.py)                     | Shows how to learn embeddings on the `IMDB` dataset using the `DMGI` model.                                                            |\n| [`han_imdb.py`](./han_imdb.py)                         | Shows how to train a heterogeneous Graph Attention Network (HAN) for node classification on the `IMDB` dataset.                        |\n| [`recommender_system.py`](./recommender_system.py)     | Shows how to train a temporal GNN-based recommender system on the `MovieLens` dataset.                                                 |\n"
  },
  {
    "path": "examples/hetero/bipartite_sage.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Embedding, Linear\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import MovieLens\nfrom torch_geometric.nn import SAGEConv\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens')\ndataset = MovieLens(path, model_name='all-MiniLM-L6-v2')\ndata = dataset[0]\ndata['user'].x = torch.arange(data['user'].num_nodes)\ndata['user', 'movie'].edge_label = data['user', 'movie'].edge_label.float()\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\ndata = data.to(device)\n\n# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing:\ndata = T.ToUndirected()(data)\ndel data['movie', 'rev_rates', 'user'].edge_label  # Remove \"reverse\" label.\n\n# Perform a link-level split into training, validation, and test edges:\ntrain_data, val_data, test_data = T.RandomLinkSplit(\n    num_val=0.1,\n    num_test=0.1,\n    neg_sampling_ratio=0.0,\n    edge_types=[('user', 'rates', 'movie')],\n    rev_edge_types=[('movie', 'rev_rates', 'user')],\n)(data)\n\n# Generate the co-occurrence matrix of movies<>movies:\nmetapath = [('movie', 'rev_rates', 'user'), ('user', 'rates', 'movie')]\ntrain_data = T.AddMetaPaths(metapaths=[metapath])(train_data)\n\n# Apply normalization to filter the metapath:\n_, edge_weight = gcn_norm(\n    train_data['movie', 'movie'].edge_index,\n    num_nodes=train_data['movie'].num_nodes,\n    add_self_loops=False,\n)\nedge_index = train_data['movie', 'movie'].edge_index[:, edge_weight > 0.002]\n\ntrain_data['movie', 'metapath_0', 'movie'].edge_index = edge_index\nval_data['movie', 'metapath_0', 'movie'].edge_index = edge_index\ntest_data['movie', 'metapath_0', 'movie'].edge_index = edge_index\n\n\nclass MovieGNNEncoder(torch.nn.Module):\n    def __init__(self, hidden_channels, out_channels):\n        super().__init__()\n\n        self.conv1 = SAGEConv(-1, hidden_channels)\n        self.conv2 = SAGEConv(hidden_channels, hidden_channels)\n        self.lin = Linear(hidden_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index).relu()\n        return self.lin(x)\n\n\nclass UserGNNEncoder(torch.nn.Module):\n    def __init__(self, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = SAGEConv((-1, -1), hidden_channels)\n        self.conv2 = SAGEConv((-1, -1), hidden_channels)\n        self.conv3 = SAGEConv((-1, -1), hidden_channels)\n        self.lin = Linear(hidden_channels, out_channels)\n\n    def forward(self, x_dict, edge_index_dict):\n        movie_x = self.conv1(\n            x_dict['movie'],\n            edge_index_dict[('movie', 'metapath_0', 'movie')],\n        ).relu()\n\n        user_x = self.conv2(\n            (x_dict['movie'], x_dict['user']),\n            edge_index_dict[('movie', 'rev_rates', 'user')],\n        ).relu()\n\n        user_x = self.conv3(\n            (movie_x, user_x),\n            edge_index_dict[('movie', 'rev_rates', 'user')],\n        ).relu()\n\n        return self.lin(user_x)\n\n\nclass EdgeDecoder(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.lin1 = Linear(2 * hidden_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, 1)\n\n    def forward(self, z_src, z_dst, edge_label_index):\n        row, col = edge_label_index\n        z = torch.cat([z_src[row], z_dst[col]], dim=-1)\n\n        z = self.lin1(z).relu()\n        z = self.lin2(z)\n        return z.view(-1)\n\n\nclass Model(torch.nn.Module):\n    def __init__(self, num_users, hidden_channels, out_channels):\n        super().__init__()\n        self.user_emb = Embedding(num_users, hidden_channels)\n        self.user_encoder = UserGNNEncoder(hidden_channels, out_channels)\n        self.movie_encoder = MovieGNNEncoder(hidden_channels, out_channels)\n        self.decoder = EdgeDecoder(out_channels)\n\n    def forward(self, x_dict, edge_index_dict, edge_label_index):\n        z_dict = {}\n        x_dict['user'] = self.user_emb(x_dict['user'])\n        z_dict['user'] = self.user_encoder(x_dict, edge_index_dict)\n        z_dict['movie'] = self.movie_encoder(\n            x_dict['movie'],\n            edge_index_dict[('movie', 'metapath_0', 'movie')],\n        )\n        return self.decoder(z_dict['user'], z_dict['movie'], edge_label_index)\n\n\nmodel = Model(data['user'].num_nodes, hidden_channels=64, out_channels=64)\nmodel = model.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0003)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(\n        train_data.x_dict,\n        train_data.edge_index_dict,\n        train_data['user', 'movie'].edge_label_index,\n    )\n    loss = F.mse_loss(out, train_data['user', 'movie'].edge_label)\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test(data):\n    model.eval()\n    out = model(\n        data.x_dict,\n        data.edge_index_dict,\n        data['user', 'movie'].edge_label_index,\n    ).clamp(min=0, max=5)\n    rmse = F.mse_loss(out, data['user', 'movie'].edge_label).sqrt()\n    return float(rmse)\n\n\nfor epoch in range(1, 701):\n    loss = train()\n    train_rmse = test(train_data)\n    val_rmse = test(val_data)\n    test_rmse = test(test_data)\n    print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '\n          f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')\n"
  },
  {
    "path": "examples/hetero/bipartite_sage_unsup.py",
    "content": "# An implementation of unsupervised bipartite GraphSAGE using the Alibaba\n# Taobao dataset.\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nimport tqdm\nfrom sklearn.metrics import roc_auc_score\nfrom torch.nn import Embedding, Linear\n\nimport torch_geometric\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Taobao\nfrom torch_geometric.loader import LinkNeighborLoader\nfrom torch_geometric.nn import SAGEConv\nfrom torch_geometric.utils.convert import to_scipy_sparse_matrix\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/Taobao')\n\ndataset = Taobao(path)\ndata = dataset[0]\n\ndata['user'].x = torch.arange(0, data['user'].num_nodes)\ndata['item'].x = torch.arange(0, data['item'].num_nodes)\n\n# Only consider user<>item relationships for simplicity:\ndel data['category']\ndel data['item', 'category']\ndel data['user', 'item'].time\ndel data['user', 'item'].behavior\n\n# Add a reverse ('item', 'rev_to', 'user') relation for message passing:\ndata = T.ToUndirected()(data)\n\n# Perform a link-level split into training, validation, and test edges:\nprint('Computing data splits...')\ntrain_data, val_data, test_data = T.RandomLinkSplit(\n    num_val=0.1,\n    num_test=0.1,\n    neg_sampling_ratio=1.0,\n    add_negative_train_samples=False,\n    edge_types=[('user', 'to', 'item')],\n    rev_edge_types=[('item', 'rev_to', 'user')],\n)(data)\nprint('Done!')\n\n# Compute sparsified item<>item relationships through users:\nprint('Computing item<>item relationships...')\nmat = to_scipy_sparse_matrix(data['user', 'item'].edge_index).tocsr()\nmat = mat[:data['user'].num_nodes, :data['item'].num_nodes]\ncomat = mat.T @ mat\ncomat.setdiag(0)\ncomat = comat >= 3.\ncomat = comat.tocoo()\nrow = torch.from_numpy(comat.row).to(torch.long)\ncol = torch.from_numpy(comat.col).to(torch.long)\nitem_to_item_edge_index = torch.stack([row, col], dim=0)\n\n# Add the generated item<>item relationships for high-order information:\ntrain_data['item', 'item'].edge_index = item_to_item_edge_index\nval_data['item', 'item'].edge_index = item_to_item_edge_index\ntest_data['item', 'item'].edge_index = item_to_item_edge_index\nprint('Done!')\n\ntrain_loader = LinkNeighborLoader(\n    data=train_data,\n    num_neighbors=[8, 4],\n    edge_label_index=('user', 'to', 'item'),\n    neg_sampling='binary',\n    batch_size=2048,\n    shuffle=True,\n    num_workers=16,\n    drop_last=True,\n)\n\nval_loader = LinkNeighborLoader(\n    data=val_data,\n    num_neighbors=[8, 4],\n    edge_label_index=(\n        ('user', 'to', 'item'),\n        val_data[('user', 'to', 'item')].edge_label_index,\n    ),\n    edge_label=val_data[('user', 'to', 'item')].edge_label,\n    batch_size=2048,\n    shuffle=False,\n    num_workers=16,\n)\n\ntest_loader = LinkNeighborLoader(\n    data=test_data,\n    num_neighbors=[8, 4],\n    edge_label_index=(\n        ('user', 'to', 'item'),\n        test_data[('user', 'to', 'item')].edge_label_index,\n    ),\n    edge_label=test_data[('user', 'to', 'item')].edge_label,\n    batch_size=2048,\n    shuffle=False,\n    num_workers=16,\n)\n\n\nclass ItemGNNEncoder(torch.nn.Module):\n    def __init__(self, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = SAGEConv(-1, hidden_channels)\n        self.conv2 = SAGEConv(hidden_channels, hidden_channels)\n        self.lin = Linear(hidden_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index).relu()\n        return self.lin(x)\n\n\nclass UserGNNEncoder(torch.nn.Module):\n    def __init__(self, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = SAGEConv((-1, -1), hidden_channels)\n        self.conv2 = SAGEConv((-1, -1), hidden_channels)\n        self.conv3 = SAGEConv((-1, -1), hidden_channels)\n        self.lin = Linear(hidden_channels, out_channels)\n\n    def forward(self, x_dict, edge_index_dict):\n        item_x = self.conv1(\n            x_dict['item'],\n            edge_index_dict[('item', 'to', 'item')],\n        ).relu()\n\n        user_x = self.conv2(\n            (x_dict['item'], x_dict['user']),\n            edge_index_dict[('item', 'rev_to', 'user')],\n        ).relu()\n\n        user_x = self.conv3(\n            (item_x, user_x),\n            edge_index_dict[('item', 'rev_to', 'user')],\n        ).relu()\n\n        return self.lin(user_x)\n\n\nclass EdgeDecoder(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.lin1 = Linear(2 * hidden_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, 1)\n\n    def forward(self, z_src, z_dst, edge_label_index):\n        row, col = edge_label_index\n        z = torch.cat([z_src[row], z_dst[col]], dim=-1)\n\n        z = self.lin1(z).relu()\n        z = self.lin2(z)\n        return z.view(-1)\n\n\nclass Model(torch.nn.Module):\n    def __init__(self, num_users, num_items, hidden_channels, out_channels):\n        super().__init__()\n        self.user_emb = Embedding(num_users, hidden_channels, device=device)\n        self.item_emb = Embedding(num_items, hidden_channels, device=device)\n        self.item_encoder = ItemGNNEncoder(hidden_channels, out_channels)\n        self.user_encoder = UserGNNEncoder(hidden_channels, out_channels)\n        self.decoder = EdgeDecoder(out_channels)\n\n    def forward(self, x_dict, edge_index_dict, edge_label_index):\n        z_dict = {}\n        x_dict['user'] = self.user_emb(x_dict['user'])\n        x_dict['item'] = self.item_emb(x_dict['item'])\n        z_dict['item'] = self.item_encoder(\n            x_dict['item'],\n            edge_index_dict[('item', 'to', 'item')],\n        )\n        z_dict['user'] = self.user_encoder(x_dict, edge_index_dict)\n\n        return self.decoder(z_dict['user'], z_dict['item'], edge_label_index)\n\n\nmodel = Model(\n    num_users=data['user'].num_nodes,\n    num_items=data['item'].num_nodes,\n    hidden_channels=64,\n    out_channels=64,\n).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train():\n    model.train()\n\n    total_loss = total_examples = 0\n    for batch in tqdm.tqdm(train_loader):\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        pred = model(\n            batch.x_dict,\n            batch.edge_index_dict,\n            batch['user', 'item'].edge_label_index,\n        )\n        loss = F.binary_cross_entropy_with_logits(\n            pred, batch['user', 'item'].edge_label)\n\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss)\n        total_examples += pred.numel()\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    preds, targets = [], []\n    for batch in tqdm.tqdm(loader):\n        batch = batch.to(device)\n\n        pred = model(\n            batch.x_dict,\n            batch.edge_index_dict,\n            batch['user', 'item'].edge_label_index,\n        ).sigmoid().view(-1).cpu()\n        target = batch['user', 'item'].edge_label.long().cpu()\n\n        preds.append(pred)\n        targets.append(target)\n\n    pred = torch.cat(preds, dim=0).numpy()\n    target = torch.cat(targets, dim=0).numpy()\n\n    return roc_auc_score(target, pred)\n\n\nfor epoch in range(1, 21):\n    loss = train()\n    val_auc = test(val_loader)\n    test_auc = test(test_loader)\n\n    print(f'Epoch: {epoch:02d}, Loss: {loss:4f}, Val: {val_auc:.4f}, '\n          f'Test: {test_auc:.4f}')\n"
  },
  {
    "path": "examples/hetero/dmgi_unsup.py",
    "content": "# An implementation of \"Unsupervised Attributed Multiplex Network\n# Embedding\" (DMGI) for unsupervised learning on  heterogeneous graphs:\n# * Paper: <https://arxiv.org/abs/1911.06750> (AAAI 2020)\n\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.linear_model import LogisticRegression\nfrom torch.optim import Adam\n\nimport torch_geometric\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import IMDB\nfrom torch_geometric.nn import GCNConv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/IMDB')\ndataset = IMDB(path)\n\nmetapaths = [\n    [('movie', 'actor'), ('actor', 'movie')],  # MAM\n    [('movie', 'director'), ('director', 'movie')],  # MDM\n]\ndata = T.AddMetaPaths(metapaths, drop_orig_edge_types=True)(dataset[0])\n\n\nclass DMGI(torch.nn.Module):\n    def __init__(self, num_nodes, in_channels, out_channels, num_relations):\n        super().__init__()\n        self.convs = torch.nn.ModuleList(\n            [GCNConv(in_channels, out_channels) for _ in range(num_relations)])\n        self.M = torch.nn.Bilinear(out_channels, out_channels, 1)\n        self.Z = torch.nn.Parameter(torch.empty(num_nodes, out_channels))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for conv in self.convs:\n            conv.reset_parameters()\n        torch.nn.init.xavier_uniform_(self.M.weight)\n        self.M.bias.data.zero_()\n        torch.nn.init.xavier_uniform_(self.Z)\n\n    def forward(self, x, edge_indices):\n        pos_hs, neg_hs, summaries = [], [], []\n        for conv, edge_index in zip(self.convs, edge_indices):\n            pos_h = F.dropout(x, p=0.5, training=self.training)\n            pos_h = conv(pos_h, edge_index).relu()\n            pos_hs.append(pos_h)\n\n            neg_h = F.dropout(x, p=0.5, training=self.training)\n            neg_h = neg_h[torch.randperm(neg_h.size(0), device=neg_h.device)]\n            neg_h = conv(neg_h, edge_index).relu()\n            neg_hs.append(neg_h)\n\n            summaries.append(pos_h.mean(dim=0, keepdim=True))\n\n        return pos_hs, neg_hs, summaries\n\n    def loss(self, pos_hs, neg_hs, summaries):\n        loss = 0.\n        for pos_h, neg_h, s in zip(pos_hs, neg_hs, summaries):\n            s = s.expand_as(pos_h)\n            loss += -torch.log(self.M(pos_h, s).sigmoid() + 1e-15).mean()\n            loss += -torch.log(1 - self.M(neg_h, s).sigmoid() + 1e-15).mean()\n\n        pos_mean = torch.stack(pos_hs, dim=0).mean(dim=0)\n        neg_mean = torch.stack(neg_hs, dim=0).mean(dim=0)\n\n        pos_reg_loss = (self.Z - pos_mean).pow(2).sum()\n        neg_reg_loss = (self.Z - neg_mean).pow(2).sum()\n        loss += 0.001 * (pos_reg_loss - neg_reg_loss)\n\n        return loss\n\n\nmodel = DMGI(data['movie'].num_nodes, data['movie'].x.size(-1),\n             out_channels=64, num_relations=len(data.edge_types))\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\ndata, model = data.to(device), model.to(device)\n\noptimizer = Adam(model.parameters(), lr=0.0005, weight_decay=0.0001)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    x = data['movie'].x\n    edge_indices = data.edge_index_dict.values()\n    pos_hs, neg_hs, summaries = model(x, edge_indices)\n    loss = model.loss(pos_hs, neg_hs, summaries)\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    train_emb = model.Z[data['movie'].train_mask].cpu()\n    val_emb = model.Z[data['movie'].val_mask].cpu()\n    test_emb = model.Z[data['movie'].test_mask].cpu()\n\n    train_y = data['movie'].y[data['movie'].train_mask].cpu()\n    val_y = data['movie'].y[data['movie'].val_mask].cpu()\n    test_y = data['movie'].y[data['movie'].test_mask].cpu()\n\n    clf = LogisticRegression().fit(train_emb, train_y)\n    return clf.score(val_emb, val_y), clf.score(test_emb, test_y)\n\n\nfor epoch in range(1, 1001):\n    loss = train()\n    if epoch % 50 == 0:\n        val_acc, test_acc = test()\n        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '\n              f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/hetero/han_imdb.py",
    "content": "import os.path as osp\nfrom typing import Dict, List, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\n\nimport torch_geometric\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import IMDB\nfrom torch_geometric.nn import HANConv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/IMDB')\nmetapaths = [[('movie', 'actor'), ('actor', 'movie')],\n             [('movie', 'director'), ('director', 'movie')]]\ntransform = T.AddMetaPaths(metapaths=metapaths, drop_orig_edge_types=True,\n                           drop_unconnected_node_types=True)\ndataset = IMDB(path, transform=transform)\ndata = dataset[0]\n\n\nclass HAN(nn.Module):\n    def __init__(self, in_channels: Union[int, Dict[str, int]],\n                 out_channels: int, hidden_channels=128, heads=8):\n        super().__init__()\n        self.han_conv = HANConv(in_channels, hidden_channels, heads=heads,\n                                dropout=0.6, metadata=data.metadata())\n        self.lin = nn.Linear(hidden_channels, out_channels)\n\n    def forward(self, x_dict, edge_index_dict):\n        out = self.han_conv(x_dict, edge_index_dict)\n        out = self.lin(out['movie'])\n        return out\n\n\nmodel = HAN(in_channels=-1, out_channels=3)\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\ndata, model = data.to(device), model.to(device)\n\nwith torch.no_grad():  # Initialize lazy modules.\n    out = model(data.x_dict, data.edge_index_dict)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)\n\n\ndef train() -> float:\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x_dict, data.edge_index_dict)\n    mask = data['movie'].train_mask\n    loss = F.cross_entropy(out[mask], data['movie'].y[mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test() -> List[float]:\n    model.eval()\n    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)\n\n    accs = []\n    for split in ['train_mask', 'val_mask', 'test_mask']:\n        mask = data['movie'][split]\n        acc = (pred[mask] == data['movie'].y[mask]).sum() / mask.sum()\n        accs.append(float(acc))\n    return accs\n\n\nbest_val_acc = 0\nstart_patience = patience = 100\nfor epoch in range(1, 200):\n\n    loss = train()\n    train_acc, val_acc, test_acc = test()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n\n    if best_val_acc <= val_acc:\n        patience = start_patience\n        best_val_acc = val_acc\n    else:\n        patience -= 1\n\n    if patience <= 0:\n        print('Stopping training as validation accuracy did not improve '\n              f'for {start_patience} epochs')\n        break\n"
  },
  {
    "path": "examples/hetero/hetero_conv_dblp.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import DBLP\nfrom torch_geometric.nn import HeteroConv, Linear, SAGEConv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP')\n# We initialize conference node features with a single one-vector as feature:\ndataset = DBLP(path, transform=T.Constant(node_types='conference'))\ndata = dataset[0]\n\n\nclass HeteroGNN(torch.nn.Module):\n    def __init__(self, metadata, hidden_channels, out_channels, num_layers):\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            conv = HeteroConv({\n                edge_type: SAGEConv((-1, -1), hidden_channels)\n                for edge_type in metadata[1]\n            })\n            self.convs.append(conv)\n\n        self.lin = Linear(hidden_channels, out_channels)\n\n    def forward(self, x_dict, edge_index_dict):\n        for conv in self.convs:\n            x_dict = conv(x_dict, edge_index_dict)\n            x_dict = {key: F.leaky_relu(x) for key, x in x_dict.items()}\n        return self.lin(x_dict['author'])\n\n\nmodel = HeteroGNN(data.metadata(), hidden_channels=64, out_channels=4,\n                  num_layers=2)\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\ndata, model = data.to(device), model.to(device)\n\nwith torch.no_grad():  # Initialize lazy modules.\n    out = model(data.x_dict, data.edge_index_dict)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x_dict, data.edge_index_dict)\n    mask = data['author'].train_mask\n    loss = F.cross_entropy(out[mask], data['author'].y[mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)\n\n    accs = []\n    for split in ['train_mask', 'val_mask', 'test_mask']:\n        mask = data['author'][split]\n        acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()\n        accs.append(float(acc))\n    return accs\n\n\nfor epoch in range(1, 101):\n    loss = train()\n    train_acc, val_acc, test_acc = test()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/hetero/hetero_link_pred.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nimport torch_geometric\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import MovieLens\nfrom torch_geometric.nn import SAGEConv, to_hetero\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--use_weighted_loss', action='store_true',\n                    help='Whether to use weighted MSE loss.')\nargs = parser.parse_args()\n\ndevice = torch_geometric.device('auto')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens')\ndataset = MovieLens(path, model_name='all-MiniLM-L6-v2')\ndata = dataset[0].to(device)\n\n# Add user node features for message passing:\ndata['user'].x = torch.eye(data['user'].num_nodes, device=device)\ndel data['user'].num_nodes\n\n# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing:\ndata = T.ToUndirected()(data)\ndel data['movie', 'rev_rates', 'user'].edge_label  # Remove \"reverse\" label.\n\n# Perform a link-level split into training, validation, and test edges:\ntrain_data, val_data, test_data = T.RandomLinkSplit(\n    num_val=0.1,\n    num_test=0.1,\n    neg_sampling_ratio=0.0,\n    edge_types=[('user', 'rates', 'movie')],\n    rev_edge_types=[('movie', 'rev_rates', 'user')],\n)(data)\n\n# We have an unbalanced dataset with many labels for rating 3 and 4, and very\n# few for 0 and 1. Therefore we use a weighted MSE loss.\nif args.use_weighted_loss:\n    weight = torch.bincount(train_data['user', 'movie'].edge_label)\n    weight = weight.max() / weight\nelse:\n    weight = None\n\n\ndef weighted_mse_loss(pred, target, weight=None):\n    weight = 1. if weight is None else weight[target].to(pred.dtype)\n    return (weight * (pred - target.to(pred.dtype)).pow(2)).mean()\n\n\nclass GNNEncoder(torch.nn.Module):\n    def __init__(self, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = SAGEConv((-1, -1), hidden_channels)\n        self.conv2 = SAGEConv((-1, -1), out_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n        return x\n\n\nclass EdgeDecoder(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.lin1 = Linear(2 * hidden_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, 1)\n\n    def forward(self, z_dict, edge_label_index):\n        row, col = edge_label_index\n        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)\n\n        z = self.lin1(z).relu()\n        z = self.lin2(z)\n        return z.view(-1)\n\n\nclass Model(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.encoder = GNNEncoder(hidden_channels, hidden_channels)\n        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')\n        self.decoder = EdgeDecoder(hidden_channels)\n\n    def forward(self, x_dict, edge_index_dict, edge_label_index):\n        z_dict = self.encoder(x_dict, edge_index_dict)\n        return self.decoder(z_dict, edge_label_index)\n\n\nmodel = Model(hidden_channels=32).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    pred = model(train_data.x_dict, train_data.edge_index_dict,\n                 train_data['user', 'movie'].edge_label_index)\n    target = train_data['user', 'movie'].edge_label\n    loss = weighted_mse_loss(pred, target, weight)\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test(data):\n    model.eval()\n    pred = model(data.x_dict, data.edge_index_dict,\n                 data['user', 'movie'].edge_label_index)\n    pred = pred.clamp(min=0, max=5)\n    target = data['user', 'movie'].edge_label.float()\n    rmse = F.mse_loss(pred, target).sqrt()\n    return float(rmse)\n\n\nfor epoch in range(1, 301):\n    loss = train()\n    train_rmse = test(train_data)\n    val_rmse = test(val_data)\n    test_rmse = test(test_data)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_rmse:.4f}, '\n          f'Val: {val_rmse:.4f}, Test: {test_rmse:.4f}')\n"
  },
  {
    "path": "examples/hetero/hgt_dblp.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import DBLP\nfrom torch_geometric.nn import HGTConv, Linear\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/DBLP')\n# We initialize conference node features with a single one-vector as feature:\ndataset = DBLP(path, transform=T.Constant(node_types='conference'))\ndata = dataset[0]\n\n\nclass HGT(torch.nn.Module):\n    def __init__(self, hidden_channels, out_channels, num_heads, num_layers):\n        super().__init__()\n\n        self.lin_dict = torch.nn.ModuleDict()\n        for node_type in data.node_types:\n            self.lin_dict[node_type] = Linear(-1, hidden_channels)\n\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            conv = HGTConv(hidden_channels, hidden_channels, data.metadata(),\n                           num_heads)\n            self.convs.append(conv)\n\n        self.lin = Linear(hidden_channels, out_channels)\n\n    def forward(self, x_dict, edge_index_dict):\n        x_dict = {\n            node_type: self.lin_dict[node_type](x).relu_()\n            for node_type, x in x_dict.items()\n        }\n\n        for conv in self.convs:\n            x_dict = conv(x_dict, edge_index_dict)\n\n        return self.lin(x_dict['author'])\n\n\nmodel = HGT(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1)\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\ndata, model = data.to(device), model.to(device)\n\nwith torch.no_grad():  # Initialize lazy modules.\n    out = model(data.x_dict, data.edge_index_dict)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0.001)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x_dict, data.edge_index_dict)\n    mask = data['author'].train_mask\n    loss = F.cross_entropy(out[mask], data['author'].y[mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.x_dict, data.edge_index_dict).argmax(dim=-1)\n\n    accs = []\n    for split in ['train_mask', 'val_mask', 'test_mask']:\n        mask = data['author'][split]\n        acc = (pred[mask] == data['author'].y[mask]).sum() / mask.sum()\n        accs.append(float(acc))\n    return accs\n\n\nfor epoch in range(1, 101):\n    loss = train()\n    train_acc, val_acc, test_acc = test()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/hetero/hierarchical_sage.py",
    "content": "import argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nimport torch_geometric\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import OGB_MAG\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import HeteroConv, Linear, SAGEConv\nfrom torch_geometric.utils import trim_to_layer\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--use-sparse-tensor', action='store_true')\nargs = parser.parse_args()\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\n\ntransforms = [T.ToUndirected(merge=True)]\nif args.use_sparse_tensor:\n    transforms.append(T.ToSparseTensor())\ndataset = OGB_MAG(root='../../data', preprocess='metapath2vec',\n                  transform=T.Compose(transforms))\ndata = dataset[0].to(device, 'x', 'y')\n\n\nclass HierarchicalHeteroGraphSage(torch.nn.Module):\n    def __init__(self, edge_types, hidden_channels, out_channels, num_layers):\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            conv = HeteroConv(\n                {\n                    edge_type: SAGEConv((-1, -1), hidden_channels)\n                    for edge_type in edge_types\n                }, aggr='sum')\n            self.convs.append(conv)\n\n        self.lin = Linear(hidden_channels, out_channels)\n\n    def forward(self, x_dict, edge_index_dict, num_sampled_edges_dict,\n                num_sampled_nodes_dict):\n\n        for i, conv in enumerate(self.convs):\n            x_dict, edge_index_dict, _ = trim_to_layer(\n                layer=i,\n                num_sampled_nodes_per_hop=num_sampled_nodes_dict,\n                num_sampled_edges_per_hop=num_sampled_edges_dict,\n                x=x_dict,\n                edge_index=edge_index_dict,\n            )\n\n            x_dict = conv(x_dict, edge_index_dict)\n            x_dict = {key: x.relu() for key, x in x_dict.items()}\n\n        return self.lin(x_dict['paper'])\n\n\nmodel = HierarchicalHeteroGraphSage(\n    edge_types=data.edge_types,\n    hidden_channels=64,\n    out_channels=dataset.num_classes,\n    num_layers=2,\n).to(device)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\nkwargs = {'batch_size': 1024, 'num_workers': 0}\ntrain_loader = NeighborLoader(\n    data,\n    num_neighbors=[10] * 2,\n    shuffle=True,\n    input_nodes=('paper', data['paper'].train_mask),\n    **kwargs,\n)\n\nval_loader = NeighborLoader(\n    data,\n    num_neighbors=[10] * 2,\n    shuffle=False,\n    input_nodes=('paper', data['paper'].val_mask),\n    **kwargs,\n)\n\n\ndef train():\n    model.train()\n\n    total_examples = total_loss = 0\n    for batch in tqdm(train_loader):\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        out = model(\n            batch.x_dict,\n            batch.adj_t_dict\n            if args.use_sparse_tensor else batch.edge_index_dict,\n            num_sampled_nodes_dict=batch.num_sampled_nodes_dict,\n            num_sampled_edges_dict=batch.num_sampled_edges_dict,\n        )\n\n        batch_size = batch['paper'].batch_size\n        loss = F.cross_entropy(out[:batch_size], batch['paper'].y[:batch_size])\n        loss.backward()\n        optimizer.step()\n\n        total_examples += batch_size\n        total_loss += float(loss) * batch_size\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_examples = total_correct = 0\n    for batch in tqdm(loader):\n        batch = batch.to(device)\n        out = model(\n            batch.x_dict,\n            batch.adj_t_dict\n            if args.use_sparse_tensor else batch.edge_index_dict,\n            num_sampled_nodes_dict=batch.num_sampled_nodes_dict,\n            num_sampled_edges_dict=batch.num_sampled_edges_dict,\n        )\n\n        batch_size = batch['paper'].batch_size\n        pred = out[:batch_size].argmax(dim=-1)\n        total_examples += batch_size\n        total_correct += int((pred == batch['paper'].y[:batch_size]).sum())\n\n    return total_correct / total_examples\n\n\nfor epoch in range(1, 6):\n    loss = train()\n    val_acc = test(val_loader)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')\n"
  },
  {
    "path": "examples/hetero/load_csv.py",
    "content": "import os.path as osp\n\nimport pandas as pd\nimport torch\nfrom sentence_transformers import SentenceTransformer\n\nfrom torch_geometric.data import HeteroData, download_url, extract_zip\nfrom torch_geometric.transforms import RandomLinkSplit, ToUndirected\n\nurl = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'\nroot = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens')\nextract_zip(download_url(url, root), root)\nmovie_path = osp.join(root, 'ml-latest-small', 'movies.csv')\nrating_path = osp.join(root, 'ml-latest-small', 'ratings.csv')\n\n\ndef load_node_csv(path, index_col, encoders=None, **kwargs):\n    df = pd.read_csv(path, index_col=index_col, **kwargs)\n    mapping = {index: i for i, index in enumerate(df.index.unique())}\n\n    x = None\n    if encoders is not None:\n        xs = [encoder(df[col]) for col, encoder in encoders.items()]\n        x = torch.cat(xs, dim=-1)\n\n    return x, mapping\n\n\ndef load_edge_csv(path, src_index_col, src_mapping, dst_index_col, dst_mapping,\n                  encoders=None, **kwargs):\n    df = pd.read_csv(path, **kwargs)\n\n    src = [src_mapping[index] for index in df[src_index_col]]\n    dst = [dst_mapping[index] for index in df[dst_index_col]]\n    edge_index = torch.tensor([src, dst])\n\n    edge_attr = None\n    if encoders is not None:\n        edge_attrs = [encoder(df[col]) for col, encoder in encoders.items()]\n        edge_attr = torch.cat(edge_attrs, dim=-1)\n\n    return edge_index, edge_attr\n\n\nclass SequenceEncoder:\n    # The 'SequenceEncoder' encodes raw column strings into embeddings.\n    def __init__(self, model_name='all-MiniLM-L6-v2', device=None):\n        self.device = device\n        self.model = SentenceTransformer(model_name, device=device)\n\n    @torch.no_grad()\n    def __call__(self, df):\n        x = self.model.encode(df.values, show_progress_bar=True,\n                              convert_to_tensor=True, device=self.device)\n        return x.cpu()\n\n\nclass GenresEncoder:\n    # The 'GenreEncoder' splits the raw column strings by 'sep' and converts\n    # individual elements to categorical labels.\n    def __init__(self, sep='|'):\n        self.sep = sep\n\n    def __call__(self, df):\n        genres = {g for col in df.values for g in col.split(self.sep)}\n        mapping = {genre: i for i, genre in enumerate(genres)}\n\n        x = torch.zeros(len(df), len(mapping))\n        for i, col in enumerate(df.values):\n            for genre in col.split(self.sep):\n                x[i, mapping[genre]] = 1\n        return x\n\n\nclass IdentityEncoder:\n    # The 'IdentityEncoder' takes the raw column values and converts them to\n    # PyTorch tensors.\n    def __init__(self, dtype=None):\n        self.dtype = dtype\n\n    def __call__(self, df):\n        return torch.from_numpy(df.values).view(-1, 1).to(self.dtype)\n\n\nuser_x, user_mapping = load_node_csv(rating_path, index_col='userId')\n\nmovie_x, movie_mapping = load_node_csv(\n    movie_path, index_col='movieId', encoders={\n        'title': SequenceEncoder(),\n        'genres': GenresEncoder()\n    })\n\nedge_index, edge_label = load_edge_csv(\n    rating_path,\n    src_index_col='userId',\n    src_mapping=user_mapping,\n    dst_index_col='movieId',\n    dst_mapping=movie_mapping,\n    encoders={'rating': IdentityEncoder(dtype=torch.long)},\n)\n\ndata = HeteroData()\ndata['user'].num_nodes = len(user_mapping)  # Users do not have any features.\ndata['movie'].x = movie_x\ndata['user', 'rates', 'movie'].edge_index = edge_index\ndata['user', 'rates', 'movie'].edge_label = edge_label\nprint(data)\n\n# We can now convert `data` into an appropriate format for training a\n# graph-based machine learning model:\n\n# 1. Add a reverse ('movie', 'rev_rates', 'user') relation for message passing.\ndata = ToUndirected()(data)\ndel data['movie', 'rev_rates', 'user'].edge_label  # Remove \"reverse\" label.\n\n# 2. Perform a link-level split into training, validation, and test edges.\ntransform = RandomLinkSplit(\n    num_val=0.05,\n    num_test=0.1,\n    neg_sampling_ratio=0.0,\n    edge_types=[('user', 'rates', 'movie')],\n    rev_edge_types=[('movie', 'rev_rates', 'user')],\n)\ntrain_data, val_data, test_data = transform(data)\nprint(train_data)\nprint(val_data)\nprint(test_data)\n"
  },
  {
    "path": "examples/hetero/metapath2vec.py",
    "content": "# Reaches around 91.8% Micro-F1 after 5 epochs.\n\nimport os.path as osp\n\nimport torch\n\nimport torch_geometric\nfrom torch_geometric.datasets import AMiner\nfrom torch_geometric.nn import MetaPath2Vec\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/AMiner')\ndataset = AMiner(path)\ndata = dataset[0]\n\nmetapath = [\n    ('author', 'writes', 'paper'),\n    ('paper', 'published_in', 'venue'),\n    ('venue', 'publishes', 'paper'),\n    ('paper', 'written_by', 'author'),\n]\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\nmodel = MetaPath2Vec(data.edge_index_dict, embedding_dim=128,\n                     metapath=metapath, walk_length=50, context_size=7,\n                     walks_per_node=5, num_negative_samples=5,\n                     sparse=True).to(device)\n\nloader = model.loader(batch_size=128, shuffle=True, num_workers=6)\noptimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)\n\n\ndef train(epoch, log_steps=100, eval_steps=2000):\n    model.train()\n\n    total_loss = 0\n    for i, (pos_rw, neg_rw) in enumerate(loader):\n        optimizer.zero_grad()\n        loss = model.loss(pos_rw.to(device), neg_rw.to(device))\n        loss.backward()\n        optimizer.step()\n\n        total_loss += loss.item()\n        if (i + 1) % log_steps == 0:\n            print(f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '\n                  f'Loss: {total_loss / log_steps:.4f}')\n            total_loss = 0\n\n        if (i + 1) % eval_steps == 0:\n            acc = test()\n            print(f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '\n                  f'Acc: {acc:.4f}')\n\n\n@torch.no_grad()\ndef test(train_ratio=0.1):\n    model.eval()\n\n    z = model('author', batch=data['author'].y_index.to(device))\n    y = data['author'].y\n\n    perm = torch.randperm(z.size(0))\n    train_perm = perm[:int(z.size(0) * train_ratio)]\n    test_perm = perm[int(z.size(0) * train_ratio):]\n\n    return model.test(z[train_perm], y[train_perm], z[test_perm], y[test_perm],\n                      max_iter=150)\n\n\nfor epoch in range(1, 6):\n    train(epoch)\n    acc = test()\n    print(f'Epoch: {epoch}, Accuracy: {acc:.4f}')\n"
  },
  {
    "path": "examples/hetero/recommender_system.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nimport torch_geometric.transforms as T\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.datasets import MovieLens\nfrom torch_geometric.loader import LinkNeighborLoader, NeighborLoader\nfrom torch_geometric.metrics import (\n    LinkPredMAP,\n    LinkPredPrecision,\n    LinkPredRecall,\n)\nfrom torch_geometric.nn import MIPSKNNIndex, SAGEConv, to_hetero\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--k', type=int, default=20, help='Number of predictions')\nargs = parser.parse_args()\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens')\ndata = MovieLens(path, model_name='all-MiniLM-L6-v2')[0]\n\n# Add user node features for message passing:\ndata['user'].x = torch.eye(data['user'].num_nodes)\ndel data['user'].num_nodes\n\n# Only use edges with high ratings (>= 4):\nmask = data['user', 'rates', 'movie'].edge_label >= 4\ndata['user', 'movie'].edge_index = data['user', 'movie'].edge_index[:, mask]\ndata['user', 'movie'].time = data['user', 'movie'].time[mask]\ndel data['user', 'movie'].edge_label  # Drop rating information from graph.\n\n# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing:\ndata = T.ToUndirected()(data)\n\n# Perform a temporal link-level split into training and test edges:\nedge_label_index = data['user', 'movie'].edge_index\ntime = data['user', 'movie'].time\n\nperm = time.argsort()\ntrain_index = perm[:int(0.8 * perm.numel())]\ntest_index = perm[int(0.8 * perm.numel()):]\n\nkwargs = dict(  # Shared data loader arguments:\n    data=data,\n    num_neighbors=[5, 5, 5],\n    batch_size=256,\n    time_attr='time',\n    num_workers=4,\n    persistent_workers=True,\n    temporal_strategy='last',\n)\n\ntrain_loader = LinkNeighborLoader(\n    edge_label_index=(('user', 'movie'), edge_label_index[:, train_index]),\n    edge_label_time=time[train_index] - 1,  # No leakage.\n    neg_sampling=dict(mode='binary', amount=2),\n    shuffle=True,\n    **kwargs,\n)\n\n# During testing, we sample node-level subgraphs from both endpoints to\n# retrieve their embeddings.\n# This allows us to do efficient k-NN search on top of embeddings:\nsrc_loader = NeighborLoader(\n    input_nodes='user',\n    input_time=(time[test_index].min() - 1).repeat(data['user'].num_nodes),\n    **kwargs,\n)\ndst_loader = NeighborLoader(\n    input_nodes='movie',\n    input_time=(time[test_index].min() - 1).repeat(data['movie'].num_nodes),\n    **kwargs,\n)\n\n# Save test edges and the edges we want to exclude when evaluating:\nsparse_size = (data['user'].num_nodes, data['movie'].num_nodes)\ntest_edge_label_index = EdgeIndex(\n    edge_label_index[:, test_index].to(device),\n    sparse_size=sparse_size,\n).sort_by('row')[0]\ntest_exclude_links = EdgeIndex(\n    edge_label_index[:, train_index].to(device),\n    sparse_size=sparse_size,\n).sort_by('row')[0]\n\n\nclass GNN(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.conv1 = SAGEConv((-1, -1), hidden_channels)\n        self.conv2 = SAGEConv((-1, -1), hidden_channels)\n        self.conv3 = SAGEConv((-1, -1), hidden_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index).relu()\n        x = self.conv3(x, edge_index)\n        return x\n\n\nclass InnerProductDecoder(torch.nn.Module):\n    def forward(self, x_dict, edge_label_index):\n        x_src = x_dict['user'][edge_label_index[0]]\n        x_dst = x_dict['movie'][edge_label_index[1]]\n        return (x_src * x_dst).sum(dim=-1)\n\n\nclass Model(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.encoder = GNN(hidden_channels)\n        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')\n        self.decoder = InnerProductDecoder()\n\n    def forward(self, x_dict, edge_index_dict, edge_label_index):\n        x_dict = self.encoder(x_dict, edge_index_dict)\n        return self.decoder(x_dict, edge_label_index)\n\n\nmodel = Model(hidden_channels=64).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train():\n    model.train()\n\n    total_loss = total_examples = 0\n    for batch in tqdm(train_loader):\n        batch = batch.to(device)\n        optimizer.zero_grad()\n\n        out = model(\n            batch.x_dict,\n            batch.edge_index_dict,\n            batch['user', 'movie'].edge_label_index,\n        )\n        y = batch['user', 'movie'].edge_label\n\n        loss = F.binary_cross_entropy_with_logits(out, y)\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss) * y.numel()\n        total_examples += y.numel()\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(edge_label_index, exclude_links):\n    model.eval()\n\n    dst_embs = []\n    for batch in dst_loader:  # Collect destination node/movie embeddings:\n        batch = batch.to(device)\n        emb = model.encoder(batch.x_dict, batch.edge_index_dict)['movie']\n        emb = emb[:batch['movie'].batch_size]\n        dst_embs.append(emb)\n    dst_emb = torch.cat(dst_embs, dim=0)\n    del dst_embs\n\n    # Instantiate k-NN index based on maximum inner product search (MIPS):\n    mips = MIPSKNNIndex(dst_emb)\n\n    # Initialize metrics:\n    map_metric = LinkPredMAP(k=args.k).to(device)\n    precision_metric = LinkPredPrecision(k=args.k).to(device)\n    recall_metric = LinkPredRecall(k=args.k).to(device)\n\n    num_processed = 0\n    for batch in src_loader:  # Collect source node/user embeddings:\n        batch = batch.to(device)\n\n        # Compute user embeddings:\n        emb = model.encoder(batch.x_dict, batch.edge_index_dict)['user']\n        emb = emb[:batch['user'].batch_size]\n\n        # Filter labels/exclusion by current batch:\n        _edge_label_index = edge_label_index.sparse_narrow(\n            dim=0,\n            start=num_processed,\n            length=emb.size(0),\n        )\n        _exclude_links = exclude_links.sparse_narrow(\n            dim=0,\n            start=num_processed,\n            length=emb.size(0),\n        )\n        num_processed += emb.size(0)\n\n        # Perform MIPS search:\n        _, pred_index_mat = mips.search(emb, args.k, _exclude_links)\n\n        # Update retrieval metrics:\n        map_metric.update(pred_index_mat, _edge_label_index)\n        precision_metric.update(pred_index_mat, _edge_label_index)\n        recall_metric.update(pred_index_mat, _edge_label_index)\n\n    return (\n        float(map_metric.compute()),\n        float(precision_metric.compute()),\n        float(recall_metric.compute()),\n    )\n\n\nfor epoch in range(1, 16):\n    train_loss = train()\n    print(f'Epoch: {epoch:02d}, Loss: {train_loss:.4f}')\n    val_map, val_precision, val_recall = test(\n        test_edge_label_index,\n        test_exclude_links,\n    )\n    print(f'Test MAP@{args.k}: {val_map:.4f}, '\n          f'Test Precision@{args.k}: {val_precision:.4f}, '\n          f'Test Recall@{args.k}: {val_recall:.4f}')\n"
  },
  {
    "path": "examples/hetero/temporal_link_pred.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import MovieLens\nfrom torch_geometric.loader import LinkNeighborLoader\nfrom torch_geometric.nn import SAGEConv, to_hetero\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/MovieLens')\ndataset = MovieLens(path, model_name='all-MiniLM-L6-v2')\ndata = dataset[0]\n\n# Add user node features for message passing:\ndata['user'].x = torch.eye(data['user'].num_nodes, device=device)\ndel data['user'].num_nodes\n\n# Add a reverse ('movie', 'rev_rates', 'user') relation for message passing:\ndata = T.ToUndirected()(data)\ndel data['movie', 'rev_rates', 'user'].edge_label  # Remove \"reverse\" label.\n\n# Perform a 80/10/10 temporal link-level split:\nperm = torch.argsort(data['user', 'movie'].time)\ntrain_idx = perm[:int(0.8 * perm.size(0))]\nval_idx = perm[int(0.8 * perm.size(0)):int(0.9 * perm.size(0))]\ntest_idx = perm[int(0.9 * perm.size(0)):]\n\nedge_index = data['user', 'movie'].edge_index\nkwargs = dict(\n    data=data,\n    num_neighbors=[20, 10],\n    batch_size=1024,\n    time_attr='time',\n    temporal_strategy='last',\n    num_workers=4,\n    persistent_workers=True,\n)\ntrain_loader = LinkNeighborLoader(\n    edge_label_index=(('user', 'movie'), edge_index[:, train_idx]),\n    edge_label=data['user', 'movie'].edge_label[train_idx],\n    edge_label_time=data['user', 'movie'].time[train_idx] - 1,\n    shuffle=True,\n    **kwargs,\n)\nval_loader = LinkNeighborLoader(\n    edge_label_index=(('user', 'movie'), edge_index[:, val_idx]),\n    edge_label=data['user', 'movie'].edge_label[val_idx],\n    edge_label_time=data['user', 'movie'].time[val_idx] - 1,\n    **kwargs,\n)\ntest_loader = LinkNeighborLoader(\n    edge_label_index=(('user', 'movie'), edge_index[:, test_idx]),\n    edge_label=data['user', 'movie'].edge_label[test_idx],\n    edge_label_time=data['user', 'movie'].time[test_idx] - 1,\n    **kwargs,\n)\n\n\nclass GNNEncoder(torch.nn.Module):\n    def __init__(self, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = SAGEConv((-1, -1), hidden_channels)\n        self.conv2 = SAGEConv((-1, -1), out_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n        return x\n\n\nclass EdgeDecoder(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.lin1 = Linear(2 * hidden_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, 1)\n\n    def forward(self, z_dict, edge_label_index):\n        row, col = edge_label_index\n        z = torch.cat([z_dict['user'][row], z_dict['movie'][col]], dim=-1)\n\n        z = self.lin1(z).relu()\n        z = self.lin2(z)\n        return z.view(-1)\n\n\nclass Model(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.encoder = GNNEncoder(hidden_channels, hidden_channels)\n        self.encoder = to_hetero(self.encoder, data.metadata(), aggr='sum')\n        self.decoder = EdgeDecoder(hidden_channels)\n\n    def forward(self, x_dict, edge_index_dict, edge_label_index):\n        z_dict = self.encoder(x_dict, edge_index_dict)\n        return self.decoder(z_dict, edge_label_index)\n\n\nmodel = Model(hidden_channels=32).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train():\n    model.train()\n    total_loss = total_examples = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n        pred = model(\n            batch.x_dict,\n            batch.edge_index_dict,\n            batch['user', 'movie'].edge_label_index,\n        )\n        target = batch['user', 'movie'].edge_label.float()\n        loss = F.mse_loss(pred, target)\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss * pred.size(0))\n        total_examples += pred.size(0)\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n    preds, targets = [], []\n    for batch in loader:\n        batch = batch.to(device)\n        pred = model(\n            batch.x_dict,\n            batch.edge_index_dict,\n            batch['user', 'movie'].edge_label_index,\n        ).clamp(min=0, max=5)\n        preds.append(pred)\n        targets.append(batch['user', 'movie'].edge_label.float())\n\n    pred = torch.cat(preds, dim=0)\n    target = torch.cat(targets, dim=0)\n    rmse = (pred - target).pow(2).mean().sqrt()\n    return float(rmse)\n\n\nfor epoch in range(1, 11):\n    loss = train()\n    val_rmse = test(val_loader)\n    test_rmse = test(test_loader)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val RMSE: {val_rmse:.4f}, '\n          f'Test RMSE: {test_rmse:.4f}')\n"
  },
  {
    "path": "examples/hetero/to_hetero_mag.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import ReLU\nfrom tqdm import tqdm\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import OGB_MAG\nfrom torch_geometric.loader import HGTLoader, NeighborLoader\nfrom torch_geometric.nn import Linear, SAGEConv, Sequential, to_hetero\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--use_hgt_loader', action='store_true')\nargs = parser.parse_args()\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../../data/OGB')\ntransform = T.ToUndirected(merge=True)\ndataset = OGB_MAG(path, preprocess='metapath2vec', transform=transform)\n\n# Already send node features/labels to GPU for faster access during sampling:\ndata = dataset[0].to(device, 'x', 'y')\n\ntrain_input_nodes = ('paper', data['paper'].train_mask)\nval_input_nodes = ('paper', data['paper'].val_mask)\nkwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}\n\nif not args.use_hgt_loader:\n    train_loader = NeighborLoader(data, num_neighbors=[10] * 2, shuffle=True,\n                                  input_nodes=train_input_nodes, **kwargs)\n    val_loader = NeighborLoader(data, num_neighbors=[10] * 2,\n                                input_nodes=val_input_nodes, **kwargs)\nelse:\n    train_loader = HGTLoader(data, num_samples=[1024] * 4, shuffle=True,\n                             input_nodes=train_input_nodes, **kwargs)\n    val_loader = HGTLoader(data, num_samples=[1024] * 4,\n                           input_nodes=val_input_nodes, **kwargs)\n\nmodel = Sequential('x, edge_index', [\n    (SAGEConv((-1, -1), 64), 'x, edge_index -> x'),\n    ReLU(inplace=True),\n    (SAGEConv((-1, -1), 64), 'x, edge_index -> x'),\n    ReLU(inplace=True),\n    (Linear(-1, dataset.num_classes), 'x -> x'),\n])\nmodel = to_hetero(model, data.metadata(), aggr='sum').to(device)\n\n\n@torch.no_grad()\ndef init_params():\n    # Initialize lazy parameters via forwarding a single batch to the model:\n    batch = next(iter(train_loader))\n    batch = batch.to(device, 'edge_index')\n    model(batch.x_dict, batch.edge_index_dict)\n\n\ndef train():\n    model.train()\n\n    total_examples = total_loss = 0\n    for batch in tqdm(train_loader):\n        optimizer.zero_grad()\n        batch = batch.to(device, 'edge_index')\n        batch_size = batch['paper'].batch_size\n        out = model(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]\n        loss = F.cross_entropy(out, batch['paper'].y[:batch_size])\n        loss.backward()\n        optimizer.step()\n\n        total_examples += batch_size\n        total_loss += float(loss) * batch_size\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_examples = total_correct = 0\n    for batch in tqdm(loader):\n        batch = batch.to(device, 'edge_index')\n        batch_size = batch['paper'].batch_size\n        out = model(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]\n        pred = out.argmax(dim=-1)\n\n        total_examples += batch_size\n        total_correct += int((pred == batch['paper'].y[:batch_size]).sum())\n\n    return total_correct / total_examples\n\n\ninit_params()  # Initialize parameters.\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\nfor epoch in range(1, 21):\n    loss = train()\n    val_acc = test(val_loader)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_acc:.4f}')\n"
  },
  {
    "path": "examples/hierarchical_sampling.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import Reddit\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn.models.basic_gnn import GraphSAGE\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')\ndataset = Reddit(path)\n\n# Already send node features/labels to GPU for faster access during sampling:\ndata = dataset[0].to(device, 'x', 'y')\n\nkwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}\nloader = NeighborLoader(data, input_nodes=data.train_mask,\n                        num_neighbors=[20, 10, 5], shuffle=True, **kwargs)\n\nmodel = GraphSAGE(\n    dataset.num_features,\n    hidden_channels=64,\n    out_channels=dataset.num_classes,\n    num_layers=3,\n).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train(trim=False):\n    for batch in tqdm(loader):\n        optimizer.zero_grad()\n        batch = batch.to(device)\n\n        if not trim:\n            out = model(batch.x, batch.edge_index)\n        else:\n            out = model(\n                batch.x,\n                batch.edge_index,\n                num_sampled_nodes_per_hop=batch.num_sampled_nodes,\n                num_sampled_edges_per_hop=batch.num_sampled_edges,\n            )\n\n        out = out[:batch.batch_size]\n        y = batch.y[:batch.batch_size]\n\n        loss = F.cross_entropy(out, y)\n        loss.backward()\n        optimizer.step()\n\n\nprint('One epoch training without Hierarchical Graph Sampling:')\ntrain(trim=False)\n\nprint('One epoch training with Hierarchical Graph Sampling:')\ntrain(trim=True)\n"
  },
  {
    "path": "examples/infomax_inductive.py",
    "content": "import os.path as osp\n\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import Reddit\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import DeepGraphInfomax, SAGEConv\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')\ndataset = Reddit(path)\ndata = dataset[0].to(device, 'x', 'edge_index')\n\ntrain_loader = NeighborLoader(data, num_neighbors=[10, 10, 25], batch_size=256,\n                              shuffle=True, num_workers=12)\ntest_loader = NeighborLoader(data, num_neighbors=[10, 10, 25], batch_size=256,\n                             num_workers=12)\n\n\nclass Encoder(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels):\n        super().__init__()\n        self.convs = torch.nn.ModuleList([\n            SAGEConv(in_channels, hidden_channels),\n            SAGEConv(hidden_channels, hidden_channels),\n            SAGEConv(hidden_channels, hidden_channels)\n        ])\n\n        self.activations = torch.nn.ModuleList()\n        self.activations.extend([\n            torch.nn.PReLU(hidden_channels),\n            torch.nn.PReLU(hidden_channels),\n            torch.nn.PReLU(hidden_channels)\n        ])\n\n    def forward(self, x, edge_index, batch_size):\n        for conv, act in zip(self.convs, self.activations):\n            x = conv(x, edge_index)\n            x = act(x)\n        return x[:batch_size]\n\n\ndef corruption(x, edge_index, batch_size):\n    return x[torch.randperm(x.size(0))], edge_index, batch_size\n\n\nmodel = DeepGraphInfomax(\n    hidden_channels=512, encoder=Encoder(dataset.num_features, 512),\n    summary=lambda z, *args, **kwargs: torch.sigmoid(z.mean(dim=0)),\n    corruption=corruption).to(device)\n\nmodel = model.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n\n\ndef train(epoch):\n    model.train()\n\n    total_loss = total_examples = 0\n    for batch in tqdm(train_loader, desc=f'Epoch {epoch:02d}'):\n        optimizer.zero_grad()\n        pos_z, neg_z, summary = model(batch.x, batch.edge_index,\n                                      batch.batch_size)\n        loss = model.loss(pos_z, neg_z, summary)\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss) * pos_z.size(0)\n        total_examples += pos_z.size(0)\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n\n    zs = []\n    for batch in tqdm(test_loader, desc='Evaluating'):\n        pos_z, _, _ = model(batch.x, batch.edge_index, batch.batch_size)\n        zs.append(pos_z.cpu())\n    z = torch.cat(zs, dim=0)\n    train_val_mask = data.train_mask | data.val_mask\n    acc = model.test(z[train_val_mask], data.y[train_val_mask],\n                     z[data.test_mask], data.y[data.test_mask], max_iter=10000)\n    return acc\n\n\nfor epoch in range(1, 31):\n    loss = train(epoch)\n    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}')\n\ntest_acc = test()\nprint(f'Test Accuracy: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/infomax_transductive.py",
    "content": "import os.path as osp\n\nimport torch\n\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import DeepGraphInfomax, GCNConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset)\n\n\nclass Encoder(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels):\n        super().__init__()\n        self.conv = GCNConv(in_channels, hidden_channels)\n        self.prelu = torch.nn.PReLU(hidden_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv(x, edge_index)\n        x = self.prelu(x)\n        return x\n\n\ndef corruption(x, edge_index):\n    return x[torch.randperm(x.size(0), device=x.device)], edge_index\n\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nmodel = DeepGraphInfomax(\n    hidden_channels=512,\n    encoder=Encoder(dataset.num_features, 512),\n    summary=lambda z, *args, **kwargs: z.mean(dim=0).sigmoid(),\n    corruption=corruption,\n).to(device)\ndata = dataset[0].to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    pos_z, neg_z, summary = model(data.x, data.edge_index)\n    loss = model.loss(pos_z, neg_z, summary)\n    loss.backward()\n    optimizer.step()\n    return loss.item()\n\n\ndef test():\n    model.eval()\n    z, _, _ = model(data.x, data.edge_index)\n    acc = model.test(z[data.train_mask], data.y[data.train_mask],\n                     z[data.test_mask], data.y[data.test_mask], max_iter=150)\n    return acc\n\n\nfor epoch in range(1, 301):\n    loss = train()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')\nacc = test()\nprint(f'Accuracy: {acc:.4f}')\n"
  },
  {
    "path": "examples/jit/README.md",
    "content": "# JIT Examples\n\nThis directory contains examples demonstrating the use of Just-In-Time (JIT) compilation in different GNN models.\n\n| Example                | Description                                                       |\n| ---------------------- | ----------------------------------------------------------------- |\n| [`gcn.py`](./gcn.py)   | JIT compilation in `GCN`                                          |\n| [`gat.py`](./gat.py)   | JIT compilation in `GAT`                                          |\n| [`gin.py`](./gin.py)   | JIT compilation in `GIN`                                          |\n| [`film.py`](./film.py) | JIT compilation in [`GNN-FiLM`](https://arxiv.org/abs/1906.12192) |\n"
  },
  {
    "path": "examples/jit/film.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.metrics import f1_score\nfrom torch import Tensor\nfrom torch.nn import BatchNorm1d\n\nfrom torch_geometric.datasets import PPI\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import FiLMConv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'PPI')\ntrain_dataset = PPI(path, split='train')\nval_dataset = PPI(path, split='val')\ntest_dataset = PPI(path, split='test')\ntrain_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)\ntest_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)\n\n\nclass FiLM(torch.nn.Module):\n    def __init__(self, in_channels: int, hidden_channels: int,\n                 out_channels: int, num_layers: int, dropout: float = 0.0):\n        super().__init__()\n        self.dropout = dropout\n\n        self.convs = torch.nn.ModuleList()\n        self.convs.append(FiLMConv(in_channels, hidden_channels))\n        for _ in range(num_layers - 2):\n            conv = FiLMConv(hidden_channels, hidden_channels)\n            self.convs.append(conv)\n        self.last_conv = FiLMConv(hidden_channels, out_channels, act=None)\n\n        self.norms = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.norms.append(BatchNorm1d(hidden_channels))\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        for conv, norm in zip(self.convs, self.norms):\n            x = norm(conv(x, edge_index))\n            x = F.dropout(x, p=self.dropout, training=self.training)\n        x = self.last_conv(x, edge_index)\n        return x\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = FiLM(train_dataset.num_features, 320, train_dataset.num_classes,\n             num_layers=4, dropout=0.1)\nmodel = torch.jit.script(model).to(device)\ncriterion = torch.nn.BCEWithLogitsLoss()\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        loss = criterion(model(data.x, data.edge_index), data.y)\n        total_loss += loss.item() * data.num_graphs\n        loss.backward()\n        optimizer.step()\n    return total_loss / len(train_loader.dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    ys, preds = [], []\n    for data in loader:\n        ys.append(data.y)\n        out = model(data.x.to(device), data.edge_index.to(device))\n        preds.append((out > 0).float().cpu())\n\n    y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()\n    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0\n\n\nfor epoch in range(1, 501):\n    loss = train()\n    val_f1 = test(val_loader)\n    test_f1 = test(test_loader)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, '\n          f'Test: {test_f1:.4f}')\n"
  },
  {
    "path": "examples/jit/gat.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GATConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', '..', 'data',\n                dataset)\ndataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\n\n\nclass GAT(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GATConv(dataset.num_features, 8, heads=8, dropout=0.6)\n\n        self.conv2 = GATConv(64, dataset.num_classes, heads=1, concat=True,\n                             dropout=0.6)\n\n    def forward(self, x, edge_index):\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = F.elu(self.conv1(x, edge_index))\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel, data = GAT().to(device), dataset[0].to(device)\nmodel = torch.jit.script(model)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    return loss.item()\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    out, accs = model(data.x, data.edge_index), []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        pred = out[mask].argmax(1)\n        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n        accs.append(acc)\n    return accs\n\n\nfor epoch in range(1, 201):\n    loss = train()\n    train_acc, val_acc, test_acc = test()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '\n          f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/jit/gcn.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GCNConv\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\npath = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Planetoid')\ndataset = Planetoid(path, 'Cora', transform=T.NormalizeFeatures())\ndata = dataset[0]\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, in_channels: int, hidden_channels: int,\n                 out_channels: int):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, hidden_channels)\n        self.conv2 = GCNConv(hidden_channels, out_channels)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.conv1(x, edge_index).relu()\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.conv2(x, edge_index)\n        return x\n\n\nmodel = GCN(dataset.num_features, 16, dataset.num_classes)\nmodel = torch.jit.script(model).to(device)\ndata = data.to(device)\noptimizer = torch.optim.Adam([\n    dict(params=model.conv1.parameters(), weight_decay=5e-4),\n    dict(params=model.conv2.parameters(), weight_decay=0)\n], lr=0.01)  # Only perform weight-decay on first convolution.\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.x, data.edge_index).argmax(dim=-1)\n\n    accs = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\nbest_val_acc = final_test_acc = 0\nfor epoch in range(1, 201):\n    loss = train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '\n          f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/jit/gin.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import BatchNorm1d as BatchNorm\nfrom torch.nn import Linear, ReLU, Sequential\n\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GINConv, global_add_pool\nfrom torch_geometric.transforms import OneHotDegree\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', '..', 'data', 'TU')\ndataset = TUDataset(path, name='IMDB-BINARY', transform=OneHotDegree(135))\ndataset = dataset.shuffle()\ntest_dataset = dataset[:len(dataset) // 10]\ntrain_dataset = dataset[len(dataset) // 10:]\ntest_loader = DataLoader(test_dataset, batch_size=128)\ntrain_loader = DataLoader(train_dataset, batch_size=128)\n\n\nclass GIN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList()\n        self.batch_norms = torch.nn.ModuleList()\n\n        for _ in range(num_layers):\n            mlp = Sequential(\n                Linear(in_channels, 2 * hidden_channels),\n                BatchNorm(2 * hidden_channels),\n                ReLU(),\n                Linear(2 * hidden_channels, hidden_channels),\n            )\n            conv = GINConv(mlp, train_eps=True)\n\n            self.convs.append(conv)\n            self.batch_norms.append(BatchNorm(hidden_channels))\n\n            in_channels = hidden_channels\n\n        self.lin1 = Linear(hidden_channels, hidden_channels)\n        self.batch_norm1 = BatchNorm(hidden_channels)\n        self.lin2 = Linear(hidden_channels, out_channels)\n\n    def forward(self, x, edge_index, batch):\n        for conv, batch_norm in zip(self.convs, self.batch_norms):\n            x = F.relu(batch_norm(conv(x, edge_index)))\n        x = global_add_pool(x, batch)\n        x = F.relu(self.batch_norm1(self.lin1(x)))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = GIN(dataset.num_features, 64, dataset.num_classes, num_layers=3)\nmodel = model.to(device)\nmodel = torch.jit.script(model)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0.\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.edge_index, data.batch)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        total_loss += loss.item() * data.num_graphs\n        optimizer.step()\n    return total_loss / len(train_dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_correct = 0\n    for data in loader:\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.batch)\n        pred = out.max(dim=1)[1]\n        total_correct += pred.eq(data.y).sum().item()\n    return total_correct / len(loader.dataset)\n\n\nfor epoch in range(1, 101):\n    loss = train()\n    train_acc = test(train_loader)\n    test_acc = test(test_loader)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, '\n          f'Train: {train_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/kge_fb15k_237.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\nimport torch.optim as optim\n\nfrom torch_geometric.datasets import FB15k_237\nfrom torch_geometric.nn import ComplEx, DistMult, RotatE, TransE\n\nmodel_map = {\n    'transe': TransE,\n    'complex': ComplEx,\n    'distmult': DistMult,\n    'rotate': RotatE,\n}\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--model', choices=model_map.keys(), type=str.lower,\n                    required=True)\nargs = parser.parse_args()\n\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'FB15k')\n\ntrain_data = FB15k_237(path, split='train')[0].to(device)\nval_data = FB15k_237(path, split='val')[0].to(device)\ntest_data = FB15k_237(path, split='test')[0].to(device)\n\nmodel_arg_map = {'rotate': {'margin': 9.0}}\nmodel = model_map[args.model](\n    num_nodes=train_data.num_nodes,\n    num_relations=train_data.num_edge_types,\n    hidden_channels=50,\n    **model_arg_map.get(args.model, {}),\n).to(device)\n\nloader = model.loader(\n    head_index=train_data.edge_index[0],\n    rel_type=train_data.edge_type,\n    tail_index=train_data.edge_index[1],\n    batch_size=1000,\n    shuffle=True,\n)\n\noptimizer_map = {\n    'transe': optim.Adam(model.parameters(), lr=0.01),\n    'complex': optim.Adagrad(model.parameters(), lr=0.001, weight_decay=1e-6),\n    'distmult': optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-6),\n    'rotate': optim.Adam(model.parameters(), lr=1e-3),\n}\noptimizer = optimizer_map[args.model]\n\n\ndef train():\n    model.train()\n    total_loss = total_examples = 0\n    for head_index, rel_type, tail_index in loader:\n        optimizer.zero_grad()\n        loss = model.loss(head_index, rel_type, tail_index)\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss) * head_index.numel()\n        total_examples += head_index.numel()\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(data):\n    model.eval()\n    return model.test(\n        head_index=data.edge_index[0],\n        rel_type=data.edge_type,\n        tail_index=data.edge_index[1],\n        batch_size=20000,\n        k=10,\n    )\n\n\nfor epoch in range(1, 501):\n    loss = train()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')\n    if epoch % 25 == 0:\n        rank, mrr, hits = test(val_data)\n        print(f'Epoch: {epoch:03d}, Val Mean Rank: {rank:.2f}, '\n              f'Val MRR: {mrr:.4f}, Val Hits@10: {hits:.4f}')\n\nrank, mrr, hits_at_10 = test(test_data)\nprint(f'Test Mean Rank: {rank:.2f}, Test MRR: {mrr:.4f}, '\n      f'Test Hits@10: {hits_at_10:.4f}')\n"
  },
  {
    "path": "examples/label_prop.py",
    "content": "import os.path as osp\n\nfrom ogb.nodeproppred import Evaluator, PygNodePropPredDataset\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.nn import LabelPropagation\n\nroot = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB')\ndataset = PygNodePropPredDataset(\n    'ogbn-arxiv', root, transform=T.Compose([\n        T.ToUndirected(),\n        T.ToSparseTensor(),\n    ]))\nsplit_idx = dataset.get_idx_split()\nevaluator = Evaluator(name='ogbn-arxiv')\ndata = dataset[0]\n\nmodel = LabelPropagation(num_layers=3, alpha=0.9)\nout = model(data.y, data.adj_t, mask=split_idx['train'])\n\ny_pred = out.argmax(dim=-1, keepdim=True)\n\nval_acc = evaluator.eval({\n    'y_true': data.y[split_idx['valid']],\n    'y_pred': y_pred[split_idx['valid']],\n})['acc']\ntest_acc = evaluator.eval({\n    'y_true': data.y[split_idx['test']],\n    'y_pred': y_pred[split_idx['test']],\n})['acc']\n\nprint(f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/lcm_aggr_2nd_min.py",
    "content": "# Final validation accuracy: ~95%\nimport argparse\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import LCMAggregation\nfrom torch_geometric.transforms import BaseTransform\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--num_bits', type=int, default=8)\nargs = parser.parse_args()\n\n\nclass RandomPermutation(BaseTransform):\n    def forward(self, data: Data) -> Data:\n        data.x = torch.x[torch.randperm(data.x.size(0))]\n        return data\n\n\nclass Random2ndMinimumDataset(InMemoryDataset):\n    r\"\"\"\"A labeled dataset, where each sample is a multiset of integers\n    encoded as bit-vectors, and the label is the second smallest integer\n    in the multiset.\n    \"\"\"\n    def __init__(\n        self,\n        num_examples: int,\n        num_bits: int,\n        min_num_elems: int,\n        max_num_elems: int,\n    ):\n        super().__init__(transform=RandomPermutation())\n\n        self.data, self.slices = self.collate([\n            self.get_data(num_bits, min_num_elems, max_num_elems)\n            for _ in range(num_examples)\n        ])\n\n    def get_data(\n        self,\n        num_bits: int,\n        min_num_elems: int,\n        max_num_elems: int,\n    ) -> Data:\n\n        num_elems = int(torch.randint(min_num_elems, max_num_elems + 1, (1, )))\n\n        x = torch.randint(0, 2, (num_elems, num_bits))\n\n        power = torch.pow(2, torch.arange(num_bits)).flip([0])\n        ints = (x * power.view(1, -1)).sum(dim=-1)\n        y = x[ints.topk(k=2, largest=False).indices[-1:]].to(torch.float)\n\n        return Data(x=x, y=y)\n\n\ntrain_dataset = Random2ndMinimumDataset(\n    num_examples=2**16,  # 65,536\n    num_bits=args.num_bits,\n    min_num_elems=2,\n    max_num_elems=16,\n)\n# Validate on multi sets of size 32, larger than observed during training:\nval_dataset = Random2ndMinimumDataset(\n    num_examples=2**10,  # 1024\n    num_bits=args.num_bits,\n    min_num_elems=32,\n    max_num_elems=32,\n)\n\ntrain_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=128)\n\n\nclass BitwiseEmbedding(torch.nn.Module):\n    def __init__(self, emb_dim: int):\n        super().__init__()\n        self.embs = torch.nn.ModuleList(\n            [torch.nn.Embedding(2, emb_dim) for _ in range(args.num_bits)])\n\n    def forward(self, x: Tensor) -> Tensor:\n        xs = [emb(b) for emb, b in zip(self.embs, x.t())]\n        return torch.stack(xs, dim=0).sum(0)\n\n\nclass LCM(torch.nn.Module):\n    def __init__(self, emb_dim: int, dropout: float = 0.25):\n        super().__init__()\n\n        self.encoder = torch.nn.Sequential(\n            BitwiseEmbedding(emb_dim),\n            torch.nn.Linear(emb_dim, emb_dim),\n            torch.nn.Dropout(),\n            torch.nn.GELU(),\n        )\n\n        self.aggr = LCMAggregation(emb_dim, emb_dim, project=False)\n\n        self.decoder = torch.nn.Sequential(\n            torch.nn.Linear(emb_dim, emb_dim),\n            torch.nn.Dropout(dropout),\n            torch.nn.GELU(),\n            torch.nn.Linear(emb_dim, args.num_bits),\n        )\n\n    def forward(self, x: Tensor, batch: Tensor) -> Tensor:\n        x = self.encoder(x)\n        x = self.aggr(x, batch)\n        x = self.decoder(x)\n        return x\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = LCM(emb_dim=128).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n\n\ndef train():\n    total_loss = total_examples = 0\n    for batch in train_loader:\n        batch = batch.to(device)\n        optimizer.zero_grad()\n        out = model(batch.x, batch.batch)\n        loss = F.binary_cross_entropy_with_logits(out, batch.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += batch.num_graphs * float(loss)\n        total_examples += batch.num_graphs\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(loader):\n    total_correct = total_examples = 0\n    for batch in loader:\n        batch = batch.to(device)\n        pred = model(batch.x, batch.batch).sigmoid().round()\n        num_mistakes = (pred != batch.y).sum(dim=-1)\n        total_correct += int((num_mistakes == 0).sum())\n        total_examples += batch.num_graphs\n    return total_correct / total_examples\n\n\nfor epoch in range(1, 1001):\n    loss = train()\n    val_acc = test(val_loader)\n    print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}')\n"
  },
  {
    "path": "examples/lightgcn.py",
    "content": "import os.path as osp\n\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import AmazonBook\nfrom torch_geometric.nn import LightGCN\nfrom torch_geometric.utils import degree\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Amazon')\ndataset = AmazonBook(path)\ndata = dataset[0]\nnum_users, num_books = data['user'].num_nodes, data['book'].num_nodes\ndata = data.to_homogeneous().to(device)\n\n# Use all message passing edges as training labels:\nbatch_size = 8192\nmask = data.edge_index[0] < data.edge_index[1]\ntrain_edge_label_index = data.edge_index[:, mask]\ntrain_loader = torch.utils.data.DataLoader(\n    range(train_edge_label_index.size(1)),\n    shuffle=True,\n    batch_size=batch_size,\n)\n\nmodel = LightGCN(\n    num_nodes=data.num_nodes,\n    embedding_dim=64,\n    num_layers=2,\n).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train():\n    total_loss = total_examples = 0\n\n    for index in tqdm(train_loader):\n        # Sample positive and negative labels.\n        pos_edge_label_index = train_edge_label_index[:, index]\n        neg_edge_label_index = torch.stack([\n            pos_edge_label_index[0],\n            torch.randint(num_users, num_users + num_books,\n                          (index.numel(), ), device=device)\n        ], dim=0)\n        edge_label_index = torch.cat([\n            pos_edge_label_index,\n            neg_edge_label_index,\n        ], dim=1)\n\n        optimizer.zero_grad()\n        pos_rank, neg_rank = model(data.edge_index, edge_label_index).chunk(2)\n\n        loss = model.recommendation_loss(\n            pos_rank,\n            neg_rank,\n            node_id=edge_label_index.unique(),\n        )\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss) * pos_rank.numel()\n        total_examples += pos_rank.numel()\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(k: int):\n    emb = model.get_embedding(data.edge_index)\n    user_emb, book_emb = emb[:num_users], emb[num_users:]\n\n    precision = recall = total_examples = 0\n    for start in range(0, num_users, batch_size):\n        end = start + batch_size\n        logits = user_emb[start:end] @ book_emb.t()\n\n        # Exclude training edges:\n        mask = ((train_edge_label_index[0] >= start) &\n                (train_edge_label_index[0] < end))\n        logits[train_edge_label_index[0, mask] - start,\n               train_edge_label_index[1, mask] - num_users] = float('-inf')\n\n        # Computing precision and recall:\n        ground_truth = torch.zeros_like(logits, dtype=torch.bool)\n        mask = ((data.edge_label_index[0] >= start) &\n                (data.edge_label_index[0] < end))\n        ground_truth[data.edge_label_index[0, mask] - start,\n                     data.edge_label_index[1, mask] - num_users] = True\n        node_count = degree(data.edge_label_index[0, mask] - start,\n                            num_nodes=logits.size(0))\n\n        topk_index = logits.topk(k, dim=-1).indices\n        isin_mat = ground_truth.gather(1, topk_index)\n\n        precision += float((isin_mat.sum(dim=-1) / k).sum())\n        recall += float((isin_mat.sum(dim=-1) / node_count.clamp(1e-6)).sum())\n        total_examples += int((node_count > 0).sum())\n\n    return precision / total_examples, recall / total_examples\n\n\nfor epoch in range(1, 101):\n    loss = train()\n    precision, recall = test(k=20)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Precision@20: '\n          f'{precision:.4f}, Recall@20: {recall:.4f}')\n"
  },
  {
    "path": "examples/link_pred.py",
    "content": "import os.path as osp\n\nimport torch\nfrom sklearn.metrics import roc_auc_score\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GCNConv\nfrom torch_geometric.utils import negative_sampling\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\ntransform = T.Compose([\n    T.NormalizeFeatures(),\n    T.ToDevice(device),\n    T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True,\n                      add_negative_train_samples=False),\n])\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, name='Cora', transform=transform)\n# After applying the `RandomLinkSplit` transform, the data is transformed from\n# a data object to a list of tuples (train_data, val_data, test_data), with\n# each element representing the corresponding split.\ntrain_data, val_data, test_data = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, hidden_channels)\n        self.conv2 = GCNConv(hidden_channels, out_channels)\n\n    def encode(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        return self.conv2(x, edge_index)\n\n    def decode(self, z, edge_label_index):\n        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)\n\n    def decode_all(self, z):\n        prob_adj = z @ z.t()\n        return (prob_adj > 0).nonzero(as_tuple=False).t()\n\n\nmodel = Net(dataset.num_features, 128, 64).to(device)\noptimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)\ncriterion = torch.nn.BCEWithLogitsLoss()\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    z = model.encode(train_data.x, train_data.edge_index)\n\n    # We perform a new round of negative sampling for every training epoch:\n    neg_edge_index = negative_sampling(\n        edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,\n        num_neg_samples=train_data.edge_label_index.size(1), method='sparse')\n\n    edge_label_index = torch.cat(\n        [train_data.edge_label_index, neg_edge_index],\n        dim=-1,\n    )\n    edge_label = torch.cat([\n        train_data.edge_label,\n        train_data.edge_label.new_zeros(neg_edge_index.size(1))\n    ], dim=0)\n\n    out = model.decode(z, edge_label_index).view(-1)\n    loss = criterion(out, edge_label)\n    loss.backward()\n    optimizer.step()\n    return loss\n\n\n@torch.no_grad()\ndef test(data):\n    model.eval()\n    z = model.encode(data.x, data.edge_index)\n    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()\n    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())\n\n\nbest_val_auc = final_test_auc = 0\nfor epoch in range(1, 101):\n    loss = train()\n    val_auc = test(val_data)\n    test_auc = test(test_data)\n    if val_auc > best_val_auc:\n        best_val_auc = val_auc\n        final_test_auc = test_auc\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '\n          f'Test: {test_auc:.4f}')\n\nprint(f'Final Test: {final_test_auc:.4f}')\n\nz = model.encode(test_data.x, test_data.edge_index)\nfinal_edge_index = model.decode_all(z)\n"
  },
  {
    "path": "examples/linkx.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import LINKXDataset\nfrom torch_geometric.nn import LINKX\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'LINKX')\ndataset = LINKXDataset(path, name='Penn94')\ndata = dataset[0].to(device)\n\nmodel = LINKX(data.num_nodes, data.num_features, hidden_channels=32,\n              out_channels=dataset.num_classes, num_layers=1,\n              num_edge_layers=1, num_node_layers=1, dropout=0.5).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    mask = data.train_mask[:, 0]  # Use the first set of the five masks.\n    loss = F.cross_entropy(out[mask], data.y[mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    accs = []\n    model.eval()\n    pred = model(data.x, data.edge_index).argmax(dim=-1)\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        mask = mask[:, 0]  # Use the first set of the five masks.\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\nfor epoch in range(1, 201):\n    loss = train()\n    train_acc, val_acc, test_acc = test()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/llm/README.md",
    "content": "# Examples for Co-training LLMs and GNNs\n\n| Example                                | Description                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                               |\n| -------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |\n| [`g_retriever.py`](./g_retriever.py)   | Example helper functions for using the [G-retriever](https://arxiv.org/abs/2402.07630) GNN+LLM module in PyG. Includes an [example repo](https://github.com/neo4j-product-examples/neo4j-gnn-llm-example) for [Neo4j](https://neo4j.com) integration with an associated [blog post](https://developer.nvidia.com/blog/boosting-qa-accuracy-with-graphrag-using-pyg-and-graph-databases/) demonstrating 2x accuracy gains over LLMs on real medical data. For a complete end-to-end pipeline (KG Creation, Subgraph Retrieval, GNN+LLM Finetuning, Testing, LLM Judge Eval), see [`txt2kg_rag.py`](./txt2kg_rag.py). For a native PyG implementation without external graph databases, see [gretriever-stark-prime](https://github.com/puririshi98/gretriever-stark-prime/tree/main).      |\n| [`molecule_gpt.py`](./molecule_gpt.py) | Example for MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction. Supports MoleculeGPT and InstructMol dataset                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                      |\n| [`glem.py`](./glem.py)                 | Example for [GLEM](https://arxiv.org/abs/2210.14709), a GNN+LLM co-training model via variational Expectation-Maximization (EM) framework on node classification tasks to achieve SOTA results                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                            |\n| [`git_mol.py`](./git_mol.py)           | Example for GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 |\n| [`protein_mpnn.py`](./protein_mpnn.py) | Example for [Robust deep learning--based protein sequence design using ProteinMPNN](https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                          |\n| [`txt2kg_rag.py`](./txt2kg_rag.py)     | Full end 2 end RAG pipeline using TXT2KG and Vector and Graph RAG with a GNN to achieve state of the art results. Uses the [techQA dataset](https://paperswithcode.com/dataset/techqa) but can be extended to handle any RAG dataset with a corpus of documents and an associated set of Q+A pairs to be split for train/eval/test. See [Stanford GNN+LLM Talk](https://www.nvidia.com/en-us/on-demand/session/other25-nv-0003/) for more details. Note that the TechQA data requires only a single document to answer each question so it can be viewed as a toy example. To see significant accuracy boosts from GNN+LLM TXT2KG based RAG, use data that requires multiple text chunks to answer a question. In cases where single document can answer, basic RAG should be sufficient. |\n"
  },
  {
    "path": "examples/llm/g_retriever.py",
    "content": "\"\"\"This example provides helper functions for using the G-retriever model\n(https://arxiv.org/abs/2402.07630) in PyG.\n\nRequirements:\n`pip install datasets transformers pcst_fast sentencepiece accelerate`\n\n\nExample blog showing 2x accuracy over agentic graphRAG on real medical data\n(integration with Neo4j Graph DB):\nhttps://developer.nvidia.com/blog/boosting-qa-accuracy-with-graphrag-using-pyg-and-graph-databases/\n\nhttps://github.com/neo4j-product-examples/neo4j-gnn-llm-example\n\nSee examples/llm/txt2kg_rag.py for e2e pipeline in PyG including:\n- KG Creation\n- Subgraph Retrieval\n- GNN+LLM Finetuning\n- Testing\n- LLM Judge Eval\n\n\"\"\"\nimport math\n\nimport torch\nfrom torch import Tensor\n\n\ndef adjust_learning_rate(param_group: dict, LR: float, epoch: int,\n                         num_epochs: int):\n    \"\"\"Decay learning rate with half-cycle cosine after warmup.\n\n    Args:\n        param_group (dict): Parameter group.\n        LR (float): Learning rate.\n        epoch (int): current epoch\n        num_epochs (int): total epochs\n    Returns:\n        float: Adjusted learning rate.\n    \"\"\"\n    min_lr = 5e-6\n    warmup_epochs = 1\n    if epoch < warmup_epochs:\n        lr = LR\n    else:\n        lr = min_lr + (LR - min_lr) * 0.5 * (\n            1.0 + math.cos(math.pi * (epoch - warmup_epochs) /\n                           (num_epochs - warmup_epochs)))\n    param_group['lr'] = lr\n    return lr\n\n\ndef save_params_dict(model, save_path):\n    \"\"\"Saves a model's parameters, excluding non-trainable weights.\n\n    Args:\n        model (torch.nn.Module): The model to save parameters from.\n        save_path (str): The path to save the parameters to.\n    \"\"\"\n    # Get the model's state dictionary, which contains all its parameters\n    state_dict = model.state_dict()\n\n    # Create a dictionary mapping parameter names to their requires_grad status\n    param_grad_dict = {\n        k: v.requires_grad\n        for (k, v) in model.named_parameters()\n    }\n\n    # Remove non-trainable parameters from the state dictionary\n    for k in list(state_dict.keys()):\n        if k in param_grad_dict.keys() and not param_grad_dict[k]:\n            del state_dict[k]  # Delete parameters that do not require gradient\n\n    # Save the filtered state dictionary to the specified path\n    torch.save(state_dict, save_path)\n\n\ndef load_params_dict(model, save_path):\n    # Load the saved model parameters from the specified file path\n    state_dict = torch.load(save_path)\n\n    # Update the model's parameters with the loaded state dictionary\n    model.load_state_dict(state_dict)\n\n    # Return the model with updated parameters\n    return model\n\n\ndef normalize_batch_dtype(batch):\n    batch.x = batch.x.float()\n    if hasattr(batch, \"edge_attr\") and batch.edge_attr is not None:\n        batch.edge_attr = batch.edge_attr.float()\n\n\ndef get_loss(model, batch, model_save_name=\"gnn+llm\") -> Tensor:\n    \"\"\"Compute the loss for a given model and batch of data.\n\n    Args:\n        model: The model to compute the loss for.\n        batch: The batch of data to compute the loss for.\n        model_save_name: The name of the model being used (e.g. 'llm').\n\n    Returns:\n        Tensor: The computed loss.\n    \"\"\"\n    # Check the type of model being used to determine the input arguments\n    if model_save_name == 'llm':\n        # For LLM models\n        return model(batch.question, batch.label, batch.desc)\n    else:  # (GNN+LLM)\n        normalize_batch_dtype(batch)\n        return model(\n            batch.question,  # [\"list\", \"of\", \"questions\", \"here\"]\n            batch.x,  # [num_nodes, num_features]\n            batch.edge_index,  # [2, num_edges]\n            batch.batch,  # which node belongs to which batch index\n            batch.label,  # list answers (labels)\n            batch.edge_attr,  # edge attributes\n            batch.desc  # list of text graph descriptions\n        )\n\n\ndef inference_step(model, batch, model_save_name=\"gnn+llm\",\n                   max_out_tokens=128):\n    \"\"\"Performs inference on a given batch of data using the provided model.\n\n    Args:\n        model (nn.Module): The model to use for inference.\n        batch: The batch of data to process.\n        model_save_name (str): The name of the model (e.g. 'llm').\n        max_out_tokens (int): The maximum number of tokens\n            for our model to output.\n\n    Returns:\n        The output of the inference step.\n    \"\"\"\n    # Check the type of model being used to determine the input arguments\n    if model_save_name == 'llm':\n        # Perform inference on the question and textual graph description\n        return model.inference(batch.question, batch.desc,\n                               max_out_tokens=max_out_tokens)\n    else:  # (GNN+LLM)\n        normalize_batch_dtype(batch)\n        return model.inference(batch.question, batch.x, batch.edge_index,\n                               batch.batch, batch.edge_attr, batch.desc,\n                               max_out_tokens=max_out_tokens)\n"
  },
  {
    "path": "examples/llm/git_mol.py",
    "content": "\"\"\"This example implements the GIT-Mol model\n(https://arxiv.org/abs/2308.06911) using PyG.\n\"\"\"\nimport argparse\nimport os.path as osp\n\nimport torch\nfrom accelerate import Accelerator\nfrom accelerate.utils import DistributedDataParallelKwargs\nfrom torch.optim.lr_scheduler import StepLR\nfrom tqdm import tqdm\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.datasets import GitMolDataset\nfrom torch_geometric.llm.models import GITMol\nfrom torch_geometric.loader import DataLoader\n\n\n@torch.no_grad()\ndef eval(model, data_loader):\n    model.eval()\n    loss = 0\n\n    for batch in data_loader:\n        batch_loss = model(batch.x, batch.edge_index, batch.batch,\n                           batch.edge_attr, batch.smiles, batch.image,\n                           batch.caption)\n        loss += batch_loss.item() / len(data_loader)\n    return loss\n\n\ndef train(\n    num_epochs: int,\n    lr: float,\n    weight_decay: float,\n    batch_size: int,\n    checkpointing: bool,\n):\n    # Load dataset ================================================\n    path = osp.dirname(osp.realpath(__file__))\n    path = osp.join(path, '..', '..', 'data', 'GITMol')\n    train_dataset = GitMolDataset(path, split=0)\n    val_dataset = GitMolDataset(path, split=1)\n    test_dataset = GitMolDataset(path, split=2)\n\n    seed_everything(42)\n\n    train_loader = DataLoader(train_dataset, batch_size=batch_size,\n                              drop_last=True, pin_memory=True, shuffle=True)\n    val_loader = DataLoader(val_dataset, batch_size=batch_size,\n                            drop_last=False, pin_memory=True, shuffle=False)\n    test_loader = DataLoader(test_dataset, batch_size=batch_size,\n                             drop_last=False, pin_memory=True, shuffle=False)\n\n    # Create model ===============================================\n    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)\n    accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])\n    device = accelerator.device\n    model = GITMol().to(device)\n    optimizer = torch.optim.AdamW(\n        [p for p in model.parameters() if p.requires_grad], lr=lr,\n        weight_decay=weight_decay)\n    scheduler = StepLR(optimizer, step_size=1, gamma=0.1)\n    model, optimizer, train_loader, scheduler = accelerator.prepare(\n        model, optimizer, train_loader, scheduler)\n    val_loader = accelerator.prepare_data_loader(val_loader,\n                                                 device_placement=True)\n    test_loader = accelerator.prepare_data_loader(test_loader,\n                                                  device_placement=True)\n\n    # Train and eval ============================================\n    best_epoch = 0\n    best_val_loss = float('inf')\n    for epoch in range(num_epochs):\n        # Train\n        model.train()\n        epoch_loss = 0\n        if epoch == 0:\n            print(\"Training beginning...\")\n        epoch_str = f'Epoch: {epoch + 1}|{num_epochs}'\n\n        for batch in tqdm(train_loader, desc=epoch_str):\n            optimizer.zero_grad()\n            loss = model(batch.x, batch.edge_index, batch.batch,\n                         batch.edge_attr, batch.smiles, batch.image,\n                         batch.caption)\n            accelerator.backward(loss)\n\n            optimizer.step()\n            epoch_loss += loss.item()\n\n        train_loss = epoch_loss / len(train_loader)\n\n        # Eval\n        val_loss = eval(model, val_loader)\n        print(\n            f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}'  # noqa: E501\n        )\n\n        if checkpointing and val_loss < best_val_loss:\n            best_val_loss = val_loss\n            best_epoch = epoch\n            torch.save(\n                {\n                    'model_state_dict':\n                    accelerator.unwrap_model(model).state_dict(),\n                    'best_loss':\n                    best_val_loss\n                },\n                f'gitmol_pretrain_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt'  # noqa: E501\n            )\n    torch.cuda.empty_cache()\n    torch.cuda.reset_peak_memory_stats()\n\n    # Test\n    test_loss = eval(model, test_loader)\n    print(f'Test loss: {test_loss:4f}')\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--epochs', type=int, default=3)\n    parser.add_argument('--lr', type=float, default=1e-5)\n    parser.add_argument('--batch_size', type=int, default=4)\n    parser.add_argument(\"--weight_decay\", type=float, default=0.01)\n    parser.add_argument('--checkpointing', type=bool, default=True)\n    args = parser.parse_args()\n\n    train(\n        args.epochs,\n        args.lr,\n        args.weight_decay,\n        args.batch_size,\n        args.checkpointing,\n    )\n"
  },
  {
    "path": "examples/llm/glem.py",
    "content": "\"\"\"This example run GLEM model using PyG.\nOriginal Paper: https://arxiv.org/abs/2210.14709\n“Learning on Large-scale Text-attributed Graphs via Variational Inference“.\nRequirements on top of basic PyG:\n`pip install ogb transformers peft tqdm`.\nGLEM is a data augmentation co-training strategy for LM and GNN, our\nimplementation extended original implementation from LM to LLM and opt for LoRA\nfrom peft.\n\n``note::\n    use additional trick, please add your external prediction by assigning\n    `ext_pred_path` and combine it into pretraining phase and node features\n\"\"\"\n\nimport argparse\nimport os\nimport os.path as osp\nimport time\n\nimport psutil\nimport torch\nfrom ogb.nodeproppred import Evaluator, PygNodePropPredDataset\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.data import download_google_url\nfrom torch_geometric.datasets import TAGDataset\nfrom torch_geometric.llm import GLEM\nfrom torch_geometric.loader import DataLoader, NeighborLoader\nfrom torch_geometric.nn.models import GAT, GCN, GraphSAGE\n\n\ndef get_n_params(model):\n    pp = 0\n    for p in list(model.parameters()):\n        nn = 1\n        for s in list(p.size()):\n            nn = nn * s\n        pp += nn\n    return pp\n\n\ndef main(args):\n    gpu = args.gpu\n    dataset_name = args.dataset\n    text_type = args.text_type if args.dataset == 'arxiv' else 'raw_text'\n    root = osp.join('data', 'ogb')\n    hf_model = args.hf_model\n    pl_ratio = args.pl_ratio\n    gnn_lr = args.gnn_lr\n    lm_lr = args.lm_lr\n    em_order = args.em_order\n    gnn_epochs = args.gnn_epochs\n    lm_epochs = args.lm_epochs\n    patience = args.patience\n    verbose = args.verbose\n    out_dir = args.out_dir\n    lm_batch_size = args.lm_batch_size\n    gnn_batch_size = args.gnn_batch_size\n    lm_use_lora = args.lm_use_lora\n    token_on_disk = args.token_on_disk\n    num_em_iters = args.num_em_iters\n    start_time = time.time()\n    train_with_ext_pred = not args.train_without_ext_pred and \\\n        dataset_name == 'products'\n    ext_pred = None\n    pretrain_augmented = False\n    ext_pseudo_labels = None\n    device = torch.device(\n        f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu')\n    print(f'Running on: {torch.cuda.get_device_name({gpu})}')\n    torch.cuda.empty_cache()\n\n    if train_with_ext_pred:\n        ext_pred_path = download_google_url(\n            id='15sO2m7BeW7C1Upmdw3Cx1JS__6nxTAzY',\n            folder='data/ogb/ogbn_products/ext_preds',\n            filename='giant_sagn_scr.pt', log=True)\n        ext_pred = torch.load(ext_pred_path, map_location=device)\n        ext_pseudo_labels = ext_pred.argmax(dim=-1)\n        pretrain_augmented = True\n\n    seed_everything(42)\n\n    dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root)\n    split_idx = dataset.get_idx_split()\n    data = dataset._data\n\n    tag_dataset = TAGDataset(root, dataset, hf_model,\n                             token_on_disk=token_on_disk)\n    text_dataset = tag_dataset.to_text_dataset(text_type)\n    print(tag_dataset.num_classes, tag_dataset.raw_file_names)\n\n    num_classes = tag_dataset.num_classes\n    num_features = data.num_features\n    # =========================== LM Data split ===============================\n    split_idx = tag_dataset.get_idx_split()\n\n    # GLEM train with augmented data, mark original train data as gold data,\n    gold_idx = split_idx['train']\n    split_idx['valid']\n    test_idx = split_idx['test']\n\n    # random sample pseudo labels nodes, generate their index\n    num_pseudo_labels = int(gold_idx.numel() * pl_ratio)\n    idx_to_select = torch.randperm(test_idx.numel())[:num_pseudo_labels]\n    pseudo_labels_idx = test_idx[idx_to_select]\n    train_idx = torch.cat(\n        (gold_idx, pseudo_labels_idx))  # augmented train_indx\n\n    print(f'train_idx: {train_idx.size(0)}, '\n          f'gold_idx: {gold_idx.size(0)}, '\n          f'pseudo labels ratio: {pl_ratio}, '\n          f'{train_idx.size(0)/gold_idx.size(0) - 1.0}')\n    gold_dataset = torch.utils.data.Subset(dataset=text_dataset,\n                                           indices=gold_idx)\n    train_dataset = torch.utils.data.Subset(dataset=text_dataset,\n                                            indices=train_idx)\n    # ========================== LM Data Loader ===============================\n\n    print('Building language model dataloader...', end='-->')\n\n    # if set train_without_ext_pred == True, use this for pretrain\n    text_pretrain_loader = DataLoader(gold_dataset, batch_size=lm_batch_size,\n                                      drop_last=False, pin_memory=True,\n                                      shuffle=True)\n    # training with augmented data,\n    text_train_loader = DataLoader(train_dataset, batch_size=lm_batch_size,\n                                   drop_last=False, pin_memory=True,\n                                   shuffle=True)\n    text_test_loader = DataLoader(text_dataset, batch_size=lm_batch_size * 4,\n                                  drop_last=False, pin_memory=True,\n                                  shuffle=False)\n    print('done')\n\n    # =========================== GNN Data Loader =============================\n    initial_memory = torch.cuda.memory_allocated()\n    data = data.to(device)\n    if ext_pred is not None:\n        data.x = torch.cat((data.x, ext_pred), dim=1)\n        num_features += ext_pred.size(1)\n    current_memory_1 = torch.cuda.max_memory_allocated()\n    # 1 GB = 1073741824 Byte\n    gpu_usage = float(current_memory_1 - initial_memory) / 1073741824\n    # Print the maximum memory usage after running the model\n    print(f'GPU memory usage -- data to gpu: {gpu_usage:.2f} GB')\n\n    print('build GNN dataloader(GraphSAGE NeighborLoader)', end='-->')\n\n    # train on gold data w/o pseudo labels\n    graph_pretrain_loader = NeighborLoader(\n        data,\n        input_nodes=gold_idx,\n        num_neighbors=[15, 10, 5],\n        batch_size=gnn_batch_size,\n        shuffle=True,\n        num_workers=12,\n        persistent_workers=True,\n    )\n\n    # graph data loader w/ pseudo labels in M-step\n    graph_train_loader = NeighborLoader(\n        data,\n        input_nodes=train_idx,\n        num_neighbors=[15, 10, 5],\n        batch_size=gnn_batch_size,\n        shuffle=True,\n        num_workers=12,\n        persistent_workers=True,\n    )\n\n    # for gnn inference\n    subgraph_loader = NeighborLoader(\n        data,\n        input_nodes=None,\n        num_neighbors=[-1],\n        batch_size=gnn_batch_size * 4,\n        num_workers=12,\n        persistent_workers=True,\n    )\n    # =========================== internal function ===========================\n\n    evaluator = Evaluator(name=f'ogbn-{dataset_name}')\n\n    def evaluate(out, split):\n        y_true = data.y.cpu()\n        y_pred = out.argmax(dim=-1, keepdim=True)\n        train_acc, val_acc, test_acc = None, None, None\n        if 'train' in split:\n            train_acc = evaluator.eval({\n                'y_true': y_true[split_idx['train']],\n                'y_pred': y_pred[split_idx['train']],\n            })['acc']\n        if 'valid' in split:\n            val_acc = evaluator.eval({\n                'y_true': y_true[split_idx['valid']],\n                'y_pred': y_pred[split_idx['valid']],\n            })['acc']\n        if 'test' in split:\n            test_acc = evaluator.eval({\n                'y_true': y_true[split_idx['test']],\n                'y_pred': y_pred[split_idx['test']],\n            })['acc']\n\n        return train_acc, val_acc, test_acc\n\n    # =========================== Build GNN Model =============================\n    gnn = None\n    if args.gnn_model == 'SAGE':\n        gnn = GraphSAGE(\n            in_channels=num_features,\n            hidden_channels=args.gnn_hidden_channels,\n            num_layers=args.gnn_num_layers,\n            out_channels=dataset.num_classes,\n        )\n    elif args.gnn_model == 'GAT':\n        gnn = GAT(in_channels=num_features,\n                  hidden_channels=args.gnn_hidden_channels,\n                  num_layers=args.gnn_num_layers,\n                  out_channels=dataset.num_classes, heads=args.gat_heads)\n    else:\n        gnn = GCN(\n            in_channels=num_features,\n            hidden_channels=args.gnn_hidden_channels,\n            num_layers=args.gnn_num_layers,\n            out_channels=dataset.num_classes,\n        )\n\n    print(\"# GNN Params:\", get_n_params(gnn))\n    # =========================== Build LM Model ==============================\n\n    model = GLEM(lm_to_use=hf_model, gnn_to_use=gnn, out_channels=num_classes,\n                 lm_use_lora=lm_use_lora, device=device)\n    lm = model.lm\n    print(\"# LM Params:\", get_n_params(lm))\n    gnn_opt = torch.optim.Adam(gnn.parameters(), lr=gnn_lr)\n    lm_opt = torch.optim.Adam(lm.parameters(), lr=lm_lr)\n\n    def load_model(em_phase):\n        print(f'Move {em_phase} model from cpu memory')\n        if em_phase == 'lm':\n            model.lm = model.lm.to(device, non_blocking=True)\n            optimizer = torch.optim.Adam(model.lm.parameters(), lr=lm_lr)\n        if em_phase == 'gnn':\n            model.gnn = model.gnn.to(device, non_blocking=True)\n            optimizer = torch.optim.Adam(model.gnn.parameters(), lr=gnn_lr)\n        return optimizer\n\n    # ================================= Run GLEM ==============================\n    preds_filename = 'lm_pretrain'\n    preds_dir = f'{out_dir}preds/{dataset_name}/'\n    gnn_test_acc = 0.0\n    lm_test_acc = 0.0\n    # =============================== GLEM pretraining ========================\n    pretrain_phase = 'lm'\n    if em_order == 'lm':\n        pretrain_phase = 'gnn'\n    pretrain_start_time = time.time()\n    # pretraining\n    pretrain_loader = graph_pretrain_loader\n    test_loader = subgraph_loader\n    pretrain_num_epochs = gnn_epochs\n    pretrain_opt = gnn_opt\n    if pretrain_phase == 'gnn':\n        model.gnn = model.gnn.to(device)\n        print('pretraining gnn to generate pseudo labels')\n        if train_with_ext_pred:\n            pretrain_loader = graph_train_loader\n        preds_filename = 'gnn_pretrain'\n    elif pretrain_phase == 'lm':\n        model.lm = model.lm.to(device)\n        print('pretraining lm to generate pseudo labels')\n        pretrain_num_epochs = lm_epochs\n        pretrain_loader = text_pretrain_loader\n        test_loader = text_test_loader\n        pretrain_opt = lm_opt\n        if train_with_ext_pred:\n            pretrain_loader = text_train_loader\n        preds_filename = 'lm_pretrain'\n\n    early_stopping = 0\n    best_val_acc = 0.0\n    for epoch in range(1, pretrain_num_epochs + 1):\n        acc, loss = model.train(pretrain_phase, pretrain_loader, pretrain_opt,\n                                ext_pseudo_labels, epoch, pretrain_augmented,\n                                verbose)\n        if epoch >= 5 or epoch == pretrain_num_epochs:\n            pretrain_preds = model.inference(pretrain_phase, test_loader,\n                                             verbose=verbose)\n            train_acc, val_acc, _ = evaluate(pretrain_preds,\n                                             ['train', 'valid'])\n\n            print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}')\n\n            if val_acc <= best_val_acc:\n                early_stopping += 1\n                if early_stopping > patience:\n                    print(f'Pretrain Early stopped by Epoch: {epoch}')\n                    break\n            else:\n                best_val_acc = val_acc\n    preds = model.inference(pretrain_phase, test_loader, verbose=verbose)\n    train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test'])\n    if pretrain_phase == 'gnn':\n        gnn_test_acc = max(gnn_test_acc, test_acc)\n        model.gnn = model.gnn.to('cpu', non_blocking=True)\n    else:\n        lm_test_acc = max(lm_test_acc, test_acc)\n        model.lm = model.lm.to('cpu', non_blocking=True)\n    torch.cuda.empty_cache()\n\n    pretrain_phase_time = time.time() - pretrain_start_time\n    print(f'Pretrain {pretrain_phase} time: {pretrain_phase_time:.2f}s')\n    os.makedirs(osp.dirname(preds_dir), exist_ok=True)\n    torch.save(preds, osp.join(preds_dir, f'{preds_filename}.pt'))\n    print(\n        f'Saved predictions to {osp.join(preds_dir, f\"{preds_filename}.pt\")}')\n    train_acc, val_acc, test_acc = evaluate(preds, ['train', 'valid', 'test'])\n    print(f'Pretraining acc: {train_acc:.4f}, Val: {val_acc:.4f}, '\n          f'Test: {test_acc:.4f}')\n\n    # EM iterations\n\n    em_phase = em_order\n    \"\"\"\n    We run E-step(LM training) and M-Step(GNN training) alternatively in each\n    em iterations, so the total number of iterations is num_em_iter * 2 and\n    we switch the em_phase at end of each iteration in following loop\n    \"\"\"\n    gnn_val_acc = lm_val_acc = 0.0\n    for em_it in range(1, num_em_iters * 2 + 1):\n        pseudo_labels = preds.argmax(dim=-1)\n        best_val_acc = 0.0\n        print(f'EM iteration: {em_it}, EM phase: {em_phase}')\n        optimizer = load_model(em_phase)\n        num_epochs = lm_epochs\n        train_loader = text_train_loader\n        test_loader = text_test_loader\n        early_stopping = 0\n        if em_phase == 'gnn':\n            train_loader = graph_train_loader\n            num_epochs = gnn_epochs\n            test_loader = subgraph_loader\n        for epoch in range(1, num_epochs + 1):\n            acc, loss = model.train(em_phase, train_loader, optimizer,\n                                    pseudo_labels, epoch, True, verbose)\n            if epoch >= 5 or epoch == num_epochs:\n                cur_preds = model.inference(em_phase, test_loader,\n                                            verbose=verbose)\n                train_acc, val_acc, _ = evaluate(cur_preds, ['train', 'valid'])\n\n                print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f},')\n\n                if val_acc <= best_val_acc:\n                    early_stopping += 1\n                    if early_stopping > patience:\n                        print(f'''Early stopped by Epoch: {epoch}, \\\n                            Best acc: {best_val_acc}''')\n                        break\n                else:\n                    best_val_acc = val_acc\n\n        preds = model.inference(em_phase, test_loader, verbose=verbose)\n        if em_phase == 'gnn':\n            gnn_val_acc = max(gnn_val_acc, best_val_acc)\n            model.gnn = model.gnn.to('cpu', non_blocking=True)\n            em_phase = 'lm'\n        else:\n            lm_val_acc = max(lm_val_acc, best_val_acc)\n            model.lm = model.lm.to('cpu', non_blocking=True)\n            em_phase = 'gnn'\n        torch.cuda.empty_cache()\n    print(f'Best GNN validation acc: {gnn_val_acc},'\n          f'LM validation acc: {lm_val_acc}')\n    print('============================')\n    if gnn_val_acc > lm_val_acc:\n        em_phase = 'gnn'\n        model.gnn = model.gnn.to(device, non_blocking=True)\n        test_loader = subgraph_loader\n    else:\n        em_phase = 'lm'\n        model.lm = model.lm.to(device, non_blocking=True)\n        test_loader = text_test_loader\n    test_preds = model.inference(em_phase, test_loader, verbose=verbose)\n    train_acc, val_acc, test_acc = evaluate(test_preds,\n                                            ['train', 'valid', 'test'])\n    final_test_acc = max(gnn_test_acc, max(lm_test_acc, test_acc))\n    print(f'Best test acc: {final_test_acc}, model: {em_phase}')\n    end_time = time.time()\n    running_time = (end_time - start_time) / 3600\n    print(f'Total running time: {running_time:.2f} hours')\n\n\nif __name__ == '__main__':\n    available_gb = psutil.virtual_memory().available / (1024**3)\n    if available_gb < 80:\n        print(f\"  WARNING: This test may require more RAM than available.\\n\"\n              f\"    Estimated RAM needed: ~80 GB\\n\"\n              f\"    Detected available RAM: {available_gb:.2f} GB\\n\"\n              \"    If the program crashes or is killed, consider upgrading \"\n              \"system memory.\")\n\n    parser = argparse.ArgumentParser(description='GLEM Example:')\n    parser.add_argument('--gpu', type=int, default=0)\n    parser.add_argument('--num_runs', type=int, default=10,\n                        help='number of runs')\n    parser.add_argument('--num_em_iters', type=int, default=1,\n                        help='number of iterations')\n    parser.add_argument(\"--dataset\", type=str, default='products',\n                        help='arxiv or products')\n    parser.add_argument(\n        \"--text_type\", type=str, default='raw_text',\n        help=\"type of text, support raw_text, llm_explanation,\"\n        \"all for arxiv and raw_text for products\")\n    parser.add_argument(\"--pl_ratio\", type=float, default=0.5,\n                        help=\"pseudo labels ratio\")\n    parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny',\n                        help='huggingface model repo id')\n    parser.add_argument(\n        '--gnn_model', type=str, default='SAGE',\n        help='gnn model for node classification,'\n        'options: SAGE, GAT, GCN')\n    parser.add_argument('--gnn_hidden_channels', type=int, default=256)\n    parser.add_argument('--gnn_num_layers', type=int, default=3)\n    parser.add_argument('--gat_heads', type=int, default=4,\n                        help='Number of multi-head-attentions for GAT ')\n    parser.add_argument('--lm_batch_size', type=int, default=256)\n    parser.add_argument('--gnn_batch_size', type=int, default=1024)\n    parser.add_argument(\n        '--external_pred_path', type=str, default=None,\n        help=\"Other model's output logits during the \"\n        \"pretraining phase or simply concatenate it with\"\n        \"node features as augmented data for gnn\")\n    parser.add_argument('--alpha', type=float, default=0.5,\n                        help='pseudo label weight in E-step')\n    parser.add_argument('--beta', type=float, default=0.5,\n                        help='pseudo label weight in M-step')\n    parser.add_argument('--lm_epochs', type=int, default=10)\n    parser.add_argument('--gnn_epochs', type=int, default=50)\n    parser.add_argument('--gnn_lr', type=float, default=0.002)\n    parser.add_argument('--lm_lr', type=float, default=0.001)\n    parser.add_argument('--patience', type=int, default=3,\n                        help='Patience for early stopping')\n    parser.add_argument('--verbose', action='store_true',\n                        help='show progress bar during training or not')\n    parser.add_argument('--em_order', type=str, default='lm',\n                        help='decide train LM first or GNN first')\n    parser.add_argument('--lm_use_lora', action='store_true',\n                        help='use Lora to fine-tune model or not')\n    parser.add_argument(\n        '--token_on_disk', action='store_true',\n        help='save token on disk and load token from disk'\n        'for reducing duplicated tokenizing')\n    parser.add_argument('--out_dir', type=str, default='output/',\n                        help='output directory')\n    parser.add_argument(\n        '--train_without_ext_pred', action='store_true',\n        help='train glem without using additional pseudo labels '\n        'for augmenting data only available for ogbn-products')\n    args = parser.parse_args()\n    print(args)\n    main(args)\n"
  },
  {
    "path": "examples/llm/molecule_gpt.py",
    "content": "\"\"\"This example implements the MoleculeGPT model\n(https://ai4d3.github.io/papers/34.pdf) using PyG.\n\"\"\"\nimport argparse\nimport math\nimport os.path as osp\nimport time\n\nimport torch\nfrom torch.nn.utils import clip_grad_norm_\nfrom tqdm import tqdm\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.datasets import InstructMolDataset, MoleculeGPTDataset\nfrom torch_geometric.llm.models import LLM, MoleculeGPT, SentenceTransformer\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GINEConv\n\n\ndef save_params_dict(model, save_path):\n    state_dict = model.state_dict()\n    param_grad_dict = {\n        k: v.requires_grad\n        for (k, v) in model.named_parameters()\n    }\n    for k in list(state_dict.keys()):\n        if k in param_grad_dict.keys() and not param_grad_dict[k]:\n            del state_dict[k]  # Delete parameters that do not require gradient\n    torch.save(state_dict, save_path)\n\n\n@torch.no_grad()\ndef eval(model, data_loader):\n    model.eval()\n    loss = 0\n\n    for batch in data_loader:\n        batch_loss = model(batch.x, batch.edge_index, batch.batch,\n                           batch.edge_attr, batch.smiles, batch.instruction,\n                           batch.y)\n        loss += batch_loss.item() / len(data_loader)\n    return loss\n\n\ndef train(\n    dataset_name: str,\n    num_epochs: int,\n    lr: float,\n    batch_size: int,\n    checkpointing: bool,\n):\n    def adjust_learning_rate(param_group, LR, epoch):\n        # Decay the learning rate with half-cycle cosine after warmup\n        min_lr = 5e-6\n        warmup_epochs = 1\n        if epoch < warmup_epochs:\n            lr = LR\n        else:\n            lr = min_lr + (LR - min_lr) * 0.5 * (\n                1.0 + math.cos(math.pi * (epoch - warmup_epochs) /\n                               (num_epochs - warmup_epochs)))\n        param_group['lr'] = lr\n        return lr\n\n    def get_clippable_params(params):\n        return [\n            p for p in params\n            if isinstance(p, torch.Tensor) and not hasattr(p, '_spec')\n        ]\n\n    start_time = time.time()\n    # Load dataset ================================================\n    path = osp.dirname(osp.realpath(__file__))\n    path = osp.join(path, '..', '..', 'data', dataset_name)\n    if dataset_name == 'MoleculeGPT':\n        dataset = MoleculeGPTDataset(path)\n    elif dataset_name == 'InstructMol':\n        dataset = InstructMolDataset(path)\n    train_size, val_size = int(0.8 * len(dataset)), int(0.1 * len(dataset))\n    train_dataset = dataset[:train_size]\n    val_dataset = dataset[train_size:train_size + val_size]\n    test_dataset = dataset[train_size + val_size:]\n\n    seed_everything(42)\n\n    train_loader = DataLoader(train_dataset, batch_size=batch_size,\n                              drop_last=True, pin_memory=True, shuffle=True)\n    val_loader = DataLoader(val_dataset, batch_size=batch_size,\n                            drop_last=False, pin_memory=True, shuffle=False)\n    test_loader = DataLoader(test_dataset, batch_size=batch_size,\n                             drop_last=False, pin_memory=True, shuffle=False)\n\n    # Create model ===============================================\n    llm = LLM(\n        # model_name='lmsys/vicuna-7b-v1.5',\n        model_name='Qwen/Qwen3-0.6B',\n        num_params=1,\n        dtype=torch.bfloat16,\n        sys_prompt='You are an agent, answer my questions.',\n    )\n\n    graph_encoder = GINEConv(\n        nn=torch.nn.Sequential(\n            torch.nn.Linear(6, 768),\n            torch.nn.ReLU(),\n            torch.nn.Linear(768, 768),\n        ),\n        train_eps=True,\n        edge_dim=4,\n    )\n\n    smiles_encoder = SentenceTransformer(\n        model_name='DeepChem/ChemBERTa-77M-MTR',\n        pooling_strategy='last_hidden_state',\n    )\n\n    model = MoleculeGPT(\n        llm=llm,\n        graph_encoder=graph_encoder,\n        smiles_encoder=smiles_encoder,\n    )\n\n    # Train and eval ============================================\n    params = [p for _, p in model.named_parameters() if p.requires_grad]\n    optimizer = torch.optim.AdamW([\n        {\n            'params': params,\n            'lr': lr,\n            'weight_decay': 0.05,\n        },\n    ], betas=(0.9, 0.95))\n    grad_steps = 2\n\n    best_epoch = 0\n    best_val_loss = float('inf')\n    for epoch in range(num_epochs):\n        # Train\n        model.train()\n        epoch_loss = 0\n        if epoch == 0:\n            print(f\"Total Preparation Time: {time.time() - start_time:2f}s\")\n            start_time = time.time()\n            print(\"Training beginning...\")\n        epoch_str = f'Epoch: {epoch + 1}|{num_epochs}'\n        loader = tqdm(train_loader, desc=epoch_str)\n\n        for step, batch in enumerate(loader):\n            optimizer.zero_grad()\n            loss = model(batch.x, batch.edge_index, batch.batch,\n                         batch.edge_attr, batch.smiles, batch.instruction,\n                         batch.y)\n            loss.backward()\n            clip_grad_norm_(\n                get_clippable_params(optimizer.param_groups[0]['params']), 0.1)\n\n            if (step + 1) % grad_steps == 0:\n                adjust_learning_rate(optimizer.param_groups[0], lr,\n                                     step / len(train_loader) + epoch)\n\n            optimizer.step()\n            epoch_loss += loss.detach().item()\n\n            if (step + 1) % grad_steps == 0:\n                lr = optimizer.param_groups[0]['lr']\n        train_loss = epoch_loss / len(train_loader)\n\n        # Eval\n        val_loss = eval(model, val_loader)\n        print(\n            f'{epoch_str}, Train loss: {train_loss:4f}, Val loss: {val_loss:4f}'  # noqa: E501\n        )\n\n        if checkpointing and val_loss < best_val_loss:\n            best_val_loss = val_loss\n            best_epoch = epoch\n            save_params_dict(\n                model,\n                f'moleculegpt_epoch{best_epoch}_val_loss{best_val_loss:4f}_ckpt.pt'  # noqa: E501\n            )\n    torch.cuda.empty_cache()\n    torch.cuda.reset_peak_memory_stats()\n\n    print(f\"Total Training Time: {time.time() - start_time:2f}s\")\n    # Test\n    test_loss = eval(model, test_loader)\n    print(f'Test loss: {test_loss:4f}')\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset_name\", type=str, default='MoleculeGPT',\n                        choices=['MoleculeGPT', 'InstructMol'],\n                        help='Support MoleculeGPT and InstructMol')\n    parser.add_argument('--epochs', type=int, default=3)\n    parser.add_argument('--lr', type=float, default=1e-5)\n    parser.add_argument('--batch_size', type=int, default=2)\n    parser.add_argument('--checkpointing', type=bool, default=True)\n    args = parser.parse_args()\n\n    start_time = time.time()\n    train(\n        args.dataset_name,\n        args.epochs,\n        args.lr,\n        args.batch_size,\n        args.checkpointing,\n    )\n    print(f'Total Time: {time.time() - start_time:2f}s')\n"
  },
  {
    "path": "examples/llm/protein_mpnn.py",
    "content": "\"\"\"This example implements the ProteinMPNN model\n(https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1) using PyG.\n\"\"\"\nimport argparse\nimport time\n\nimport numpy as np\nimport psutil\nimport torch\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.datasets import ProteinMPNNDataset\nfrom torch_geometric.llm.models import ProteinMPNN\nfrom torch_geometric.loader import DataLoader\n\n\ndef loss_smoothed(y, logits, mask, weight=0.1):\n    \"\"\"Negative log probabilities.\"\"\"\n    y_onehot = torch.nn.functional.one_hot(y, 21).float()\n\n    # Label smoothing\n    y_onehot = y_onehot + weight / float(y_onehot.size(-1))\n    y_onehot = y_onehot / y_onehot.sum(-1, keepdim=True)\n\n    loss = -(y_onehot * logits).sum(-1)\n    loss_av = torch.sum(loss * mask) / 2000.0\n    return loss, loss_av\n\n\ndef loss_nll(y, logits, mask):\n    \"\"\"Negative log probabilities.\"\"\"\n    criterion = torch.nn.NLLLoss(reduction='none')\n    loss = criterion(logits.contiguous().view(-1, logits.size(-1)),\n                     y.contiguous().view(-1)).view(y.size())\n    y_argmaxed = torch.argmax(logits, -1)  # [B, L]\n    true_false = (y == y_argmaxed).float()\n    loss_av = torch.sum(loss * mask) / torch.sum(mask)\n    return loss, loss_av, true_false\n\n\nclass NoamOpt:\n    \"\"\"Optim wrapper that implements rate.\"\"\"\n    def __init__(self, model_size, factor, warmup, optimizer, step):\n        self.optimizer = optimizer\n        self._step = step\n        self.warmup = warmup\n        self.factor = factor\n        self.model_size = model_size\n        self._rate = 0\n\n    @property\n    def param_groups(self):\n        \"\"\"Return param_groups.\"\"\"\n        return self.optimizer.param_groups\n\n    def step(self):\n        \"\"\"Update parameters and rate.\"\"\"\n        self._step += 1\n        rate = self.rate()\n        for p in self.optimizer.param_groups:\n            p['lr'] = rate\n        self._rate = rate\n        self.optimizer.step()\n\n    def rate(self, step=None):\n        \"\"\"Implement learning rate above.\"\"\"\n        if step is None:\n            step = self._step\n        return self.factor * (self.model_size**(-0.5) *\n                              min(step**(-0.5), step * self.warmup**(-1.5)))\n\n    def zero_grad(self):\n        self.optimizer.zero_grad()\n\n\ndef train(model, optimizer, data_loader, device, scaler):\n    model.train()\n    train_sum = 0.0\n    train_acc = 0.0\n    train_weights = 0.0\n    for batch in data_loader:\n        optimizer.zero_grad()\n        batch = batch.to(device)\n        mask_for_loss = batch.mask * batch.chain_mask_all\n        y = batch.chain_seq_label\n\n        if torch.cuda.is_available() and args.mixed_precision:\n            with torch.amp.autocast('cuda'):\n                logits = model(batch.x, batch.chain_seq_label, batch.mask,\n                               batch.chain_mask_all, batch.residue_idx,\n                               batch.chain_encoding_all, batch.batch)\n                _, loss = loss_smoothed(y, logits, mask_for_loss)\n\n            scaler.scale(loss).backward()\n\n            if args.gradient_norm > 0:\n                torch.nn.utils.clip_grad_norm_(model.parameters(),\n                                               args.gradient_norm)\n\n            scaler.step(optimizer)\n            scaler.update()\n        else:\n            logits = model(batch.x, batch.chain_seq_label, batch.mask,\n                           batch.chain_mask_all, batch.residue_idx,\n                           batch.chain_encoding_all, batch.batch)\n\n            _, loss = loss_smoothed(y, logits, mask_for_loss)\n            loss.backward()\n\n            if args.gradient_norm > 0:\n                torch.nn.utils.clip_grad_norm_(model.parameters(),\n                                               args.gradient_norm)\n\n            optimizer.step()\n\n        loss, _, true_false = loss_nll(y, logits, mask_for_loss)\n\n        train_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy()\n        train_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy()\n        train_weights += torch.sum(mask_for_loss).cpu().data.numpy()\n\n    train_loss = train_sum / train_weights\n    train_accuracy = train_acc / train_weights\n    train_perplexity = np.exp(train_loss)\n\n    return train_perplexity, train_accuracy\n\n\n@torch.no_grad()\ndef eval(model, data_loader, device):\n    model.eval()\n    valid_sum = 0.\n    valid_weights = 0.\n    valid_acc = 0.\n    for batch in data_loader:\n        batch = batch.to(device)\n        logits = model(batch.x, batch.chain_seq_label, batch.mask,\n                       batch.chain_mask_all, batch.residue_idx,\n                       batch.chain_encoding_all, batch.batch)\n\n        mask_for_loss = batch.mask * batch.chain_mask_all\n        y = batch.chain_seq_label\n        loss, _, true_false = loss_nll(y, logits, mask_for_loss)\n\n        valid_sum += torch.sum(loss * mask_for_loss).cpu().data.numpy()\n        valid_acc += torch.sum(true_false * mask_for_loss).cpu().data.numpy()\n        valid_weights += torch.sum(mask_for_loss).cpu().data.numpy()\n\n    valid_loss = valid_sum / valid_weights\n    valid_accuracy = valid_acc / valid_weights\n    valid_perplexity = np.exp(valid_loss)\n\n    return valid_perplexity, valid_accuracy\n\n\ndef main(args):\n    wall_clock_start = time.perf_counter()\n    seed_everything(123)\n    scaler = torch.amp.GradScaler()\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\n    if args.size == 'large' and psutil.virtual_memory().total < 64.1 * 1024**3:\n        print('Warning: may not have enough RAM to run this example.')\n        print('Consider upgrading RAM if an error occurs.')\n        print('Estimated RAM Needed: ~64.1GB.')\n\n    train_dataset = ProteinMPNNDataset(\n        root=args.data_path,\n        size=args.size,\n        split='train',\n        rescut=args.rescut,\n        max_length=args.max_protein_length,\n    )\n    valid_dataset = ProteinMPNNDataset(\n        root=args.data_path,\n        size=args.size,\n        split='valid',\n        rescut=args.rescut,\n        max_length=args.max_protein_length,\n    )\n    test_dataset = ProteinMPNNDataset(\n        root=args.data_path,\n        size=args.size,\n        split='test',\n        rescut=args.rescut,\n        max_length=args.max_protein_length,\n    )\n    train_loader = DataLoader(train_dataset, batch_size=args.train_batch_size,\n                              shuffle=True, num_workers=6)\n    valid_loader = DataLoader(valid_dataset, batch_size=args.eval_batch_size,\n                              shuffle=False, num_workers=6)\n    test_loader = DataLoader(test_dataset, batch_size=args.eval_batch_size,\n                             shuffle=False, num_workers=6)\n\n    model = ProteinMPNN(\n        hidden_dim=args.hidden_dim,\n        num_encoder_layers=args.num_encoder_layers,\n        num_decoder_layers=args.num_decoder_layers,\n        num_neighbors=args.num_neighbors,\n        dropout=args.dropout,\n        augment_eps=args.backbone_noise,\n        num_positional_embedding=16,\n    ).to(device)\n\n    total_step = 0\n    optimizer = NoamOpt(\n        model_size=args.hidden_dim, factor=2, warmup=4000,\n        optimizer=torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98),\n                                   eps=1e-9), step=total_step)\n\n    times = []\n    for e in range(args.num_epochs):\n        start = time.perf_counter()\n        train_perplexity, train_accuracy = train(model, optimizer,\n                                                 train_loader, device, scaler)\n        valid_perplexity, valid_accuracy = eval(model, valid_loader, device)\n\n        print(\n            f'epoch: {e:03d}, step: {total_step}, '\n            f'train: {train_perplexity:.3f}, valid: {valid_perplexity:.3f}, '\n            f'train_acc: {train_accuracy:.3f}, valid_acc: {valid_accuracy:.3f}'\n        )\n        times.append(time.perf_counter() - start)\n\n    print(f'Average Epoch Time: {torch.tensor(times).mean():.4f}s')\n    print(f'Median Epoch Time: {torch.tensor(times).median():.4f}s')\n    print(f'Total Program Runtime: '\n          f'{time.perf_counter() - wall_clock_start:.4f}s')\n    # Test\n    test_perplexity, test_accuracy = eval(model, test_loader, device)\n    print(f'test: {test_perplexity:.3f}, test_acc: {test_accuracy:.3f}')\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    # dataset config\n    parser.add_argument('--data_path', type=str, default='data/ProteinMPNN',\n                        help='path for loading training data')\n    parser.add_argument(\n        '--size', type=str, default='small', choices=['small', 'large'],\n        help='Use of \"small (229.4 MB)\" or \"large (64.1 GB)\" dataset')\n    parser.add_argument('--max_protein_length', type=int, default=10000,\n                        help='maximum length of the protein complext')\n    parser.add_argument('--rescut', type=float, default=3.5,\n                        help='PDB resolution cutoff')\n    # training config\n    parser.add_argument('--num_epochs', type=int, default=50,\n                        help='number of epochs to train for')\n    parser.add_argument('--train_batch_size', type=int, default=4,\n                        help='number of tokens for one train batch')\n    parser.add_argument('--eval_batch_size', type=int, default=8,\n                        help='number of tokens for one valid or test batch')\n    parser.add_argument(\n        '--gradient_norm', type=float, default=-1.0,\n        help='clip gradient norm, set to negative to omit clipping')\n    parser.add_argument('--mixed_precision', type=bool, default=True,\n                        help='train with mixed precision')\n    # model config\n    parser.add_argument('--hidden_dim', type=int, default=128,\n                        help='hidden model dimension')\n    parser.add_argument('--num_encoder_layers', type=int, default=3,\n                        help='number of encoder layers')\n    parser.add_argument('--num_decoder_layers', type=int, default=3,\n                        help='number of decoder layers')\n    parser.add_argument('--num_neighbors', type=int, default=30,\n                        help='number of neighbors for the sparse graph')\n    parser.add_argument('--num_rbf', type=int, default=16,\n                        help='number of radial basis functions')\n    parser.add_argument('--dropout', type=float, default=0.1,\n                        help='dropout level; 0.0 means no dropout')\n    parser.add_argument(\n        '--backbone_noise', type=float, default=0.2,\n        help='amount of noise added to backbone during training')\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "examples/llm/txt2kg_rag.py",
    "content": "import argparse\nimport gc\nimport json\nimport os\nimport random\nimport re\nimport sys\nfrom datetime import datetime\nfrom glob import glob\nfrom itertools import chain\nfrom pathlib import Path\n\nimport yaml\n\ntry:\n    import wandb\n    wandb_available = True\nexcept ImportError:\n    wandb_available = False\n\nimport torch\nfrom g_retriever import (\n    adjust_learning_rate,\n    get_loss,\n    inference_step,\n    load_params_dict,\n    save_params_dict,\n)\nfrom huggingface_hub import hf_hub_download\nfrom torch.nn.utils import clip_grad_norm_\nfrom tqdm import tqdm\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.llm import RAGQueryLoader\nfrom torch_geometric.llm.models import (\n    LLM,\n    TXT2KG,\n    GRetriever,\n    LLMJudge,\n    SentenceTransformer,\n)\nfrom torch_geometric.llm.models.txt2kg import _chunk_text\nfrom torch_geometric.llm.utils.backend_utils import (\n    create_graph_from_triples,\n    create_remote_backend_from_graph_data,\n    make_pcst_filter,\n    preprocess_triplet,\n)\nfrom torch_geometric.llm.utils.feature_store import KNNRAGFeatureStore\nfrom torch_geometric.llm.utils.graph_store import NeighborSamplingRAGGraphStore\nfrom torch_geometric.llm.utils.vectorrag import DocumentRetriever\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GAT, SGFormer\n\n# Define constants for better readability\nNV_NIM_MODEL_DEFAULT = \"nvidia/llama-3.1-nemotron-ultra-253b-v1\"\nLLM_GENERATOR_NAME_DEFAULT = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\nENCODER_MODEL_NAME_DEFAULT = \"Alibaba-NLP/gte-modernbert-base\"\nKG_CHUNK_SIZE_DEFAULT = 512\nGNN_HID_CHANNELS_DEFAULT = 1024\nGNN_LAYERS_DEFAULT = 4\nLR_DEFAULT = 1e-5\nEPOCHS_DEFAULT = 2\nBATCH_SIZE_DEFAULT = 1\nEVAL_BATCH_SIZE_DEFAULT = 2\nLLM_GEN_MODE_DEFAULT = \"full\"\nDEFAULT_ENDPOINT_URL = \"https://integrate.api.nvidia.com/v1\"\nmax_chars_in_train_answer = 128\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--gnn_model', type=str, default=\"GAT\",\n                        choices=[\"GAT\", \"SGFormer\"],\n                        help=\"The GNN model to use. Default is GAT.\")\n    parser.add_argument('--NV_NIM_MODEL', type=str,\n                        default=NV_NIM_MODEL_DEFAULT,\n                        help=\"The NIM LLM to use for TXT2KG for LLMJudge\")\n    parser.add_argument('--NV_NIM_KEY', type=str, help=\"NVIDIA API key\")\n    parser.add_argument(\n        '--ENDPOINT_URL', type=str, default=DEFAULT_ENDPOINT_URL,\n        help=\"The URL hosting your model, \\\n        in case you are not using the public NIM.\")\n    parser.add_argument(\n        '--kg_chunk_size', type=int, default=KG_CHUNK_SIZE_DEFAULT,\n        help=\"When splitting context documents for txt2kg,\\\n        the maximum number of characters per chunk.\")\n    parser.add_argument('--gnn_hidden_channels', type=int,\n                        default=GNN_HID_CHANNELS_DEFAULT,\n                        help=\"Hidden channels for GNN\")\n    parser.add_argument('--num_gnn_layers', type=int,\n                        default=GNN_LAYERS_DEFAULT,\n                        help=\"Number of GNN layers\")\n    parser.add_argument('--lr', type=float, default=LR_DEFAULT,\n                        help=\"Learning rate\")\n    parser.add_argument('--epochs', type=int, default=EPOCHS_DEFAULT,\n                        help=\"Number of epochs\")\n    parser.add_argument('--batch_size', type=int, default=BATCH_SIZE_DEFAULT,\n                        help=\"Batch size\")\n    parser.add_argument('--eval_batch_size', type=int,\n                        default=EVAL_BATCH_SIZE_DEFAULT,\n                        help=\"Evaluation batch size\")\n    parser.add_argument('--llm_generator_name', type=str,\n                        default=LLM_GENERATOR_NAME_DEFAULT,\n                        help=\"The LLM to use for Generation\")\n    parser.add_argument(\n        '--llm_generator_mode', type=str, default=LLM_GEN_MODE_DEFAULT,\n        choices=[\"frozen\", \"lora\",\n                 \"full\"], help=\"Whether to freeze the Generator LLM,\\\n                        use LORA, or fully finetune\")\n    parser.add_argument('--dont_save_model', action=\"store_true\",\n                        help=\"Whether to skip model saving.\")\n    parser.add_argument('--log_steps', type=int, default=30,\n                        help=\"Log to wandb every N steps\")\n    parser.add_argument('--wandb_project', type=str, default=\"techqa\",\n                        help=\"Weights & Biases project name\")\n    parser.add_argument('--wandb', action=\"store_true\",\n                        help=\"Enable wandb logging\")\n    parser.add_argument(\n        '--num_gpus', type=int, default=None,\n        help=\"Number of GPUs to use. If not specified,\"\n        \"will determine automatically based on model size.\")\n    parser.add_argument('--regenerate_dataset', action=\"store_true\",\n                        help=\"Regenerate the dataset\")\n    parser.add_argument(\n        '--doc_parsing_mode', type=str, default=None,\n        choices=[\"paragraph\",\n                 \"file\"], help=\"How to parse documents: 'paragraph' splits \"\n        \"files by paragraphs, 'file' treats each file as\"\n        \"one document. \"\n        \"This will override any value set in the config file.\")\n    parser.add_argument(\n        '--k_for_docs', type=int, default=None,\n        help=\"Number of docs to retrieve for each question. \"\n        \"This will override any value set in the config file.\")\n    parser.add_argument(\n        '--doc_chunk_size', type=int, default=None,\n        help=\"The chunk size to use VectorRAG (document retrieval). \"\n        \"This will override any value set in the config file.\")\n    parser.add_argument(\n        '--dataset', type=str, default=\"techqa\", help=\"Dataset folder name, \"\n        \"should contain corpus and train.json files.\"\n        \"extracted triples, processed dataset, \"\n        \"document retriever, and model checkpoints \"\n        \"will be saved in the dataset folder\")\n    parser.add_argument(\n        '--skip_graph_rag', action=\"store_true\",\n        help=\"Skip the graph RAG step. \"\n        \"Used to compare the performance of Vector+Graph RAG vs Vector RAG.\")\n    parser.add_argument(\n        '--use_x_percent_corpus', default=100.0, type=float,\n        help=\"Debug flag that allows user to only use a random percentage \"\n        \"of available knowledge base corpus for RAG\")\n    args = parser.parse_args()\n\n    assert args.NV_NIM_KEY, \"NVIDIA API key is required for TXT2KG and eval\"\n    assert args.use_x_percent_corpus <= 100 and \\\n        args.use_x_percent_corpus > 0, \"Please provide a value in (0,100]\"\n    if args.skip_graph_rag:\n        print(\"Skipping graph RAG step, setting GNN layers to 0...\")\n        args.num_gnn_layers = 0\n\n    config_path = os.path.join(args.dataset, \"config.yaml\")\n    if os.path.exists(config_path):\n        print(f\"Loading config from {config_path}...\")\n        with open(config_path) as config_file:\n            config = yaml.safe_load(config_file)\n\n        if config is not None:\n            # Use a loop to check and apply config values for each parameter\n            config_params = [\n                'doc_parsing_mode', 'doc_chunk_size', 'k_for_docs'\n            ]\n            for param in config_params:\n                if param in config and getattr(args, param) is None:\n                    setattr(args, param, config[param])\n                    print(f\"Using config value for {param}: {config[param]}\")\n    else:\n        print(\"Skipping config loading...\")\n        if args.dataset == \"techqa\":\n            if args.doc_chunk_size is None:\n                args.doc_chunk_size = 1024\n            if args.k_for_docs is None:\n                args.k_for_docs = 14\n\n    assert args.doc_chunk_size is not None, \"doc_chunk_size has not been set\"\n    assert args.k_for_docs is not None, \"k_for_docs has not been set\"\n\n    return args\n\n\nsys_prompt = (\n    \"You are an expert assistant that can answer \"\n    \"any question from its knowledge, given a knowledge graph embedding and \"\n    \"it's textualized context. Just give the answer, without explanation.\")\n\nprompt_template = \"\"\"\n    [QUESTION]\n    {question}\n    [END_QUESTION]\n\n    [RETRIEVED_CONTEXTS]\n    {context}\n    [END_RETRIEVED_CONTEXTS]\n    \"\"\"\n\n\ndef _process_and_chunk_text(text, chunk_size, doc_parsing_mode):\n    full_chunks = []\n    \"\"\"\n    Some corpora of docs are grouped into chunked files,\n    typically by paragraph.\n    Only split into individual documents\n    if multiple paragraphs are detected.\n    \"\"\"\n    if doc_parsing_mode == \"paragraph\":\n        paragraphs = re.split(r'\\n{2,}', text)\n    else:\n        # doc_parsing_mode == 'file' or doc_parsing_mode is None\n        paragraphs = [text]\n\n    for paragraph in paragraphs:\n        if chunk_size is not None:\n            chunks = _chunk_text(paragraph, chunk_size)\n        else:\n            # defaults to 512 in _chunk_text\n            chunks = _chunk_text(paragraph)\n        full_chunks.extend(chunks)\n    return full_chunks\n\n\ndef get_data(args):\n    # need a JSON dict of Questions and answers, see below for how its used\n\n    json_path = Path(args.dataset) / \"train.json\"\n    corpus_path = Path(args.dataset) / \"corpus\"\n\n    # techqa specified but neither corpus or train.json exists\n    if \"techqa\" in args.dataset.lower() and not (json_path.exists()\n                                                 or corpus_path.exists()):\n        print(\"Could not find Q&A pairs and/or knowledge base corpus\")\n        print(\"Would you like to download the TechQA dataset for demo?\")\n        user_input = input(\"Y/N: \")\n        if user_input.lower() == \"y\" or user_input.lower() == \"yes\":\n            print(\"Downloading data...\")\n            # downloads\n            zip_path = hf_hub_download(\n                repo_id=\"nvidia/TechQA-RAG-Eval\",\n                repo_type=\"dataset\",\n                filename=\"corpus.zip\",\n            )\n            json_path = hf_hub_download(\n                repo_id=\"nvidia/TechQA-RAG-Eval\",\n                repo_type=\"dataset\",\n                filename=\"train.json\",\n            )\n            # move to working dir\n            if not os.path.exists(args.dataset):\n                os.mkdir(args.dataset)\n            import zipfile\n            with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n                zip_ref.extractall(args.dataset)\n            import shutil\n            shutil.copy(json_path, os.path.join(args.dataset, \"train.json\"))\n        elif user_input.lower() == \"n\" or user_input.lower() == \"no\":\n            sys.exit(\"No selected, no data to work with... exiting.\")\n        else:\n            sys.exit(\"Invalid user input, exiting.\")\n    with open(os.path.join(args.dataset, \"train.json\")) as file:\n        json_obj = json.load(file)\n    text_contexts = []\n\n    # Read corpus data to create the KG and for document retrieval (RAG).\n    # Prefer *.json files, fall back to txt files.\n    # TODO: add support for additional corpus file formats: PDF, CSV, XML,\n    # HTML, possibly others.\n    # corpus folder is simply a folder with context documents in it.\n    file_paths = glob(os.path.join(args.dataset, \"corpus\", \"*.json\"))\n    if len(file_paths) > 0:\n        for file_path in file_paths:\n            with open(file_path, \"r+\") as f:\n                data = json.load(f)\n            doc_type = data[0][\"document_type\"]\n            if doc_type != \"text\":\n                raise ValueError(f\"Bad extraction for {file_path}, expecting \"\n                                 f\"text only but got {doc_type}\")\n            text_contexts.extend(\n                _process_and_chunk_text(data[0][\"metadata\"][\"content\"],\n                                        args.doc_chunk_size,\n                                        args.doc_parsing_mode))\n    else:\n        for file_path in glob(os.path.join(args.dataset, \"corpus\", \"*\")):\n            with open(file_path, \"r+\") as f:\n                text_context = f.read()\n            text_contexts.extend(\n                _process_and_chunk_text(text_context, args.doc_chunk_size,\n                                        args.doc_parsing_mode))\n    if args.use_x_percent_corpus < 100:\n        random.shuffle(text_contexts)\n        text_contexts = text_contexts[\n            0:int(len(text_contexts) * args.use_x_percent_corpus / 100.0)]\n\n    return json_obj, text_contexts\n\n\ndef index_kg(args, context_docs):\n    kg_maker = TXT2KG(NVIDIA_NIM_MODEL=args.NV_NIM_MODEL,\n                      NVIDIA_API_KEY=args.NV_NIM_KEY,\n                      ENDPOINT_URL=args.ENDPOINT_URL,\n                      chunk_size=args.kg_chunk_size)\n    print(\n        \"Note that if the TXT2KG process is too slow for you're liking using \"\n        \"the public NIM, consider deploying yourself using local_lm flag of \"\n        \"TXT2KG or using https://build.nvidia.com/nvidia/llama-3_1-nemotron-70b-instruct \"  # noqa\n        \"to deploy to a private endpoint, which you can pass to this script \"\n        \"w/ --ENDPOINT_URL flag.\")\n    print(\n        \"Guide for deploying NIM: https://developer.nvidia.com/blog/a-simple-guide-to-deploying-generative-ai-with-nvidia-nim/\"  # noqa\n    )\n    total_tqdm_count = len(context_docs)\n    initial_tqdm_count = 0\n    checkpoint_file = list(Path(args.dataset).glob(\"*--*--checkpoint_kg.pt\"))\n    if len(checkpoint_file) > 1:\n        raise RuntimeError(\"Error: more than one checkpoint file found\")\n\n    if len(checkpoint_file) == 1:\n        print(\"Restoring KG from checkpoint\")\n        checkpoint_file = checkpoint_file[0]\n        checkpoint_model_name = checkpoint_file.name.split('--')[0]\n        # check if triples generation are using the correct model\n        if args.NV_NIM_MODEL.split('/')[-1] != checkpoint_model_name:\n            raise RuntimeError(\n                \"Error: stored triples were generated using a different model\")\n        saved_relevant_triples = torch.load(checkpoint_file,\n                                            weights_only=False)\n        kg_maker.relevant_triples = saved_relevant_triples\n        kg_maker.doc_id_counter = len(saved_relevant_triples)\n        initial_tqdm_count = kg_maker.doc_id_counter\n        context_docs = context_docs[kg_maker.doc_id_counter:]\n\n    chkpt_interval = 10\n    chkpt_count = 0\n    for context_doc in tqdm(context_docs, total=total_tqdm_count,\n                            initial=initial_tqdm_count,\n                            desc=\"Extracting KG triples\"):\n        kg_maker.add_doc_2_KG(txt=context_doc)\n        chkpt_count += 1\n        if chkpt_count == chkpt_interval:\n            chkpt_count = 0\n            path = args.dataset + \"/{m}--{t}--checkpoint_kg.pt\"\n            model = kg_maker.NIM_MODEL.split(\n                '/')[-1] if not kg_maker.local_LM else \"local\"\n            path = path.format(m=model,\n                               t=datetime.now().strftime(\"%Y%m%d_%H%M%S\"))\n            torch.save(kg_maker.relevant_triples, path)\n\n    relevant_triples = kg_maker.relevant_triples\n    triples = list(\n        chain.from_iterable(triple_set\n                            for triple_set in relevant_triples.values()))\n    triples = [preprocess_triplet(triplet) for triplet in triples]\n    triples = list(dict.fromkeys(triples))\n    raw_triples_path = args.dataset + \"/{m}--{t}--raw_triples.pt\"\n\n    model_name = kg_maker.NIM_MODEL.split(\n        '/')[-1] if not kg_maker.local_LM else \"local\"\n    torch.save(\n        triples,\n        raw_triples_path.format(m=model_name,\n                                t=datetime.now().strftime(\"%Y%m%d_%H%M%S\")))\n\n    for old_checkpoint_file in Path(\n            args.dataset).glob(\"*--*--checkpoint_kg.pt\"):\n        os.remove(old_checkpoint_file)\n\n    return triples\n\n\ndef update_data_lists(args, data_lists):\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    # creating the embedding model\n    sent_trans_batch_size = 256\n    model = SentenceTransformer(\n        model_name=ENCODER_MODEL_NAME_DEFAULT).to(device).eval()\n    model_kwargs = {\n        \"output_device\": device,\n        \"batch_size\": int(sent_trans_batch_size / 4),\n    }\n    doc_retriever_path = os.path.join(args.dataset, \"document_retriever.pt\")\n    if os.path.exists(doc_retriever_path):\n        print(\"Loading document retriever from checkpoint...\")\n        vector_retriever = DocumentRetriever.load(doc_retriever_path,\n                                                  model=model.encode,\n                                                  model_kwargs=model_kwargs)\n        if args.k_for_docs != vector_retriever.k_for_docs:\n            vector_retriever.k_for_docs = args.k_for_docs\n        else:\n            return data_lists\n    else:\n        raise ValueError(\"Document retriever not found\")\n\n    print(\"k_for_docs changed, updating data lists...\")\n\n    total_points = sum(len(data_list) for data_list in data_lists.values())\n\n    progress_bar = tqdm(total=total_points, desc=\"Updating text contexts\")\n\n    for data_list in data_lists.values():\n        for data_point in data_list:\n            q = data_point[\"question\"]\n            data_point[\"text_context\"] = vector_retriever.query(q)\n            progress_bar.update(1)\n\n    progress_bar.close()\n\n    vector_retriever.save(doc_retriever_path)\n\n    del vector_retriever\n    gc.collect()\n    torch.cuda.empty_cache()\n\n    dataset_name = os.path.basename(args.dataset)\n    dataset_path = os.path.join(args.dataset, f\"{dataset_name}.pt\")\n    torch.save(data_lists, dataset_path)\n    return data_lists\n\n\ndef make_dataset(args):\n    qa_pairs, context_docs = get_data(args)\n    print(\"Number of Docs in our VectorDB =\", len(context_docs))\n    data_lists = {\"train\": [], \"validation\": [], \"test\": []}\n\n    triples = []\n    raw_triples_file = list(Path(args.dataset).glob(\"*--*--raw_triples.pt\"))\n    if len(raw_triples_file) > 1:\n        raise RuntimeError(\"Error: multiple raw_triples files found\")\n    if len(raw_triples_file) == 1:\n        raw_triples_file = raw_triples_file[0]\n        stored_model_name = raw_triples_file.name.split('--')[0]\n\n        if args.NV_NIM_MODEL.split('/')[-1] != stored_model_name:\n            raise RuntimeError(\n                \"Error: stored triples were generated using a different model\")\n\n        print(f\" -> Saved triples generated with: {stored_model_name}\")\n        triples = torch.load(raw_triples_file)\n    else:\n        triples = index_kg(args, context_docs)\n\n    print(\"Number of triples in our GraphDB =\", len(triples))\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n    # creating the embedding model\n    sent_trans_batch_size = 256\n    model = SentenceTransformer(\n        model_name=ENCODER_MODEL_NAME_DEFAULT).to(device)\n\n    print(\"Creating the graph data from raw triples...\")\n    # create the graph data from raw triples\n    graph_data = create_graph_from_triples(\n        triples=triples, embedding_model=model.encode,\n        embedding_method_kwargs={\n            \"batch_size\": min(len(triples), sent_trans_batch_size),\n            \"verbose\": True\n        })\n\n    print(\"Creating the graph and feature stores...\")\n    # creating the graph and feature stores\n    fs, gs = create_remote_backend_from_graph_data(\n        graph_data=graph_data, path=\"backend\",\n        graph_db=NeighborSamplingRAGGraphStore,\n        feature_db=KNNRAGFeatureStore).load()\n    \"\"\"\n    NOTE: these retriever hyperparams are very important.\n    Tuning may be needed for custom data...\n    \"\"\"\n\n    model_kwargs = {\n        \"output_device\": device,\n        \"batch_size\": int(sent_trans_batch_size / 4),\n        \"verbose\": True\n    }\n\n    doc_retriever_path = os.path.join(args.dataset, \"document_retriever.pt\")\n    if os.path.exists(doc_retriever_path):\n        print(\"Loading document retriever from checkpoint...\")\n        vector_retriever = DocumentRetriever.load(doc_retriever_path,\n                                                  model=model.encode,\n                                                  model_kwargs=model_kwargs)\n        if args.k_for_docs != vector_retriever.k_for_docs:\n            vector_retriever.k_for_docs = args.k_for_docs\n    else:\n        print(\"Creating document retriever...\")\n        vector_retriever = DocumentRetriever(context_docs,\n                                             k_for_docs=args.k_for_docs,\n                                             model=model.encode,\n                                             model_kwargs=model_kwargs)\n        vector_retriever.save(doc_retriever_path)\n\n    subgraph_filter = make_pcst_filter(\n        triples,\n        model,\n        topk=5,  # nodes\n        topk_e=5,  # edges\n        cost_e=.5,  # edge cost\n        num_clusters=10)  # num clusters\n\n    # number of neighbors for each seed node selected by KNN\n    fanout = 100\n    # number of hops for neighborsampling\n    num_hops = 2\n\n    query_loader_config = {\n        \"k_nodes\": 1024,  # k for Graph KNN\n        \"num_neighbors\": [fanout] * num_hops,  # number of sampled neighbors\n        \"encoder_model\": model,\n    }\n\n    # GraphDB retrieval done with KNN+NeighborSampling+PCST\n    # PCST = Prize Collecting Steiner Tree\n    # VectorDB retrieval just vanilla vector RAG\n    print(\"Now to retrieve context for each query from \"\n          \"our Vector and Graph DBs...\")\n\n    query_loader = RAGQueryLoader(graph_data=(fs, gs),\n                                  subgraph_filter=subgraph_filter,\n                                  vector_retriever=vector_retriever,\n                                  config=query_loader_config)\n\n    # pre-process the dataset\n    total_data_list = []\n    extracted_triple_sizes = []\n    global max_chars_in_train_answer\n    for data_point in tqdm(qa_pairs, desc=\"Building un-split dataset\"):\n        if data_point[\"is_impossible\"]:\n            continue\n        QA_pair = (data_point[\"question\"], data_point[\"answer\"])\n        q = QA_pair[0]\n        max_chars_in_train_answer = max(len(QA_pair[1]),\n                                        max_chars_in_train_answer)\n        # (TODO) make this batch queries for retrieving w/ CuVS+CuGraph\n        subgraph = query_loader.query(q)\n        subgraph.label = QA_pair[1]\n        total_data_list.append(subgraph)\n        extracted_triple_sizes.append(len(subgraph.triples))\n    random.shuffle(total_data_list)\n\n    # stats\n    print(\"Min # of Retrieved Triples =\", min(extracted_triple_sizes))\n    print(\"Max # of Retrieved Triples =\", max(extracted_triple_sizes))\n    print(\"Average # of Retrieved Triples =\",\n          sum(extracted_triple_sizes) / len(extracted_triple_sizes))\n\n    # 60:20:20 split\n    data_lists[\"train\"] = total_data_list[:int(.6 * len(total_data_list))]\n    data_lists[\"validation\"] = total_data_list[int(.6 * len(total_data_list)\n                                                   ):int(.8 *\n                                                         len(total_data_list))]\n    data_lists[\"test\"] = total_data_list[int(.8 * len(total_data_list)):]\n\n    dataset_name = os.path.basename(args.dataset)\n    dataset_path = os.path.join(args.dataset, f\"{dataset_name}.pt\")\n    torch.save((data_lists, max_chars_in_train_answer), dataset_path)\n    del model\n    gc.collect()\n    torch.cuda.empty_cache()\n    return data_lists\n\n\ndef train(args, train_loader, val_loader):\n    if args.wandb:\n        wandb.init(project=args.wandb_project,\n                   name=f\"run_{datetime.now().strftime('%Y-%m-%d_%H:%M')}\",\n                   config=vars(args))\n    hidden_channels = args.gnn_hidden_channels\n    num_gnn_layers = args.num_gnn_layers\n\n    if args.num_gnn_layers > 0:\n        if args.gnn_model == \"GAT\":\n            gnn = GAT(in_channels=768, hidden_channels=hidden_channels,\n                      out_channels=1024, num_layers=num_gnn_layers, heads=4)\n        elif args.gnn_model == \"SGFormer\":\n            gnn = SGFormer(in_channels=768, hidden_channels=hidden_channels,\n                           out_channels=1024, trans_num_heads=1,\n                           trans_dropout=0.5, gnn_num_layers=num_gnn_layers,\n                           gnn_dropout=0.5)\n        else:\n            raise ValueError(f\"Invalid GNN model: {args.gnn_model}\")\n    else:\n        gnn = None\n\n    if args.llm_generator_mode == \"full\":\n        llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt,\n                  n_gpus=args.num_gpus)\n    elif args.llm_generator_mode == \"lora\":\n        llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt,\n                  dtype=torch.float32, n_gpus=args.num_gpus)\n    else:\n        # frozen\n        llm = LLM(model_name=args.llm_generator_name, sys_prompt=sys_prompt,\n                  dtype=torch.float32, n_gpus=args.num_gpus).eval()\n        for _, p in llm.named_parameters():\n            p.requires_grad = False\n\n    model = GRetriever(llm=llm, gnn=gnn,\n                       use_lora=args.llm_generator_mode == \"lora\")\n\n    save_name = os.path.join(args.dataset, \"model.pt\")\n\n    if args.llm_generator_mode == \"frozen\" and args.num_gnn_layers == 0:\n        if not args.dont_save_model:\n            save_params_dict(model, save_path=save_name)\n        return model\n\n    if os.path.exists(save_name) and not args.regenerate_dataset:\n        print(\"Re-using saved G-retriever model for testing...\")\n        model = load_params_dict(model, save_name)\n    else:\n        params = [p for _, p in model.named_parameters() if p.requires_grad]\n        lr = args.lr\n        optimizer = torch.optim.AdamW([{\n            'params': params,\n            'lr': lr,\n            'weight_decay': 0.05\n        }], betas=(0.9, 0.95))\n\n        num_oom_errors = 0\n        for epoch in range(args.epochs):\n            model.train()\n            epoch_loss = 0\n            epoch_str = f'Epoch: {epoch + 1}|{args.epochs}'\n            loader = tqdm(train_loader, desc=epoch_str)\n            for step, batch in enumerate(loader):\n                new_qs = []\n                for i, q in enumerate(batch[\"question\"]):\n                    # insert VectorRAG context\n                    new_qs.append(\n                        prompt_template.format(\n                            question=q,\n                            context=\"\\n\".join(batch.text_context[i])))\n                batch.question = new_qs\n\n                if args.skip_graph_rag:\n                    batch.desc = \"\"\n\n                optimizer.zero_grad()\n                try:\n                    loss = get_loss(model, batch)\n                    loss.backward()\n                    clip_grad_norm_(optimizer.param_groups[0]['params'], 0.1)\n                    if (step + 1) % 2 == 0:\n                        adjust_learning_rate(optimizer.param_groups[0], lr,\n                                             step / len(train_loader) + epoch,\n                                             args.epochs)\n                    optimizer.step()\n                    epoch_loss += float(loss.detach())\n\n                    if args.wandb and (step + 1) % args.log_steps == 0:\n                        wandb.log({\n                            \"train/loss\": float(loss.detach()),\n                            \"train/lr\": optimizer.param_groups[0]['lr'],\n                        })\n\n                    if (step + 1) % 2 == 0:\n                        lr = optimizer.param_groups[0]['lr']\n                except torch.cuda.OutOfMemoryError:\n                    torch.cuda.empty_cache()\n                    print(\"Sequence length of last batch: \",\n                          model.seq_length_stats[-1])\n                    # TODO: Implement CPU fallback (WIP)\n                    num_oom_errors += 1\n            print(\"Sequence length stats: \")\n            print(\"seq_len avg: \",\n                  sum(model.seq_length_stats) / len(model.seq_length_stats))\n            print(\"seq_len min: \", min(model.seq_length_stats))\n            print(\"seq_len max: \", max(model.seq_length_stats))\n            print(\"Percent of OOM errors: \",\n                  num_oom_errors / len(train_loader))\n            train_loss = epoch_loss / len(train_loader)\n            print(epoch_str + f', Train Loss: {train_loss:4f}')\n\n            # Eval Step\n            val_loss = 0\n            model.eval()\n            with torch.no_grad():\n                for batch in val_loader:\n                    new_qs = []\n                    for i, q in enumerate(batch[\"question\"]):\n                        # insert VectorRAG context\n                        new_qs.append(\n                            prompt_template.format(\n                                question=q,\n                                context=\"\\n\".join(batch.text_context[i])))\n                    batch.question = new_qs\n                    if args.skip_graph_rag:\n                        batch.desc = \"\"\n                    loss = get_loss(model, batch)\n                    val_loss += loss.item()\n                val_loss = val_loss / len(val_loader)\n                print(epoch_str + f\", Val Loss: {val_loss:4f}\")\n\n                if args.wandb:\n                    wandb.log({\n                        \"val/loss\": val_loss,\n                        \"train/epoch_loss\": train_loss,\n                        \"epoch\": epoch + 1\n                    })\n\n        if args.wandb:\n            wandb.finish()\n\n        torch.cuda.empty_cache()\n        torch.cuda.reset_peak_memory_stats()\n        model.eval()\n        if not args.dont_save_model:\n            save_params_dict(model, save_path=save_name)\n    return model\n\n\ndef test(model, test_loader, args):\n    llm_judge = LLMJudge(args.NV_NIM_MODEL, args.NV_NIM_KEY, args.ENDPOINT_URL)\n\n    def eval(question: str, pred: str, correct_answer: str):\n        # calculate the score based on pred and correct answer\n        return llm_judge.score(question, pred, correct_answer)\n\n    scores = []\n    eval_tuples = []\n    for test_batch in tqdm(test_loader, desc=\"Testing\"):\n        new_qs = []\n        raw_qs = test_batch[\"question\"]\n        for i, q in enumerate(test_batch[\"question\"]):\n            # insert VectorRAG context\n            new_qs.append(\n                prompt_template.format(\n                    question=q, context=\"\\n\".join(test_batch.text_context[i])))\n        test_batch.question = new_qs\n        if args.skip_graph_rag:\n            test_batch.desc = \"\"\n        preds = (inference_step(model, test_batch,\n                                max_out_tokens=max_chars_in_train_answer / 2))\n        for question, pred, label in zip(raw_qs, preds, test_batch.label):\n            eval_tuples.append((question, pred, label))\n    for question, pred, label in tqdm(eval_tuples, desc=\"Eval\"):\n        scores.append(eval(question, pred, label))\n    avg_scores = sum(scores) / len(scores)\n    print(\"Avg marlin accuracy=\", avg_scores)\n    print(\"*\" * 5 + \"NOTE\" + \"*\" * 5)\n    print(\"Marlin Accuracy is Estimated by LLM as a Judge!\")\n    print(\"Improvement of this estimation process is WIP...\")\n\n\nif __name__ == '__main__':\n    # for reproducibility\n    seed_everything(50)\n\n    args = parse_args()\n    if args.wandb and not wandb_available:\n        print(\"Error: wandb package not found but --wandb flag was used.\")\n        print(\"Please install wandb and rerun the script.\")\n        sys.exit(1)\n\n    # Need to sanitize sensitive keys\n    saved_NIM_KEY = args.NV_NIM_KEY\n    args.NV_NIM_KEY = \"********\"\n    print(f\"Starting {args.dataset} training with args: \", args)\n    args.NV_NIM_KEY = saved_NIM_KEY\n\n    dataset_name = os.path.basename(args.dataset)\n    dataset_path = os.path.join(args.dataset, f\"{dataset_name}.pt\")\n    if os.path.exists(dataset_path) and not args.regenerate_dataset:\n        print(f\"Re-using Saved {dataset_name} KG-RAG Dataset...\")\n        data_lists, max_chars_in_train_answer = torch.load(\n            dataset_path, weights_only=False)\n        doc_retriever_path = os.path.join(args.dataset,\n                                          \"document_retriever.pt\")\n        if os.path.exists(doc_retriever_path):\n            print(\"Updating data lists with document retriever...\")\n            data_lists = update_data_lists(args, data_lists)\n    else:\n        data_lists = make_dataset(args)\n    batch_size = args.batch_size\n    eval_batch_size = args.eval_batch_size\n    train_loader = DataLoader(data_lists[\"train\"], batch_size=batch_size,\n                              drop_last=True, pin_memory=True, shuffle=True)\n    val_loader = DataLoader(data_lists[\"validation\"],\n                            batch_size=eval_batch_size, drop_last=False,\n                            pin_memory=True, shuffle=False)\n    test_loader = DataLoader(data_lists[\"test\"], batch_size=eval_batch_size,\n                             drop_last=False, pin_memory=True, shuffle=False)\n\n    model = train(args, train_loader, val_loader)\n    test(model, test_loader, args)\n"
  },
  {
    "path": "examples/lpformer.py",
    "content": "import random\nfrom argparse import ArgumentParser\nfrom collections import defaultdict\n\nimport numpy as np\nimport torch\nfrom ogb.linkproppred import Evaluator, PygLinkPropPredDataset\nfrom torch.utils.data import DataLoader\nfrom tqdm import tqdm\n\nfrom torch_geometric.nn.models import LPFormer\n\nparser = ArgumentParser()\nparser.add_argument('--data_name', type=str, default='ogbl-ppa')\nparser.add_argument('--lr', type=float, default=1e-3)\nparser.add_argument('--epochs', type=int, default=100)\nparser.add_argument('--runs', help=\"# random seeds to run over\", type=int,\n                    default=5)\nparser.add_argument('--batch_size', type=int, default=32768)\nparser.add_argument('--hidden_channels', type=int, default=64)\nparser.add_argument('--gnn_layers', type=int, default=3)\nparser.add_argument('--dropout', help=\"Applies to GNN and Transformer\",\n                    type=float, default=0.1)\nparser.add_argument('--device', type=str, default='cuda')\nparser.add_argument('--eps', help=\"PPR precision\", type=float, default=5e-5)\nparser.add_argument('--thresholds',\n                    help=\"List of cn, 1-hop, >1-hop (in that order)\",\n                    nargs=\"+\", default=[0, 1e-4, 1e-2])\nargs = parser.parse_args()\n\ndevice = torch.device(args.device)\n\ndataset = PygLinkPropPredDataset(name=args.data_name)\ndata = dataset[0].to(device)\ndata.edge_index = data.edge_index.to(device)\n\nif hasattr(data, 'x') and data.x is not None:\n    data.x = data.x.to(device).to(torch.float)\n\nsplit_edge = dataset.get_edge_split()\nsplit_data = {\n    \"train_pos\": split_edge['train']['edge'].to(device),\n    \"valid_pos\": split_edge['valid']['edge'].to(device),\n    \"valid_neg\": split_edge['valid']['edge_neg'].to(device),\n    \"test_pos\": split_edge['test']['edge'].to(device),\n    \"test_neg\": split_edge['test']['edge_neg'].to(device)\n}\n\nif hasattr(data, 'edge_weight') and data.edge_weight is not None:\n    edge_weight = data.edge_weight.to(torch.float)\n    data.edge_weight = data.edge_weight.view(-1).to(torch.float)\nelse:\n    edge_weight = torch.ones(data.edge_index.size(1)).to(device).float()\n\n# Convert edge_index to SparseTensor for efficiency\n# adj_prop = SparseTensor.from_edge_index(\n#     data.edge_index, edge_weight.squeeze(-1),\n#     [data.num_nodes, data.num_nodes]).to(device)\nadj_prop = torch.sparse_coo_tensor(data.edge_index, edge_weight.squeeze(-1),\n                                   [data.num_nodes, data.num_nodes]).to(device)\n\nevaluator_hit = Evaluator(name=args.data_name)\n\nmodel = LPFormer(data.x.size(-1), args.hidden_channels,\n                 num_gnn_layers=args.gnn_layers,\n                 ppr_thresholds=args.thresholds, gnn_dropout=args.dropout,\n                 transformer_dropout=args.dropout, gcn_cache=True).to(device)\n\n# Get PPR matrix in sparse format\nppr_matrix = model.calc_sparse_ppr(data.edge_index, data.num_nodes,\n                                   eps=args.eps).to(device)\n\n\ndef train_epoch():\n    model.train()\n    train_pos = split_data['train_pos'].to(device)\n    adjt_mask = torch.ones(train_pos.size(0), dtype=torch.bool, device=device)\n\n    total_loss = total_examples = 0\n    d = DataLoader(range(train_pos.size(0)), args.batch_size, shuffle=True)\n\n    for perm in tqdm(d, \"Epoch\"):\n        edges = train_pos[perm].t()\n\n        # Mask positive input samples - Common strategy during training\n        adjt_mask[perm] = 0\n        edge2keep = train_pos[adjt_mask, :].t()\n        # masked_adj_prop = SparseTensor.from_edge_index(\n        #     edge2keep.t(), sparse_sizes=(data['num_nodes'],\n        #                                  data['num_nodes'])).to_device(device)\n        # masked_adj_prop = masked_adj_prop.to_symmetric()\n\n        # Ensure symmetric\n        edge2keep = torch.cat((edge2keep, edge2keep[[1, 0]]), dim=1)\n        masked_adj_prop = torch.sparse_coo_tensor(\n            edge2keep,\n            torch.ones(edge2keep.size(1)).to(device),\n            (data['num_nodes'], data['num_nodes'])).to(device)\n\n        # For next batch\n        adjt_mask[perm] = 1\n\n        pos_out = model(edges, data.x, masked_adj_prop, ppr_matrix)\n        pos_loss = -torch.log(torch.sigmoid(pos_out) + 1e-6).mean()\n\n        # Trivial random sampling\n        neg_edges = torch.randint(0, data['num_nodes'],\n                                  (edges.size(0), edges.size(1)),\n                                  dtype=torch.long, device=edges.device)\n\n        neg_out = model(neg_edges, data.x, adj_prop, ppr_matrix)\n        neg_loss = -torch.log(1 - torch.sigmoid(neg_out) + 1e-6).mean()\n\n        loss = pos_loss + neg_loss\n        loss.backward()\n\n        optimizer.step()\n        optimizer.zero_grad()\n\n        num_examples = pos_out.size(0)\n        total_loss += loss.item() * num_examples\n        total_examples += num_examples\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test():\n    # NOTE: Eval for ogbl-citation2 is different\n    # See `train.py` in https://github.com/HarryShomer/LPFormer/ for more\n    # Also see there for how to eval under the HeaRT setting\n    # HeaRT = https://arxiv.org/abs/2306.10453\n    model.eval()\n    all_preds = defaultdict(list)\n\n    for split_key, split_vals in split_data.items():\n        if \"train\" not in split_key:\n            preds = []\n            for perm in DataLoader(range(split_vals.size(0)), args.batch_size):\n                edges = split_vals[perm].t()\n                perm_logits = model(edges, data.x, adj_prop, ppr_matrix)\n                preds += [torch.sigmoid(perm_logits).cpu()]\n\n            all_preds[split_key] = torch.cat(preds, dim=0)\n\n    val_hits = evaluator_hit.eval({\n        'y_pred_pos': all_preds['valid_pos'],\n        'y_pred_neg': all_preds['valid_neg']\n    })[f'hits@{evaluator_hit.K}']\n    test_hits = evaluator_hit.eval({\n        'y_pred_pos': all_preds['test_pos'],\n        'y_pred_neg': all_preds['test_neg']\n    })[f'hits@{evaluator_hit.K}']\n\n    return val_hits, test_hits\n\n\ndef set_seeds(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n\n\n# Train over args.runs seeds and average results\n# Best result for reach run chosen via validation\nval_perf_runs = []\ntest_perf_runs = []\nfor run in range(1, args.runs + 1):\n    print(\"=\" * 75)\n    print(f\"RUNNING run={run}\")\n    print(\"=\" * 75)\n\n    set_seeds(run)\n    model.reset_parameters()\n    optimizer = torch.optim.Adam(list(model.parameters()), lr=args.lr)\n\n    best_valid = 0\n    best_valid_test = 0\n\n    for epoch in range(1, 1 + args.epochs):\n        loss = train_epoch()\n        print(f\"Epoch {epoch} Loss: {loss:.4f}\\n\")\n\n        if epoch % 5 == 0:\n            print(\"Evaluating model...\\n\", flush=True)\n            eval_val, eval_test = test()\n\n            print(f\"Valid Hits@{evaluator_hit.K} = {eval_val}\")\n            print(f\"Test Hits@{evaluator_hit.K} = {eval_test}\")\n\n            if eval_val > best_valid:\n                best_valid = eval_val\n                best_valid_test = eval_test\n\n    print(\n        f\"\\nBest Performance:\\n  Valid={best_valid}\\n  Test={best_valid_test}\")\n    val_perf_runs.append(best_valid)\n    test_perf_runs.append(best_valid_test)\n\nif args.runs > 1:\n    print(\"\\n\\n\")\n    print(f\"Results over {args.runs} runs:\")\n    print(f\"  Valid = {np.mean(val_perf_runs)} +/- {np.std(val_perf_runs)}\")\n    print(f\"  Test = {np.mean(test_perf_runs)} +/- {np.std(test_perf_runs)}\")\n"
  },
  {
    "path": "examples/mem_pool.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import BatchNorm1d, LeakyReLU, Linear\n\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import DeepGCNLayer, GATConv, MemPooling\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TUD')\ndataset = TUDataset(path, name=\"PROTEINS_full\", use_node_attr=True)\ndataset._data.x = dataset._data.x[:, :-3]  # only use non-binary features.\ndataset = dataset.shuffle()\n\nn = (len(dataset)) // 10\ntest_dataset = dataset[:n]\nval_dataset = dataset[n:2 * n]\ntrain_dataset = dataset[2 * n:]\n\ntest_loader = DataLoader(test_dataset, batch_size=20)\nval_loader = DataLoader(val_dataset, batch_size=20)\ntrain_loader = DataLoader(train_dataset, batch_size=20)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, dropout):\n        super().__init__()\n        self.dropout = dropout\n\n        self.lin = Linear(in_channels, hidden_channels)\n\n        self.convs = torch.nn.ModuleList()\n        for _ in range(2):\n            conv = GATConv(hidden_channels, hidden_channels, dropout=dropout)\n            norm = BatchNorm1d(hidden_channels)\n            act = LeakyReLU()\n            self.convs.append(\n                DeepGCNLayer(conv, norm, act, block='res+', dropout=dropout))\n\n        self.mem1 = MemPooling(hidden_channels, 80, heads=5, num_clusters=10)\n        self.mem2 = MemPooling(80, out_channels, heads=5, num_clusters=1)\n\n    def forward(self, x, edge_index, batch):\n        x = self.lin(x)\n        for conv in self.convs:\n            x = conv(x, edge_index)\n\n        x, S1 = self.mem1(x, batch)\n        x = F.leaky_relu(x)\n        x = F.dropout(x, p=self.dropout)\n        x, S2 = self.mem2(x)\n\n        return (\n            F.log_softmax(x.squeeze(1), dim=-1),\n            MemPooling.kl_loss(S1) + MemPooling.kl_loss(S2),\n        )\n\n\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\nmodel = Net(dataset.num_features, 32, dataset.num_classes, dropout=0.1)\nmodel = model.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=4e-5)\n\n\ndef train():\n    model.train()\n\n    model.mem1.k.requires_grad = False\n    model.mem2.k.requires_grad = False\n    for data in train_loader:\n        optimizer.zero_grad()\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.batch)[0]\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        optimizer.step()\n\n    kl_loss = 0.\n    model.mem1.k.requires_grad = True\n    model.mem2.k.requires_grad = True\n    optimizer.zero_grad()\n    for data in train_loader:\n        data = data.to(device)\n        kl_loss += model(data.x, data.edge_index, data.batch)[1]\n    kl_loss /= len(train_loader.dataset)\n    kl_loss.backward()\n    optimizer.step()\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n    total_correct = 0\n    for data in loader:\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.batch)[0]\n        total_correct += int((out.argmax(dim=-1) == data.y).sum())\n    return total_correct / len(loader.dataset)\n\n\ntimes = []\npatience = start_patience = 250\ntest_acc = best_val_acc = 0.\nfor epoch in range(1, 2001):\n    start = time.time()\n    train()\n    val_acc = test(val_loader)\n    if epoch % 500 == 0:\n        optimizer.param_groups[0]['lr'] *= 0.5\n    if best_val_acc < val_acc:\n        patience = start_patience\n        best_val_acc = val_acc\n        test_acc = test(test_loader)\n    else:\n        patience -= 1\n    print(f'Epoch {epoch:02d}, Val: {val_acc:.3f}, Test: {test_acc:.3f}')\n    if patience <= 0:\n        break\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/mixhop.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import BatchNorm, Linear, MixHopConv\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, name='Cora')\ndata = dataset[0]\n\n\nclass MixHop(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = MixHopConv(dataset.num_features, 60, powers=[0, 1, 2])\n        self.norm1 = BatchNorm(3 * 60)\n\n        self.conv2 = MixHopConv(3 * 60, 60, powers=[0, 1, 2])\n        self.norm2 = BatchNorm(3 * 60)\n\n        self.conv3 = MixHopConv(3 * 60, 60, powers=[0, 1, 2])\n        self.norm3 = BatchNorm(3 * 60)\n\n        self.lin = Linear(3 * 60, dataset.num_classes)\n\n    def forward(self, x, edge_index):\n        x = F.dropout(x, p=0.7, training=self.training)\n\n        x = self.conv1(x, edge_index)\n        x = self.norm1(x)\n        x = F.dropout(x, p=0.9, training=self.training)\n\n        x = self.conv2(x, edge_index)\n        x = self.norm2(x)\n        x = F.dropout(x, p=0.9, training=self.training)\n\n        x = self.conv3(x, edge_index)\n        x = self.norm3(x)\n        x = F.dropout(x, p=0.9, training=self.training)\n\n        return self.lin(x)\n\n\nmodel, data = MixHop().to(device), data.to(device)\noptimizer = torch.optim.SGD(model.parameters(), lr=0.5, weight_decay=0.005)\nscheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=40,\n                                            gamma=0.01)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    scheduler.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.x, data.edge_index).argmax(dim=-1)\n\n    accs = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\nbest_val_acc = test_acc = 0\nfor epoch in range(1, 101):\n    loss = train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/mnist_graclus.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import MNISTSuperpixels\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import (\n    SplineConv,\n    global_mean_pool,\n    graclus,\n    max_pool,\n    max_pool_x,\n)\nfrom torch_geometric.typing import WITH_SPLINE, WITH_TORCH_CLUSTER\nfrom torch_geometric.utils import normalized_cut\n\nif not WITH_TORCH_CLUSTER:\n    quit(\"This example requires 'torch-cluster'\")\nif not WITH_SPLINE:\n    quit(\"This example requires 'pyg-lib>=0.6.0'\")\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST')\ntransform = T.Cartesian(cat=False)\ntrain_dataset = MNISTSuperpixels(path, True, transform=transform)\ntest_dataset = MNISTSuperpixels(path, False, transform=transform)\ntrain_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\ntest_loader = DataLoader(test_dataset, batch_size=64)\nd = train_dataset\n\n\ndef normalized_cut_2d(edge_index, pos):\n    row, col = edge_index\n    edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)\n    return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = SplineConv(d.num_features, 32, dim=2, kernel_size=5)\n        self.conv2 = SplineConv(32, 64, dim=2, kernel_size=5)\n        self.fc1 = torch.nn.Linear(64, 128)\n        self.fc2 = torch.nn.Linear(128, d.num_classes)\n\n    def forward(self, data):\n        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))\n        weight = normalized_cut_2d(data.edge_index, data.pos)\n        cluster = graclus(data.edge_index, weight, data.x.size(0))\n        data.edge_attr = None\n        data = max_pool(cluster, data, transform=transform)\n\n        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))\n        weight = normalized_cut_2d(data.edge_index, data.pos)\n        cluster = graclus(data.edge_index, weight, data.x.size(0))\n        x, batch = max_pool_x(cluster, data.x, data.batch)\n\n        x = global_mean_pool(x, batch)\n        x = F.elu(self.fc1(x))\n        x = F.dropout(x, training=self.training)\n        return F.log_softmax(self.fc2(x), dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train(epoch):\n    model.train()\n\n    if epoch == 16:\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = 0.001\n\n    if epoch == 26:\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = 0.0001\n\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        F.nll_loss(model(data), data.y).backward()\n        optimizer.step()\n\n\ndef test():\n    model.eval()\n    correct = 0\n\n    for data in test_loader:\n        data = data.to(device)\n        pred = model(data).max(1)[1]\n        correct += pred.eq(data.y).sum().item()\n    return correct / len(test_dataset)\n\n\nfor epoch in range(1, 31):\n    train(epoch)\n    test_acc = test()\n    print(f'Epoch: {epoch:02d}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/mnist_nn_conv.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear, ReLU, Sequential\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import MNISTSuperpixels\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import (\n    NNConv,\n    global_mean_pool,\n    graclus,\n    max_pool,\n    max_pool_x,\n)\nfrom torch_geometric.typing import WITH_TORCH_CLUSTER\nfrom torch_geometric.utils import normalized_cut\n\nif not WITH_TORCH_CLUSTER:\n    quit(\"This example requires 'torch-cluster'\")\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST')\ntransform = T.Cartesian(cat=False)\ntrain_dataset = MNISTSuperpixels(path, True, transform=transform)\ntest_dataset = MNISTSuperpixels(path, False, transform=transform)\ntrain_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\ntest_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\nd = train_dataset\n\n\ndef normalized_cut_2d(edge_index, pos):\n    row, col = edge_index\n    edge_attr = torch.norm(pos[row] - pos[col], p=2, dim=1)\n    return normalized_cut(edge_index, edge_attr, num_nodes=pos.size(0))\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        nn1 = Sequential(\n            Linear(2, 25),\n            ReLU(),\n            Linear(25, d.num_features * 32),\n        )\n        self.conv1 = NNConv(d.num_features, 32, nn1, aggr='mean')\n\n        nn2 = Sequential(\n            Linear(2, 25),\n            ReLU(),\n            Linear(25, 32 * 64),\n        )\n        self.conv2 = NNConv(32, 64, nn2, aggr='mean')\n\n        self.fc1 = torch.nn.Linear(64, 128)\n        self.fc2 = torch.nn.Linear(128, d.num_classes)\n\n    def forward(self, data):\n        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))\n        weight = normalized_cut_2d(data.edge_index, data.pos)\n        cluster = graclus(data.edge_index, weight, data.x.size(0))\n        data.edge_attr = None\n        data = max_pool(cluster, data, transform=transform)\n\n        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))\n        weight = normalized_cut_2d(data.edge_index, data.pos)\n        cluster = graclus(data.edge_index, weight, data.x.size(0))\n        x, batch = max_pool_x(cluster, data.x, data.batch)\n\n        x = global_mean_pool(x, batch)\n        x = F.elu(self.fc1(x))\n        x = F.dropout(x, training=self.training)\n        return F.log_softmax(self.fc2(x), dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train(epoch):\n    model.train()\n\n    if epoch == 16:\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = 0.001\n\n    if epoch == 26:\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = 0.0001\n\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        F.nll_loss(model(data), data.y).backward()\n        optimizer.step()\n\n\ndef test():\n    model.eval()\n    correct = 0\n\n    for data in test_loader:\n        data = data.to(device)\n        pred = model(data).max(1)[1]\n        correct += pred.eq(data.y).sum().item()\n    return correct / len(test_dataset)\n\n\nfor epoch in range(1, 31):\n    train(epoch)\n    test_acc = test()\n    print(f'Epoch: {epoch:02d}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/mnist_voxel_grid.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import MNISTSuperpixels\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import SplineConv, max_pool, max_pool_x, voxel_grid\nfrom torch_geometric.typing import WITH_SPLINE\n\nif not WITH_SPLINE:\n    quit(\"This example requires 'pyg-lib>=0.6.0'\")\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'MNIST')\ntransform = T.Cartesian(cat=False)\ntrain_dataset = MNISTSuperpixels(path, True, transform=transform)\ntest_dataset = MNISTSuperpixels(path, False, transform=transform)\ntrain_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)\ntest_loader = DataLoader(test_dataset, batch_size=64)\nd = train_dataset\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = SplineConv(d.num_features, 32, dim=2, kernel_size=5)\n        self.conv2 = SplineConv(32, 64, dim=2, kernel_size=5)\n        self.conv3 = SplineConv(64, 64, dim=2, kernel_size=5)\n        self.fc1 = torch.nn.Linear(4 * 64, 128)\n        self.fc2 = torch.nn.Linear(128, d.num_classes)\n\n    def forward(self, data):\n        data.x = F.elu(self.conv1(data.x, data.edge_index, data.edge_attr))\n        cluster = voxel_grid(data.pos, batch=data.batch, size=5, start=0,\n                             end=28)\n        data.edge_attr = None\n        data = max_pool(cluster, data, transform=transform)\n\n        data.x = F.elu(self.conv2(data.x, data.edge_index, data.edge_attr))\n        cluster = voxel_grid(data.pos, batch=data.batch, size=7, start=0,\n                             end=28)\n        data.edge_attr = None\n        data = max_pool(cluster, data, transform=transform)\n\n        data.x = F.elu(self.conv3(data.x, data.edge_index, data.edge_attr))\n        cluster = voxel_grid(data.pos, batch=data.batch, size=14, start=0,\n                             end=27.99)\n        x, _ = max_pool_x(cluster, data.x, data.batch, size=4)\n\n        x = x.view(-1, self.fc1.weight.size(1))\n        x = F.elu(self.fc1(x))\n        x = F.dropout(x, training=self.training)\n        x = self.fc2(x)\n        return F.log_softmax(x, dim=1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train(epoch):\n    model.train()\n\n    if epoch == 6:\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = 0.001\n\n    if epoch == 16:\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = 0.0001\n\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        F.nll_loss(model(data), data.y).backward()\n        optimizer.step()\n\n\ndef test():\n    model.eval()\n    correct = 0\n\n    for data in test_loader:\n        data = data.to(device)\n        pred = model(data).max(1)[1]\n        correct += pred.eq(data.y).sum().item()\n    return correct / len(test_dataset)\n\n\nfor epoch in range(1, 21):\n    train(epoch)\n    test_acc = test()\n    print(f'Epoch: {epoch:02d}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/multi_gpu/README.md",
    "content": "# Examples for Distributed Training\n\n## Examples with NVIDIA GPUs\n\n### Examples with cuGraph\n\nFor the best performance with NVIDIA GPUs, we recommend using **cuGraph**.\nRefer to [our installation guide](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html#accelerating-pyg-with-nvidia-cugraph-gnn) for setup instructions and to the [cuGraph-PyG examples](https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples) for ready-to-run training scripts covering single-node, multi-node, and link-prediction workloads.\n\n### Examples with Pure PyTorch\n\n| Example                                                                            | Scalability | Description                                                                                                                    |\n| ---------------------------------------------------------------------------------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------ |\n| [`distributed_batching.py`](./distributed_batching.py)                             | single-node | Graph-level prediction on many small graphs (ogbg-molhiv) using `DataLoader` with `DistributedSampler`.                        |\n| [`distributed_sampling.py`](./distributed_sampling.py)                             | single-node | Node-level classification on a single large graph (Reddit) using `NeighborLoader` for multi-hop subgraph sampling.             |\n| [`distributed_sampling_multinode.py`](./distributed_sampling_multinode.py)         | multi-node  | Training GNNs on a homogeneous graph with neighbor sampling on multiple nodes.                                                 |\n| [`distributed_sampling_multinode.sbatch`](./distributed_sampling_multinode.sbatch) | multi-node  | Submitting a training job to a Slurm cluster using [`distributed_sampling_multinode.py`](./distributed_sampling_multinode.py). |\n| [`papers100m_gcn.py`](./papers100m_gcn.py)                                         | single-node | Training GNNs on the `ogbn-papers100M` homogeneous graph w/ ~1.6B edges.                                                       |\n| [`papers100m_gcn_multinode.py`](./papers100m_gcn_multinode.py)                     | multi-node  | Training GNNs on a homogeneous graph on multiple nodes.                                                                        |\n| [`pcqm4m_ogb.py`](./pcqm4m_ogb.py)                                                 | single-node | Training GNNs for a graph-level regression task.                                                                               |\n| [`mag240m_graphsage.py`](./mag240m_graphsage.py)                                   | single-node | Training GNNs on a large heterogeneous graph.                                                                                  |\n| [`taobao.py`](./taobao.py)                                                         | single-node | Training link prediction GNNs on a heterogeneous graph.                                                                        |\n| [`model_parallel.py`](./model_parallel.py)                                         | single-node | Model parallelism by manually placing layers on each GPU.                                                                      |\n\n## Examples with Intel GPUs (XPUs)\n\n| Example                                                        | Scalability            | Description                                                  |\n| -------------------------------------------------------------- | ---------------------- | ------------------------------------------------------------ |\n| [`distributed_sampling_xpu.py`](./distributed_sampling_xpu.py) | single-node, multi-gpu | Training GNNs on a homogeneous graph with neighbor sampling. |\n"
  },
  {
    "path": "examples/multi_gpu/distributed_batching.py",
    "content": "import os\nimport os.path as osp\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn.functional as F\nfrom ogb.graphproppred import Evaluator, PygGraphPropPredDataset\nfrom ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder\nfrom torch.nn import BatchNorm1d as BatchNorm\nfrom torch.nn import Linear, ReLU, Sequential\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.utils.data.distributed import DistributedSampler\nfrom torch_sparse import SparseTensor\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GINEConv, global_mean_pool\n\n\nclass GIN(torch.nn.Module):\n    def __init__(\n        self,\n        hidden_channels: int,\n        out_channels: int,\n        num_layers: int = 3,\n        dropout: float = 0.5,\n    ) -> None:\n        super().__init__()\n        self.dropout = dropout\n        self.atom_encoder = AtomEncoder(hidden_channels)\n        self.bond_encoder = BondEncoder(hidden_channels)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            nn = Sequential(\n                Linear(hidden_channels, 2 * hidden_channels),\n                BatchNorm(2 * hidden_channels),\n                ReLU(),\n                Linear(2 * hidden_channels, hidden_channels),\n                BatchNorm(hidden_channels),\n                ReLU(),\n            )\n            self.convs.append(GINEConv(nn, train_eps=True))\n\n        self.lin = Linear(hidden_channels, out_channels)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        adj_t: SparseTensor,\n        batch: torch.Tensor,\n    ) -> torch.Tensor:\n        x = self.atom_encoder(x)\n        edge_attr = adj_t.coo()[2]\n        adj_t = adj_t.set_value(self.bond_encoder(edge_attr), layout='coo')\n\n        for conv in self.convs:\n            x = conv(x, adj_t)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n\n        x = global_mean_pool(x, batch)\n        x = self.lin(x)\n        return x\n\n\ndef run(rank: int, world_size: int, dataset_name: str, root: str) -> None:\n    os.environ['MASTER_ADDR'] = 'localhost'\n    os.environ['MASTER_PORT'] = '12355'\n    dist.init_process_group('nccl', rank=rank, world_size=world_size)\n\n    dataset = PygGraphPropPredDataset(\n        dataset_name,\n        root=root,\n        pre_transform=T.ToSparseTensor(attr='edge_attr'),\n    )\n    split_idx = dataset.get_idx_split()\n    evaluator = Evaluator(dataset_name)\n\n    train_dataset = dataset[split_idx['train']]\n    train_loader = DataLoader(\n        train_dataset,\n        batch_size=128,\n        sampler=DistributedSampler(\n            train_dataset,\n            shuffle=True,\n            drop_last=True,\n        ),\n    )\n\n    torch.manual_seed(12345)\n    model = GIN(128, dataset.num_tasks, num_layers=3, dropout=0.5).to(rank)\n    model = DistributedDataParallel(model, device_ids=[rank])\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n    criterion = torch.nn.BCEWithLogitsLoss()\n\n    if rank == 0:\n        val_loader = DataLoader(dataset[split_idx['valid']], batch_size=256)\n        test_loader = DataLoader(dataset[split_idx['test']], batch_size=256)\n\n    for epoch in range(1, 51):\n        model.train()\n        train_loader.sampler.set_epoch(epoch)\n        total_loss = torch.zeros(2, device=rank)\n        for data in train_loader:\n            data = data.to(rank)\n            logits = model(data.x, data.adj_t, data.batch)\n            loss = criterion(logits, data.y.to(torch.float))\n            loss.backward()\n            optimizer.step()\n            optimizer.zero_grad()\n\n            with torch.no_grad():\n                total_loss[0] += loss * logits.size(0)\n                total_loss[1] += data.num_graphs\n\n        dist.all_reduce(total_loss, op=dist.ReduceOp.AVG)\n        train_loss = total_loss[0] / total_loss[1]\n\n        if rank == 0:  # We evaluate on a single GPU for now.\n            model.eval()\n\n            y_pred, y_true = [], []\n            for data in val_loader:\n                data = data.to(rank)\n                with torch.no_grad():\n                    y_pred.append(model.module(data.x, data.adj_t, data.batch))\n                    y_true.append(data.y)\n            val_rocauc = evaluator.eval({\n                'y_pred': torch.cat(y_pred, dim=0),\n                'y_true': torch.cat(y_true, dim=0),\n            })['rocauc']\n\n            y_pred, y_true = [], []\n            for data in test_loader:\n                data = data.to(rank)\n                with torch.no_grad():\n                    y_pred.append(model.module(data.x, data.adj_t, data.batch))\n                    y_true.append(data.y)\n            test_rocauc = evaluator.eval({\n                'y_pred': torch.cat(y_pred, dim=0),\n                'y_true': torch.cat(y_true, dim=0),\n            })['rocauc']\n\n            print(f'Epoch: {epoch:03d}, '\n                  f'Loss: {train_loss:.4f}, '\n                  f'Val: {val_rocauc:.4f}, '\n                  f'Test: {test_rocauc:.4f}')\n\n        dist.barrier()\n\n    dist.destroy_process_group()\n\n\nif __name__ == '__main__':\n    dataset_name = 'ogbg-molhiv'\n    root = osp.join(\n        osp.dirname(__file__),\n        '..',\n        '..',\n        'data',\n        'OGB',\n    )\n    # Download and process the dataset on main process.\n    PygGraphPropPredDataset(\n        dataset_name,\n        root,\n        pre_transform=T.ToSparseTensor(attr='edge_attr'),\n    )\n\n    world_size = torch.cuda.device_count()\n    print('Let\\'s use', world_size, 'GPUs!')\n    args = (world_size, dataset_name, root)\n    mp.spawn(run, args=args, nprocs=world_size, join=True)\n"
  },
  {
    "path": "examples/multi_gpu/distributed_sampling.py",
    "content": "import os\nimport os.path as osp\nfrom math import ceil\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn.parallel import DistributedDataParallel\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import Reddit\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import SAGEConv\n\n\nclass SAGE(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        num_layers: int = 2,\n    ) -> None:\n        super().__init__()\n        self.convs = torch.nn.ModuleList()\n        self.convs.append(SAGEConv(in_channels, hidden_channels))\n        for _ in range(num_layers - 2):\n            self.convs.append(SAGEConv(hidden_channels, hidden_channels))\n        self.convs.append(SAGEConv(hidden_channels, out_channels))\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        for i, conv in enumerate(self.convs):\n            x = conv(x, edge_index)\n            if i < len(self.convs) - 1:\n                x = x.relu()\n                x = F.dropout(x, p=0.5, training=self.training)\n        return x\n\n\n@torch.no_grad()\ndef test(\n    loader: NeighborLoader,\n    model: DistributedDataParallel,\n    rank: int,\n) -> Tensor:\n    model.eval()\n    total_correct = torch.tensor(0, dtype=torch.long, device=rank)\n    total_examples = 0\n    for batch in loader:\n        out = model(batch.x, batch.edge_index.to(rank))\n        pred = out[:batch.batch_size].argmax(dim=-1)\n        y = batch.y[:batch.batch_size].to(rank)\n        total_correct += (pred == y).sum()\n        total_examples += batch.batch_size\n\n    return total_correct / total_examples\n\n\ndef run(rank: int, world_size: int, dataset: Reddit) -> None:\n    os.environ['MASTER_ADDR'] = 'localhost'\n    os.environ['MASTER_PORT'] = '12355'\n    dist.init_process_group('nccl', rank=rank, world_size=world_size)\n\n    data = dataset[0]\n    data = data.to(rank, 'x', 'y')  # Move to device for faster feature fetch.\n\n    # Split indices into `world_size` many chunks:\n    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)\n    train_idx = train_idx.split(ceil(train_idx.size(0) / world_size))[rank]\n    val_idx = data.val_mask.nonzero(as_tuple=False).view(-1)\n    val_idx = val_idx.split(ceil(val_idx.size(0) / world_size))[rank]\n    test_idx = data.val_mask.nonzero(as_tuple=False).view(-1)\n    test_idx = test_idx.split(ceil(test_idx.size(0) / world_size))[rank]\n\n    kwargs = dict(\n        data=data,\n        batch_size=1024,\n        num_neighbors=[25, 10],\n        drop_last=True,\n        num_workers=4,\n        persistent_workers=True,\n    )\n    train_loader = NeighborLoader(\n        input_nodes=train_idx,\n        shuffle=True,\n        **kwargs,\n    )\n    val_loader = NeighborLoader(\n        input_nodes=val_idx,\n        shuffle=False,\n        **kwargs,\n    )\n    test_loader = NeighborLoader(\n        input_nodes=test_idx,\n        shuffle=False,\n        **kwargs,\n    )\n\n    torch.manual_seed(12345)\n    model = SAGE(dataset.num_features, 256, dataset.num_classes).to(rank)\n    model = DistributedDataParallel(model, device_ids=[rank])\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n    for epoch in range(1, 21):\n        model.train()\n        for batch in tqdm(\n                train_loader,\n                desc=f'Epoch {epoch:02d}',\n                disable=rank != 0,\n        ):\n            out = model(batch.x, batch.edge_index.to(rank))[:batch.batch_size]\n            loss = F.cross_entropy(out, batch.y[:batch.batch_size])\n            loss.backward()\n            optimizer.step()\n            optimizer.zero_grad()\n\n        if rank == 0:\n            print(f'Epoch {epoch:02d}: Train loss: {loss:.4f}')\n\n        if epoch % 5 == 0:\n            train_acc = test(train_loader, model, rank)\n            val_acc = test(val_loader, model, rank)\n            test_acc = test(test_loader, model, rank)\n\n            if world_size > 1:\n                dist.all_reduce(train_acc, op=dist.ReduceOp.AVG)\n                dist.all_reduce(val_acc, op=dist.ReduceOp.AVG)\n                dist.all_reduce(test_acc, op=dist.ReduceOp.AVG)\n\n            if rank == 0:\n                print(f'Train acc: {train_acc:.4f}, '\n                      f'Val acc: {val_acc:.4f}, '\n                      f'Test acc: {test_acc:.4f}')\n\n    dist.destroy_process_group()\n\n\nif __name__ == '__main__':\n    path = osp.join(\n        osp.dirname(__file__),\n        '..',\n        '..',\n        'data',\n        'Reddit',\n    )\n    dataset = Reddit(path)\n    world_size = torch.cuda.device_count()\n    print(\"Let's use\", world_size, \"GPUs!\")\n    mp.spawn(run, args=(world_size, dataset), nprocs=world_size, join=True)\n"
  },
  {
    "path": "examples/multi_gpu/distributed_sampling_multinode.py",
    "content": "import copy\nimport os\nfrom math import ceil\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn.parallel import DistributedDataParallel\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import Reddit\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import SAGEConv\n\n\nclass SAGE(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        num_layers: int = 2,\n    ):\n        super().__init__()\n        self.convs = torch.nn.ModuleList()\n        self.convs.append(SAGEConv(in_channels, hidden_channels))\n        for _ in range(num_layers - 2):\n            self.convs.append(SAGEConv(hidden_channels, hidden_channels))\n        self.convs.append(SAGEConv(hidden_channels, out_channels))\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        for i, conv in enumerate(self.convs):\n            x = conv(x, edge_index)\n            if i < len(self.convs) - 1:\n                x = x.relu_()\n                x = F.dropout(x, p=0.5, training=self.training)\n        return x\n\n    @torch.no_grad()\n    def inference(\n        self,\n        x_all: Tensor,\n        device: torch.device,\n        subgraph_loader: NeighborLoader,\n    ) -> Tensor:\n        pbar = tqdm(total=len(subgraph_loader) * len(self.convs))\n        pbar.set_description('Evaluating')\n\n        # Compute representations of nodes layer by layer, using *all*\n        # available edges. This leads to faster computation in contrast to\n        # immediately computing the final representations of each batch:\n        for i, conv in enumerate(self.convs):\n            xs = []\n            for batch in subgraph_loader:\n                x = x_all[batch.node_id.to(x_all.device)].to(device)\n                x = conv(x, batch.edge_index.to(device))\n                x = x[:batch.batch_size]\n                if i < len(self.convs) - 1:\n                    x = x.relu_()\n                xs.append(x.cpu())\n                pbar.update(1)\n            x_all = torch.cat(xs, dim=0)\n\n        pbar.close()\n        return x_all\n\n\ndef run(world_size: int, rank: int, local_rank: int):\n    # Will query the runtime environment for `MASTER_ADDR` and `MASTER_PORT`.\n    # Make sure, those are set!\n    dist.init_process_group('nccl', world_size=world_size, rank=rank)\n\n    # Download and unzip only with one process ...\n    if rank == 0:\n        dataset = Reddit('data/Reddit')\n    dist.barrier()\n    # ... and then read from all the other processes:\n    if rank != 0:\n        dataset = Reddit('data/Reddit')\n    dist.barrier()\n\n    data = dataset[0]\n\n    # Move to device for faster feature fetch.\n    data = data.to(local_rank, 'x', 'y')\n\n    # Split training indices into `world_size` many chunks:\n    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)\n    train_idx = train_idx.split(ceil(train_idx.size(0) / world_size))[rank]\n\n    kwargs = dict(batch_size=1024, num_workers=4, persistent_workers=True)\n    train_loader = NeighborLoader(\n        data,\n        input_nodes=train_idx,\n        num_neighbors=[25, 10],\n        shuffle=True,\n        drop_last=True,\n        **kwargs,\n    )\n\n    if rank == 0:  # Create single-hop evaluation neighbor loader:\n        subgraph_loader = NeighborLoader(\n            copy.copy(data),\n            num_neighbors=[-1],\n            shuffle=False,\n            **kwargs,\n        )\n        # No need to maintain these features during evaluation:\n        del subgraph_loader.data.x, subgraph_loader.data.y\n        # Add global node index information:\n        subgraph_loader.data.node_id = torch.arange(data.num_nodes)\n\n    torch.manual_seed(12345)\n    model = SAGE(dataset.num_features, 256, dataset.num_classes).to(local_rank)\n    model = DistributedDataParallel(model, device_ids=[local_rank])\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n    for epoch in range(1, 21):\n        model.train()\n        for batch in train_loader:\n            optimizer.zero_grad()\n            out = model(batch.x,\n                        batch.edge_index.to(local_rank))[:batch.batch_size]\n            loss = F.cross_entropy(out, batch.y[:batch.batch_size])\n            loss.backward()\n            optimizer.step()\n\n        dist.barrier()\n\n        if rank == 0:\n            print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')\n\n        if rank == 0 and epoch % 5 == 0:  # We evaluate on a single GPU for now\n            model.eval()\n            with torch.no_grad():\n                out = model.module.inference(\n                    data.x,\n                    local_rank,\n                    subgraph_loader,\n                )\n            res = out.argmax(dim=-1) == data.y.to(out.device)\n            acc1 = int(res[data.train_mask].sum()) / int(data.train_mask.sum())\n            acc2 = int(res[data.val_mask].sum()) / int(data.val_mask.sum())\n            acc3 = int(res[data.test_mask].sum()) / int(data.test_mask.sum())\n            print(f'Train: {acc1:.4f}, Val: {acc2:.4f}, Test: {acc3:.4f}')\n\n        dist.barrier()\n\n    dist.destroy_process_group()\n\n\nif __name__ == '__main__':\n    # Get the world size from the WORLD_SIZE variable or directly from SLURM:\n    world_size = int(\n        os.environ.get('WORLD_SIZE', os.environ.get('SLURM_NTASKS')))\n    # Likewise for RANK and LOCAL_RANK:\n    rank = int(os.environ.get('RANK', os.environ.get('SLURM_PROCID')))\n    local_rank = int(\n        os.environ.get('LOCAL_RANK', os.environ.get('SLURM_LOCALID')))\n    run(world_size, rank, local_rank)\n"
  },
  {
    "path": "examples/multi_gpu/distributed_sampling_multinode.sbatch",
    "content": "#!/bin/bash\n#SBATCH --job-name=pyg-multinode-tutorial # identifier for the job listings\n#SBATCH --output=pyg-multinode.log        # outputfile\n#SBATCH --partition=gpucloud              # ADJUST this to your system\n#SBATCH -N 2                              # number of nodes you want to use\n#SBATCH --ntasks=4                        # number of processes to be run\n#SBATCH --gpus-per-task=1                 # every process wants one GPU!\n#SBATCH --gpu-bind=none                   # NCCL can't deal with task-binding...\n## Now you can add more stuff for your convenience\n#SBATCH --cpus-per-task=8                 # make sure more cpu-cores are available to each process to spawn workers (default=1 and this is a hard limit)\n#SBATCH --mem=100G                        # total number of memory available per node (tensorflow need(ed) at least <GPU-memory> per GPU)\n#SBATCH --export=ALL                      # use your shell environment (PATHs, ...)\n\n# Thanks for shell-ideas to https://github.com/PrincetonUniversity/multi_gpu_training\nexport MASTER_PORT=$(expr 10000 + $(echo -n $SLURM_JOBID | tail -c 4))\nexport MASTER_ADDR=$(scontrol show hostnames \"$SLURM_JOB_NODELIST\" | head -n 1)\necho \"MASTER_ADDR:MASTER_PORT=\"${MASTER_ADDR}:${MASTER_PORT}\n\necho \"###########################################################################\"\necho \"We recommend you set up your environment here (conda/spack/pip/modulefiles)\"\necho \"then remove --export=ALL (allows running the sbatch from any shell\"\necho \"###########################################################################\"\n\n# use --output=0 so that only the first task logs to the file!\nsrun --output=0 python distributed_sampling_multinode.py\n"
  },
  {
    "path": "examples/multi_gpu/distributed_sampling_xpu.py",
    "content": "\"\"\"Distributed GAT training, targeting XPU devices.\nPVC has 2 tiles, each reports itself as a separate\ndevice. DDP approach allows us to employ both tiles.\n\nAdditional requirements:\n    IPEX (intel_extension_for_pytorch)\n    oneCCL (oneccl_bindings_for_pytorch)\n\n    We need to import both these modules, as they extend\n    torch module with XPU/oneCCL related functionality.\n\nRun with:\n    mpirun -np 2 python distributed_sampling_xpu.py\n\"\"\"\n\nimport copy\nimport os\nimport os.path as osp\nfrom typing import Any, Tuple, Union\n\nimport intel_extension_for_pytorch  # noqa\nimport oneccl_bindings_for_pytorch  # noqa\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom ogb.nodeproppred import Evaluator, PygNodePropPredDataset\nfrom torch import Tensor\nfrom torch.nn import Linear as Lin\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom tqdm import tqdm\n\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import GATConv\nfrom torch_geometric.profile import get_stats_summary, profileit\n\n\nclass GAT(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        num_layers: int,\n        heads: int,\n    ):\n        super().__init__()\n\n        self.num_layers = num_layers\n\n        self.convs = torch.nn.ModuleList()\n        self.convs.append(GATConv(dataset.num_features, hidden_channels,\n                                  heads))\n        for _ in range(num_layers - 2):\n            self.convs.append(\n                GATConv(heads * hidden_channels, hidden_channels, heads))\n        self.convs.append(\n            GATConv(heads * hidden_channels, out_channels, heads,\n                    concat=False))\n\n        self.skips = torch.nn.ModuleList()\n        self.skips.append(Lin(dataset.num_features, hidden_channels * heads))\n        for _ in range(num_layers - 2):\n            self.skips.append(\n                Lin(hidden_channels * heads, hidden_channels * heads))\n        self.skips.append(Lin(hidden_channels * heads, out_channels))\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        for i, (conv, skip) in enumerate(zip(self.convs, self.skips)):\n            x = conv(x, edge_index) + skip(x)\n            if i != self.num_layers - 1:\n                x = F.elu(x)\n                x = F.dropout(x, p=0.5, training=self.training)\n        return x\n\n    def inference(\n        self,\n        x_all: Tensor,\n        device: Union[str, torch.device],\n        subgraph_loader: NeighborLoader,\n    ) -> Tensor:\n        pbar = tqdm(total=x_all.size(0) * self.num_layers)\n        pbar.set_description(\"Evaluating\")\n\n        # Compute representations of nodes layer by layer, using *all*\n        # available edges. This leads to faster computation in contrast to\n        # immediately computing the final representations of each batch.\n        for i in range(self.num_layers):\n            xs = []\n            for batch in subgraph_loader:\n                x = x_all[batch.n_id].to(device)\n                edge_index = batch.edge_index.to(device)\n                x = self.convs[i](x, edge_index) + self.skips[i](x)\n                x = x[:batch.batch_size]\n                if i != self.num_layers - 1:\n                    x = F.elu(x)\n                xs.append(x.cpu())\n\n                pbar.update(batch.batch_size)\n\n            x_all = torch.cat(xs, dim=0)\n\n        pbar.close()\n\n        return x_all\n\n\n@profileit('xpu')\ndef train_step(model: Any, optimizer: Any, x: Tensor, edge_index: Tensor,\n               y: Tensor, bs: int) -> float:\n    optimizer.zero_grad()\n    out = model(x, edge_index)[:bs]\n    loss = F.cross_entropy(out, y[:bs].squeeze())\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\ndef run(rank: int, world_size: int, dataset: PygNodePropPredDataset):\n    device = f\"xpu:{rank}\"\n\n    split_idx = dataset.get_idx_split()\n    split_idx[\"train\"] = (split_idx[\"train\"].split(\n        split_idx[\"train\"].size(0) // world_size, dim=0)[rank].clone())\n    data = dataset[0].to(device, \"x\", \"y\")\n\n    kwargs = dict(batch_size=1024, num_workers=0, pin_memory=True)\n    train_loader = NeighborLoader(data, input_nodes=split_idx[\"train\"],\n                                  num_neighbors=[10, 10, 5], **kwargs)\n\n    if rank == 0:\n        subgraph_loader = NeighborLoader(copy.copy(data), num_neighbors=[-1],\n                                         **kwargs)\n        evaluator = Evaluator(name=\"ogbn-products\")\n\n    torch.manual_seed(12345)\n    model = GAT(dataset.num_features, 128, dataset.num_classes, num_layers=3,\n                heads=4).to(device)\n    model = DDP(model, device_ids=[device])\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n    for epoch in range(1, 21):\n        epoch_stats = []\n        model.train()\n        for batch in train_loader:\n            batch = batch.to(device)\n            loss, stats = train_step(model, optimizer, batch.x,\n                                     batch.edge_index, batch.y,\n                                     batch.batch_size)\n            epoch_stats.append(stats)\n\n        dist.barrier()\n\n        if rank == 0:\n            print(f\"Epoch: {epoch:02d}, Loss: {loss:.4f}\")\n\n        print(f\"Epoch: {epoch:02d}, Rank: {rank}, \"\n              f\"Stats: {get_stats_summary(epoch_stats)}\")\n\n        if rank == 0 and epoch % 5 == 0:  # Evaluation on a single GPU\n            model.eval()\n            with torch.no_grad():\n                out = model.module.inference(data.x, device, subgraph_loader)\n\n            y_true = data.y.to(out.device)\n            y_pred = out.argmax(dim=-1, keepdim=True)\n\n            train_acc = evaluator.eval({\n                \"y_true\": y_true[split_idx[\"train\"]],\n                \"y_pred\": y_pred[split_idx[\"train\"]],\n            })[\"acc\"]\n            val_acc = evaluator.eval({\n                \"y_true\": y_true[split_idx[\"valid\"]],\n                \"y_pred\": y_pred[split_idx[\"valid\"]],\n            })[\"acc\"]\n            test_acc = evaluator.eval({\n                \"y_true\": y_true[split_idx[\"test\"]],\n                \"y_pred\": y_pred[split_idx[\"test\"]],\n            })[\"acc\"]\n\n            print(f\"Train: {train_acc:.4f}, Val: {val_acc:.4f}, \"\n                  f\"Test: {test_acc:.4f}\")\n\n        dist.barrier()\n\n    dist.destroy_process_group()\n\n\ndef get_dist_params() -> Tuple[int, int, str]:\n    master_addr = \"127.0.0.1\"\n    master_port = \"29500\"\n    os.environ[\"MASTER_ADDR\"] = master_addr\n    os.environ[\"MASTER_PORT\"] = master_port\n\n    mpi_rank = int(os.environ.get(\"PMI_RANK\", -1))\n    mpi_world_size = int(os.environ.get(\"PMI_SIZE\", -1))\n    rank = mpi_rank if mpi_world_size > 0 else os.environ.get(\"RANK\", 0)\n    world_size = (mpi_world_size if mpi_world_size > 0 else os.environ.get(\n        \"WORLD_SIZE\", 1))\n\n    os.environ[\"RANK\"] = str(rank)\n    os.environ[\"WORLD_SIZE\"] = str(world_size)\n\n    init_method = f\"tcp://{master_addr}:{master_port}\"\n\n    return rank, world_size, init_method\n\n\nif __name__ == \"__main__\":\n    rank, world_size, init_method = get_dist_params()\n    dist.init_process_group(backend=\"ccl\", init_method=init_method,\n                            world_size=world_size, rank=rank)\n\n    path = osp.join(osp.dirname(osp.realpath(__file__)), \"../../data\",\n                    \"ogbn-products\")\n    dataset = PygNodePropPredDataset(\"ogbn-products\", path)\n\n    run(rank, world_size, dataset)\n"
  },
  {
    "path": "examples/multi_gpu/mag240m_graphsage.py",
    "content": "import argparse\nimport os\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn.functional as F\nfrom ogb.lsc import MAG240MDataset\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torchmetrics import Accuracy\nfrom tqdm import tqdm\n\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import BatchNorm, HeteroConv, SAGEConv\n\n\ndef common_step(batch, model):\n    batch_size = batch['paper'].batch_size\n    x_dict = model(batch.x_dict, batch.edge_index_dict)\n    y_hat = x_dict['paper'][:batch_size]\n    y = batch['paper'].y[:batch_size].to(torch.long)\n    return y_hat, y\n\n\ndef training_step(batch, acc, model):\n    y_hat, y = common_step(batch, model)\n    train_loss = F.cross_entropy(y_hat, y)\n    acc(y_hat, y)\n    return train_loss\n\n\ndef validation_step(batch, acc, model):\n    y_hat, y = common_step(batch, model)\n    acc(y_hat, y)\n\n\nclass HeteroSAGEConv(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, dropout, node_types,\n                 edge_types, is_output_layer=False):\n        super().__init__()\n        self.conv = HeteroConv({\n            edge_type: SAGEConv(in_channels, out_channels)\n            for edge_type in edge_types\n        })\n        if not is_output_layer:\n            self.dropout = torch.nn.Dropout(dropout)\n            self.norm_dict = torch.nn.ModuleDict({\n                node_type:\n                BatchNorm(out_channels)\n                for node_type in node_types\n            })\n\n        self.is_output_layer = is_output_layer\n\n    def forward(self, x_dict, edge_index_dict):\n        x_dict = self.conv(x_dict, edge_index_dict)\n        if not self.is_output_layer:\n            for node_type, x in x_dict.items():\n                x = self.dropout(x.relu())\n                x = self.norm_dict[node_type](x)\n                x_dict[node_type] = x\n        return x_dict\n\n\nclass HeteroGraphSAGE(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, num_layers, out_channels,\n                 dropout, node_types, edge_types):\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList()\n        for i in range(num_layers):\n            # Since authors and institution do not come with features, we learn\n            # them via the GNN. However, this also means we need to exclude\n            # them as source types in the first two iterations:\n            if i == 0:\n                edge_types_of_layer = [\n                    edge_type for edge_type in edge_types\n                    if edge_type[0] == 'paper'\n                ]\n            elif i == 1:\n                edge_types_of_layer = [\n                    edge_type for edge_type in edge_types\n                    if edge_type[0] != 'institution'\n                ]\n            else:\n                edge_types_of_layer = edge_types\n\n            conv = HeteroSAGEConv(\n                in_channels if i == 0 else hidden_channels,\n                out_channels if i == num_layers - 1 else hidden_channels,\n                dropout=dropout,\n                node_types=node_types,\n                edge_types=edge_types_of_layer,\n                is_output_layer=i == num_layers - 1,\n            )\n            self.convs.append(conv)\n\n    def forward(self, x_dict, edge_index_dict):\n        for conv in self.convs:\n            x_dict = conv(x_dict, edge_index_dict)\n        return x_dict\n\n\ndef run(\n    rank,\n    data,\n    num_devices,\n    num_epochs,\n    num_steps_per_epoch,\n    log_every_n_steps,\n    batch_size,\n    num_neighbors,\n    hidden_channels,\n    dropout,\n    num_val_steps,\n    lr,\n):\n    if num_devices > 1:\n        if rank == 0:\n            print(\"Setting up distributed...\")\n        os.environ['MASTER_ADDR'] = 'localhost'\n        os.environ['MASTER_PORT'] = '12355'\n        dist.init_process_group('nccl', rank=rank, world_size=num_devices)\n\n    acc = Accuracy(task='multiclass', num_classes=data.num_classes)\n    model = HeteroGraphSAGE(\n        in_channels=-1,\n        hidden_channels=hidden_channels,\n        num_layers=len(num_neighbors),\n        out_channels=data.num_classes,\n        dropout=dropout,\n        node_types=data.node_types,\n        edge_types=data.edge_types,\n    )\n\n    train_idx = data['paper'].train_mask.nonzero(as_tuple=False).view(-1)\n    val_idx = data['paper'].val_mask.nonzero(as_tuple=False).view(-1)\n    if num_devices > 1:  # Split indices into `num_devices` many chunks:\n        train_idx = train_idx.split(train_idx.size(0) // num_devices)[rank]\n        val_idx = val_idx.split(val_idx.size(0) // num_devices)[rank]\n\n    # Delete unused tensors to not sample:\n    del data['paper'].train_mask\n    del data['paper'].val_mask\n    del data['paper'].test_mask\n    del data['paper'].year\n\n    kwargs = dict(\n        batch_size=batch_size,\n        num_workers=16,\n        persistent_workers=True,\n        num_neighbors=num_neighbors,\n        drop_last=True,\n    )\n\n    train_loader = NeighborLoader(\n        data,\n        input_nodes=('paper', train_idx),\n        shuffle=True,\n        **kwargs,\n    )\n    val_loader = NeighborLoader(data, input_nodes=('paper', val_idx), **kwargs)\n\n    if num_devices > 0:\n        model = model.to(rank)\n        acc = acc.to(rank)\n    if num_devices > 1:\n        model = DistributedDataParallel(model, device_ids=[rank])\n    optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n\n    for epoch in range(1, num_epochs + 1):\n        model.train()\n        for i, batch in enumerate(tqdm(train_loader)):\n            if num_steps_per_epoch >= 0 and i >= num_steps_per_epoch:\n                break\n\n            if num_devices > 0:\n                batch = batch.to(rank, 'x', 'y', 'edge_index')\n                # Features loaded in as 16 bits, train in 32 bits:\n                batch['paper'].x = batch['paper'].x.to(torch.float32)\n\n            optimizer.zero_grad()\n            loss = training_step(batch, acc, model)\n            loss.backward()\n            optimizer.step()\n\n            if i % log_every_n_steps == 0:\n                if rank == 0:\n                    print(f\"Epoch: {epoch:02d}, Step: {i:d}, \"\n                          f\"Loss: {loss:.4f}, \"\n                          f\"Train Acc: {acc.compute():.4f}\")\n\n        if num_devices > 1:\n            dist.barrier()\n\n        if rank == 0:\n            print(f\"Epoch: {epoch:02d}, Loss: {loss:.4f}, \"\n                  f\"Train Acc :{acc.compute():.4f}\")\n        acc.reset()\n\n        model.eval()\n        with torch.no_grad():\n            for i, batch in enumerate(tqdm(val_loader)):\n                if num_val_steps >= 0 and i >= num_val_steps:\n                    break\n\n                if num_devices > 0:\n                    batch = batch.to(rank, 'x', 'y', 'edge_index')\n                    batch['paper'].x = batch['paper'].x.to(torch.float32)\n\n                validation_step(batch, acc, model)\n\n            if num_devices > 1:\n                dist.barrier()\n\n            if rank == 0:\n                print(f\"Val Acc: {acc.compute():.4f}\")\n            acc.reset()\n\n    model.eval()\n\n    if num_devices > 1:\n        dist.destroy_process_group()\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--hidden_channels\", type=int, default=1024)\n    parser.add_argument(\"--batch_size\", type=int, default=1024)\n    parser.add_argument(\"--dropout\", type=float, default=0.5)\n    parser.add_argument(\"--lr\", type=float, default=0.001)\n    parser.add_argument(\"--num_epochs\", type=int, default=20)\n    parser.add_argument(\"--num_steps_per_epoch\", type=int, default=-1)\n    parser.add_argument(\"--log_every_n_steps\", type=int, default=100)\n    parser.add_argument(\"--num_val_steps\", type=int, default=-1, help=50)\n    parser.add_argument(\"--num_neighbors\", type=str, default=\"25-15\")\n    parser.add_argument(\"--num_devices\", type=int, default=1)\n    args = parser.parse_args()\n\n    args.num_neighbors = [int(i) for i in args.num_neighbors.split('-')]\n\n    import warnings\n    warnings.simplefilter(\"ignore\")\n\n    if not torch.cuda.is_available():\n        args.num_devices = 0\n    elif args.num_devices > torch.cuda.device_count():\n        args.num_devices = torch.cuda.device_count()\n\n    dataset = MAG240MDataset()\n    data = dataset.to_pyg_hetero_data()\n\n    if args.num_devices > 1:\n        print(\"Let's use\", args.num_devices, \"GPUs!\")\n        from torch.multiprocessing.spawn import ProcessExitedException\n        try:\n            mp.spawn(\n                run,\n                args=(\n                    data,\n                    args.num_devices,\n                    args.num_epochs,\n                    args.num_steps_per_epoch,\n                    args.log_every_n_steps,\n                    args.batch_size,\n                    args.num_neighbors,\n                    args.hidden_channels,\n                    args.dropout,\n                    args.num_val_steps,\n                    args.lr,\n                ),\n                nprocs=args.num_devices,\n                join=True,\n            )\n        except ProcessExitedException as e:\n            print(\"torch.multiprocessing.spawn.ProcessExitedException:\", e)\n            print(\"Exceptions/SIGBUS/Errors may be caused by a lack of RAM\")\n\n    else:\n        run(\n            0,\n            data,\n            args.num_devices,\n            args.num_epochs,\n            args.num_steps_per_epoch,\n            args.log_every_n_steps,\n            args.batch_size,\n            args.num_neighbors,\n            args.hidden_channels,\n            args.dropout,\n            args.num_val_steps,\n            args.lr,\n        )\n"
  },
  {
    "path": "examples/multi_gpu/model_parallel.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GCNConv\nfrom torch_geometric.transforms import NormalizeFeatures\n\nif torch.cuda.device_count() < 2:\n    quit('This example requires multiple GPUs')\n\npath = osp.dirname(osp.realpath(__file__))\npath = osp.join(path, '..', '..', 'data', 'Planetoid')\ndataset = Planetoid(root=path, name='Cora', transform=NormalizeFeatures())\ndata = dataset[0].to('cuda:0')\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, device1, device2):\n        super().__init__()\n        self.device1 = device1\n        self.device2 = device2\n\n        self.conv1 = GCNConv(in_channels, 16).to(device1)\n        self.conv2 = GCNConv(16, out_channels).to(device2)\n\n    def forward(self, x, edge_index):\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.conv1(x, edge_index).relu()\n        # Move data to the second device:\n        x, edge_index = x.to(self.device2), edge_index.to(self.device2)\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.conv2(x, edge_index)\n        return x\n\n\nmodel = GCN(\n    dataset.num_features,\n    dataset.num_classes,\n    device1='cuda:0',\n    device2='cuda:1',\n)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index).to('cuda:0')\n    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    out = model(data.x, data.edge_index)\n    pred = out.argmax(dim=-1).to('cuda:0')\n\n    accs = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\nbest_val_acc = test_acc = 0\ntimes = []\nfor epoch in range(1, 201):\n    loss = train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/multi_gpu/papers100m_gcn.py",
    "content": "import argparse\nimport os\nimport tempfile\nimport time\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn.functional as F\nfrom ogb.nodeproppred import PygNodePropPredDataset\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torchmetrics import Accuracy\n\nimport torch_geometric\nfrom torch_geometric.loader import NeighborLoader\n\n\ndef get_num_workers(world_size):\n    num_work = None\n    if hasattr(os, \"sched_getaffinity\"):\n        try:\n            num_work = len(os.sched_getaffinity(0)) / (2 * world_size)\n        except Exception:\n            pass\n    if num_work is None:\n        num_work = os.cpu_count() / (2 * world_size)\n    return int(num_work)\n\n\ndef run_train(rank, data, world_size, model, epochs, batch_size, fan_out,\n              split_idx, num_classes, wall_clock_start, tempdir=None,\n              num_layers=3):\n\n    # init pytorch worker\n    os.environ['MASTER_ADDR'] = 'localhost'\n    os.environ['MASTER_PORT'] = '12355'\n    dist.init_process_group('nccl', rank=rank, world_size=world_size)\n\n    if world_size > 1:\n        split_idx['train'] = split_idx['train'].split(\n            split_idx['train'].size(0) // world_size, dim=0)[rank].clone()\n        split_idx['valid'] = split_idx['valid'].split(\n            split_idx['valid'].size(0) // world_size, dim=0)[rank].clone()\n        split_idx['test'] = split_idx['test'].split(\n            split_idx['test'].size(0) // world_size, dim=0)[rank].clone()\n    model = model.to(rank)\n    model = DistributedDataParallel(model, device_ids=[rank])\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01,\n                                 weight_decay=0.0005)\n\n    kwargs = dict(\n        num_neighbors=[fan_out] * num_layers,\n        batch_size=batch_size,\n    )\n    num_work = get_num_workers(world_size)\n    train_loader = NeighborLoader(data, input_nodes=split_idx['train'],\n                                  num_workers=num_work, shuffle=True,\n                                  drop_last=True, **kwargs)\n    val_loader = NeighborLoader(data, input_nodes=split_idx['valid'],\n                                num_workers=num_work, **kwargs)\n    test_loader = NeighborLoader(data, input_nodes=split_idx['test'],\n                                 num_workers=num_work, **kwargs)\n\n    eval_steps = 1000\n    warmup_steps = 20\n    acc = Accuracy(task=\"multiclass\", num_classes=num_classes).to(rank)\n    dist.barrier()\n    torch.cuda.synchronize()\n    if rank == 0:\n        prep_time = round(time.perf_counter() - wall_clock_start, 2)\n        print(\"Total time before training begins (prep_time) =\", prep_time,\n              \"seconds\")\n        print(\"Beginning training...\")\n    for epoch in range(epochs):\n        for i, batch in enumerate(train_loader):\n            if i == warmup_steps:\n                torch.cuda.synchronize()\n                start = time.time()\n            batch = batch.to(rank)\n            batch_size = batch.num_sampled_nodes[0]\n            batch.y = batch.y.to(torch.long)\n            optimizer.zero_grad()\n            out = model(batch.x, batch.edge_index)\n            loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size])\n            loss.backward()\n            optimizer.step()\n            if rank == 0 and i % 10 == 0:\n                print(\"Epoch: \" + str(epoch) + \", Iteration: \" + str(i) +\n                      \", Loss: \" + str(loss))\n        nb = i + 1.0\n        dist.barrier()\n        torch.cuda.synchronize()\n        if rank == 0:\n            print(\"Average Training Iteration Time:\",\n                  (time.time() - start) / (nb - warmup_steps), \"s/iter\")\n        with torch.no_grad():\n            for i, batch in enumerate(val_loader):\n                if i >= eval_steps:\n                    break\n\n                batch = batch.to(rank)\n                batch_size = batch.num_sampled_nodes[0]\n\n                batch.y = batch.y.to(torch.long)\n                out = model(batch.x, batch.edge_index)\n                acc_i = acc(  # noqa\n                    out[:batch_size].softmax(dim=-1), batch.y[:batch_size])\n            acc_sum = acc.compute()\n            if rank == 0:\n                print(f\"Validation Accuracy: {acc_sum * 100.0:.4f}%\", )\n        dist.barrier()\n        acc.reset()\n\n    with torch.no_grad():\n        for batch in test_loader:\n            batch = batch.to(rank)\n            batch_size = batch.num_sampled_nodes[0]\n\n            batch.y = batch.y.to(torch.long)\n            out = model(batch.x, batch.edge_index)\n            acc_i = acc(  # noqa\n                out[:batch_size].softmax(dim=-1), batch.y[:batch_size])\n        acc_sum = acc.compute()\n        if rank == 0:\n            print(f\"Test Accuracy: {acc_sum * 100.0:.4f}%\", )\n    dist.barrier()\n    acc.reset()\n    if rank == 0:\n        total_time = round(time.perf_counter() - wall_clock_start, 2)\n        print(\"Total Program Runtime (total_time) =\", total_time, \"seconds\")\n        print(\"total_time - prep_time =\", total_time - prep_time, \"seconds\")\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--hidden_channels', type=int, default=256)\n    parser.add_argument('--num_layers', type=int, default=2)\n    parser.add_argument('--lr', type=float, default=0.001)\n    parser.add_argument('--epochs', type=int, default=20)\n    parser.add_argument('--batch_size', type=int, default=1024)\n    parser.add_argument('--fan_out', type=int, default=30)\n    parser.add_argument(\n        \"--use_gat_conv\",\n        action='store_true',\n        help=\"Whether or not to use GATConv. (Defaults to using GCNConv)\",\n    )\n    parser.add_argument(\n        \"--n_gat_conv_heads\",\n        type=int,\n        default=4,\n        help=\"If using GATConv, number of attention heads to use\",\n    )\n    parser.add_argument(\n        \"--n_devices\", type=int, default=-1,\n        help=\"1-8 to use that many GPUs. Defaults to all available GPUs\")\n\n    args = parser.parse_args()\n    wall_clock_start = time.perf_counter()\n    if args.n_devices == -1:\n        world_size = torch.cuda.device_count()\n    else:\n        world_size = args.n_devices\n    import psutil\n    gb_ram_needed = 190 + 200 * world_size\n    if (psutil.virtual_memory().total / (1024**3)) < gb_ram_needed:\n        print(\"Warning: may not have enough RAM to use this many GPUs.\")\n        print(\"Consider upgrading RAM or using less GPUs if an error occurs.\")\n        print(\"Estimated RAM Needed: ~\" + str(gb_ram_needed))\n    print('Let\\'s use', world_size, 'GPUs!')\n    dataset = PygNodePropPredDataset(name='ogbn-papers100M',\n                                     root='/datasets/ogb_datasets')\n    split_idx = dataset.get_idx_split()\n    data = dataset[0]\n    data.y = data.y.reshape(-1)\n    if args.use_gat_conv:\n        model = torch_geometric.nn.models.GAT(dataset.num_features,\n                                              args.hidden_channels,\n                                              args.num_layers,\n                                              dataset.num_classes,\n                                              heads=args.n_gat_conv_heads)\n    else:\n        model = torch_geometric.nn.models.GCN(\n            dataset.num_features,\n            args.hidden_channels,\n            args.num_layers,\n            dataset.num_classes,\n        )\n\n    print(\"Data =\", data)\n    with tempfile.TemporaryDirectory() as tempdir:\n        if world_size > 1:\n            mp.spawn(\n                run_train,\n                args=(data, world_size, model, args.epochs, args.batch_size,\n                      args.fan_out, split_idx, dataset.num_classes,\n                      wall_clock_start, tempdir, args.num_layers),\n                nprocs=world_size, join=True)\n        else:\n            run_train(0, data, world_size, model, args.epochs, args.batch_size,\n                      args.fan_out, split_idx, dataset.num_classes,\n                      wall_clock_start, tempdir, args.num_layers)\n"
  },
  {
    "path": "examples/multi_gpu/papers100m_gcn_multinode.py",
    "content": "\"\"\"Multi-node multi-GPU example on ogbn-papers100m.\n\nExample way to run using srun:\nsrun -l -N<num_nodes> --ntasks-per-node=<ngpu_per_node> \\\n--container-name=cont --container-image=<image_url> \\\n--container-mounts=/ogb-papers100m/:/workspace/dataset\npython3 path_to_script.py\n\"\"\"\nimport os\nimport time\nfrom typing import Optional\n\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom ogb.nodeproppred import PygNodePropPredDataset\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torchmetrics import Accuracy\n\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import GCN\n\n\ndef get_num_workers() -> int:\n    num_workers = None\n    if hasattr(os, \"sched_getaffinity\"):\n        try:\n            num_workers = len(os.sched_getaffinity(0)) // 2\n        except Exception:\n            pass\n    if num_workers is None:\n        num_workers = os.cpu_count() // 2\n    return num_workers\n\n\ndef run(world_size, data, split_idx, model, acc, wall_clock_start):\n    local_id = int(os.environ['LOCAL_RANK'])\n    rank = torch.distributed.get_rank()\n    torch.cuda.set_device(local_id)\n    device = torch.device(local_id)\n    if rank == 0:\n        print(f'Using {nprocs} GPUs...')\n\n    split_idx['train'] = split_idx['train'].split(\n        split_idx['train'].size(0) // world_size, dim=0)[rank].clone()\n    split_idx['valid'] = split_idx['valid'].split(\n        split_idx['valid'].size(0) // world_size, dim=0)[rank].clone()\n    split_idx['test'] = split_idx['test'].split(\n        split_idx['test'].size(0) // world_size, dim=0)[rank].clone()\n\n    model = DistributedDataParallel(model.to(device), device_ids=[local_id])\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001,\n                                 weight_decay=5e-4)\n\n    kwargs = dict(\n        data=data,\n        batch_size=1024,\n        num_workers=get_num_workers(),\n        num_neighbors=[30, 30],\n    )\n\n    train_loader = NeighborLoader(\n        input_nodes=split_idx['train'],\n        shuffle=True,\n        drop_last=True,\n        **kwargs,\n    )\n    val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)\n    test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)\n\n    val_steps = 1000\n    warmup_steps = 100\n    acc = acc.to(device)\n    dist.barrier()\n    torch.cuda.synchronize()\n    if rank == 0:\n        prep_time = round(time.perf_counter() - wall_clock_start, 2)\n        print(\"Total time before training begins (prep_time)=\", prep_time,\n              \"seconds\")\n        print(\"Beginning training...\")\n\n    for epoch in range(1, 21):\n        model.train()\n        for i, batch in enumerate(train_loader):\n            if i == warmup_steps:\n                torch.cuda.synchronize()\n                start = time.time()\n            batch = batch.to(device)\n            optimizer.zero_grad()\n            y = batch.y[:batch.batch_size].view(-1).to(torch.long)\n            out = model(batch.x, batch.edge_index)[:batch.batch_size]\n            loss = F.cross_entropy(out, y)\n            loss.backward()\n            optimizer.step()\n\n            if rank == 0 and i % 10 == 0:\n                print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}')\n\n        dist.barrier()\n        torch.cuda.synchronize()\n        if rank == 0:\n            sec_per_iter = (time.time() - start) / (i + 1 - warmup_steps)\n            print(f\"Avg Training Iteration Time: {sec_per_iter:.6f} s/iter\")\n\n        @torch.no_grad()\n        def test(loader: NeighborLoader, num_steps: Optional[int] = None):\n            model.eval()\n            for j, batch in enumerate(loader):\n                if num_steps is not None and j >= num_steps:\n                    break\n                batch = batch.to(device)\n                out = model(batch.x, batch.edge_index)[:batch.batch_size]\n                y = batch.y[:batch.batch_size].view(-1).to(torch.long)\n                acc(out, y)\n            acc_sum = acc.compute()\n            return acc_sum\n\n        eval_acc = test(val_loader, num_steps=val_steps)\n        if rank == 0:\n            print(f\"Val Accuracy: {eval_acc:.4f}%\", )\n\n        acc.reset()\n        dist.barrier()\n\n    test_acc = test(test_loader)\n    if rank == 0:\n        print(f\"Test Accuracy: {test_acc:.4f}%\", )\n\n    dist.barrier()\n    acc.reset()\n    torch.cuda.synchronize()\n\n    if rank == 0:\n        total_time = round(time.perf_counter() - wall_clock_start, 2)\n        print(\"Total Program Runtime (total_time) =\", total_time, \"seconds\")\n        print(\"total_time - prep_time =\", total_time - prep_time, \"seconds\")\n\n\nif __name__ == '__main__':\n    wall_clock_start = time.perf_counter()\n    # Setup multi-node:\n    torch.distributed.init_process_group(\"nccl\")\n    nprocs = dist.get_world_size()\n    assert dist.is_initialized(), \"Distributed cluster not initialized\"\n    dataset = PygNodePropPredDataset(name='ogbn-papers100M')\n    split_idx = dataset.get_idx_split()\n    model = GCN(dataset.num_features, 256, 2, dataset.num_classes)\n    acc = Accuracy(task=\"multiclass\", num_classes=dataset.num_classes)\n    data = dataset[0]\n    data.y = data.y.reshape(-1)\n    run(nprocs, data, split_idx, model, acc, wall_clock_start)\n"
  },
  {
    "path": "examples/multi_gpu/pcqm4m_ogb.py",
    "content": "# Code adapted from OGB.\n# https://github.com/snap-stanford/ogb/tree/master/examples/lsc/pcqm4m-v2\nimport argparse\nimport math\nimport os\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder\nfrom torch.nn.parallel import DistributedDataParallel\nfrom torch.optim.lr_scheduler import StepLR\nfrom torch.utils.tensorboard import SummaryWriter\nfrom tqdm.auto import tqdm\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets import PCQM4Mv2\nfrom torch_geometric.io import fs\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import (\n    GlobalAttention,\n    MessagePassing,\n    Set2Set,\n    global_add_pool,\n    global_max_pool,\n    global_mean_pool,\n)\nfrom torch_geometric.utils import degree\n\ntry:\n    from ogb.lsc import PCQM4Mv2Evaluator, PygPCQM4Mv2Dataset\nexcept ImportError as e:\n    raise ImportError(\n        \"`PygPCQM4Mv2Dataset` requires rdkit (`pip install rdkit`)\") from e\n\nfrom ogb.utils import smiles2graph\n\n\ndef ogb_from_smiles_wrapper(smiles, *args, **kwargs):\n    \"\"\"Returns `torch_geometric.data.Data` object from smiles while\n    `ogb.utils.smiles2graph` returns a dict of np arrays.\n    \"\"\"\n    data_dict = smiles2graph(smiles, *args, **kwargs)\n    return Data(\n        x=torch.from_numpy(data_dict['node_feat']),\n        edge_index=torch.from_numpy(data_dict['edge_index']),\n        edge_attr=torch.from_numpy(data_dict['edge_feat']),\n        smiles=smiles,\n    )\n\n\nclass GINConv(MessagePassing):\n    def __init__(self, emb_dim):\n        r\"\"\"GINConv.\n\n        Args:\n            emb_dim (int): node embedding dimensionality\n        \"\"\"\n        super().__init__(aggr=\"add\")\n        self.mlp = torch.nn.Sequential(\n            torch.nn.Linear(emb_dim, emb_dim),\n            torch.nn.BatchNorm1d(emb_dim),\n            torch.nn.ReLU(),\n            torch.nn.Linear(emb_dim, emb_dim),\n        )\n        self.eps = torch.nn.Parameter(torch.Tensor([0]))\n        self.bond_encoder = BondEncoder(emb_dim=emb_dim)\n\n    def forward(self, x, edge_index, edge_attr):\n        edge_embedding = self.bond_encoder(edge_attr)\n        return self.mlp(\n            (1 + self.eps) * x +\n            self.propagate(edge_index, x=x, edge_attr=edge_embedding))\n\n    def message(self, x_j, edge_attr):\n        return F.relu(x_j + edge_attr)\n\n    def update(self, aggr_out):\n        return aggr_out\n\n\nclass GCNConv(MessagePassing):\n    def __init__(self, emb_dim):\n        super().__init__(aggr='add')\n        self.linear = torch.nn.Linear(emb_dim, emb_dim)\n        self.root_emb = torch.nn.Embedding(1, emb_dim)\n        self.bond_encoder = BondEncoder(emb_dim=emb_dim)\n\n    def forward(self, x, edge_index, edge_attr):\n        x = self.linear(x)\n        edge_embedding = self.bond_encoder(edge_attr)\n        row, col = edge_index\n        deg = degree(row, x.size(0), dtype=x.dtype) + 1\n        deg_inv_sqrt = deg.pow(-0.5)\n        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0\n        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]\n        return self.propagate(\n            edge_index, x=x, edge_attr=edge_embedding, norm=norm\n        ) + F.relu(x + self.root_emb.weight) * 1. / deg.view(-1, 1)\n\n    def message(self, x_j, edge_attr, norm):\n        return norm.view(-1, 1) * F.relu(x_j + edge_attr)\n\n    def update(self, aggr_out):\n        return aggr_out\n\n\nclass GNNNode(torch.nn.Module):\n    def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK=\"last\",\n                 residual=False, gnn_type='gin'):\n        r\"\"\"GNN Node.\n\n        Args:\n            emb_dim (int): node embedding dimensionality.\n            num_layers (int): number of GNN message passing layers.\n            residual (bool): whether to add residual connection.\n            drop_ratio (float): dropout ratio.\n            JK (str): \"last\" or \"sum\" to choose JK concat strat.\n            residual (bool): Whether or not to add the residual\n            gnn_type (str): Type of GNN to use.\n        \"\"\"\n        super().__init__()\n        if num_layers < 2:\n            raise ValueError(\"Number of GNN layers must be greater than 1.\")\n\n        self.num_layers = num_layers\n        self.drop_ratio = drop_ratio\n        self.JK = JK\n        self.residual = residual\n        self.atom_encoder = AtomEncoder(emb_dim)\n        self.convs = torch.nn.ModuleList()\n        self.batch_norms = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            if gnn_type == 'gin':\n                self.convs.append(GINConv(emb_dim))\n            elif gnn_type == 'gcn':\n                self.convs.append(GCNConv(emb_dim))\n            else:\n                raise ValueError(f'Undefined GNN type called {gnn_type}')\n\n            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))\n\n    def forward(self, batched_data):\n        x = batched_data.x\n        edge_index = batched_data.edge_index\n        edge_attr = batched_data.edge_attr\n\n        # compute input node embedding\n        h_list = [self.atom_encoder(x)]\n        for layer in range(self.num_layers):\n            h = self.convs[layer](h_list[layer], edge_index, edge_attr)\n            h = self.batch_norms[layer](h)\n\n            if layer == self.num_layers - 1:\n                # remove relu for the last layer\n                h = F.dropout(h, self.drop_ratio, training=self.training)\n            else:\n                h = F.dropout(F.relu(h), self.drop_ratio,\n                              training=self.training)\n\n            if self.residual:\n                h += h_list[layer]\n\n            h_list.append(h)\n\n        # Different implementations of Jk-concat\n        if self.JK == \"last\":\n            node_representation = h_list[-1]\n        elif self.JK == \"sum\":\n            node_representation = 0\n            for layer in range(self.num_layers + 1):\n                node_representation += h_list[layer]\n\n        return node_representation\n\n\nclass GNNNodeVirtualNode(torch.nn.Module):\n    \"\"\"Outputs node representations.\"\"\"\n    def __init__(self, num_layers, emb_dim, drop_ratio=0.5, JK=\"last\",\n                 residual=False, gnn_type='gin'):\n        super().__init__()\n        if num_layers < 2:\n            raise ValueError(\"Number of GNN layers must be greater than 1.\")\n\n        self.num_layers = num_layers\n        self.drop_ratio = drop_ratio\n        self.JK = JK\n        self.residual = residual\n        self.atom_encoder = AtomEncoder(emb_dim)\n\n        # set the initial virtual node embedding to 0.\n        self.virtualnode_embedding = torch.nn.Embedding(1, emb_dim)\n        torch.nn.init.constant_(self.virtualnode_embedding.weight.data, 0)\n\n        self.convs = torch.nn.ModuleList()\n        self.batch_norms = torch.nn.ModuleList()\n        self.mlp_virtualnode_list = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            if gnn_type == 'gin':\n                self.convs.append(GINConv(emb_dim))\n            elif gnn_type == 'gcn':\n                self.convs.append(GCNConv(emb_dim))\n            else:\n                raise ValueError('Undefined GNN type called {gnn_type}')\n\n            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))\n\n        for _ in range(num_layers - 1):\n            self.mlp_virtualnode_list.append(\n                torch.nn.Sequential(\n                    torch.nn.Linear(emb_dim, emb_dim),\n                    torch.nn.BatchNorm1d(emb_dim),\n                    torch.nn.ReLU(),\n                    torch.nn.Linear(emb_dim, emb_dim),\n                    torch.nn.BatchNorm1d(emb_dim),\n                    torch.nn.ReLU(),\n                ))\n\n    def forward(self, batched_data):\n        x = batched_data.x\n        edge_index = batched_data.edge_index\n        edge_attr = batched_data.edge_attr\n        batch = batched_data.batch\n\n        # virtual node embeddings for graphs\n        virtualnode_embedding = self.virtualnode_embedding(\n            torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(\n                edge_index.device))\n\n        h_list = [self.atom_encoder(x)]\n        for layer in range(self.num_layers):\n            # add message from virtual nodes to graph nodes\n            h_list[layer] = h_list[layer] + virtualnode_embedding[batch]\n\n            # Message passing among graph nodes\n            h = self.convs[layer](h_list[layer], edge_index, edge_attr)\n\n            h = self.batch_norms[layer](h)\n            if layer == self.num_layers - 1:\n                # remove relu for the last layer\n                h = F.dropout(h, self.drop_ratio, training=self.training)\n            else:\n                h = F.dropout(F.relu(h), self.drop_ratio,\n                              training=self.training)\n\n            if self.residual:\n                h = h + h_list[layer]\n\n            h_list.append(h)\n\n            # update the virtual nodes\n            if layer < self.num_layers - 1:\n                # add message from graph nodes to virtual nodes\n                virtualnode_embedding_temp = global_add_pool(\n                    h_list[layer], batch) + virtualnode_embedding\n                # transform virtual nodes using MLP\n\n                if self.residual:\n                    virtualnode_embedding = virtualnode_embedding + F.dropout(\n                        self.mlp_virtualnode_list[layer]\n                        (virtualnode_embedding_temp), self.drop_ratio,\n                        training=self.training)\n                else:\n                    virtualnode_embedding = F.dropout(\n                        self.mlp_virtualnode_list[layer](\n                            virtualnode_embedding_temp), self.drop_ratio,\n                        training=self.training)\n\n        # Different implementations of Jk-concat\n        if self.JK == \"last\":\n            node_representation = h_list[-1]\n        elif self.JK == \"sum\":\n            node_representation = 0\n            for layer in range(self.num_layers + 1):\n                node_representation += h_list[layer]\n\n        return node_representation\n\n\nclass GNN(torch.nn.Module):\n    def __init__(\n        self,\n        num_tasks=1,\n        num_layers=5,\n        emb_dim=300,\n        gnn_type='gin',\n        virtual_node=True,\n        residual=False,\n        drop_ratio=0,\n        JK=\"last\",\n        graph_pooling=\"sum\",\n    ):\n        r\"\"\"GNN.\n\n        Args:\n            num_tasks (int): number of labels to be predicted\n            num_layers (int): number of gnn layers.\n            emb_dim (int): embedding dim to use.\n            gnn_type (str): Type of GNN to use.\n            virtual_node (bool): whether to add virtual node or not.\n            residual (bool): Whether or not to add the residual\n            drop_ratio (float): dropout ratio.\n            JK (str): \"last\" or \"sum\" to choose JK concat strat.\n            graph_pooling (str): Graph pooling strat to use.\n        \"\"\"\n        super().__init__()\n        if num_layers < 2:\n            raise ValueError(\"Number of GNN layers must be greater than 1.\")\n\n        self.num_layers = num_layers\n        self.drop_ratio = drop_ratio\n        self.JK = JK\n        self.emb_dim = emb_dim\n        self.num_tasks = num_tasks\n        self.graph_pooling = graph_pooling\n        if virtual_node:\n            self.gnn_node = GNNNodeVirtualNode(\n                num_layers,\n                emb_dim,\n                JK=JK,\n                drop_ratio=drop_ratio,\n                residual=residual,\n                gnn_type=gnn_type,\n            )\n        else:\n            self.gnn_node = GNNNode(\n                num_layers,\n                emb_dim,\n                JK=JK,\n                drop_ratio=drop_ratio,\n                residual=residual,\n                gnn_type=gnn_type,\n            )\n\n        # Pooling function to generate whole-graph embeddings\n        if self.graph_pooling == \"sum\":\n            self.pool = global_add_pool\n        elif self.graph_pooling == \"mean\":\n            self.pool = global_mean_pool\n        elif self.graph_pooling == \"max\":\n            self.pool = global_max_pool\n        elif self.graph_pooling == \"attention\":\n            self.pool = GlobalAttention(gate_nn=torch.nn.Sequential(\n                torch.nn.Linear(emb_dim, emb_dim),\n                torch.nn.BatchNorm1d(emb_dim),\n                torch.nn.ReLU(),\n                torch.nn.Linear(emb_dim, 1),\n            ))\n        elif self.graph_pooling == \"set2set\":\n            self.pool = Set2Set(emb_dim, processing_steps=2)\n        else:\n            raise ValueError(\"Invalid graph pooling type.\")\n\n        if graph_pooling == \"set2set\":\n            self.graph_pred_linear = torch.nn.Linear(2 * emb_dim, num_tasks)\n        else:\n            self.graph_pred_linear = torch.nn.Linear(emb_dim, num_tasks)\n\n    def forward(self, batched_data):\n        h_node = self.gnn_node(batched_data)\n        h_graph = self.pool(h_node, batched_data.batch)\n        output = self.graph_pred_linear(h_graph)\n        if self.training:\n            return output\n        else:\n            # At inference time, we clamp the value between 0 and 20\n            return torch.clamp(output, min=0, max=20)\n\n\ndef train(model, rank, device, loader, optimizer):\n    model.train()\n    reg_criterion = torch.nn.L1Loss()\n    loss_accum = 0.0\n    for step, batch in enumerate(  # noqa: B007\n            tqdm(loader, desc=\"Training\", disable=(rank > 0))):\n        batch = batch.to(device)\n        pred = model(batch).view(-1, )\n        optimizer.zero_grad()\n        loss = reg_criterion(pred, batch.y)\n        loss.backward()\n        optimizer.step()\n        loss_accum += loss.detach().cpu().item()\n    return loss_accum / (step + 1)\n\n\ndef eval(model, device, loader, evaluator):\n    model.eval()\n    y_true = []\n    y_pred = []\n    for batch in tqdm(loader, desc=\"Evaluating\"):\n        batch = batch.to(device)\n        with torch.no_grad():\n            pred = model(batch).view(-1, )\n\n        y_true.append(batch.y.view(pred.shape).detach().cpu())\n        y_pred.append(pred.detach().cpu())\n\n    y_true = torch.cat(y_true, dim=0)\n    y_pred = torch.cat(y_pred, dim=0)\n    input_dict = {\"y_true\": y_true, \"y_pred\": y_pred}\n    return evaluator.eval(input_dict)[\"mae\"]\n\n\ndef test(model, device, loader):\n    model.eval()\n    y_pred = []\n    for batch in tqdm(loader, desc=\"Testing\"):\n        batch = batch.to(device)\n        with torch.no_grad():\n            pred = model(batch).view(-1, )\n\n        y_pred.append(pred.detach().cpu())\n\n    y_pred = torch.cat(y_pred, dim=0)\n    return y_pred\n\n\ndef run(rank, dataset, args):\n    num_devices = args.num_devices\n    device = torch.device(\n        \"cuda:\" + str(rank)) if num_devices > 0 else torch.device(\"cpu\")\n\n    if num_devices > 1:\n        os.environ[\"MASTER_ADDR\"] = \"localhost\"\n        os.environ[\"MASTER_PORT\"] = \"12355\"\n        dist.init_process_group(\"nccl\", rank=rank, world_size=num_devices)\n\n    if args.on_disk_dataset:\n        train_idx = torch.arange(len(dataset.indices()))\n    else:\n        split_idx = dataset.get_idx_split()\n        train_idx = split_idx[\"train\"]\n\n    if num_devices > 1:\n        num_splits = math.ceil(train_idx.size(0) / num_devices)\n        train_idx = train_idx.split(num_splits)[rank]\n\n    if args.train_subset:\n        subset_ratio = 0.1\n        n = len(train_idx)\n        subset_idx = torch.randperm(n)[:int(subset_ratio * n)]\n        train_dataset = dataset[train_idx[subset_idx]]\n    else:\n        train_dataset = dataset[train_idx]\n\n    train_loader = DataLoader(\n        train_dataset,\n        batch_size=args.batch_size,\n        shuffle=True,\n        num_workers=args.num_workers,\n    )\n\n    if rank == 0:\n        if args.on_disk_dataset:\n            valid_dataset = PCQM4Mv2(root='on_disk_dataset/', split=\"val\",\n                                     from_smiles_func=ogb_from_smiles_wrapper)\n            test_dev_dataset = PCQM4Mv2(\n                root='on_disk_dataset/', split=\"test\",\n                from_smiles_func=ogb_from_smiles_wrapper)\n            test_challenge_dataset = PCQM4Mv2(\n                root='on_disk_dataset/', split=\"holdout\",\n                from_smiles_func=ogb_from_smiles_wrapper)\n        else:\n            valid_dataset = dataset[split_idx[\"valid\"]]\n            test_dev_dataset = dataset[split_idx[\"test-dev\"]]\n            test_challenge_dataset = dataset[split_idx[\"test-challenge\"]]\n\n        valid_loader = DataLoader(\n            valid_dataset,\n            batch_size=args.batch_size,\n            shuffle=False,\n            num_workers=args.num_workers,\n        )\n        if args.save_test_dir != '':\n            testdev_loader = DataLoader(\n                test_dev_dataset,\n                batch_size=args.batch_size,\n                shuffle=False,\n                num_workers=args.num_workers,\n            )\n            testchallenge_loader = DataLoader(\n                test_challenge_dataset,\n                batch_size=args.batch_size,\n                shuffle=False,\n                num_workers=args.num_workers,\n            )\n\n        if args.checkpoint_dir != '':\n            os.makedirs(args.checkpoint_dir, exist_ok=True)\n\n        evaluator = PCQM4Mv2Evaluator()\n\n    gnn_type, virtual_node = args.gnn.split('-')\n    model = GNN(\n        gnn_type=gnn_type,\n        virtual_node=virtual_node,\n        num_layers=args.num_layers,\n        emb_dim=args.emb_dim,\n        drop_ratio=args.drop_ratio,\n        graph_pooling=args.graph_pooling,\n    )\n    if num_devices > 0:\n        model = model.to(rank)\n    if num_devices > 1:\n        model = DistributedDataParallel(model, device_ids=[rank])\n\n    optimizer = optim.Adam(model.parameters(), lr=0.001)\n\n    if args.log_dir != '':\n        writer = SummaryWriter(log_dir=args.log_dir)\n\n    best_valid_mae = 1000\n\n    if args.train_subset:\n        scheduler = StepLR(optimizer, step_size=300, gamma=0.25)\n        args.epochs = 1000\n    else:\n        scheduler = StepLR(optimizer, step_size=30, gamma=0.25)\n\n    current_epoch = 1\n\n    checkpoint_path = os.path.join(args.checkpoint_dir, 'checkpoint.pt')\n    if os.path.isfile(checkpoint_path):\n        checkpoint = fs.torch_load(checkpoint_path)\n        current_epoch = checkpoint['epoch'] + 1\n        model.load_state_dict(checkpoint['model_state_dict'])\n        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])\n        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])\n        best_valid_mae = checkpoint['best_val_mae']\n        print(f\"Found checkpoint, resume training at epoch {current_epoch}\")\n\n    for epoch in range(current_epoch, args.epochs + 1):\n        train_mae = train(model, rank, device, train_loader, optimizer)\n\n        if num_devices > 1:\n            dist.barrier()\n\n        if rank == 0:\n            valid_mae = eval(\n                model.module if isinstance(model, DistributedDataParallel) else\n                model, device, valid_loader, evaluator)\n\n            print(f\"Epoch {epoch:02d}, \"\n                  f\"Train MAE: {train_mae:.4f}, \"\n                  f\"Val MAE: {valid_mae:.4f}\")\n\n            if args.log_dir != '':\n                writer.add_scalar('valid/mae', valid_mae, epoch)\n                writer.add_scalar('train/mae', train_mae, epoch)\n\n            if valid_mae < best_valid_mae:\n                best_valid_mae = valid_mae\n                if args.checkpoint_dir != '':\n                    checkpoint = {\n                        'epoch': epoch,\n                        'model_state_dict': model.state_dict(),\n                        'optimizer_state_dict': optimizer.state_dict(),\n                        'scheduler_state_dict': scheduler.state_dict(),\n                        'best_val_mae': best_valid_mae,\n                    }\n                    torch.save(checkpoint, checkpoint_path)\n\n                if args.save_test_dir != '':\n                    test_model = model.module if isinstance(\n                        model, DistributedDataParallel) else model\n\n                    testdev_pred = test(test_model, device, testdev_loader)\n                    evaluator.save_test_submission(\n                        {'y_pred': testdev_pred.cpu().detach().numpy()},\n                        args.save_test_dir,\n                        mode='test-dev',\n                    )\n\n                    testchallenge_pred = test(test_model, device,\n                                              testchallenge_loader)\n                    evaluator.save_test_submission(\n                        {'y_pred': testchallenge_pred.cpu().detach().numpy()},\n                        args.save_test_dir,\n                        mode='test-challenge',\n                    )\n\n            print(f'Best validation MAE so far: {best_valid_mae}')\n\n        if num_devices > 1:\n            dist.barrier()\n\n        scheduler.step()\n\n    if rank == 0 and args.log_dir != '':\n        writer.close()\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\n        description='GNN baselines on pcqm4m with Pytorch Geometrics',\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument('--gnn', type=str, default='gin-virtual',\n                        choices=['gin', 'gin-virtual', 'gcn',\n                                 'gcn-virtual'], help='GNN architecture')\n    parser.add_argument('--graph_pooling', type=str, default='sum',\n                        help='graph pooling strategy mean or sum')\n    parser.add_argument('--drop_ratio', type=float, default=0,\n                        help='dropout ratio')\n    parser.add_argument('--num_layers', type=int, default=5,\n                        help='number of GNN message passing layers')\n    parser.add_argument('--emb_dim', type=int, default=600,\n                        help='dimensionality of hidden units in GNNs')\n    parser.add_argument('--train_subset', action='store_true')\n    parser.add_argument('--batch_size', type=int, default=256,\n                        help='input batch size for training')\n    parser.add_argument('--epochs', type=int, default=100,\n                        help='number of epochs to train')\n    parser.add_argument('--num_workers', type=int, default=0,\n                        help='number of workers')\n    parser.add_argument('--log_dir', type=str, default=\"\",\n                        help='tensorboard log directory')\n    parser.add_argument('--checkpoint_dir', type=str, default='',\n                        help='directory to save checkpoint')\n    parser.add_argument('--save_test_dir', type=str, default='',\n                        help='directory to save test submission file')\n    parser.add_argument('--num_devices', type=int, default='1',\n                        help=\"Number of GPUs, if 0 runs on the CPU\")\n    parser.add_argument('--on_disk_dataset', action='store_true')\n    args = parser.parse_args()\n\n    available_gpus = torch.cuda.device_count() if torch.cuda.is_available(\n    ) else 0\n    if args.num_devices > available_gpus:\n        if available_gpus == 0:\n            print(\"No GPUs available, running w/ CPU...\")\n        else:\n            raise ValueError(f\"Cannot train with {args.num_devices} GPUs: \"\n                             f\"available GPUs count {available_gpus}\")\n\n    # automatic dataloading and splitting\n    if args.on_disk_dataset:\n        dataset = PCQM4Mv2(root='on_disk_dataset/', split='train',\n                           from_smiles_func=ogb_from_smiles_wrapper)\n    else:\n        dataset = PygPCQM4Mv2Dataset(root='dataset/')\n\n    if args.num_devices > 1:\n        mp.spawn(run, args=(dataset, args), nprocs=args.num_devices, join=True)\n    else:\n        run(0, dataset, args)\n"
  },
  {
    "path": "examples/multi_gpu/taobao.py",
    "content": "# An Multi GPU implementation of unsupervised bipartite GraphSAGE\n# using the Alibaba Taobao dataset.\nimport argparse\nimport os\nimport os.path as osp\n\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn.functional as F\nimport tqdm\nfrom sklearn.metrics import roc_auc_score\nfrom torch.nn import Embedding, Linear\nfrom torch.nn.parallel import DistributedDataParallel\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Taobao\nfrom torch_geometric.loader import LinkNeighborLoader\nfrom torch_geometric.nn import SAGEConv\nfrom torch_geometric.utils.convert import to_scipy_sparse_matrix\n\n\nclass ItemGNNEncoder(torch.nn.Module):\n    def __init__(self, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = SAGEConv(-1, hidden_channels)\n        self.conv2 = SAGEConv(hidden_channels, hidden_channels)\n        self.lin = Linear(hidden_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index).relu()\n        return self.lin(x)\n\n\nclass UserGNNEncoder(torch.nn.Module):\n    def __init__(self, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = SAGEConv((-1, -1), hidden_channels)\n        self.conv2 = SAGEConv((-1, -1), hidden_channels)\n        self.conv3 = SAGEConv((-1, -1), hidden_channels)\n        self.lin = Linear(hidden_channels, out_channels)\n\n    def forward(self, x_dict, edge_index_dict):\n        item_x = self.conv1(\n            x_dict['item'],\n            edge_index_dict[('item', 'to', 'item')],\n        ).relu()\n\n        user_x = self.conv2(\n            (x_dict['item'], x_dict['user']),\n            edge_index_dict[('item', 'rev_to', 'user')],\n        ).relu()\n\n        user_x = self.conv3(\n            (item_x, user_x),\n            edge_index_dict[('item', 'rev_to', 'user')],\n        ).relu()\n\n        return self.lin(user_x)\n\n\nclass EdgeDecoder(torch.nn.Module):\n    def __init__(self, hidden_channels):\n        super().__init__()\n        self.lin1 = Linear(2 * hidden_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, 1)\n\n    def forward(self, z_src, z_dst, edge_label_index):\n        row, col = edge_label_index\n        z = torch.cat([z_src[row], z_dst[col]], dim=-1)\n\n        z = self.lin1(z).relu()\n        z = self.lin2(z)\n        return z.view(-1)\n\n\nclass Model(torch.nn.Module):\n    def __init__(self, num_users, num_items, hidden_channels, out_channels):\n        super().__init__()\n        self.user_emb = Embedding(num_users, hidden_channels)\n        self.item_emb = Embedding(num_items, hidden_channels)\n        self.item_encoder = ItemGNNEncoder(hidden_channels, out_channels)\n        self.user_encoder = UserGNNEncoder(hidden_channels, out_channels)\n        self.decoder = EdgeDecoder(out_channels)\n\n    def forward(self, x_dict, edge_index_dict, edge_label_index):\n        z_dict = {}\n        x_dict['user'] = self.user_emb(x_dict['user'])\n        x_dict['item'] = self.item_emb(x_dict['item'])\n        z_dict['item'] = self.item_encoder(\n            x_dict['item'],\n            edge_index_dict[('item', 'to', 'item')],\n        )\n        z_dict['user'] = self.user_encoder(x_dict, edge_index_dict)\n\n        return self.decoder(z_dict['user'], z_dict['item'], edge_label_index)\n\n\ndef run_train(rank, data, train_data, val_data, test_data, args, world_size):\n    if rank == 0:\n        print(\"Setting up Data Loaders...\")\n    train_edge_label_idx = train_data[('user', 'to', 'item')].edge_label_index\n    train_edge_label_idx = train_edge_label_idx.split(\n        train_edge_label_idx.size(1) // world_size, dim=1)[rank].clone()\n    train_loader = LinkNeighborLoader(\n        data=train_data,\n        num_neighbors=[8, 4],\n        edge_label_index=(('user', 'to', 'item'), train_edge_label_idx),\n        neg_sampling='binary',\n        batch_size=args.batch_size,\n        shuffle=True,\n        num_workers=args.num_workers,\n        drop_last=True,\n    )\n\n    val_loader = LinkNeighborLoader(\n        data=val_data,\n        num_neighbors=[8, 4],\n        edge_label_index=(\n            ('user', 'to', 'item'),\n            val_data[('user', 'to', 'item')].edge_label_index,\n        ),\n        edge_label=val_data[('user', 'to', 'item')].edge_label,\n        batch_size=args.batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n    )\n\n    test_loader = LinkNeighborLoader(\n        data=test_data,\n        num_neighbors=[8, 4],\n        edge_label_index=(\n            ('user', 'to', 'item'),\n            test_data[('user', 'to', 'item')].edge_label_index,\n        ),\n        edge_label=test_data[('user', 'to', 'item')].edge_label,\n        batch_size=args.batch_size,\n        shuffle=False,\n        num_workers=args.num_workers,\n    )\n\n    def train():\n        model.train()\n\n        total_loss = total_examples = 0\n        for batch in tqdm.tqdm(train_loader, disable=rank != 0):\n            batch = batch.to(rank)\n            optimizer.zero_grad()\n\n            pred = model(\n                batch.x_dict,\n                batch.edge_index_dict,\n                batch['user', 'item'].edge_label_index,\n            )\n            loss = F.binary_cross_entropy_with_logits(\n                pred, batch['user', 'item'].edge_label)\n\n            loss.backward()\n            optimizer.step()\n            total_loss += float(loss)\n            total_examples += pred.numel()\n\n        return total_loss / total_examples\n\n    @torch.no_grad()\n    def test(loader):\n        model.eval()\n        preds, targets = [], []\n        for batch in tqdm.tqdm(loader, disable=rank != 0):\n            batch = batch.to(rank)\n\n            pred = model(\n                batch.x_dict,\n                batch.edge_index_dict,\n                batch['user', 'item'].edge_label_index,\n            ).sigmoid().view(-1).cpu()\n            target = batch['user', 'item'].edge_label.long().cpu()\n\n            preds.append(pred)\n            targets.append(target)\n\n        pred = torch.cat(preds, dim=0).numpy()\n        target = torch.cat(targets, dim=0).numpy()\n\n        return roc_auc_score(target, pred)\n\n    os.environ['MASTER_ADDR'] = 'localhost'\n    os.environ['MASTER_PORT'] = '12355'\n    dist.init_process_group('nccl', rank=rank, world_size=world_size)\n    model = Model(\n        num_users=data['user'].num_nodes,\n        num_items=data['item'].num_nodes,\n        hidden_channels=64,\n        out_channels=64,\n    ).to(rank)\n    # Initialize lazy modules\n    for batch in train_loader:\n        batch = batch.to(rank)\n        _ = model(\n            batch.x_dict,\n            batch.edge_index_dict,\n            batch['user', 'item'].edge_label_index,\n        )\n        break\n    model = DistributedDataParallel(model, device_ids=[rank])\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n    best_val_auc = 0\n    for epoch in range(1, args.epochs):\n        print(\"Train\")\n        loss = train()\n        if rank == 0:\n            print(\"Val\")\n            val_auc = test(val_loader)\n            best_val_auc = max(best_val_auc, val_auc)\n        if rank == 0:\n            print(\n                f'Epoch: {epoch:02d}, Loss: {loss:4f}, Val AUC: {val_auc:.4f}')\n    if rank == 0:\n        print(\"Test\")\n        test_auc = test(test_loader)\n        print(f'Total {args.epochs:02d} epochs: Final Loss: {loss:4f}, '\n              f'Best Val AUC: {best_val_auc:.4f}, '\n              f'Test AUC: {test_auc:.4f}')\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--num_workers', type=int, default=16,\n                        help=\"Number of workers per dataloader\")\n    parser.add_argument('--lr', type=float, default=0.001)\n    parser.add_argument('--epochs', type=int, default=21)\n    parser.add_argument('--batch_size', type=int, default=2048)\n    parser.add_argument(\n        '--dataset_root_dir', type=str,\n        default=osp.join(osp.dirname(osp.realpath(__file__)),\n                         '../../data/Taobao'))\n    args = parser.parse_args()\n\n    def pre_transform(data):\n        # Compute sparsified item<>item relationships through users:\n        print('Computing item<>item relationships...')\n        mat = to_scipy_sparse_matrix(data['user', 'item'].edge_index).tocsr()\n        mat = mat[:data['user'].num_nodes, :data['item'].num_nodes]\n        comat = mat.T @ mat\n        comat.setdiag(0)\n        comat = comat >= 3.\n        comat = comat.tocoo()\n        row = torch.from_numpy(comat.row).to(torch.long)\n        col = torch.from_numpy(comat.col).to(torch.long)\n        data['item', 'item'].edge_index = torch.stack([row, col], dim=0)\n        return data\n\n    dataset = Taobao(args.dataset_root_dir, pre_transform=pre_transform)\n    data = dataset[0]\n\n    data['user'].x = torch.arange(0, data['user'].num_nodes)\n    data['item'].x = torch.arange(0, data['item'].num_nodes)\n\n    # Only consider user<>item relationships for simplicity:\n    del data['category']\n    del data['item', 'category']\n    del data['user', 'item'].time\n    del data['user', 'item'].behavior\n\n    # Add a reverse ('item', 'rev_to', 'user') relation for message passing:\n    data = T.ToUndirected()(data)\n\n    # Perform a link-level split into training, validation, and test edges:\n    print('Computing data splits...')\n    train_data, val_data, test_data = T.RandomLinkSplit(\n        num_val=0.1,\n        num_test=0.1,\n        neg_sampling_ratio=1.0,\n        add_negative_train_samples=False,\n        edge_types=[('user', 'to', 'item')],\n        rev_edge_types=[('item', 'rev_to', 'user')],\n    )(data)\n    print('Done!')\n\n    world_size = torch.cuda.device_count()\n    print('Let\\'s use', world_size, 'GPUs!')\n    mp.spawn(run_train,\n             args=(data, train_data, val_data, test_data, args, world_size),\n             nprocs=world_size, join=True)\n"
  },
  {
    "path": "examples/mutag_gin.py",
    "content": "import argparse\nimport os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.logging import init_wandb, log\nfrom torch_geometric.nn import MLP, GINConv, global_add_pool\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, default='MUTAG')\nparser.add_argument('--batch_size', type=int, default=128)\nparser.add_argument('--hidden_channels', type=int, default=32)\nparser.add_argument('--num_layers', type=int, default=5)\nparser.add_argument('--lr', type=float, default=0.01)\nparser.add_argument('--epochs', type=int, default=100)\nparser.add_argument('--wandb', action='store_true', help='Track experiment')\nargs = parser.parse_args()\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    # MPS is currently slower than CPU due to missing int64 min/max ops\n    device = torch.device('cpu')\nelse:\n    device = torch.device('cpu')\n\ninit_wandb(\n    name=f'GIN-{args.dataset}',\n    batch_size=args.batch_size,\n    lr=args.lr,\n    epochs=args.epochs,\n    hidden_channels=args.hidden_channels,\n    num_layers=args.num_layers,\n    device=device,\n)\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TU')\ndataset = TUDataset(path, name=args.dataset).shuffle()\n\ntrain_loader = DataLoader(dataset[:0.9], args.batch_size, shuffle=True)\ntest_loader = DataLoader(dataset[0.9:], args.batch_size)\n\n\nclass GIN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            mlp = MLP([in_channels, hidden_channels, hidden_channels])\n            self.convs.append(GINConv(nn=mlp, train_eps=False))\n            in_channels = hidden_channels\n\n        self.mlp = MLP([hidden_channels, hidden_channels, out_channels],\n                       norm=None, dropout=0.5)\n\n    def forward(self, x, edge_index, batch):\n        for conv in self.convs:\n            x = conv(x, edge_index).relu()\n        x = global_add_pool(x, batch)\n        return self.mlp(x)\n\n\nmodel = GIN(\n    in_channels=dataset.num_features,\n    hidden_channels=args.hidden_channels,\n    out_channels=dataset.num_classes,\n    num_layers=args.num_layers,\n).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.edge_index, data.batch)\n        loss = F.cross_entropy(out, data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss) * data.num_graphs\n    return total_loss / len(train_loader.dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_correct = 0\n    for data in loader:\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.batch)\n        pred = out.argmax(dim=-1)\n        total_correct += int((pred == data.y).sum())\n    return total_correct / len(loader.dataset)\n\n\ntimes = []\nfor epoch in range(1, args.epochs + 1):\n    start = time.time()\n    loss = train()\n    train_acc = test(train_loader)\n    test_acc = test(test_loader)\n    log(Epoch=epoch, Loss=loss, Train=train_acc, Test=test_acc)\n    times.append(time.time() - start)\nprint(f'Median time per epoch: {torch.tensor(times).median():.4f}s')\n"
  },
  {
    "path": "examples/node2vec.py",
    "content": "import os.path as osp\nimport sys\n\nimport matplotlib.pyplot as plt\nimport torch\nfrom sklearn.manifold import TSNE\n\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import Node2Vec\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, name='Cora')\ndata = dataset[0]\n\ndevice = 'cuda' if torch.cuda.is_available() else 'cpu'\nmodel = Node2Vec(\n    data.edge_index,\n    embedding_dim=128,\n    walk_length=20,\n    context_size=10,\n    walks_per_node=10,\n    num_negative_samples=1,\n    p=1.0,\n    q=1.0,\n    sparse=True,\n).to(device)\n\nnum_workers = 4 if sys.platform == 'linux' else 0\nloader = model.loader(batch_size=128, shuffle=True, num_workers=num_workers)\noptimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)\n\n\ndef train():\n    model.train()\n    total_loss = 0\n    for pos_rw, neg_rw in loader:\n        optimizer.zero_grad()\n        loss = model.loss(pos_rw.to(device), neg_rw.to(device))\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item()\n    return total_loss / len(loader)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    z = model()\n    acc = model.test(\n        train_z=z[data.train_mask],\n        train_y=data.y[data.train_mask],\n        test_z=z[data.test_mask],\n        test_y=data.y[data.test_mask],\n        max_iter=150,\n    )\n    return acc\n\n\nfor epoch in range(1, 101):\n    loss = train()\n    acc = test()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Acc: {acc:.4f}')\n\n\n@torch.no_grad()\ndef plot_points(colors):\n    model.eval()\n    z = model().cpu().numpy()\n    z = TSNE(n_components=2).fit_transform(z)\n    y = data.y.cpu().numpy()\n\n    plt.figure(figsize=(8, 8))\n    for i in range(dataset.num_classes):\n        plt.scatter(z[y == i, 0], z[y == i, 1], s=20, color=colors[i])\n    plt.axis('off')\n    plt.show()\n\n\ncolors = [\n    '#ffc0cb', '#bada55', '#008080', '#420420', '#7fe5f0', '#065535', '#ffd700'\n]\nplot_points(colors)\n"
  },
  {
    "path": "examples/ogbn_proteins_deepgcn.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom ogb.nodeproppred import Evaluator, PygNodePropPredDataset\nfrom torch.nn import LayerNorm, Linear, ReLU\nfrom tqdm import tqdm\n\nfrom torch_geometric.loader import RandomNodeLoader\nfrom torch_geometric.nn import DeepGCNLayer, GENConv\nfrom torch_geometric.utils import scatter\n\ndataset = PygNodePropPredDataset('ogbn-proteins', root='../data')\nsplitted_idx = dataset.get_idx_split()\ndata = dataset[0]\ndata.node_species = None\ndata.y = data.y.to(torch.float)\n\n# Initialize features of nodes by aggregating edge features.\nrow, col = data.edge_index\ndata.x = scatter(data.edge_attr, col, dim_size=data.num_nodes, reduce='sum')\n\n# Set split indices to masks.\nfor split in ['train', 'valid', 'test']:\n    mask = torch.zeros(data.num_nodes, dtype=torch.bool)\n    mask[splitted_idx[split]] = True\n    data[f'{split}_mask'] = mask\n\ntrain_loader = RandomNodeLoader(data, num_parts=40, shuffle=True,\n                                num_workers=5)\ntest_loader = RandomNodeLoader(data, num_parts=5, num_workers=5)\n\n\nclass DeeperGCN(torch.nn.Module):\n    def __init__(self, hidden_channels, num_layers):\n        super().__init__()\n\n        self.node_encoder = Linear(data.x.size(-1), hidden_channels)\n        self.edge_encoder = Linear(data.edge_attr.size(-1), hidden_channels)\n\n        self.layers = torch.nn.ModuleList()\n        for i in range(1, num_layers + 1):\n            conv = GENConv(hidden_channels, hidden_channels, aggr='softmax',\n                           t=1.0, learn_t=True, num_layers=2, norm='layer')\n            norm = LayerNorm(hidden_channels, elementwise_affine=True)\n            act = ReLU(inplace=True)\n\n            layer = DeepGCNLayer(conv, norm, act, block='res+', dropout=0.1,\n                                 ckpt_grad=i % 3)\n            self.layers.append(layer)\n\n        self.lin = Linear(hidden_channels, data.y.size(-1))\n\n    def forward(self, x, edge_index, edge_attr):\n        x = self.node_encoder(x)\n        edge_attr = self.edge_encoder(edge_attr)\n\n        x = self.layers[0].conv(x, edge_index, edge_attr)\n\n        for layer in self.layers[1:]:\n            x = layer(x, edge_index, edge_attr)\n\n        x = self.layers[0].act(self.layers[0].norm(x))\n        x = F.dropout(x, p=0.1, training=self.training)\n\n        return self.lin(x)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = DeeperGCN(hidden_channels=64, num_layers=28).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\ncriterion = torch.nn.BCEWithLogitsLoss()\nevaluator = Evaluator('ogbn-proteins')\n\n\ndef train(epoch):\n    model.train()\n\n    pbar = tqdm(total=len(train_loader))\n    pbar.set_description(f'Training epoch: {epoch:04d}')\n\n    total_loss = total_examples = 0\n    for data in train_loader:\n        optimizer.zero_grad()\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.edge_attr)\n        loss = criterion(out[data.train_mask], data.y[data.train_mask])\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss) * int(data.train_mask.sum())\n        total_examples += int(data.train_mask.sum())\n\n        pbar.update(1)\n\n    pbar.close()\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n\n    y_true = {'train': [], 'valid': [], 'test': []}\n    y_pred = {'train': [], 'valid': [], 'test': []}\n\n    pbar = tqdm(total=len(test_loader))\n    pbar.set_description(f'Evaluating epoch: {epoch:04d}')\n\n    for data in test_loader:\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.edge_attr)\n\n        for split in y_true.keys():\n            mask = data[f'{split}_mask']\n            y_true[split].append(data.y[mask].cpu())\n            y_pred[split].append(out[mask].cpu())\n\n        pbar.update(1)\n\n    pbar.close()\n\n    train_rocauc = evaluator.eval({\n        'y_true': torch.cat(y_true['train'], dim=0),\n        'y_pred': torch.cat(y_pred['train'], dim=0),\n    })['rocauc']\n\n    valid_rocauc = evaluator.eval({\n        'y_true': torch.cat(y_true['valid'], dim=0),\n        'y_pred': torch.cat(y_pred['valid'], dim=0),\n    })['rocauc']\n\n    test_rocauc = evaluator.eval({\n        'y_true': torch.cat(y_true['test'], dim=0),\n        'y_pred': torch.cat(y_pred['test'], dim=0),\n    })['rocauc']\n\n    return train_rocauc, valid_rocauc, test_rocauc\n\n\nfor epoch in range(1, 1001):\n    loss = train(epoch)\n    train_rocauc, valid_rocauc, test_rocauc = test()\n    print(f'Loss: {loss:.4f}, Train: {train_rocauc:.4f}, '\n          f'Val: {valid_rocauc:.4f}, Test: {test_rocauc:.4f}')\n"
  },
  {
    "path": "examples/ogbn_train.py",
    "content": "import argparse\nimport os.path as osp\nimport time\n\nimport psutil\nimport torch\nimport torch.nn.functional as F\nfrom ogb.nodeproppred import PygNodePropPredDataset\nfrom torch import Tensor\nfrom tqdm import tqdm\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn.models import GAT, GraphSAGE, Polynormer, SGFormer\nfrom torch_geometric.utils import (\n    add_self_loops,\n    remove_self_loops,\n    to_undirected,\n)\n\nparser = argparse.ArgumentParser(\n    formatter_class=argparse.ArgumentDefaultsHelpFormatter, )\nparser.add_argument(\n    '--dataset',\n    type=str,\n    default='ogbn-arxiv',\n    choices=['ogbn-papers100M', 'ogbn-products', 'ogbn-arxiv'],\n    help='Dataset name.',\n)\nparser.add_argument(\n    '--dataset_dir',\n    type=str,\n    default='./data',\n    help='Root directory of dataset.',\n)\nparser.add_argument(\n    \"--model\",\n    type=str.lower,\n    default='SGFormer',\n    choices=['sage', 'gat', 'sgformer', 'polynormer'],\n    help=\"Model used for training\",\n)\n\nparser.add_argument('-e', '--epochs', type=int, default=50)\nparser.add_argument('-le', '--local_epochs', type=int, default=50,\n                    help='warmup epochs for polynormer')\nparser.add_argument('--num_layers', type=int, default=3)\nparser.add_argument('--num_heads', type=int, default=1,\n                    help='number of heads for GAT or Graph Transformer model.')\nparser.add_argument('-b', '--batch_size', type=int, default=1024)\nparser.add_argument('--num_workers', type=int, default=12)\nparser.add_argument('--fan_out', type=int, default=10,\n                    help='number of neighbors in each layer')\nparser.add_argument('--hidden_channels', type=int, default=256)\nparser.add_argument('--lr', type=float, default=0.003)\nparser.add_argument('--wd', type=float, default=0.0)\nparser.add_argument('--dropout', type=float, default=0.5)\nparser.add_argument(\n    '--use_directed_graph',\n    action='store_true',\n    help='Whether or not to use directed graph',\n)\nparser.add_argument(\n    '--add_self_loop',\n    action='store_true',\n    help='Whether or not to add self loop',\n)\nargs = parser.parse_args()\n\nwall_clock_start = time.perf_counter()\n\nif (args.dataset == 'ogbn-papers100M'\n        and (psutil.virtual_memory().total / (1024**3)) < 390):\n    print('Warning: may not have enough RAM to run this example.')\n    print('Consider upgrading RAM if an error occurs.')\n    print('Estimated RAM Needed: ~390GB.')\n\nif args.model == 'polynormer' and args.num_layers != 7:\n    print(\"The original polynormer paper recommends 7 layers, you have chosen\",\n          args.num_layers, \"which may effect results. See for details\")\n\nprint(f'Training {args.dataset} with {args.model} model.')\n\nseed_everything(123)\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nnum_epochs = args.epochs\nif args.model == 'polynormer':\n    num_epochs += args.local_epochs\nnum_layers = args.num_layers\nnum_workers = args.num_workers\nnum_hidden_channels = args.hidden_channels\nbatch_size = args.batch_size\nroot = osp.join(args.dataset_dir, args.dataset)\nprint('The root is: ', root)\ndataset = PygNodePropPredDataset(name=args.dataset, root=root)\nsplit_idx = dataset.get_idx_split()\ndata = dataset[0]\n\nif not args.use_directed_graph:\n    data.edge_index = to_undirected(data.edge_index, reduce='mean')\nif args.add_self_loop:\n    data.edge_index, _ = remove_self_loops(data.edge_index)\n    data.edge_index, _ = add_self_loops(data.edge_index,\n                                        num_nodes=data.num_nodes)\n\ndata.to(device, 'x', 'y')\n\n\ndef get_loader(input_nodes: dict[str, Tensor]) -> NeighborLoader:\n    return NeighborLoader(\n        data,\n        input_nodes=input_nodes,\n        num_neighbors=[args.fan_out] * num_layers,\n        batch_size=batch_size,\n        shuffle=True,\n        num_workers=num_workers,\n        persistent_workers=num_workers > 0,\n        disjoint=args.model in ['sgformer', 'polynormer'],\n    )\n\n\ntrain_loader = get_loader(split_idx['train'])\nval_loader = get_loader(split_idx['valid'])\ntest_loader = get_loader(split_idx['test'])\n\n\ndef train(epoch: int) -> tuple[Tensor, float]:\n    model.train()\n\n    pbar = tqdm(total=split_idx['train'].size(0))\n    pbar.set_description(f'Epoch {epoch:02d}')\n\n    total_loss = total_correct = 0\n    for batch in train_loader:\n        optimizer.zero_grad()\n        if args.model in ['sgformer', 'polynormer']:\n            if args.model == 'polynormer' and epoch == args.local_epochs:\n                print('start global attention')\n                model._global = True\n            out = model(batch.x, batch.edge_index.to(device),\n                        batch.batch.to(device))[:batch.batch_size]\n        else:\n            out = model(batch.x,\n                        batch.edge_index.to(device))[:batch.batch_size]\n        y = batch.y[:batch.batch_size].squeeze().to(torch.long)\n        loss = F.cross_entropy(out, y)\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss)\n        total_correct += int(out.argmax(dim=-1).eq(y).sum())\n        pbar.update(batch.batch_size)\n\n    pbar.close()\n    loss = total_loss / len(train_loader)\n    approx_acc = total_correct / split_idx['train'].size(0)\n    return loss, approx_acc\n\n\n@torch.no_grad()\ndef test(loader: NeighborLoader) -> float:\n    model.eval()\n\n    total_correct = total_examples = 0\n    for batch in loader:\n        batch = batch.to(device)\n        batch_size = batch.num_sampled_nodes[0]\n        if args.model in ['sgformer', 'polynormer']:\n            out = model(batch.x, batch.edge_index,\n                        batch.batch)[:batch.batch_size]\n        else:\n            out = model(batch.x, batch.edge_index)[:batch_size]\n        pred = out.argmax(dim=-1)\n        y = batch.y[:batch_size].view(-1).to(torch.long)\n\n        total_correct += int((pred == y).sum())\n        total_examples += y.size(0)\n\n    return total_correct / total_examples\n\n\ndef get_model(model_name: str) -> torch.nn.Module:\n    if model_name == 'gat':\n        model = GAT(\n            in_channels=dataset.num_features,\n            hidden_channels=num_hidden_channels,\n            num_layers=num_layers,\n            out_channels=dataset.num_classes,\n            dropout=args.dropout,\n            heads=args.num_heads,\n        )\n    elif model_name == 'sage':\n        model = GraphSAGE(\n            in_channels=dataset.num_features,\n            hidden_channels=num_hidden_channels,\n            num_layers=num_layers,\n            out_channels=dataset.num_classes,\n            dropout=args.dropout,\n        )\n    elif model_name == 'sgformer':\n        model = SGFormer(\n            in_channels=dataset.num_features,\n            hidden_channels=num_hidden_channels,\n            out_channels=dataset.num_classes,\n            trans_num_heads=args.num_heads,\n            trans_dropout=args.dropout,\n            gnn_num_layers=num_layers,\n            gnn_dropout=args.dropout,\n        )\n    elif model_name == 'polynormer':\n        model = Polynormer(\n            in_channels=dataset.num_features,\n            hidden_channels=num_hidden_channels,\n            out_channels=dataset.num_classes,\n            local_layers=num_layers,\n        )\n    else:\n        raise ValueError(f'Unsupported model type: {model_name}')\n\n    return model\n\n\nmodel = get_model(args.model).to(device)\nmodel.reset_parameters()\noptimizer = torch.optim.Adam(\n    model.parameters(),\n    lr=args.lr,\n    weight_decay=args.wd,\n)\nscheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max',\n                                                       patience=5)\n\nprint(f'Total time before training begins took '\n      f'{time.perf_counter() - wall_clock_start:.4f}s')\nprint('Training...')\n\ntimes = []\ntrain_times = []\ninference_times = []\nbest_val = 0.\nfor epoch in range(1, num_epochs + 1):\n    train_start = time.perf_counter()\n    loss, _ = train(epoch)\n    train_times.append(time.perf_counter() - train_start)\n\n    inference_start = time.perf_counter()\n    val_acc = test(val_loader)\n    inference_times.append(time.perf_counter() - inference_start)\n\n    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, ',\n          f'Train Time: {train_times[-1]:.4f}s')\n    print(f'Val: {val_acc * 100.0:.2f}%,')\n\n    if val_acc > best_val:\n        best_val = val_acc\n    times.append(time.perf_counter() - train_start)\n    for param_group in optimizer.param_groups:\n        print('lr:')\n        print(param_group['lr'])\n    scheduler.step(val_acc)\n\nprint(f'Average Epoch Time on training: '\n      f'{torch.tensor(train_times).mean():.4f}s')\nprint(f'Average Epoch Time on inference: '\n      f'{torch.tensor(inference_times).mean():.4f}s')\nprint(f'Average Epoch Time: {torch.tensor(times).mean():.4f}s')\nprint(f'Median Epoch Time: {torch.tensor(times).median():.4f}s')\nprint(f'Best Validation Accuracy: {100.0 * best_val:.2f}%')\n\nprint('Testing...')\ntest_final_acc = test(test_loader)\nprint(f'Test Accuracy: {100.0 * test_final_acc:.2f}%')\nprint(f'Total Program Runtime: '\n      f'{time.perf_counter() - wall_clock_start:.4f}s')\n"
  },
  {
    "path": "examples/ogc.py",
    "content": "# The OGC method from the \"From Cluster Assumption to Graph Convolution:\n# Graph-based Semi-Supervised Learning Revisited\" paper.\n# ArXiv: https://arxiv.org/abs/2309.13599\n\n# Datasets  CiteSeer  Cora   PubMed\n# Acc       0.774     0.869  0.837\n# Time      3.76      1.53   2.92\n\nimport argparse\nimport os.path as osp\nimport time\nimport warnings\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.utils import one_hot\n\nwarnings.filterwarnings('ignore', '.*Sparse CSR tensor support.*')\n\ndecline = 0.9  # decline rate\neta_sup = 0.001  # learning rate for supervised loss\neta_W = 0.5  # learning rate for updating W\nbeta = 0.1  # moving probability that a node moves to neighbors\nmax_sim_tol = 0.995  # max label prediction similarity between iterations\nmax_patience = 2  # tolerance for consecutive similar test predictions\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, default='Cora')\nargs = parser.parse_args()\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\n\ntransform = T.Compose([\n    T.NormalizeFeatures(),\n    T.GCNNorm(),\n    T.ToSparseTensor(layout=torch.sparse_csr),\n])\ndataset = Planetoid(path, name=args.dataset, transform=transform)\ndata = dataset[0].to(device)\n\ny_one_hot = one_hot(data.y, dataset.num_classes)\ndata.trainval_mask = data.train_mask | data.val_mask\n# LIM track, else use trainval_mask to construct S\nS = torch.diag(data.train_mask).float().to_sparse()\nI_N = torch.eye(data.num_nodes).to_sparse(layout=torch.sparse_csr).to(device)\n\n# Lazy random walk (also known as lazy graph convolution):\nlazy_adj = beta * data.adj_t + (1 - beta) * I_N\n\n\nclass LinearNeuralNetwork(torch.nn.Module):\n    def __init__(self, num_features: int, num_classes: int, bias: bool = True):\n        super().__init__()\n        self.W = torch.nn.Linear(num_features, num_classes, bias=bias)\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self.W(x)\n\n    @torch.no_grad()\n    def test(self, U: Tensor, y_one_hot: Tensor, data: Data):\n        self.eval()\n        out = self(U)\n\n        loss = F.mse_loss(\n            out[data.trainval_mask],\n            y_one_hot[data.trainval_mask],\n        )\n\n        accs = []\n        pred = out.argmax(dim=-1)\n        for _, mask in data('trainval_mask', 'test_mask'):\n            accs.append(float((pred[mask] == data.y[mask]).sum() / mask.sum()))\n\n        return float(loss), accs[0], accs[1], pred\n\n    def update_W(self, U: Tensor, y_one_hot: Tensor, data: Data):\n        optimizer = torch.optim.SGD(self.parameters(), lr=eta_W)\n        self.train()\n        optimizer.zero_grad()\n        pred = self(U)\n        loss = F.mse_loss(pred[data.trainval_mask], y_one_hot[\n            data.trainval_mask,\n        ], reduction='sum')\n        loss.backward()\n        optimizer.step()\n        return self(U).data, self.W.weight.data\n\n\nmodel = LinearNeuralNetwork(\n    num_features=dataset.num_features,\n    num_classes=dataset.num_classes,\n    bias=False,\n).to(device)\n\n\ndef update_U(U: Tensor, y_one_hot: Tensor, pred: Tensor, W: Tensor):\n    global eta_sup\n\n    # Update the smoothness loss via LGC:\n    U = lazy_adj @ U\n\n    # Update the supervised loss via SEB:\n    dU_sup = 2 * (S @ (-y_one_hot + pred)) @ W\n    U = U - eta_sup * dU_sup\n\n    eta_sup = eta_sup * decline\n    return U\n\n\ndef ogc() -> float:\n    U = data.x\n    _, _, last_acc, last_pred = model.test(U, y_one_hot, data)\n\n    patience = 0\n    for i in range(1, 65):\n        # Updating W by training a simple linear neural network:\n        pred, W = model.update_W(U, y_one_hot, data)\n\n        # Updating U by LGC and SEB jointly:\n        U = update_U(U, y_one_hot, pred, W)\n\n        loss, trainval_acc, test_acc, pred = model.test(U, y_one_hot, data)\n        print(f'Epoch: {i:02d}, Loss: {loss:.4f}, '\n              f'Train+Val Acc: {trainval_acc:.4f} Test Acc {test_acc:.4f}')\n\n        sim_rate = float((pred == last_pred).sum()) / pred.size(0)\n        if (sim_rate > max_sim_tol):\n            patience += 1\n            if (patience > max_patience):\n                break\n\n        last_acc, last_pred = test_acc, pred\n\n    return last_acc\n\n\nstart_time = time.time()\ntest_acc = ogc()\nprint(f'Test Accuracy: {test_acc:.4f}')\nprint(f'Total Time: {time.time() - start_time:.4f}s')\n"
  },
  {
    "path": "examples/pmlp.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import PMLP\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, name='Cora', transform=T.NormalizeFeatures())\ndata = dataset[0].to(device)\n\nmodel = PMLP(\n    in_channels=dataset.num_features,\n    hidden_channels=16,\n    out_channels=dataset.num_classes,\n    num_layers=2,\n    dropout=0.5,\n    norm=False,\n).to(device)\n\noptimizer = torch.optim.Adam(model.parameters(), weight_decay=5e-4, lr=0.01)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x)  # MLP during training.\n    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    out = model(data.x, data.edge_index)\n    pred = out.argmax(dim=-1)\n\n    accs = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\nbest_val_acc = final_test_acc = 0\nfor epoch in range(1, 201):\n    loss = train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '\n          f'Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/pna.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Embedding, Linear, ModuleList, ReLU, Sequential\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\n\nimport torch_geometric\nfrom torch_geometric.datasets import ZINC\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import BatchNorm, PNAConv, global_add_pool\nfrom torch_geometric.utils import degree\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ZINC')\ntrain_dataset = ZINC(path, subset=True, split='train')\nval_dataset = ZINC(path, subset=True, split='val')\ntest_dataset = ZINC(path, subset=True, split='test')\n\ntrain_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=128)\ntest_loader = DataLoader(test_dataset, batch_size=128)\n\n# Compute the maximum in-degree in the training data.\nmax_degree = -1\nfor data in train_dataset:\n    d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)\n    max_degree = max(max_degree, int(d.max()))\n\n# Compute the in-degree histogram tensor\ndeg = torch.zeros(max_degree + 1, dtype=torch.long)\nfor data in train_dataset:\n    d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)\n    deg += torch.bincount(d, minlength=deg.numel())\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.node_emb = Embedding(21, 75)\n        self.edge_emb = Embedding(4, 50)\n\n        aggregators = ['mean', 'min', 'max', 'std']\n        scalers = ['identity', 'amplification', 'attenuation']\n\n        self.convs = ModuleList()\n        self.batch_norms = ModuleList()\n        for _ in range(4):\n            conv = PNAConv(in_channels=75, out_channels=75,\n                           aggregators=aggregators, scalers=scalers, deg=deg,\n                           edge_dim=50, towers=5, pre_layers=1, post_layers=1,\n                           divide_input=False)\n            self.convs.append(conv)\n            self.batch_norms.append(BatchNorm(75))\n\n        self.mlp = Sequential(Linear(75, 50), ReLU(), Linear(50, 25), ReLU(),\n                              Linear(25, 1))\n\n    def forward(self, x, edge_index, edge_attr, batch):\n        x = self.node_emb(x.squeeze())\n        edge_attr = self.edge_emb(edge_attr)\n\n        for conv, batch_norm in zip(self.convs, self.batch_norms):\n            x = F.relu(batch_norm(conv(x, edge_index, edge_attr)))\n\n        x = global_add_pool(x, batch)\n        return self.mlp(x)\n\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif torch_geometric.is_xpu_available():\n    device = torch.device('xpu')\nelse:\n    device = torch.device('cpu')\nmodel = Net().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\nscheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,\n                              min_lr=0.00001)\n\n\ndef train(epoch):\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.edge_index, data.edge_attr, data.batch)\n        loss = (out.squeeze() - data.y).abs().mean()\n        loss.backward()\n        total_loss += loss.item() * data.num_graphs\n        optimizer.step()\n    return total_loss / len(train_loader.dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_error = 0\n    for data in loader:\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.edge_attr, data.batch)\n        total_error += (out.squeeze() - data.y).abs().sum().item()\n    return total_error / len(loader.dataset)\n\n\nfor epoch in range(1, 301):\n    loss = train(epoch)\n    val_mae = test(val_loader)\n    test_mae = test(test_loader)\n    scheduler.step(val_mae)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '\n          f'Test: {test_mae:.4f}')\n"
  },
  {
    "path": "examples/point_transformer_classification.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear as Lin\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import ModelNet\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import (\n    MLP,\n    PointTransformerConv,\n    fps,\n    global_mean_pool,\n    knn,\n    knn_graph,\n)\nfrom torch_geometric.typing import WITH_TORCH_CLUSTER\nfrom torch_geometric.utils import scatter\n\nif not WITH_TORCH_CLUSTER:\n    quit(\"This example requires 'torch-cluster'\")\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data/ModelNet10')\npre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024)\ntrain_dataset = ModelNet(path, '10', True, transform, pre_transform)\ntest_dataset = ModelNet(path, '10', False, transform, pre_transform)\ntrain_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\ntest_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)\n\n\nclass TransformerBlock(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.lin_in = Lin(in_channels, in_channels)\n        self.lin_out = Lin(out_channels, out_channels)\n\n        self.pos_nn = MLP([3, 64, out_channels], norm=None, plain_last=False)\n\n        self.attn_nn = MLP([out_channels, 64, out_channels], norm=None,\n                           plain_last=False)\n\n        self.transformer = PointTransformerConv(in_channels, out_channels,\n                                                pos_nn=self.pos_nn,\n                                                attn_nn=self.attn_nn)\n\n    def forward(self, x, pos, edge_index):\n        x = self.lin_in(x).relu()\n        x = self.transformer(x, pos, edge_index)\n        x = self.lin_out(x).relu()\n        return x\n\n\nclass TransitionDown(torch.nn.Module):\n    \"\"\"Samples the input point cloud by a ratio percentage to reduce\n    cardinality and uses an mlp to augment features dimensionnality.\n    \"\"\"\n    def __init__(self, in_channels, out_channels, ratio=0.25, k=16):\n        super().__init__()\n        self.k = k\n        self.ratio = ratio\n        self.mlp = MLP([in_channels, out_channels], plain_last=False)\n\n    def forward(self, x, pos, batch):\n        # FPS sampling\n        id_clusters = fps(pos, ratio=self.ratio, batch=batch)\n\n        # compute for each cluster the k nearest points\n        sub_batch = batch[id_clusters] if batch is not None else None\n\n        # beware of self loop\n        id_k_neighbor = knn(pos, pos[id_clusters], k=self.k, batch_x=batch,\n                            batch_y=sub_batch)\n\n        # transformation of features through a simple MLP\n        x = self.mlp(x)\n\n        # Max pool onto each cluster the features from knn in points\n        x_out = scatter(x[id_k_neighbor[1]], id_k_neighbor[0], dim=0,\n                        dim_size=id_clusters.size(0), reduce='max')\n\n        # keep only the clusters and their max-pooled features\n        sub_pos, out = pos[id_clusters], x_out\n        return out, sub_pos, sub_batch\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, dim_model, k=16):\n        super().__init__()\n        self.k = k\n\n        # dummy feature is created if there is none given\n        in_channels = max(in_channels, 1)\n\n        # first block\n        self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False)\n\n        self.transformer_input = TransformerBlock(in_channels=dim_model[0],\n                                                  out_channels=dim_model[0])\n        # backbone layers\n        self.transformers_down = torch.nn.ModuleList()\n        self.transition_down = torch.nn.ModuleList()\n\n        for i in range(len(dim_model) - 1):\n            # Add Transition Down block followed by a Transformer block\n            self.transition_down.append(\n                TransitionDown(in_channels=dim_model[i],\n                               out_channels=dim_model[i + 1], k=self.k))\n\n            self.transformers_down.append(\n                TransformerBlock(in_channels=dim_model[i + 1],\n                                 out_channels=dim_model[i + 1]))\n\n        # class score computation\n        self.mlp_output = MLP([dim_model[-1], 64, out_channels], norm=None)\n\n    def forward(self, x, pos, batch=None):\n\n        # add dummy features in case there is none\n        if x is None:\n            x = torch.ones((pos.shape[0], 1), device=pos.get_device())\n\n        # first block\n        x = self.mlp_input(x)\n        edge_index = knn_graph(pos, k=self.k, batch=batch)\n        x = self.transformer_input(x, pos, edge_index)\n\n        # backbone\n        for i in range(len(self.transformers_down)):\n            x, pos, batch = self.transition_down[i](x, pos, batch=batch)\n\n            edge_index = knn_graph(pos, k=self.k, batch=batch)\n            x = self.transformers_down[i](x, pos, edge_index)\n\n        # GlobalAveragePooling\n        x = global_mean_pool(x, batch)\n\n        # Class score\n        out = self.mlp_output(x)\n\n        return F.log_softmax(out, dim=-1)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.pos, data.batch)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        total_loss += loss.item() * data.num_graphs\n        optimizer.step()\n    return total_loss / len(train_dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    correct = 0\n    for data in loader:\n        data = data.to(device)\n        pred = model(data.x, data.pos, data.batch).max(dim=1)[1]\n        correct += pred.eq(data.y).sum().item()\n    return correct / len(loader.dataset)\n\n\nif __name__ == '__main__':\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    model = Net(0, train_dataset.num_classes,\n                dim_model=[32, 64, 128, 256, 512], k=16).to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20,\n                                                gamma=0.5)\n\n    for epoch in range(1, 201):\n        loss = train()\n        test_acc = test(test_loader)\n        print(f'Epoch {epoch:03d}, Loss: {loss:.4f}, Test: {test_acc:.4f}')\n        scheduler.step()\n"
  },
  {
    "path": "examples/point_transformer_segmentation.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom point_transformer_classification import TransformerBlock, TransitionDown\nfrom torchmetrics.functional import jaccard_index\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import ShapeNet\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import MLP, knn_graph, knn_interpolate\nfrom torch_geometric.typing import WITH_TORCH_CLUSTER\nfrom torch_geometric.utils import scatter\n\nif not WITH_TORCH_CLUSTER:\n    quit(\"This example requires 'torch-cluster'\")\n\ncategory = 'Airplane'  # Pass in `None` to train on all categories.\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')\ntransform = T.Compose([\n    T.RandomJitter(0.01),\n    T.RandomRotate(15, axis=0),\n    T.RandomRotate(15, axis=1),\n    T.RandomRotate(15, axis=2),\n])\npre_transform = T.NormalizeScale()\ntrain_dataset = ShapeNet(path, category, split='trainval', transform=transform,\n                         pre_transform=pre_transform)\ntest_dataset = ShapeNet(path, category, split='test',\n                        pre_transform=pre_transform)\ntrain_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)\ntest_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)\n\n\nclass TransitionUp(torch.nn.Module):\n    \"\"\"Reduce features dimensionality and interpolate back to higher\n    resolution and cardinality.\n    \"\"\"\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.mlp_sub = MLP([in_channels, out_channels], plain_last=False)\n        self.mlp = MLP([out_channels, out_channels], plain_last=False)\n\n    def forward(self, x, x_sub, pos, pos_sub, batch=None, batch_sub=None):\n        # transform low-res features and reduce the number of features\n        x_sub = self.mlp_sub(x_sub)\n\n        # interpolate low-res feats to high-res points\n        x_interpolated = knn_interpolate(x_sub, pos_sub, pos, k=3,\n                                         batch_x=batch_sub, batch_y=batch)\n\n        x = self.mlp(x) + x_interpolated\n\n        return x\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, dim_model, k=16):\n        super().__init__()\n        self.k = k\n\n        # dummy feature is created if there is none given\n        in_channels = max(in_channels, 1)\n\n        # first block\n        self.mlp_input = MLP([in_channels, dim_model[0]], plain_last=False)\n\n        self.transformer_input = TransformerBlock(\n            in_channels=dim_model[0],\n            out_channels=dim_model[0],\n        )\n\n        # backbone layers\n        self.transformers_up = torch.nn.ModuleList()\n        self.transformers_down = torch.nn.ModuleList()\n        self.transition_up = torch.nn.ModuleList()\n        self.transition_down = torch.nn.ModuleList()\n\n        for i in range(0, len(dim_model) - 1):\n\n            # Add Transition Down block followed by a Point Transformer block\n            self.transition_down.append(\n                TransitionDown(in_channels=dim_model[i],\n                               out_channels=dim_model[i + 1], k=self.k))\n\n            self.transformers_down.append(\n                TransformerBlock(in_channels=dim_model[i + 1],\n                                 out_channels=dim_model[i + 1]))\n\n            # Add Transition Up block followed by Point Transformer block\n            self.transition_up.append(\n                TransitionUp(in_channels=dim_model[i + 1],\n                             out_channels=dim_model[i]))\n\n            self.transformers_up.append(\n                TransformerBlock(in_channels=dim_model[i],\n                                 out_channels=dim_model[i]))\n\n        # summit layers\n        self.mlp_summit = MLP([dim_model[-1], dim_model[-1]], norm=None,\n                              plain_last=False)\n\n        self.transformer_summit = TransformerBlock(\n            in_channels=dim_model[-1],\n            out_channels=dim_model[-1],\n        )\n\n        # class score computation\n        self.mlp_output = MLP([dim_model[0], 64, out_channels], norm=None)\n\n    def forward(self, x, pos, batch=None):\n\n        # add dummy features in case there is none\n        if x is None:\n            x = torch.ones((pos.shape[0], 1)).to(pos.get_device())\n\n        out_x = []\n        out_pos = []\n        out_batch = []\n\n        # first block\n        x = self.mlp_input(x)\n        edge_index = knn_graph(pos, k=self.k, batch=batch)\n        x = self.transformer_input(x, pos, edge_index)\n\n        # save outputs for skipping connections\n        out_x.append(x)\n        out_pos.append(pos)\n        out_batch.append(batch)\n\n        # backbone down : #reduce cardinality and augment dimensionnality\n        for i in range(len(self.transformers_down)):\n            x, pos, batch = self.transition_down[i](x, pos, batch=batch)\n            edge_index = knn_graph(pos, k=self.k, batch=batch)\n            x = self.transformers_down[i](x, pos, edge_index)\n\n            out_x.append(x)\n            out_pos.append(pos)\n            out_batch.append(batch)\n\n        # summit\n        x = self.mlp_summit(x)\n        edge_index = knn_graph(pos, k=self.k, batch=batch)\n        x = self.transformer_summit(x, pos, edge_index)\n\n        # backbone up : augment cardinality and reduce dimensionnality\n        n = len(self.transformers_down)\n        for i in range(n):\n            x = self.transition_up[-i - 1](x=out_x[-i - 2], x_sub=x,\n                                           pos=out_pos[-i - 2],\n                                           pos_sub=out_pos[-i - 1],\n                                           batch_sub=out_batch[-i - 1],\n                                           batch=out_batch[-i - 2])\n\n            edge_index = knn_graph(out_pos[-i - 2], k=self.k,\n                                   batch=out_batch[-i - 2])\n            x = self.transformers_up[-i - 1](x, out_pos[-i - 2], edge_index)\n\n        # Class score\n        out = self.mlp_output(x)\n\n        return F.log_softmax(out, dim=-1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(3, train_dataset.num_classes, dim_model=[32, 64, 128, 256, 512],\n            k=16).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\nscheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)\n\n\ndef train():\n    model.train()\n\n    total_loss = correct_nodes = total_nodes = 0\n    for i, data in enumerate(train_loader):\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.pos, data.batch)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item()\n        correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()\n        total_nodes += data.num_nodes\n\n        if (i + 1) % 10 == 0:\n            print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '\n                  f'Train Acc: {correct_nodes / total_nodes:.4f}')\n            total_loss = correct_nodes = total_nodes = 0\n\n\ndef test(loader):\n    model.eval()\n\n    ious, categories = [], []\n    y_map = torch.empty(loader.dataset.num_classes, device=device).long()\n    for data in loader:\n        data = data.to(device)\n        outs = model(data.x, data.pos, data.batch)\n\n        sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()\n        for out, y, category in zip(outs.split(sizes), data.y.split(sizes),\n                                    data.category.tolist()):\n            category = list(ShapeNet.seg_classes.keys())[category]\n            part = ShapeNet.seg_classes[category]\n            part = torch.tensor(part, device=device)\n\n            y_map[part] = torch.arange(part.size(0), device=device)\n\n            iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y],\n                                num_classes=part.size(0), absent_score=1.0)\n            ious.append(iou)\n\n        categories.append(data.category)\n\n    iou = torch.tensor(ious, device=device)\n    category = torch.cat(categories, dim=0)\n\n    mean_iou = scatter(iou, category, reduce='mean')  # Per-category IoU.\n    return float(mean_iou.mean())  # Global IoU.\n\n\nfor epoch in range(1, 100):\n    train()\n    iou = test(test_loader)\n    print(f'Epoch: {epoch:03d}, Test IoU: {iou:.4f}')\n    scheduler.step()\n"
  },
  {
    "path": "examples/pointnet2_classification.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import ModelNet\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import MLP, PointNetConv, fps, global_max_pool, radius\nfrom torch_geometric.typing import WITH_TORCH_CLUSTER\n\nif not WITH_TORCH_CLUSTER:\n    quit(\"This example requires 'torch-cluster'\")\n\n\nclass SAModule(torch.nn.Module):\n    def __init__(self, ratio, r, nn):\n        super().__init__()\n        self.ratio = ratio\n        self.r = r\n        self.conv = PointNetConv(nn, add_self_loops=False)\n\n    def forward(self, x, pos, batch):\n        idx = fps(pos, batch, ratio=self.ratio)\n        row, col = radius(pos, pos[idx], self.r, batch, batch[idx],\n                          max_num_neighbors=64)\n        edge_index = torch.stack([col, row], dim=0)\n        x_dst = None if x is None else x[idx]\n        x = self.conv((x, x_dst), (pos, pos[idx]), edge_index)\n        pos, batch = pos[idx], batch[idx]\n        return x, pos, batch\n\n\nclass GlobalSAModule(torch.nn.Module):\n    def __init__(self, nn):\n        super().__init__()\n        self.nn = nn\n\n    def forward(self, x, pos, batch):\n        x = self.nn(torch.cat([x, pos], dim=1))\n        x = global_max_pool(x, batch)\n        pos = pos.new_zeros((x.size(0), 3))\n        batch = torch.arange(x.size(0), device=batch.device)\n        return x, pos, batch\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        # Input channels account for both `pos` and node features.\n        self.sa1_module = SAModule(0.5, 0.2, MLP([3, 64, 64, 128]))\n        self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))\n        self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))\n\n        self.mlp = MLP([1024, 512, 256, 10], dropout=0.5, norm=None)\n\n    def forward(self, data):\n        sa0_out = (data.x, data.pos, data.batch)\n        sa1_out = self.sa1_module(*sa0_out)\n        sa2_out = self.sa2_module(*sa1_out)\n        sa3_out = self.sa3_module(*sa2_out)\n        x, pos, batch = sa3_out\n\n        return self.mlp(x).log_softmax(dim=-1)\n\n\ndef train(epoch):\n    model.train()\n\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        loss = F.nll_loss(model(data), data.y)\n        loss.backward()\n        optimizer.step()\n\n\ndef test(loader):\n    model.eval()\n\n    correct = 0\n    for data in loader:\n        data = data.to(device)\n        with torch.no_grad():\n            pred = model(data).max(1)[1]\n        correct += pred.eq(data.y).sum().item()\n    return correct / len(loader.dataset)\n\n\nif __name__ == '__main__':\n    path = osp.join(osp.dirname(osp.realpath(__file__)), '..',\n                    'data/ModelNet10')\n    pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024)\n    train_dataset = ModelNet(path, '10', True, transform, pre_transform)\n    test_dataset = ModelNet(path, '10', False, transform, pre_transform)\n    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,\n                              num_workers=6)\n    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,\n                             num_workers=6)\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    model = Net().to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n    for epoch in range(1, 201):\n        train(epoch)\n        test_acc = test(test_loader)\n        print(f'Epoch: {epoch:03d}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/pointnet2_segmentation.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom pointnet2_classification import GlobalSAModule, SAModule\nfrom torchmetrics.functional import jaccard_index\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import ShapeNet\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import MLP, knn_interpolate\nfrom torch_geometric.typing import WITH_TORCH_CLUSTER\nfrom torch_geometric.utils import scatter\n\nif not WITH_TORCH_CLUSTER:\n    quit(\"This example requires 'torch-cluster'\")\n\ncategory = 'Airplane'  # Pass in `None` to train on all categories.\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')\ntransform = T.Compose([\n    T.RandomJitter(0.01),\n    T.RandomRotate(15, axis=0),\n    T.RandomRotate(15, axis=1),\n    T.RandomRotate(15, axis=2)\n])\npre_transform = T.NormalizeScale()\ntrain_dataset = ShapeNet(path, category, split='trainval', transform=transform,\n                         pre_transform=pre_transform)\ntest_dataset = ShapeNet(path, category, split='test',\n                        pre_transform=pre_transform)\ntrain_loader = DataLoader(train_dataset, batch_size=12, shuffle=True,\n                          num_workers=6)\ntest_loader = DataLoader(test_dataset, batch_size=12, shuffle=False,\n                         num_workers=6)\n\n\nclass FPModule(torch.nn.Module):\n    def __init__(self, k, nn):\n        super().__init__()\n        self.k = k\n        self.nn = nn\n\n    def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):\n        x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)\n        if x_skip is not None:\n            x = torch.cat([x, x_skip], dim=1)\n        x = self.nn(x)\n        return x, pos_skip, batch_skip\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, num_classes):\n        super().__init__()\n\n        # Input channels account for both `pos` and node features.\n        self.sa1_module = SAModule(0.2, 0.2, MLP([3 + 3, 64, 64, 128]))\n        self.sa2_module = SAModule(0.25, 0.4, MLP([128 + 3, 128, 128, 256]))\n        self.sa3_module = GlobalSAModule(MLP([256 + 3, 256, 512, 1024]))\n\n        self.fp3_module = FPModule(1, MLP([1024 + 256, 256, 256]))\n        self.fp2_module = FPModule(3, MLP([256 + 128, 256, 128]))\n        self.fp1_module = FPModule(3, MLP([128 + 3, 128, 128, 128]))\n\n        self.mlp = MLP([128, 128, 128, num_classes], dropout=0.5, norm=None)\n\n        self.lin1 = torch.nn.Linear(128, 128)\n        self.lin2 = torch.nn.Linear(128, 128)\n        self.lin3 = torch.nn.Linear(128, num_classes)\n\n    def forward(self, data):\n        sa0_out = (data.x, data.pos, data.batch)\n        sa1_out = self.sa1_module(*sa0_out)\n        sa2_out = self.sa2_module(*sa1_out)\n        sa3_out = self.sa3_module(*sa2_out)\n\n        fp3_out = self.fp3_module(*sa3_out, *sa2_out)\n        fp2_out = self.fp2_module(*fp3_out, *sa1_out)\n        x, _, _ = self.fp1_module(*fp2_out, *sa0_out)\n\n        return self.mlp(x).log_softmax(dim=-1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(train_dataset.num_classes).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train():\n    model.train()\n\n    total_loss = correct_nodes = total_nodes = 0\n    for i, data in enumerate(train_loader):\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item()\n        correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()\n        total_nodes += data.num_nodes\n\n        if (i + 1) % 10 == 0:\n            print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '\n                  f'Train Acc: {correct_nodes / total_nodes:.4f}')\n            total_loss = correct_nodes = total_nodes = 0\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    ious, categories = [], []\n    y_map = torch.empty(loader.dataset.num_classes, device=device).long()\n    for data in loader:\n        data = data.to(device)\n        outs = model(data)\n\n        sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()\n        for out, y, category in zip(outs.split(sizes), data.y.split(sizes),\n                                    data.category.tolist()):\n            category = list(ShapeNet.seg_classes.keys())[category]\n            part = ShapeNet.seg_classes[category]\n            part = torch.tensor(part, device=device)\n\n            y_map[part] = torch.arange(part.size(0), device=device)\n\n            iou = jaccard_index(out[:, part].argmax(dim=-1), y_map[y],\n                                num_classes=part.size(0), absent_score=1.0)\n            ious.append(iou)\n\n        categories.append(data.category)\n\n    iou = torch.tensor(ious, device=device)\n    category = torch.cat(categories, dim=0)\n\n    mean_iou = scatter(iou, category, reduce='mean')  # Per-category IoU.\n    return float(mean_iou.mean())  # Global IoU.\n\n\nfor epoch in range(1, 31):\n    train()\n    iou = test(test_loader)\n    print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}')\n"
  },
  {
    "path": "examples/ppi.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom sklearn.metrics import f1_score\n\nfrom torch_geometric.datasets import PPI\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GATConv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI')\ntrain_dataset = PPI(path, split='train')\nval_dataset = PPI(path, split='val')\ntest_dataset = PPI(path, split='test')\ntrain_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)\ntest_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GATConv(train_dataset.num_features, 256, heads=4,\n                             residual=True)\n        self.conv2 = GATConv(4 * 256, 256, heads=4, residual=True)\n        self.conv3 = GATConv(4 * 256, train_dataset.num_classes, heads=6,\n                             concat=False, residual=True)\n\n    def forward(self, x, edge_index):\n        x = F.elu(self.conv1(x, edge_index))\n        x = F.elu(self.conv2(x, edge_index))\n        x = self.conv3(x, edge_index)\n        return x\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net().to(device)\nloss_op = torch.nn.BCEWithLogitsLoss()\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        loss = loss_op(model(data.x, data.edge_index), data.y)\n        total_loss += loss.item() * data.num_graphs\n        loss.backward()\n        optimizer.step()\n    return total_loss / len(train_loader.dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    ys, preds = [], []\n    for data in loader:\n        ys.append(data.y)\n        out = model(data.x.to(device), data.edge_index.to(device))\n        preds.append((out > 0).float().cpu())\n\n    y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()\n    return f1_score(y, pred, average='micro') if pred.sum() > 0 else 0\n\n\ntimes = []\nfor epoch in range(1, 101):\n    start = time.time()\n    loss = train()\n    val_f1 = test(val_loader)\n    test_f1 = test(test_loader)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_f1:.4f}, '\n          f'Test: {test_f1:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/proteins_diff_pool.py",
    "content": "import os.path as osp\nimport time\nfrom math import ceil\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DenseDataLoader\nfrom torch_geometric.nn import DenseSAGEConv, dense_diff_pool\n\nmax_nodes = 150\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',\n                'PROTEINS_dense')\ndataset = TUDataset(\n    path,\n    name='PROTEINS',\n    transform=T.ToDense(max_nodes),\n    pre_filter=lambda data: data.num_nodes <= max_nodes,\n)\ndataset = dataset.shuffle()\nn = (len(dataset) + 9) // 10\ntest_dataset = dataset[:n]\nval_dataset = dataset[n:2 * n]\ntrain_dataset = dataset[2 * n:]\ntest_loader = DenseDataLoader(test_dataset, batch_size=20)\nval_loader = DenseDataLoader(val_dataset, batch_size=20)\ntrain_loader = DenseDataLoader(train_dataset, batch_size=20)\n\n\nclass GNN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels,\n                 normalize=False, lin=True):\n        super().__init__()\n\n        self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize)\n        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)\n        self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize)\n        self.bn2 = torch.nn.BatchNorm1d(hidden_channels)\n        self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize)\n        self.bn3 = torch.nn.BatchNorm1d(out_channels)\n\n        if lin is True:\n            self.lin = torch.nn.Linear(2 * hidden_channels + out_channels,\n                                       out_channels)\n        else:\n            self.lin = None\n\n    def bn(self, i, x):\n        batch_size, num_nodes, num_channels = x.size()\n\n        x = x.view(-1, num_channels)\n        x = getattr(self, f'bn{i}')(x)\n        x = x.view(batch_size, num_nodes, num_channels)\n        return x\n\n    def forward(self, x, adj, mask=None):\n        batch_size, num_nodes, in_channels = x.size()\n\n        x0 = x\n        x1 = self.bn(1, self.conv1(x0, adj, mask).relu())\n        x2 = self.bn(2, self.conv2(x1, adj, mask).relu())\n        x3 = self.bn(3, self.conv3(x2, adj, mask).relu())\n\n        x = torch.cat([x1, x2, x3], dim=-1)\n\n        if self.lin is not None:\n            x = self.lin(x).relu()\n\n        return x\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        num_nodes = ceil(0.25 * max_nodes)\n        self.gnn1_pool = GNN(dataset.num_features, 64, num_nodes)\n        self.gnn1_embed = GNN(dataset.num_features, 64, 64, lin=False)\n\n        num_nodes = ceil(0.25 * num_nodes)\n        self.gnn2_pool = GNN(3 * 64, 64, num_nodes)\n        self.gnn2_embed = GNN(3 * 64, 64, 64, lin=False)\n\n        self.gnn3_embed = GNN(3 * 64, 64, 64, lin=False)\n\n        self.lin1 = torch.nn.Linear(3 * 64, 64)\n        self.lin2 = torch.nn.Linear(64, dataset.num_classes)\n\n    def forward(self, x, adj, mask=None):\n        s = self.gnn1_pool(x, adj, mask)\n        x = self.gnn1_embed(x, adj, mask)\n\n        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask)\n\n        s = self.gnn2_pool(x, adj)\n        x = self.gnn2_embed(x, adj)\n\n        x, adj, l2, e2 = dense_diff_pool(x, adj, s)\n\n        x = self.gnn3_embed(x, adj)\n\n        x = x.mean(dim=1)\n        x = self.lin1(x).relu()\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1), l1 + l2, e1 + e2\n\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nmodel = Net().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train(epoch):\n    model.train()\n    loss_all = 0\n\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        output, _, _ = model(data.x, data.adj, data.mask)\n        loss = F.nll_loss(output, data.y.view(-1))\n        loss.backward()\n        loss_all += data.y.size(0) * float(loss)\n        optimizer.step()\n    return loss_all / len(train_dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n    correct = 0\n\n    for data in loader:\n        data = data.to(device)\n        pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1]\n        correct += int(pred.eq(data.y.view(-1)).sum())\n    return correct / len(loader.dataset)\n\n\nbest_val_acc = test_acc = 0\ntimes = []\nfor epoch in range(1, 151):\n    start = time.time()\n    train_loss = train(epoch)\n    val_acc = test(val_loader)\n    if val_acc > best_val_acc:\n        test_acc = test(test_loader)\n        best_val_acc = val_acc\n    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '\n          f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/proteins_dmon_pool.py",
    "content": "import os.path as osp\nimport time\nfrom math import ceil\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import DenseGraphConv, DMoNPooling, GCNConv\nfrom torch_geometric.utils import to_dense_adj, to_dense_batch\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS')\ndataset = TUDataset(path, name='PROTEINS').shuffle()\navg_num_nodes = int(dataset._data.x.size(0) / len(dataset))\nn = (len(dataset) + 9) // 10\ntest_dataset = dataset[:n]\nval_dataset = dataset[n:2 * n]\ntrain_dataset = dataset[2 * n:]\ntest_loader = DataLoader(test_dataset, batch_size=20)\nval_loader = DataLoader(val_dataset, batch_size=20)\ntrain_loader = DataLoader(train_dataset, batch_size=20)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, hidden_channels=32):\n        super().__init__()\n\n        self.conv1 = GCNConv(in_channels, hidden_channels)\n        num_nodes = ceil(0.5 * avg_num_nodes)\n        self.pool1 = DMoNPooling([hidden_channels, hidden_channels], num_nodes)\n\n        self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)\n        num_nodes = ceil(0.5 * num_nodes)\n        self.pool2 = DMoNPooling([hidden_channels, hidden_channels], num_nodes)\n\n        self.conv3 = DenseGraphConv(hidden_channels, hidden_channels)\n\n        self.lin1 = Linear(hidden_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, out_channels)\n\n    def forward(self, x, edge_index, batch):\n        x = self.conv1(x, edge_index).relu()\n\n        x, mask = to_dense_batch(x, batch)\n        adj = to_dense_adj(edge_index, batch)\n\n        _, x, adj, sp1, _, c1 = self.pool1(x, adj, mask)\n\n        x = self.conv2(x, adj).relu()\n\n        _, x, adj, sp2, _, c2 = self.pool2(x, adj)\n\n        x = self.conv3(x, adj)\n\n        x = x.mean(dim=1)\n        x = self.lin1(x).relu()\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1), sp1 + sp2 + c1 + c2\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(dataset.num_features, dataset.num_classes).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train(train_loader):\n    model.train()\n    loss_all = 0\n\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out, tot_loss = model(data.x, data.edge_index, data.batch)\n        loss = F.nll_loss(out, data.y.view(-1)) + tot_loss\n        loss.backward()\n        loss_all += data.y.size(0) * float(loss.detach())\n        optimizer.step()\n    return loss_all / len(train_dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n    correct = 0\n    loss_all = 0\n    for data in loader:\n        data = data.to(device)\n        pred, tot_loss = model(data.x, data.edge_index, data.batch)\n        loss = F.nll_loss(pred, data.y.view(-1)) + tot_loss\n        loss_all += data.y.size(0) * float(loss.detach())\n        correct += int(pred.max(dim=1)[1].eq(data.y.view(-1)).sum())\n\n    return loss_all / len(loader.dataset), correct / len(loader.dataset)\n\n\ntimes = []\nfor epoch in range(1, 101):\n    start = time.time()\n    train_loss = train(train_loader)\n    _, train_acc = test(train_loader)\n    val_loss, val_acc = test(val_loader)\n    test_loss, test_acc = test(test_loader)\n    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.3f}, '\n          f'Train Acc: {train_acc:.3f}, Val Loss: {val_loss:.3f}, '\n          f'Val Acc: {val_acc:.3f}, Test Loss: {test_loss:.3f}, '\n          f'Test Acc: {test_acc:.3f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/proteins_gmt.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GCNConv, GraphMultisetTransformer\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS')\ndataset = TUDataset(path, name='PROTEINS').shuffle()\n\nn = (len(dataset) + 9) // 10\ntrain_dataset = dataset[2 * n:]\nval_dataset = dataset[n:2 * n]\ntest_dataset = dataset[:n]\n\ntrain_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=128)\ntest_loader = DataLoader(test_dataset, batch_size=128)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.conv1 = GCNConv(dataset.num_features, 32)\n        self.conv2 = GCNConv(32, 32)\n        self.conv3 = GCNConv(32, 32)\n\n        self.pool = GraphMultisetTransformer(96, k=10, heads=4)\n\n        self.lin1 = Linear(96, 16)\n        self.lin2 = Linear(16, dataset.num_classes)\n\n    def forward(self, x0, edge_index, batch):\n        x1 = self.conv1(x0, edge_index).relu()\n        x2 = self.conv2(x1, edge_index).relu()\n        x3 = self.conv3(x2, edge_index).relu()\n        x = torch.cat([x1, x2, x3], dim=-1)\n\n        x = self.pool(x, batch)\n\n        x = self.lin1(x).relu()\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = self.lin2(x)\n\n        return x\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.edge_index, data.batch)\n        loss = F.cross_entropy(out, data.y)\n        loss.backward()\n        total_loss += data.num_graphs * float(loss.detach())\n        optimizer.step()\n    return total_loss / len(train_dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_correct = 0\n    for data in loader:\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.batch)\n        total_correct += int((out.argmax(dim=-1) == data.y).sum())\n    return total_correct / len(loader.dataset)\n\n\ntimes = []\nfor epoch in range(1, 201):\n    start = time.time()\n    train_loss = train()\n    val_acc = test(val_loader)\n    test_acc = test(test_loader)\n    print(f'Epoch: {epoch:03d}, Loss: {train_loss:.4f}, '\n          f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/proteins_mincut_pool.py",
    "content": "import os.path as osp\nimport time\nfrom math import ceil\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import DenseGraphConv, GCNConv, dense_mincut_pool\nfrom torch_geometric.utils import to_dense_adj, to_dense_batch\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS')\ndataset = TUDataset(path, name='PROTEINS').shuffle()\navg_num_nodes = int(dataset._data.x.size(0) / len(dataset))\nn = (len(dataset) + 9) // 10\ntest_dataset = dataset[:n]\nval_dataset = dataset[n:2 * n]\ntrain_dataset = dataset[2 * n:]\ntest_loader = DataLoader(test_dataset, batch_size=20)\nval_loader = DataLoader(val_dataset, batch_size=20)\ntrain_loader = DataLoader(train_dataset, batch_size=20)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, hidden_channels=32):\n        super().__init__()\n\n        self.conv1 = GCNConv(in_channels, hidden_channels)\n        num_nodes = ceil(0.5 * avg_num_nodes)\n        self.pool1 = Linear(hidden_channels, num_nodes)\n\n        self.conv2 = DenseGraphConv(hidden_channels, hidden_channels)\n        num_nodes = ceil(0.5 * num_nodes)\n        self.pool2 = Linear(hidden_channels, num_nodes)\n\n        self.conv3 = DenseGraphConv(hidden_channels, hidden_channels)\n\n        self.lin1 = Linear(hidden_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, out_channels)\n\n    def forward(self, x, edge_index, batch):\n        x = self.conv1(x, edge_index).relu()\n\n        x, mask = to_dense_batch(x, batch)\n        adj = to_dense_adj(edge_index, batch)\n\n        s = self.pool1(x)\n        x, adj, mc1, o1 = dense_mincut_pool(x, adj, s, mask)\n\n        x = self.conv2(x, adj).relu()\n        s = self.pool2(x)\n\n        x, adj, mc2, o2 = dense_mincut_pool(x, adj, s)\n\n        x = self.conv3(x, adj)\n\n        x = x.mean(dim=1)\n        x = self.lin1(x).relu()\n        x = self.lin2(x)\n        return F.log_softmax(x, dim=-1), mc1 + mc2, o1 + o2\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(dataset.num_features, dataset.num_classes).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)\n\n\ndef train(epoch):\n    model.train()\n    loss_all = 0\n\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out, mc_loss, o_loss = model(data.x, data.edge_index, data.batch)\n        loss = F.nll_loss(out, data.y.view(-1)) + mc_loss + o_loss\n        loss.backward()\n        loss_all += data.y.size(0) * float(loss)\n        optimizer.step()\n    return loss_all / len(train_dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n    correct = 0\n    loss_all = 0\n\n    for data in loader:\n        data = data.to(device)\n        pred, mc_loss, o_loss = model(data.x, data.edge_index, data.batch)\n        loss = F.nll_loss(pred, data.y.view(-1)) + mc_loss + o_loss\n        loss_all += data.y.size(0) * float(loss)\n        correct += int(pred.max(dim=1)[1].eq(data.y.view(-1)).sum())\n\n    return loss_all / len(loader.dataset), correct / len(loader.dataset)\n\n\ntimes = []\nbest_val_acc = test_acc = 0\nbest_val_loss = float('inf')\npatience = start_patience = 50\nfor epoch in range(1, 15001):\n    start = time.time()\n    train_loss = train(epoch)\n    _, train_acc = test(train_loader)\n    val_loss, val_acc = test(val_loader)\n    if val_loss < best_val_loss:\n        test_loss, test_acc = test(test_loader)\n        best_val_acc = val_acc\n        patience = start_patience\n    else:\n        patience -= 1\n        if patience == 0:\n            break\n    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.3f}, '\n          f'Train Acc: {train_acc:.3f}, Val Loss: {val_loss:.3f}, '\n          f'Val Acc: {val_acc:.3f}, Test Loss: {test_loss:.3f}, '\n          f'Test Acc: {test_acc:.3f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/proteins_topk_pool.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GraphConv, TopKPooling\nfrom torch_geometric.nn import global_max_pool as gmp\nfrom torch_geometric.nn import global_mean_pool as gap\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PROTEINS')\ndataset = TUDataset(path, name='PROTEINS')\ndataset = dataset.shuffle()\nn = len(dataset) // 10\ntest_dataset = dataset[:n]\ntrain_dataset = dataset[n:]\ntest_loader = DataLoader(test_dataset, batch_size=60)\ntrain_loader = DataLoader(train_dataset, batch_size=60)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.conv1 = GraphConv(dataset.num_features, 128)\n        self.pool1 = TopKPooling(128, ratio=0.8)\n        self.conv2 = GraphConv(128, 128)\n        self.pool2 = TopKPooling(128, ratio=0.8)\n        self.conv3 = GraphConv(128, 128)\n        self.pool3 = TopKPooling(128, ratio=0.8)\n\n        self.lin1 = torch.nn.Linear(256, 128)\n        self.lin2 = torch.nn.Linear(128, 64)\n        self.lin3 = torch.nn.Linear(64, dataset.num_classes)\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n\n        x = F.relu(self.conv1(x, edge_index))\n        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, None, batch)\n        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n\n        x = F.relu(self.conv2(x, edge_index))\n        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, None, batch)\n        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n\n        x = F.relu(self.conv3(x, edge_index))\n        x, edge_index, _, batch, _, _ = self.pool3(x, edge_index, None, batch)\n        x3 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)\n\n        x = x1 + x2 + x3\n\n        x = F.relu(self.lin1(x))\n        x = F.dropout(x, p=0.5, training=self.training)\n        x = F.relu(self.lin2(x))\n        x = F.log_softmax(self.lin3(x), dim=-1)\n\n        return x\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.0005)\n\n\ndef train(epoch):\n    model.train()\n\n    loss_all = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        output = model(data)\n        loss = F.nll_loss(output, data.y)\n        loss.backward()\n        loss_all += data.num_graphs * loss.item()\n        optimizer.step()\n    return loss_all / len(train_dataset)\n\n\ndef test(loader):\n    model.eval()\n\n    correct = 0\n    for data in loader:\n        data = data.to(device)\n        pred = model(data).max(dim=1)[1]\n        correct += pred.eq(data.y).sum().item()\n    return correct / len(loader.dataset)\n\n\nfor epoch in range(1, 201):\n    loss = train(epoch)\n    train_acc = test(train_loader)\n    test_acc = test(test_loader)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.5f}, Train Acc: {train_acc:.5f}, '\n          f'Test Acc: {test_acc:.5f}')\n"
  },
  {
    "path": "examples/pytorch_ignite/README.md",
    "content": "# Examples for PyTorch Ignite\n\nThis directory provides examples showcasing the integration of PyG with [PyTorch Ingite](https://pytorch.org/ignite/index.html).\n\n| Example              | Description                                                      |\n| -------------------- | ---------------------------------------------------------------- |\n| [`gin.py`](./gin.py) | Demonstrates how to implement the GIN model using PyTorch Ignite |\n"
  },
  {
    "path": "examples/pytorch_ignite/gin.py",
    "content": "import os.path as osp\n\nimport ignite\nimport ignite.contrib.handlers.tensorboard_logger\nimport ignite.contrib.handlers.tqdm_logger\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric import seed_everything\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GIN, MLP, global_add_pool\n\n\nclass Model(torch.nn.Module):\n    def __init__(self, in_channels: int, out_channels: int,\n                 hidden_channels: int = 64, num_layers: int = 3,\n                 dropout: float = 0.5):\n        super().__init__()\n\n        self.gnn = GIN(in_channels, hidden_channels, num_layers,\n                       dropout=dropout, jk='cat')\n\n        self.classifier = MLP([hidden_channels, hidden_channels, out_channels],\n                              norm=\"batch_norm\", dropout=dropout)\n\n    def forward(self, data):\n        x = self.gnn(data.x, data.edge_index)\n        x = global_add_pool(x, data.batch)\n        x = self.classifier(x)\n        return x\n\n\ndef main():\n    seed_everything(42)\n\n    root = osp.join('data', 'TUDataset')\n    dataset = TUDataset(root, 'IMDB-BINARY', pre_transform=T.OneHotDegree(135))\n\n    dataset = dataset.shuffle()\n    test_dataset = dataset[:len(dataset) // 10]\n    val_dataset = dataset[len(dataset) // 10:2 * len(dataset) // 10]\n    train_dataset = dataset[2 * len(dataset) // 10:]\n\n    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,\n                              pin_memory=True)\n    val_loader = DataLoader(val_dataset, batch_size=64, pin_memory=True)\n    test_loader = DataLoader(test_dataset, batch_size=64, pin_memory=True)\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    model = Model(dataset.num_node_features, dataset.num_classes).to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n    metrics = {'acc': ignite.metrics.Accuracy()}\n\n    def prepare_batch_fn(batch, device, non_blocking):\n        return (batch.to(device, non_blocking=non_blocking),\n                batch.y.to(device, non_blocking=non_blocking))\n\n    trainer = ignite.engine.create_supervised_trainer(\n        model=model,\n        optimizer=optimizer,\n        loss_fn=F.cross_entropy,\n        device=device,\n        prepare_batch=prepare_batch_fn,\n        output_transform=lambda x, y, y_pred, loss: loss.item(),\n        amp_mode='amp',\n    )\n\n    # Progress bar for each epoch:\n    pbar = ignite.contrib.handlers.tqdm_logger.ProgressBar()\n    pbar.attach(trainer, output_transform=lambda x: {'loss': x})\n\n    def log_metrics(evaluator, loader, tag):\n        def logger(trainer):\n            evaluator.run(loader)\n            print(f'{tag:10} Epoch: {trainer.state.epoch:02d}, '\n                  f'Acc: {evaluator.state.metrics[\"acc\"]:.4f}')\n\n        return logger\n\n    train_evaluator = ignite.engine.create_supervised_evaluator(\n        model=model,\n        metrics=metrics,\n        device=device,\n        prepare_batch=prepare_batch_fn,\n        output_transform=lambda x, y, y_pred: (y_pred, y),\n        amp_mode='amp',\n    )\n    trainer.on(ignite.engine.Events.EPOCH_COMPLETED(every=1))(log_metrics(\n        train_evaluator, train_loader, 'Training'))\n\n    val_evaluator = ignite.engine.create_supervised_evaluator(\n        model=model,\n        metrics=metrics,\n        device=device,\n        prepare_batch=prepare_batch_fn,\n        output_transform=lambda x, y, y_pred: (y_pred, y),\n        amp_mode='amp',\n    )\n    trainer.on(ignite.engine.Events.EPOCH_COMPLETED(every=1))(log_metrics(\n        val_evaluator, val_loader, 'Validation'))\n\n    test_evaluator = ignite.engine.create_supervised_evaluator(\n        model=model,\n        metrics=metrics,\n        device=device,\n        prepare_batch=prepare_batch_fn,\n        output_transform=lambda x, y, y_pred: (y_pred, y),\n        amp_mode='amp',\n    )\n    trainer.on(ignite.engine.Events.EPOCH_COMPLETED(every=1))(log_metrics(\n        test_evaluator, test_loader, 'Test'))\n\n    # Save checkpoint of the model based on Accuracy on the validation set:\n    checkpoint_handler = ignite.handlers.Checkpoint(\n        {'model': model},\n        'runs/gin',\n        n_saved=2,\n        score_name=list(metrics.keys())[0],\n        filename_pattern='best-{global_step}-{score_name}-{score}.pt',\n        global_step_transform=ignite.handlers.global_step_from_engine(trainer),\n    )\n    val_evaluator.add_event_handler(ignite.engine.Events.EPOCH_COMPLETED,\n                                    checkpoint_handler)\n\n    # Create a tensorboard logger to write logs:\n    tb_logger = ignite.contrib.handlers.tensorboard_logger.TensorboardLogger(\n        log_dir=osp.join('runs/example', 'tb_logs'))\n\n    tb_logger.attach_output_handler(\n        trainer, event_name=ignite.engine.Events.ITERATION_COMPLETED,\n        tag='training', output_transform=lambda loss: {'loss_iteration': loss})\n    tb_logger.attach_output_handler(\n        trainer, event_name=ignite.engine.Events.EPOCH_COMPLETED,\n        tag='training', output_transform=lambda loss: {'loss_epoch': loss})\n    tb_logger.attach_output_handler(\n        train_evaluator,\n        event_name=ignite.engine.Events.EPOCH_COMPLETED,\n        tag='training',\n        metric_names='all',\n        global_step_transform=ignite.handlers.global_step_from_engine(trainer),\n    )\n    tb_logger.attach_output_handler(\n        val_evaluator,\n        event_name=ignite.engine.Events.EPOCH_COMPLETED,\n        tag='validation',\n        metric_names='all',\n        global_step_transform=ignite.handlers.global_step_from_engine(trainer),\n    )\n    tb_logger.attach_output_handler(\n        test_evaluator,\n        event_name=ignite.engine.Events.EPOCH_COMPLETED,\n        tag='test',\n        metric_names='all',\n        global_step_transform=ignite.handlers.global_step_from_engine(trainer),\n    )\n    tb_logger.close()\n\n    trainer.run(train_loader, max_epochs=50)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "examples/pytorch_lightning/README.md",
    "content": "# Examples for PyTorch Lightning\n\nThis directory provides examples showcasing the integration of PyG with [PyTorch Lightning](https://github.com/Lightning-AI/pytorch-lightning).\n\n| Example                                    | Description                                                                          |\n| ------------------------------------------ | ------------------------------------------------------------------------------------ |\n| [`graph_sage.py`](./graph_sage.py)         | Combines PyG and PyTorch Lightning for node classification via the `GraphSAGE` model |\n| [`gin.py`](./gin.py)                       | Combines PyG and PyTorch Lightning for graph classification via the `GIN` model      |\n| [`relational_gnn.py`](./relational_gnn.py) | Combines PyG and PyTorch Lightning for heterogeneous node classification             |\n"
  },
  {
    "path": "examples/pytorch_lightning/gin.py",
    "content": "import os.path as osp\n\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn.functional as F\nfrom torchmetrics import Accuracy\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.data.lightning import LightningDataset\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.nn import GIN, MLP, global_add_pool\n\n\nclass Model(pl.LightningModule):\n    def __init__(self, in_channels: int, out_channels: int,\n                 hidden_channels: int = 64, num_layers: int = 3,\n                 dropout: float = 0.5):\n        super().__init__()\n        self.gnn = GIN(in_channels, hidden_channels, num_layers,\n                       dropout=dropout, jk='cat')\n\n        self.classifier = MLP([hidden_channels, hidden_channels, out_channels],\n                              norm=\"batch_norm\", dropout=dropout)\n\n        self.train_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.val_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.test_acc = Accuracy(task='multiclass', num_classes=out_channels)\n\n    def forward(self, x, edge_index, batch):\n        x = self.gnn(x, edge_index)\n        x = global_add_pool(x, batch)\n        x = self.classifier(x)\n        return x\n\n    def training_step(self, data, batch_idx):\n        y_hat = self(data.x, data.edge_index, data.batch)\n        loss = F.cross_entropy(y_hat, data.y)\n        self.train_acc(y_hat.softmax(dim=-1), data.y)\n        self.log('train_acc', self.train_acc, prog_bar=True, on_step=False,\n                 on_epoch=True)\n        return loss\n\n    def validation_step(self, data, batch_idx):\n        y_hat = self(data.x, data.edge_index, data.batch)\n        self.val_acc(y_hat.softmax(dim=-1), data.y)\n        self.log('val_acc', self.val_acc, prog_bar=True, on_step=False,\n                 on_epoch=True)\n\n    def test_step(self, data, batch_idx):\n        y_hat = self(data.x, data.edge_index, data.batch)\n        self.test_acc(y_hat.softmax(dim=-1), data.y)\n        self.log('test_acc', self.test_acc, prog_bar=True, on_step=False,\n                 on_epoch=True)\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=0.01)\n\n\nif __name__ == '__main__':\n    root = osp.join('data', 'TUDataset')\n    dataset = TUDataset(root, 'IMDB-BINARY', pre_transform=T.OneHotDegree(135))\n\n    dataset = dataset.shuffle()\n    test_dataset = dataset[:len(dataset) // 10]\n    val_dataset = dataset[len(dataset) // 10:2 * len(dataset) // 10]\n    train_dataset = dataset[2 * len(dataset) // 10:]\n\n    datamodule = LightningDataset(train_dataset, val_dataset, test_dataset,\n                                  batch_size=64, num_workers=4)\n\n    model = Model(dataset.num_node_features, dataset.num_classes)\n\n    devices = torch.cuda.device_count()\n    strategy = pl.strategies.DDPStrategy(accelerator='gpu')\n    checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_acc', save_top_k=1,\n                                              mode='max')\n    trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=50,\n                         log_every_n_steps=5, callbacks=[checkpoint])\n\n    trainer.fit(model, datamodule)\n    trainer.test(ckpt_path='best', datamodule=datamodule)\n"
  },
  {
    "path": "examples/pytorch_lightning/graph_sage.py",
    "content": "import os.path as osp\n\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import BatchNorm1d\nfrom torchmetrics import Accuracy\n\nfrom torch_geometric.data.lightning import LightningNodeData\nfrom torch_geometric.datasets import Reddit\nfrom torch_geometric.nn import GraphSAGE\n\n\nclass Model(pl.LightningModule):\n    def __init__(self, in_channels: int, out_channels: int,\n                 hidden_channels: int = 256, num_layers: int = 2,\n                 dropout: float = 0.5):\n        super().__init__()\n        self.gnn = GraphSAGE(in_channels, hidden_channels, num_layers,\n                             out_channels, dropout=dropout,\n                             norm=BatchNorm1d(hidden_channels))\n\n        self.train_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.val_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.test_acc = Accuracy(task='multiclass', num_classes=out_channels)\n\n    def forward(self, x, edge_index):\n        return self.gnn(x, edge_index)\n\n    def training_step(self, data, batch_idx):\n        y_hat = self(data.x, data.edge_index)[:data.batch_size]\n        y = data.y[:data.batch_size]\n        loss = F.cross_entropy(y_hat, y)\n        self.train_acc(y_hat.softmax(dim=-1), y)\n        self.log('train_acc', self.train_acc, prog_bar=True, on_step=False,\n                 on_epoch=True)\n        return loss\n\n    def validation_step(self, data, batch_idx):\n        y_hat = self(data.x, data.edge_index)[:data.batch_size]\n        y = data.y[:data.batch_size]\n        self.val_acc(y_hat.softmax(dim=-1), y)\n        self.log('val_acc', self.val_acc, prog_bar=True, on_step=False,\n                 on_epoch=True)\n\n    def test_step(self, data, batch_idx):\n        y_hat = self(data.x, data.edge_index)[:data.batch_size]\n        y = data.y[:data.batch_size]\n        self.test_acc(y_hat.softmax(dim=-1), y)\n        self.log('test_acc', self.test_acc, prog_bar=True, on_step=False,\n                 on_epoch=True)\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=0.01)\n\n\nif __name__ == '__main__':\n    dataset = Reddit(osp.join('data', 'Reddit'))\n    data = dataset[0]\n\n    datamodule = LightningNodeData(\n        data,\n        input_train_nodes=data.train_mask,\n        input_val_nodes=data.val_mask,\n        input_test_nodes=data.test_mask,\n        loader='neighbor',\n        num_neighbors=[25, 10],\n        batch_size=1024,\n        num_workers=8,\n    )\n\n    model = Model(dataset.num_node_features, dataset.num_classes)\n\n    strategy = pl.strategies.SingleDeviceStrategy('cuda:0')\n    checkpoint = pl.callbacks.ModelCheckpoint(monitor='val_acc', save_top_k=1,\n                                              mode='max')\n    trainer = pl.Trainer(strategy=strategy, devices=1, max_epochs=20,\n                         callbacks=[checkpoint])\n\n    trainer.fit(model, datamodule)\n    trainer.test(ckpt_path='best', datamodule=datamodule)\n"
  },
  {
    "path": "examples/pytorch_lightning/relational_gnn.py",
    "content": "import os.path as osp\nfrom typing import Dict, List, Tuple\n\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn.functional as F\nfrom pytorch_lightning import LightningModule, Trainer\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom torch import Tensor\nfrom torchmetrics import Accuracy\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.data import Batch\nfrom torch_geometric.data.lightning import LightningNodeData\nfrom torch_geometric.datasets import OGB_MAG\nfrom torch_geometric.nn import Linear, SAGEConv, to_hetero\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\nclass GNN(torch.nn.Module):\n    def __init__(self, hidden_channels: int, out_channels: int,\n                 dropout: float):\n        super().__init__()\n        self.dropout = torch.nn.Dropout(p=dropout)\n\n        self.conv1 = SAGEConv((-1, -1), hidden_channels)\n        self.conv2 = SAGEConv((-1, -1), hidden_channels)\n        self.lin = Linear(-1, out_channels)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x = self.conv1(x, edge_index).relu()\n        x = self.dropout(x)\n        x = self.conv2(x, edge_index).relu()\n        x = self.dropout(x)\n        return self.lin(x)\n\n\nclass RelationalGNN(LightningModule):\n    def __init__(\n        self,\n        metadata: Tuple[List[NodeType], List[EdgeType]],\n        hidden_channels: int,\n        out_channels: int,\n        dropout: float,\n    ):\n        super().__init__()\n        self.save_hyperparameters()\n\n        model = GNN(hidden_channels, out_channels, dropout)\n        # Convert the homogeneous GNN model to a heterogeneous variant in\n        # which distinct parameters are learned for each node and edge type.\n        self.model = to_hetero(model, metadata, aggr='sum')\n\n        self.train_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.val_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.test_acc = Accuracy(task='multiclass', num_classes=out_channels)\n\n    def forward(\n        self,\n        x_dict: Dict[NodeType, Tensor],\n        edge_index_dict: Dict[EdgeType, Tensor],\n    ) -> Dict[NodeType, Tensor]:\n        return self.model(x_dict, edge_index_dict)\n\n    def common_step(self, batch: Batch) -> Tuple[Tensor, Tensor]:\n        batch_size = batch['paper'].batch_size\n        y_hat = self(batch.x_dict, batch.edge_index_dict)['paper'][:batch_size]\n        y = batch['paper'].y[:batch_size]\n        return y_hat, y\n\n    def training_step(self, batch: Batch, batch_idx: int) -> Tensor:\n        y_hat, y = self.common_step(batch)\n        loss = F.cross_entropy(y_hat, y)\n        self.train_acc(y_hat.softmax(dim=-1), y)\n        self.log('train_acc', self.train_acc, prog_bar=True, on_step=False,\n                 on_epoch=True)\n        return loss\n\n    def validation_step(self, batch: Batch, batch_idx: int):\n        y_hat, y = self.common_step(batch)\n        self.val_acc(y_hat.softmax(dim=-1), y)\n        self.log('val_acc', self.val_acc, prog_bar=True, on_step=False,\n                 on_epoch=True)\n\n    def test_step(self, batch: Batch, batch_idx: int):\n        y_hat, y = self.common_step(batch)\n        self.test_acc(y_hat.softmax(dim=-1), y)\n        self.log('test_acc', self.test_acc, prog_bar=True, on_step=False,\n                 on_epoch=True)\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=0.01)\n\n\ndef main():\n    dataset = OGB_MAG(osp.join('data', 'OGB'), preprocess='metapath2vec',\n                      transform=T.ToUndirected(merge=False))\n    data = dataset[0]\n\n    datamodule = LightningNodeData(\n        data,\n        input_train_nodes=('paper', data['paper'].train_mask),\n        input_val_nodes=('paper', data['paper'].val_mask),\n        input_test_nodes=('paper', data['paper'].test_mask),\n        loader='neighbor',\n        num_neighbors=[10, 10],\n        batch_size=1024,\n        num_workers=8,\n    )\n\n    model = RelationalGNN(data.metadata(), hidden_channels=64,\n                          out_channels=349, dropout=0.0)\n\n    with torch.no_grad():  # Run a dummy forward pass to initialize lazy model\n        loader = datamodule.train_dataloader()\n        batch = next(iter(loader))\n        model.common_step(batch)\n\n    strategy = pl.strategies.SingleDeviceStrategy('cuda:0')\n    checkpoint = ModelCheckpoint(monitor='val_acc', save_top_k=1, mode='max')\n    trainer = Trainer(strategy=strategy, devices=1, max_epochs=20,\n                      log_every_n_steps=5, callbacks=[checkpoint])\n\n    trainer.fit(model, datamodule)\n    trainer.test(ckpt_path='best', datamodule=datamodule)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/qm9_nn_conv.py",
    "content": "import copy\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import GRU, Linear, ReLU, Sequential\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import QM9\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import NNConv, Set2Set\nfrom torch_geometric.utils import remove_self_loops\n\ntarget = 0\ndim = 64\n\n\nclass MyTransform:\n    def __call__(self, data):\n        data = copy.copy(data)\n        data.y = data.y[:, target]  # Specify target.\n        return data\n\n\nclass Complete:\n    def __call__(self, data):\n        data = copy.copy(data)\n        device = data.edge_index.device\n\n        row = torch.arange(data.num_nodes, dtype=torch.long, device=device)\n        col = torch.arange(data.num_nodes, dtype=torch.long, device=device)\n\n        row = row.view(-1, 1).repeat(1, data.num_nodes).view(-1)\n        col = col.repeat(data.num_nodes)\n        edge_index = torch.stack([row, col], dim=0)\n\n        edge_attr = None\n        if data.edge_attr is not None:\n            idx = data.edge_index[0] * data.num_nodes + data.edge_index[1]\n            size = list(data.edge_attr.size())\n            size[0] = data.num_nodes * data.num_nodes\n            edge_attr = data.edge_attr.new_zeros(size)\n            edge_attr[idx] = data.edge_attr\n\n        edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)\n        data.edge_attr = edge_attr\n        data.edge_index = edge_index\n\n        return data\n\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9')\ntransform = T.Compose([MyTransform(), Complete(), T.Distance(norm=False)])\ndataset = QM9(path, transform=transform).shuffle()\n\n# Normalize targets to mean = 0 and std = 1.\nmean = dataset.data.y.mean(dim=0, keepdim=True)\nstd = dataset.data.y.std(dim=0, keepdim=True)\ndataset.data.y = (dataset.data.y - mean) / std\nmean, std = mean[:, target].item(), std[:, target].item()\n\n# Split datasets.\ntest_dataset = dataset[:10000]\nval_dataset = dataset[10000:20000]\ntrain_dataset = dataset[20000:]\ntest_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)\nval_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)\ntrain_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lin0 = torch.nn.Linear(dataset.num_features, dim)\n\n        nn = Sequential(Linear(5, 128), ReLU(), Linear(128, dim * dim))\n        self.conv = NNConv(dim, dim, nn, aggr='mean')\n        self.gru = GRU(dim, dim)\n\n        self.set2set = Set2Set(dim, processing_steps=3)\n        self.lin1 = torch.nn.Linear(2 * dim, dim)\n        self.lin2 = torch.nn.Linear(dim, 1)\n\n    def forward(self, data):\n        out = F.relu(self.lin0(data.x))\n        h = out.unsqueeze(0)\n\n        for _ in range(3):\n            m = F.relu(self.conv(out, data.edge_index, data.edge_attr))\n            out, h = self.gru(m.unsqueeze(0), h)\n            out = out.squeeze(0)\n\n        out = self.set2set(out, data.batch)\n        out = F.relu(self.lin1(out))\n        out = self.lin2(out)\n        return out.view(-1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\nscheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min',\n                                                       factor=0.7, patience=5,\n                                                       min_lr=0.00001)\n\n\ndef train(epoch):\n    model.train()\n    loss_all = 0\n\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        loss = F.mse_loss(model(data), data.y)\n        loss.backward()\n        loss_all += loss.item() * data.num_graphs\n        optimizer.step()\n    return loss_all / len(train_loader.dataset)\n\n\ndef test(loader):\n    model.eval()\n    error = 0\n\n    for data in loader:\n        data = data.to(device)\n        error += (model(data) * std - data.y * std).abs().sum().item()  # MAE\n    return error / len(loader.dataset)\n\n\nbest_val_error = None\nfor epoch in range(1, 301):\n    lr = scheduler.optimizer.param_groups[0]['lr']\n    loss = train(epoch)\n    val_error = test(val_loader)\n    scheduler.step(val_error)\n\n    if best_val_error is None or val_error <= best_val_error:\n        test_error = test(test_loader)\n        best_val_error = val_error\n\n    print(f'Epoch: {epoch:03d}, LR: {lr:7f}, Loss: {loss:.7f}, '\n          f'Val MAE: {val_error:.7f}, Test MAE: {test_error:.7f}')\n"
  },
  {
    "path": "examples/qm9_pretrained_dimenet.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\n\nfrom torch_geometric.datasets import QM9\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import DimeNet, DimeNetPlusPlus\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--use_dimenet_plus_plus', action='store_true')\nargs = parser.parse_args()\n\nModel = DimeNetPlusPlus if args.use_dimenet_plus_plus else DimeNet\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9')\ndataset = QM9(path)\n\n# DimeNet uses the atomization energy for targets U0, U, H, and G, i.e.:\n# 7 -> 12, 8 -> 13, 9 -> 14, 10 -> 15\nidx = torch.tensor([0, 1, 2, 3, 4, 5, 6, 12, 13, 14, 15, 11])\ndataset.data.y = dataset.data.y[:, idx]\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\nfor target in range(12):\n    # Skip target \\delta\\epsilon, since it can be computed via\n    # \\epsilon_{LUMO} - \\epsilon_{HOMO}:\n    if target == 4:\n        continue\n\n    model, datasets = Model.from_qm9_pretrained(path, dataset, target)\n    train_dataset, val_dataset, test_dataset = datasets\n\n    model = model.to(device)\n    loader = DataLoader(test_dataset, batch_size=256)\n\n    maes = []\n    for data in loader:\n        data = data.to(device)\n        with torch.no_grad():\n            pred = model(data.z, data.pos, data.batch)\n        mae = (pred.view(-1) - data.y[:, target]).abs()\n        maes.append(mae)\n\n    mae = torch.cat(maes, dim=0)\n\n    # Report meV instead of eV:\n    mae = 1000 * mae if target in [2, 3, 4, 6, 7, 8, 9, 10] else mae\n\n    print(f'Target: {target:02d}, MAE: {mae.mean():.5f} ± {mae.std():.5f}')\n"
  },
  {
    "path": "examples/qm9_pretrained_schnet.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import QM9\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import SchNet\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--cutoff', type=float, default=10.0,\n                    help='Cutoff distance for interatomic interactions')\nargs = parser.parse_args()\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'QM9')\ndataset = QM9(path)\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\nfor target in range(12):\n    model, datasets = SchNet.from_qm9_pretrained(path, dataset, target)\n    train_dataset, val_dataset, test_dataset = datasets\n\n    model = model.to(device)\n    loader = DataLoader(test_dataset, batch_size=256)\n\n    maes = []\n    for data in tqdm(loader):\n        data = data.to(device)\n        with torch.no_grad():\n            pred = model(data.z, data.pos, data.batch)\n        mae = (pred.view(-1) - data.y[:, target]).abs()\n        maes.append(mae)\n\n    mae = torch.cat(maes, dim=0)\n\n    # Report meV instead of eV.\n    mae = 1000 * mae if target in [2, 3, 4, 6, 7, 8, 9, 10] else mae\n\n    print(f'Target: {target:02d}, MAE: {mae.mean():.5f} ± {mae.std():.5f}')\n"
  },
  {
    "path": "examples/quiver/README.md",
    "content": "# Using Quiver for PyG Examples\n\n**[Quiver](https://github.com/quiver-team/torch-quiver)** is a **GPU-optimized distributed library** for PyG.\nIt can speed up graph sampling and feature aggregation through GPU when running PyG examples.\n\n## Installation\n\nAssuming you have installed PyTorch and PyG, you can install Quiver as follows:\n\n```bash\npip install torch-quiver>=0.1.1\n```\n\n## Usage\n\nThe API and design documentation of Quiver can be found [here](https://github.com/quiver-team/torch-quiver).\n\n## Examples\n\nWe provide several examples to showcase the usage of Quiver within PyG:\n\n### Single-GPU Training\n\nThe single-GPU example leverages Quiver's ability of **(i)** GPU-based graph sampling and feature aggregation, and **(ii)** GNN data caching algorithm (which cache hot data in GPU memory) while enabling fast access to CPU data using a Quiver shared tensor implementation:\n\n```bash\npython single_gpu_quiver.py\n```\n\n### Multi-GPU Training\n\nThe multi-GPU example further leverages Quiver's ability of **(i)** distributing sampling and feature aggregation to multiple GPUs, and **(ii)** using multi-GPU memories to cache and replicate hot GNN data:\n\n```bash\npython multi_gpu_quiver.py\n```\n\n### Distributed Training\n\nA Quiver-based distributed PyG example is coming soon.\n"
  },
  {
    "path": "examples/quiver/multi_gpu_quiver.py",
    "content": "# This script shows how to use Quiver in an existing PyG example:\n# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/multi_gpu/distributed_sampling.py\nimport os\nfrom math import ceil\n\nimport quiver\nimport torch\nimport torch.distributed as dist\nimport torch.multiprocessing as mp\nimport torch.nn.functional as F\nfrom torch.nn.parallel import DistributedDataParallel\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import Reddit\nfrom torch_geometric.loader import NeighborSampler\nfrom torch_geometric.nn import SAGEConv\n\n\nclass SAGE(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels,\n                 num_layers=2):\n        super().__init__()\n        self.num_layers = num_layers\n\n        self.convs = torch.nn.ModuleList()\n        self.convs.append(SAGEConv(in_channels, hidden_channels))\n        for _ in range(self.num_layers - 2):\n            self.convs.append(SAGEConv(hidden_channels, hidden_channels))\n        self.convs.append(SAGEConv(hidden_channels, out_channels))\n\n    def forward(self, x, adjs):\n        for i, (edge_index, _, size) in enumerate(adjs):\n            x_target = x[:size[1]]  # Target nodes are always placed first.\n            x = self.convs[i]((x, x_target), edge_index)\n            if i != self.num_layers - 1:\n                x = F.relu(x)\n                x = F.dropout(x, p=0.5, training=self.training)\n        return x.log_softmax(dim=-1)\n\n    @torch.no_grad()\n    def inference(self, x_all, device, subgraph_loader):\n        pbar = tqdm(total=x_all.size(0) * self.num_layers)\n        pbar.set_description('Evaluating')\n\n        for i in range(self.num_layers):\n            xs = []\n            for batch_size, n_id, adj in subgraph_loader:\n                edge_index, _, size = adj.to(device)\n                x = x_all[n_id].to(device)\n                x_target = x[:size[1]]\n                x = self.convs[i]((x, x_target), edge_index)\n                if i != self.num_layers - 1:\n                    x = F.relu(x)\n                xs.append(x.cpu())\n\n                pbar.update(batch_size)\n\n            x_all = torch.cat(xs, dim=0)\n\n        pbar.close()\n\n        return x_all\n\n\ndef run(rank, world_size, dataset, quiver_feature, quiver_sampler):\n    os.environ['MASTER_ADDR'] = 'localhost'\n    os.environ['MASTER_PORT'] = '12355'\n    dist.init_process_group('nccl', rank=rank, world_size=world_size)\n    torch.cuda.set_device(rank)\n\n    data = dataset[0]\n    train_idx = data.train_mask.nonzero(as_tuple=False).view(-1)\n    train_idx = train_idx.split(ceil(train_idx.size(0) / world_size))[rank]\n\n    train_loader = torch.utils.data.DataLoader(train_idx, batch_size=1024,\n                                               shuffle=True, num_workers=0)\n\n    if rank == 0:\n        subgraph_loader = NeighborSampler(data.edge_index, node_idx=None,\n                                          sizes=[-1], batch_size=2048,\n                                          shuffle=False, num_workers=6)\n\n    torch.manual_seed(12345)\n    model = SAGE(dataset.num_features, 256, dataset.num_classes).to(rank)\n    model = DistributedDataParallel(model, device_ids=[rank])\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n    y = data.y.to(rank)\n\n    for epoch in range(1, 21):\n        model.train()\n\n        for seeds in train_loader:\n            n_id, batch_size, adjs = quiver_sampler.sample(seeds)\n            adjs = [adj.to(rank) for adj in adjs]\n\n            optimizer.zero_grad()\n            out = model(quiver_feature[n_id].to(rank), adjs)\n            loss = F.nll_loss(out, y[n_id[:batch_size]])\n            loss.backward()\n            optimizer.step()\n\n        dist.barrier()\n\n        if rank == 0:\n            print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')\n\n        if rank == 0 and epoch % 5 == 0:  # We evaluate on a single GPU for now\n            model.eval()\n            with torch.no_grad():\n                out = model.module.inference(quiver_feature, rank,\n                                             subgraph_loader)\n            res = out.argmax(dim=-1) == data.y\n            acc1 = int(res[data.train_mask].sum()) / int(data.train_mask.sum())\n            acc2 = int(res[data.val_mask].sum()) / int(data.val_mask.sum())\n            acc3 = int(res[data.test_mask].sum()) / int(data.test_mask.sum())\n            print(f'Train: {acc1:.4f}, Val: {acc2:.4f}, Test: {acc3:.4f}')\n\n        dist.barrier()\n\n    dist.destroy_process_group()\n\n\nif __name__ == '__main__':\n    dataset = Reddit('../../data/Reddit')\n    data = dataset[0]\n    world_size = torch.cuda.device_count()\n    print('Let\\'s use', world_size, 'GPUs!')\n\n    ########################################################################\n    # The below code enable Quiver for PyG.\n    # Please refer to: https://torch-quiver.readthedocs.io/en/latest/api/ for\n    # how to configure the CSRTopo, Sampler and Feature of Quiver.\n    ########################################################################\n    csr_topo = quiver.CSRTopo(data.edge_index)  # Quiver\n    quiver_sampler = quiver.pyg.GraphSageSampler(csr_topo, [25, 10], 0,\n                                                 mode='GPU')  # Quiver\n    quiver_feature = quiver.Feature(rank=0,\n                                    device_list=list(range(world_size)),\n                                    device_cache_size=\"2G\",\n                                    cache_policy=\"device_replicate\",\n                                    csr_topo=csr_topo)  # Quiver\n    quiver_feature.from_cpu_tensor(data.x)  # Quiver\n\n    mp.spawn(run, args=(world_size, dataset, quiver_feature, quiver_sampler),\n             nprocs=world_size, join=True)\n"
  },
  {
    "path": "examples/quiver/single_gpu_quiver.py",
    "content": "# This script shows how to use Quiver in an existing PyG example:\n# https://github.com/pyg-team/pytorch_geometric/blob/master/examples/reddit.py\nimport os.path as osp\n\nimport quiver\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import Reddit\nfrom torch_geometric.loader import NeighborSampler\nfrom torch_geometric.nn import SAGEConv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')\ndataset = Reddit(path)\ndata = dataset[0]\n\ntrain_idx = data.train_mask.nonzero(as_tuple=False).view(-1)\n\n################################\n# Step 1: Using Quiver's sampler\n################################\n\ntrain_loader = torch.utils.data.DataLoader(train_idx, batch_size=1024,\n                                           shuffle=True,\n                                           drop_last=True)  # Quiver\n########################################################################\n# The below code enable Quiver for PyG.\n# Please refer to: https://torch-quiver.readthedocs.io/en/latest/api/ for\n# how to configure the CSRTopo, Sampler and Feature of Quiver.\n########################################################################\ncsr_topo = quiver.CSRTopo(data.edge_index)  # Quiver\nquiver_sampler = quiver.pyg.GraphSageSampler(csr_topo, sizes=[25, 10],\n                                             device=0, mode='GPU')  # Quiver\n\nsubgraph_loader = NeighborSampler(data.edge_index, node_idx=None, sizes=[-1],\n                                  batch_size=1024, shuffle=False,\n                                  num_workers=12)\n\n\nclass SAGE(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n\n        self.num_layers = 2\n\n        self.convs = torch.nn.ModuleList()\n        self.convs.append(SAGEConv(in_channels, hidden_channels))\n        self.convs.append(SAGEConv(hidden_channels, out_channels))\n\n    def forward(self, x, adjs):\n        # `train_loader` computes the k-hop neighborhood of a batch of nodes,\n        # and returns, for each layer, a bipartite graph object, holding the\n        # bipartite edges `edge_index`, the index `e_id` of the original edges,\n        # and the size/shape `size` of the bipartite graph.\n        # Target nodes are also included in the source nodes so that one can\n        # easily apply skip-connections or add self-loops.\n        for i, (edge_index, _, size) in enumerate(adjs):\n            x_target = x[:size[1]]  # Target nodes are always placed first.\n            x = self.convs[i]((x, x_target), edge_index)\n            if i != self.num_layers - 1:\n                x = F.relu(x)\n                x = F.dropout(x, p=0.5, training=self.training)\n        return x.log_softmax(dim=-1)\n\n    def inference(self, x_all):\n        pbar = tqdm(total=x_all.size(0) * self.num_layers)\n        pbar.set_description('Evaluating')\n\n        # Compute representations of nodes layer by layer, using *all*\n        # available edges. This leads to faster computation in contrast to\n        # immediately computing the final representations of each batch.\n        for i in range(self.num_layers):\n            xs = []\n            for batch_size, n_id, adj in subgraph_loader:\n                edge_index, _, size = adj.to(device)\n                x = x_all[n_id].to(device)\n                x_target = x[:size[1]]\n                x = self.convs[i]((x, x_target), edge_index)\n                if i != self.num_layers - 1:\n                    x = F.relu(x)\n                xs.append(x.cpu())\n\n                pbar.update(batch_size)\n\n            x_all = torch.cat(xs, dim=0)\n\n        pbar.close()\n\n        return x_all\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = SAGE(dataset.num_features, 256, dataset.num_classes)\nmodel = model.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n################################\n# Step 2: Using Quiver's Feature\n################################\nx = quiver.Feature(rank=0, device_list=[0], device_cache_size=\"4G\",\n                   cache_policy=\"device_replicate\",\n                   csr_topo=csr_topo)  # Quiver\nx.from_cpu_tensor(data.x)  # Quiver\n\ny = data.y.squeeze().to(device)\n\n\ndef train(epoch):\n    model.train()\n\n    pbar = tqdm(total=int(data.train_mask.sum()))\n    pbar.set_description(f'Epoch {epoch:02d}')\n\n    total_loss = total_correct = 0\n    ############################################\n    # Step 3: Training the PyG Model with Quiver\n    ############################################\n    # for batch_size, n_id, adjs in train_loader: # Original PyG Code\n    for seeds in train_loader:  # Quiver\n        n_id, batch_size, adjs = quiver_sampler.sample(seeds)  # Quiver\n        # `adjs` holds a list of `(edge_index, e_id, size)` tuples.\n        adjs = [adj.to(device) for adj in adjs]\n\n        optimizer.zero_grad()\n        out = model(x[n_id], adjs)\n        loss = F.nll_loss(out, y[n_id[:batch_size]])\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss)\n        total_correct += int(out.argmax(dim=-1).eq(y[n_id[:batch_size]]).sum())\n        pbar.update(batch_size)\n\n    pbar.close()\n\n    loss = total_loss / len(train_loader)\n    approx_acc = total_correct / int(data.train_mask.sum())\n\n    return loss, approx_acc\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n\n    out = model.inference(x)\n\n    y_true = y.cpu().unsqueeze(-1)\n    y_pred = out.argmax(dim=-1, keepdim=True)\n\n    results = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        results += [int(y_pred[mask].eq(y_true[mask]).sum()) / int(mask.sum())]\n\n    return results\n\n\nfor epoch in range(1, 11):\n    loss, acc = train(epoch)\n    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')\n    train_acc, val_acc, test_acc = test()\n    print(f'Train: {train_acc:.4f}, Val: {val_acc:.4f}, '\n          f'Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/randlanet_classification.py",
    "content": "\"\"\"An adaptation of RandLA-Net to the classification task, which was not\naddressed in the `\"RandLA-Net: Efficient Semantic Segmentation of Large-Scale\nPoint Clouds\" <https://arxiv.org/abs/1911.11236>`_ paper.\n\"\"\"\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Linear\nfrom tqdm import tqdm\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import ModelNet\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import MLP\nfrom torch_geometric.nn.aggr import MaxAggregation\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.pool import knn_graph\nfrom torch_geometric.nn.pool.decimation import decimation_indices\nfrom torch_geometric.typing import WITH_TORCH_CLUSTER\nfrom torch_geometric.utils import softmax\n\nif not WITH_TORCH_CLUSTER:\n    quit(\"This example requires 'torch-cluster'\")\n\n# Default activation and batch norm parameters used by RandLA-Net:\nlrelu02_kwargs = {'negative_slope': 0.2}\nbn099_kwargs = {'momentum': 0.01, 'eps': 1e-6}\n\n\nclass SharedMLP(MLP):\n    \"\"\"SharedMLP following RandLA-Net paper.\"\"\"\n    def __init__(self, *args, **kwargs):\n        # BN + Act always active even at last layer.\n        kwargs['plain_last'] = False\n        # LeakyRelu with 0.2 slope by default.\n        kwargs['act'] = kwargs.get('act', 'LeakyReLU')\n        kwargs['act_kwargs'] = kwargs.get('act_kwargs', lrelu02_kwargs)\n        # BatchNorm with 1 - 0.99 = 0.01 momentum\n        # and 1e-6 eps by default (tensorflow momentum != pytorch momentum)\n        kwargs['norm_kwargs'] = kwargs.get('norm_kwargs', bn099_kwargs)\n        super().__init__(*args, **kwargs)\n\n\nclass LocalFeatureAggregation(MessagePassing):\n    \"\"\"Positional encoding of points in a neighborhood.\"\"\"\n    def __init__(self, channels):\n        super().__init__(aggr='add')\n        self.mlp_encoder = SharedMLP([10, channels // 2])\n        self.mlp_attention = SharedMLP([channels, channels], bias=False,\n                                       act=None, norm=None)\n        self.mlp_post_attention = SharedMLP([channels, channels])\n\n    def forward(self, edge_index, x, pos):\n        out = self.propagate(edge_index, x=x, pos=pos)  # N, d_out\n        out = self.mlp_post_attention(out)  # N, d_out\n        return out\n\n    def message(self, x_j: Tensor, pos_i: Tensor, pos_j: Tensor,\n                index: Tensor) -> Tensor:\n        \"\"\"Local Spatial Encoding (locSE) and attentive pooling of features.\n\n        Args:\n            x_j (Tensor): neighboors features (K,d)\n            pos_i (Tensor): centroid position (repeated) (K,3)\n            pos_j (Tensor): neighboors positions (K,3)\n            index (Tensor): index of centroid positions\n                (e.g. [0,...,0,1,...,1,...,N,...,N])\n\n        Returns:\n            (Tensor): locSE weighted by feature attention scores.\n\n        \"\"\"\n        # Encode local neighborhood structural information\n        pos_diff = pos_j - pos_i\n        distance = torch.sqrt((pos_diff * pos_diff).sum(1, keepdim=True))\n        relative_infos = torch.cat([pos_i, pos_j, pos_diff, distance],\n                                   dim=1)  # N * K, d\n        local_spatial_encoding = self.mlp_encoder(relative_infos)  # N * K, d\n        local_features = torch.cat([x_j, local_spatial_encoding],\n                                   dim=1)  # N * K, 2d\n\n        # Attention will weight the different features of x\n        # along the neighborhood dimension.\n        att_features = self.mlp_attention(local_features)  # N * K, d_out\n        att_scores = softmax(att_features, index=index)  # N * K, d_out\n\n        return att_scores * local_features  # N * K, d_out\n\n\nclass DilatedResidualBlock(torch.nn.Module):\n    def __init__(\n        self,\n        num_neighbors,\n        d_in: int,\n        d_out: int,\n    ):\n        super().__init__()\n        self.num_neighbors = num_neighbors\n        self.d_in = d_in\n        self.d_out = d_out\n\n        # MLP on input\n        self.mlp1 = SharedMLP([d_in, d_out // 8])\n        # MLP on input, and the result is summed with the output of mlp2\n        self.shortcut = SharedMLP([d_in, d_out], act=None)\n        # MLP on output\n        self.mlp2 = SharedMLP([d_out // 2, d_out], act=None)\n\n        self.lfa1 = LocalFeatureAggregation(d_out // 4)\n        self.lfa2 = LocalFeatureAggregation(d_out // 2)\n\n        self.lrelu = torch.nn.LeakyReLU(**lrelu02_kwargs)\n\n    def forward(self, x, pos, batch):\n        edge_index = knn_graph(pos, self.num_neighbors, batch=batch, loop=True)\n\n        shortcut_of_x = self.shortcut(x)  # N, d_out\n        x = self.mlp1(x)  # N, d_out//8\n        x = self.lfa1(edge_index, x, pos)  # N, d_out//2\n        x = self.lfa2(edge_index, x, pos)  # N, d_out//2\n        x = self.mlp2(x)  # N, d_out\n        x = self.lrelu(x + shortcut_of_x)  # N, d_out\n\n        return x, pos, batch\n\n\ndef decimate(tensors, ptr: Tensor, decimation_factor: int):\n    \"\"\"Decimates each element of the given tuple of tensors.\"\"\"\n    idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor)\n    tensors_decim = tuple(tensor[idx_decim] for tensor in tensors)\n    return tensors_decim, ptr_decim\n\n\nclass Net(torch.nn.Module):\n    def __init__(\n        self,\n        num_features,\n        num_classes,\n        decimation: int = 4,\n        num_neighboors: int = 16,\n        return_logits: bool = False,\n    ):\n        super().__init__()\n        self.decimation = decimation\n        # An option to return logits instead of log probabilities:\n        self.return_logits = return_logits\n        self.fc0 = Linear(in_features=num_features, out_features=8)\n        # 2 DilatedResidualBlock converges better than 4 on ModelNet.\n        self.block1 = DilatedResidualBlock(num_neighboors, 8, 32)\n        self.block2 = DilatedResidualBlock(num_neighboors, 32, 128)\n        self.mlp1 = SharedMLP([128, 128])\n        self.max_agg = MaxAggregation()\n        self.mlp_classif = SharedMLP([128, 32], dropout=[0.5])\n        self.fc_classif = Linear(32, num_classes)\n\n    def forward(self, x, pos, batch, ptr):\n        x = x if x is not None else pos\n        b1 = self.block1(self.fc0(x), pos, batch)\n        b1_decimated, ptr1 = decimate(b1, ptr, self.decimation)\n\n        b2 = self.block2(*b1_decimated)\n        b2_decimated, _ = decimate(b2, ptr1, self.decimation)\n\n        x = self.mlp1(b2_decimated[0])\n        x = self.max_agg(x, b2_decimated[2])\n\n        x = self.mlp_classif(x)\n        logits = self.fc_classif(x)\n\n        return logits if self.return_logits else logits.log_softmax(dim=-1)\n\n\ndef train(epoch):\n    model.train()\n\n    total_loss = 0\n    for data in tqdm(train_loader):\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.pos, data.batch, data.ptr)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += data.num_graphs * float(loss)\n    return total_loss / len(train_loader.dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    correct = 0\n    for data in loader:\n        data = data.to(device)\n        out = model(data.x, data.pos, data.batch, data.ptr)\n        correct += int((out.argmax(dim=-1) == data.y).sum())\n    return correct / len(loader.dataset)\n\n\nif __name__ == '__main__':\n    path = osp.dirname(osp.realpath(__file__))\n    path = osp.join(path, '..', 'data/ModelNet10')\n    pre_transform, transform = T.NormalizeScale(), T.SamplePoints(1024)\n    train_dataset = ModelNet(path, '10', True, transform, pre_transform)\n    test_dataset = ModelNet(path, '10', False, transform, pre_transform)\n    train_loader = DataLoader(train_dataset, 32, shuffle=True, num_workers=6)\n    test_loader = DataLoader(test_dataset, 32, shuffle=False, num_workers=6)\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    model = Net(3, train_dataset.num_classes).to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20,\n                                                gamma=0.5)\n\n    for epoch in range(1, 201):\n        loss = train(epoch)\n        test_acc = test(test_loader)\n        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Test: {test_acc:.4f}')\n        scheduler.step()\n"
  },
  {
    "path": "examples/randlanet_segmentation.py",
    "content": "\"\"\"An implementation of RandLA-Net based on the `\"RandLA-Net: Efficient\nSemantic Segmentation of Large-Scale Point Clouds\"\n<https://arxiv.org/abs/1911.11236>`_ paper.\n\"\"\"\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom randlanet_classification import DilatedResidualBlock, SharedMLP, decimate\nfrom torch.nn import Linear\nfrom torchmetrics.functional import jaccard_index\nfrom tqdm import tqdm\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import ShapeNet\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import knn_interpolate\nfrom torch_geometric.typing import WITH_TORCH_CLUSTER\nfrom torch_geometric.utils import scatter\n\nif not WITH_TORCH_CLUSTER:\n    quit(\"This example requires 'torch-cluster'\")\n\ncategory = 'Airplane'  # Pass in `None` to train on all categories.\ncategory_num_classes = 4  # 4 for Airplane - see ShapeNet for details\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'ShapeNet')\n\ntransform = T.Compose([\n    T.RandomJitter(0.01),\n    T.RandomRotate(15, axis=0),\n    T.RandomRotate(15, axis=1),\n    T.RandomRotate(15, axis=2),\n])\npre_transform = T.NormalizeScale()\n\ntrain_dataset = ShapeNet(\n    path,\n    category,\n    split='trainval',\n    transform=transform,\n    pre_transform=pre_transform,\n)\ntest_dataset = ShapeNet(\n    path,\n    category,\n    split='test',\n    pre_transform=pre_transform,\n)\n\ntrain_loader = DataLoader(train_dataset, 12, shuffle=True, num_workers=6)\ntest_loader = DataLoader(test_dataset, 12, shuffle=False, num_workers=6)\n\n\nclass FPModule(torch.nn.Module):\n    \"\"\"Upsampling with a skip connection.\"\"\"\n    def __init__(self, k, nn):\n        super().__init__()\n        self.k = k\n        self.nn = nn\n\n    def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):\n        x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)\n        x = torch.cat([x, x_skip], dim=1)\n        x = self.nn(x)\n        return x, pos_skip, batch_skip\n\n\nclass Net(torch.nn.Module):\n    def __init__(\n        self,\n        num_features: int,\n        num_classes: int,\n        decimation: int = 4,\n        num_neighbors: int = 16,\n        return_logits: bool = False,\n    ):\n        super().__init__()\n\n        self.decimation = decimation\n        # An option to return logits instead of log probabilities:\n        self.return_logits = return_logits\n\n        # Authors use 8, which is a bottleneck\n        # for the final MLP, and also when num_classes>8\n        # or num_features>8.\n        d_bottleneck = max(32, num_classes, num_features)\n\n        self.fc0 = Linear(num_features, d_bottleneck)\n        self.block1 = DilatedResidualBlock(num_neighbors, d_bottleneck, 32)\n        self.block2 = DilatedResidualBlock(num_neighbors, 32, 128)\n        self.block3 = DilatedResidualBlock(num_neighbors, 128, 256)\n        self.block4 = DilatedResidualBlock(num_neighbors, 256, 512)\n        self.mlp_summit = SharedMLP([512, 512])\n        self.fp4 = FPModule(1, SharedMLP([512 + 256, 256]))\n        self.fp3 = FPModule(1, SharedMLP([256 + 128, 128]))\n        self.fp2 = FPModule(1, SharedMLP([128 + 32, 32]))\n        self.fp1 = FPModule(1, SharedMLP([32 + 32, d_bottleneck]))\n        self.mlp_classif = SharedMLP([d_bottleneck, 64, 32],\n                                     dropout=[0.0, 0.5])\n        self.fc_classif = Linear(32, num_classes)\n\n    def forward(self, x, pos, batch, ptr):\n        x = x if x is not None else pos\n\n        b1_out = self.block1(self.fc0(x), pos, batch)\n        b1_out_decimated, ptr1 = decimate(b1_out, ptr, self.decimation)\n\n        b2_out = self.block2(*b1_out_decimated)\n        b2_out_decimated, ptr2 = decimate(b2_out, ptr1, self.decimation)\n\n        b3_out = self.block3(*b2_out_decimated)\n        b3_out_decimated, ptr3 = decimate(b3_out, ptr2, self.decimation)\n\n        b4_out = self.block4(*b3_out_decimated)\n        b4_out_decimated, _ = decimate(b4_out, ptr3, self.decimation)\n\n        mlp_out = (\n            self.mlp_summit(b4_out_decimated[0]),\n            b4_out_decimated[1],\n            b4_out_decimated[2],\n        )\n\n        fp4_out = self.fp4(*mlp_out, *b3_out_decimated)\n        fp3_out = self.fp3(*fp4_out, *b2_out_decimated)\n        fp2_out = self.fp2(*fp3_out, *b1_out_decimated)\n        fp1_out = self.fp1(*fp2_out, *b1_out)\n\n        x = self.mlp_classif(fp1_out[0])\n        logits = self.fc_classif(x)\n\n        if self.return_logits:\n            return logits\n\n        probas = logits.log_softmax(dim=-1)\n        return probas\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(3, category_num_classes).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n\n    total_loss = correct_nodes = total_nodes = 0\n    for i, data in tqdm(enumerate(train_loader)):\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.pos, data.batch, data.ptr)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += loss.item()\n        correct_nodes += out.argmax(dim=1).eq(data.y).sum().item()\n        total_nodes += data.num_nodes\n\n        if (i + 1) % 10 == 0:\n            print(f'[{i+1}/{len(train_loader)}] Loss: {total_loss / 10:.4f} '\n                  f'Train Acc: {correct_nodes / total_nodes:.4f}')\n            total_loss = correct_nodes = total_nodes = 0\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    ious, categories = [], []\n    y_map = torch.empty(loader.dataset.num_classes, device=device).long()\n    for data in loader:\n        data = data.to(device)\n        outs = model(data.x, data.pos, data.batch, data.ptr)\n\n        sizes = (data.ptr[1:] - data.ptr[:-1]).tolist()\n        for out, y, category in zip(outs.split(sizes), data.y.split(sizes),\n                                    data.category.tolist()):\n            category = list(ShapeNet.seg_classes.keys())[category]\n            part = ShapeNet.seg_classes[category]\n            part = torch.tensor(part, device=device)\n\n            y_map[part] = torch.arange(part.size(0), device=device)\n\n            iou = jaccard_index(\n                out[:, part].argmax(dim=-1),\n                y_map[y],\n                num_classes=part.size(0),\n                absent_score=1.0,\n            )\n            ious.append(iou)\n\n        categories.append(data.category)\n\n    iou = torch.tensor(ious, device=device)\n    category = torch.cat(categories, dim=0)\n\n    mean_iou = scatter(iou, category, reduce='mean')  # Per-category IoU.\n    return float(mean_iou.mean())  # Global IoU.\n\n\nfor epoch in range(1, 31):\n    train()\n    iou = test(test_loader)\n    print(f'Epoch: {epoch:02d}, Test IoU: {iou:.4f}')\n"
  },
  {
    "path": "examples/rdl.py",
    "content": "\"\"\"This example demonstrates how to train a Relational Deep Learning model\nusing RelBench.\n\nPlease refer to:\n1. https://arxiv.org/abs/2407.20060 for RelBench, and\n2. https://github.com/snap-stanford/relbench for reproducing the results\n   reported on the RelBench paper.\n\"\"\"\nimport argparse\nimport math\nimport operator\nimport os\nfrom typing import Any, Dict, List, NamedTuple, Optional, Tuple\n\nimport numpy as np\nimport pandas as pd\nimport torch\nimport torch_frame\nfrom relbench.base import EntityTask, Table, TaskType\nfrom relbench.datasets import get_dataset, get_dataset_names\nfrom relbench.modeling.graph import make_pkey_fkey_graph\nfrom relbench.modeling.utils import get_stype_proposal\nfrom relbench.tasks import get_task, get_task_names\nfrom sentence_transformers import SentenceTransformer\nfrom torch import Tensor\nfrom torch_frame.config.text_embedder import TextEmbedderConfig\nfrom torch_frame.data.stats import StatType\nfrom torch_frame.nn.models import ResNet\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import (\n    MLP,\n    HeteroConv,\n    LayerNorm,\n    PositionalEncoding,\n    SAGEConv,\n)\nfrom torch_geometric.seed import seed_everything\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\nclass GloveTextEmbedding:\n    \"\"\"GloveTextEmbedding based on SentenceTransformer.\"\"\"\n    def __init__(self, device: Optional[torch.device] = None) -> None:\n        self.model = SentenceTransformer(\n            \"sentence-transformers/average_word_embeddings_glove.6B.300d\",\n            device=device,\n        )\n\n    def __call__(self, sentences: List[str]) -> Tensor:\n        return torch.from_numpy(self.model.encode(sentences))\n\n\nclass HeteroEncoder(torch.nn.Module):\n    r\"\"\"HeteroEncoder based on PyTorch Frame implemented with ResNet.\n\n    A heterogeneous encoder that processes different node types using PyTorch\n    Frame models. For each node type, it creates a separate encoder model\n    that processes the node features according to their data types\n    (categorical, numerical, etc).\n\n    Args:\n        channels: The output channels for each node type.\n        num_layers: The number of layers for the ResNet.\n        col_names_dict: A dictionary mapping from node type to column names\n            dictionary compatible with PyTorch Frame.\n        stats_dict: A dictionary containing statistics for each column in each\n            node type. Used for feature normalization and encoding.\n    \"\"\"\n    def __init__(\n        self,\n        channels: int,\n        num_layers: int,\n        col_names_dict: Dict[NodeType, Dict[torch_frame.stype, List[str]]],\n        stats_dict: Dict[NodeType, Dict[str, Dict[StatType, Any]]],\n    ) -> None:\n        super().__init__()\n\n        self.encoders = torch.nn.ModuleDict()\n\n        for node_type in col_names_dict.keys():\n            stype_encoder_dict = {\n                torch_frame.categorical:\n                torch_frame.nn.EmbeddingEncoder(),\n                torch_frame.numerical:\n                torch_frame.nn.LinearEncoder(),\n                torch_frame.multicategorical:\n                torch_frame.nn.MultiCategoricalEmbeddingEncoder(),\n                torch_frame.embedding:\n                torch_frame.nn.LinearEmbeddingEncoder(),\n                torch_frame.timestamp:\n                torch_frame.nn.TimestampEncoder()\n            }\n            torch_frame_model = ResNet(\n                channels=channels,\n                num_layers=num_layers,\n                out_channels=channels,\n                col_stats=stats_dict[node_type],\n                col_names_dict=col_names_dict[node_type],\n                stype_encoder_dict=stype_encoder_dict,\n            )\n            self.encoders[node_type] = torch_frame_model\n\n    def reset_parameters(self) -> None:\n        \"\"\"Reset the parameters of all encoder models.\"\"\"\n        for encoder in self.encoders.values():\n            encoder.reset_parameters()\n\n    def forward(\n        self,\n        tf_dict: Dict[NodeType, torch_frame.TensorFrame],\n    ) -> Dict[NodeType, Tensor]:\n        \"\"\"Forward pass of the heterogeneous encoder.\n\n        Args:\n            tf_dict: A dictionary mapping node types to their corresponding\n                TensorFrame objects containing the node features.\n\n        Returns:\n            Dictionary mapping node types to their encoded representations.\n            Each tensor has shape ``[num_nodes, channels]``.\n        \"\"\"\n        return {\n            node_type: self.encoders[node_type](tf)\n            for node_type, tf in tf_dict.items()\n        }\n\n\nclass HeteroTemporalEncoder(torch.nn.Module):\n    \"\"\"HeteroTemporalEncoder class that uses PositionalEncoding to encode\n    temporal information for heterogeneous graphs.\n\n    This encoder computes relative time embeddings between a seed time and\n    node timestamps, converting the time differences from seconds to days.\n    It applies positional encoding followed by a linear transformation for\n    each node type.\n\n    Args:\n        node_types: List of node types in the heterogeneous graph\n        channels: Number of channels/dimensions for the encoded embeddings\n\n    Example:\n        >>> encoder = HeteroTemporalEncoder(['user', 'item'], channels=64)\n        >>> seed_time = torch.tensor([1000])  # Reference timestamp\n        >>> time_dict = {'user': torch.tensor([800, 900]),\n        >>>             'item': torch.tensor([700, 850])}\n        >>> batch_dict = {'user': torch.tensor([0, 0]),\n        >>>              'item': torch.tensor([0, 0])}\n        >>> out_dict = encoder(seed_time, time_dict, batch_dict)\n        >>> out_dict['user'].shape\n        torch.Size([2, 64])\n    \"\"\"\n    def __init__(self, node_types: List[NodeType], channels: int) -> None:\n        super().__init__()\n        self.encoder_dict = torch.nn.ModuleDict({\n            node_type:\n            PositionalEncoding(channels)\n            for node_type in node_types\n        })\n        self.lin_dict = torch.nn.ModuleDict({\n            node_type:\n            torch.nn.Linear(channels, channels)\n            for node_type in node_types\n        })\n\n    def reset_parameters(self) -> None:\n        \"\"\"Reset the parameters of all encoders and linear layers.\"\"\"\n        for encoder in self.encoder_dict.values():\n            encoder.reset_parameters()\n        for lin in self.lin_dict.values():\n            lin.reset_parameters()\n\n    def forward(\n        self,\n        seed_time: Tensor,\n        time_dict: Dict[NodeType, Tensor],\n        batch_dict: Dict[NodeType, Tensor],\n    ) -> Dict[NodeType, Tensor]:\n        \"\"\"Forward pass of the temporal encoder.\n\n        Args:\n            seed_time: Reference timestamps for computing relative times\n            time_dict: Dictionary mapping node types to their timestamps\n            batch_dict: Dictionary mapping node types to batch assignments\n\n        Returns:\n            Dictionary mapping node types to their temporal embeddings\n        \"\"\"\n        out_dict: Dict[NodeType, Tensor] = {}\n\n        for node_type, time in time_dict.items():\n            rel_time = seed_time[batch_dict[node_type]] - time\n            rel_time = rel_time / (60 * 60 * 24)  # Convert seconds to days.\n\n            x = self.encoder_dict[node_type](rel_time)\n            x = self.lin_dict[node_type](x)\n            out_dict[node_type] = x\n\n        return out_dict\n\n\nclass HeteroGraphSAGE(torch.nn.Module):\n    \"\"\"Heterogeneous GraphSAGE model with layer normalization.\n\n    This model implements a heterogeneous version of GraphSAGE\n    that operates on multiple node and edge types. Each layer\n    consists of a heterogeneous graph convolution followed by\n    layer normalization and ReLU activation.\n\n    Args:\n        node_types: List of node types in the graph\n        edge_types: List of edge types in the graph\n        channels: Number of channels/features\n        aggr: Node aggregation scheme.\n        num_layers: Number of graph convolution layers.\n\n    Example:\n        >>> model = HeteroGraphSAGE(\n        >>>     node_types=['user', 'item'],\n        >>>     edge_types=[('user', 'rates', 'item')],\n        >>>     channels=64)\n        >>> out_dict = model(x_dict, edge_index_dict)\n    \"\"\"\n    def __init__(\n        self,\n        node_types: List[NodeType],\n        edge_types: List[EdgeType],\n        channels: int,\n        aggr: str = \"mean\",\n        num_layers: int = 2,\n    ) -> None:\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            conv = HeteroConv(\n                {\n                    edge_type: SAGEConv(\n                        (channels, channels), channels, aggr=aggr)\n                    for edge_type in edge_types\n                },\n                aggr=\"sum\",\n            )\n            self.convs.append(conv)\n\n        self.norms = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            norm_dict = torch.nn.ModuleDict()\n            for node_type in node_types:\n                norm_dict[node_type] = LayerNorm(channels, mode=\"node\")\n            self.norms.append(norm_dict)\n\n    def reset_parameters(self) -> None:\n        \"\"\"Reset the parameters of all convolution and normalization layers.\"\"\"\n        for conv in self.convs:\n            conv.reset_parameters()\n        for norm_dict in self.norms:\n            for norm in norm_dict.values():\n                norm.reset_parameters()\n\n    def forward(\n        self,\n        x_dict: Dict[NodeType, Tensor],\n        edge_index_dict: Dict[NodeType, Tensor],\n    ) -> Dict[NodeType, Tensor]:\n        \"\"\"Forward pass of the heterogeneous GraphSAGE model.\n\n        Args:\n            x_dict: Node feature dictionary\n            edge_index_dict: Edge index dictionary\n\n        Returns:\n            Updated node features after message passing\n        \"\"\"\n        for _, (conv, norm_dict) in enumerate(zip(self.convs, self.norms)):\n            x_dict = conv(x_dict, edge_index_dict)\n            x_dict = {key: norm_dict[key](x) for key, x in x_dict.items()}\n            x_dict = {key: x.relu() for key, x in x_dict.items()}\n\n        return x_dict\n\n\nclass Model(torch.nn.Module):\n    \"\"\"A heterogeneous graph neural network model for temporal graph learning.\n\n    This model consists of:\n    1. A heterogeneous feature encoder for node attributes\n    2. A temporal encoder for handling time information\n    3. A heterogeneous GraphSAGE model for message passing\n    4. An MLP head for final predictions\n\n    Args:\n        node_types: List of node types in the graph\n        edge_types: List of edge types in the graph\n        col_names_dict: Dictionary mapping node types to their column names and\n            types\n        temporal_node_types: List of node types with temporal features\n        col_stats_dict: Statistics of node features\n        num_layers: Number of GNN layers\n        channels: Hidden dimension size\n        out_channels: Output dimension size\n        aggr: Aggregation method for GNN\n        norm: Normalization method for MLP\n    \"\"\"\n    def __init__(\n        self,\n        node_types: List[NodeType],\n        edge_types: List[EdgeType],\n        col_names_dict: Dict[NodeType, Dict[torch_frame.stype, List[str]]],\n        temporal_node_types: List[NodeType],\n        col_stats_dict: Dict[NodeType, Dict[str, Dict[StatType, Any]]],\n        num_layers: int,\n        channels: int,\n        out_channels: int,\n        aggr: str,\n        norm: str,\n    ) -> None:\n        super().__init__()\n        self.encoder = HeteroEncoder(\n            channels=channels,\n            num_layers=num_layers,\n            col_names_dict=col_names_dict,\n            stats_dict=col_stats_dict,\n        )\n        self.temporal_encoder = HeteroTemporalEncoder(\n            node_types=temporal_node_types,\n            channels=channels,\n        )\n        self.gnn = HeteroGraphSAGE(\n            node_types=node_types,\n            edge_types=edge_types,\n            channels=channels,\n            aggr=aggr,\n            num_layers=num_layers,\n        )\n        self.head = MLP(\n            channels,\n            out_channels=out_channels,\n            norm=norm,\n            num_layers=1,\n        )\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        \"\"\"Reset the parameters of all model components.\"\"\"\n        self.encoder.reset_parameters()\n        self.temporal_encoder.reset_parameters()\n        self.gnn.reset_parameters()\n        self.head.reset_parameters()\n\n    def forward(\n        self,\n        batch: HeteroData,\n        entity_table: NodeType,\n    ) -> Tensor:\n        \"\"\"Forward pass of the model.\n\n        Steps:\n            1. Get seed time from entity table\n            2. Encode node features using HeteroEncoder\n            3. Encode temporal features using HeteroTemporalEncoder\n            4. Add temporal embeddings to node features\n            5. Apply graph neural network (HeteroGraphSAGE)\n            6. Apply final MLP head to target node embeddings\n\n        Args:\n            batch: Batch of heterogeneous graph data\n            entity_table: The target node type for prediction\n\n        Returns:\n            Tensor: Predictions for nodes in the entity table\n        \"\"\"\n        seed_time = batch[entity_table].seed_time\n        x_dict = self.encoder(batch.tf_dict)\n\n        rel_time_dict = self.temporal_encoder(\n            seed_time,\n            batch.time_dict,\n            batch.batch_dict,\n        )\n        for node_type, rel_time in rel_time_dict.items():\n            x_dict[node_type] = x_dict[node_type] + rel_time\n\n        x_dict = self.gnn(x_dict, batch.edge_index_dict)\n\n        return self.head(x_dict[entity_table][:seed_time.size(0)])\n\n\nclass AttachTargetTransform:\n    r\"\"\"Attach the target label to the heterogeneous mini-batch.\n\n    The batch consists of disjoins subgraphs loaded via temporal sampling. The\n    same input node can occur multiple times with different timestamps, and\n    thus different subgraphs and labels. Hence labels cannot be stored in the\n    graph object directly, and must be attached to the batch after the batch is\n    created.\n    \"\"\"\n    def __init__(self, entity: str, target: Tensor) -> None:\n        self.entity = entity\n        self.target = target\n\n    def __call__(self, batch: HeteroData) -> HeteroData:\n        batch[self.entity].y = self.target[batch[self.entity].input_id]\n        return batch\n\n\nclass TrainingTableInput(NamedTuple):\n    r\"\"\"Training table input for node prediction tasks.\n\n    A container for organizing input data needed for node-level predictions.\n\n    Attributes:\n        nodes: Tuple of (node_type, indices_tensor) containing the node type\n            identifier and Tensor of node IDs to predict on.\n        time: Optional Tensor of timestamps for temporal sampling. Shape\n            matches node indices. None if task is not temporal.\n        target: Optional Tensor of ground truth labels/values. Shape matches\n            node indices. None during inference.\n        transform: Optional transform that attaches target labels to batches\n            during training. Needed for temporal sampling where nodes can\n            appear multiple times with different labels.\n    \"\"\"\n    nodes: Tuple[NodeType, Tensor]\n    time: Optional[Tensor]\n    target: Optional[Tensor]\n    transform: Optional[AttachTargetTransform]\n\n\ndef get_task_type_params(\n        task: EntityTask) -> Tuple[int, torch.nn.Module, str, bool]:\n    r\"\"\"Get task-specific optimization parameters based on task type.\n\n    Args:\n        task: Task specification containing task type.\n\n    Returns:\n        Tuple containing:\n        - out_channels: Number of output channels\n        - loss_fn: Loss function\n        - tune_metric: Metric to optimize\n        - higher_is_better: Whether higher metric values are better\n    \"\"\"\n    if task.task_type == TaskType.REGRESSION:\n        out_channels = 1\n        loss_fn = torch.nn.L1Loss()\n        tune_metric = \"mae\"\n        higher_is_better = False\n    elif task.task_type == TaskType.BINARY_CLASSIFICATION:\n        out_channels = 1\n        loss_fn = torch.nn.BCEWithLogitsLoss()\n        tune_metric = \"roc_auc\"\n        higher_is_better = True\n    else:\n        raise ValueError(f\"Unsupported task type: {task.task_type}\")\n\n    return out_channels, loss_fn, tune_metric, higher_is_better\n\n\ndef to_unix_time(ser: pd.Series) -> np.ndarray:\n    r\"\"\"Convert a pandas Timestamp series to UNIX timestamp in seconds.\n\n    Args:\n        ser: Input pandas Series containing datetime values.\n\n    Returns:\n        Array of UNIX timestamps in seconds.\n    \"\"\"\n    assert ser.dtype in [np.dtype(\"datetime64[s]\"), np.dtype(\"datetime64[ns]\")]\n    unix_time = ser.astype(\"int64\").values\n    if ser.dtype == np.dtype(\"datetime64[ns]\"):\n        unix_time //= 10**9\n    return unix_time\n\n\ndef get_train_table_input(\n    split_table: Table,\n    task: EntityTask,\n) -> TrainingTableInput:\n    r\"\"\"Get the training table input for node prediction.\n\n    Processes a table split and task to create a TrainingTableInput\n    object containing:\n    1. Node indices for the target entity type\n    2. Optional timestamps for temporal sampling\n    3. Optional target labels/values for training\n    4. Optional transform to attach labels during batch loading\n\n    Args:\n        split_table: Table containing node IDs, optional timestamps, and\n            optional target values to predict.\n        task: Task specification containing entity table name, entity column\n            name, target column name, etc.\n\n    Returns:\n        Container with processed node indices, timestamps, target values and\n        transform needed for training/inference.\n    \"\"\"\n    nodes = torch.from_numpy(\n        split_table.df[task.entity_col].astype(int).values)\n\n    time: Optional[Tensor] = None\n    if split_table.time_col is not None:\n        time = torch.from_numpy(\n            to_unix_time(split_table.df[split_table.time_col]))\n\n    target: Optional[Tensor] = None\n    transform: Optional[AttachTargetTransform] = None\n    if task.target_col in split_table.df:\n        target = torch.from_numpy(\n            split_table.df[task.target_col].values.astype(float))\n        transform = AttachTargetTransform(task.entity_table, target)\n\n    return TrainingTableInput(\n        nodes=(task.entity_table, nodes),\n        time=time,\n        target=target,\n        transform=transform,\n    )\n\n\ndef train(\n    model: Model,\n    train_loader: NeighborLoader,\n    task: EntityTask,\n    optimizer: torch.optim.Optimizer,\n    loss_fn: torch.nn.Module,\n    device: torch.device,\n) -> float:\n    model.train()\n\n    loss_accum = torch.zeros(1, device=device).squeeze_()\n    count_accum = 0\n    for batch in tqdm(train_loader):\n        batch = batch.to(device)\n\n        optimizer.zero_grad()\n        pred = model(batch, task.entity_table)\n        pred = pred.view(-1) if pred.size(1) == 1 else pred\n\n        # Get the target column name from the task\n        loss = loss_fn(pred, batch[task.entity_table].y.float())\n        loss.backward()\n        optimizer.step()\n\n        loss *= pred.size(0)\n        loss_accum += loss\n        count_accum += pred.size(0)\n\n    return loss_accum.item() / count_accum\n\n\n@torch.no_grad()\ndef test(\n    test_loader: NeighborLoader,\n    model: Model,\n    task: EntityTask,\n    device: torch.device,\n) -> np.ndarray:\n    model.eval()\n\n    pred_list = []\n    for batch in tqdm(test_loader):\n        batch = batch.to(device)\n        pred = model(batch, task.entity_table)\n        pred = pred.view(-1) if pred.size(1) == 1 else pred\n        pred_list.append(pred.detach().cpu())\n    return torch.cat(pred_list, dim=0).numpy()\n\n\ndef main():\n    seed_everything(42)\n\n    parser = argparse.ArgumentParser(\n        formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n    parser.add_argument(\"--dataset\", type=str, default=\"rel-f1\",\n                        choices=get_dataset_names())\n    parser.add_argument(\n        \"--task\", type=str, default=None,\n        help=\"See available tasks at https://relbench.stanford.edu/\")\n    parser.add_argument(\"--batch_size\", type=int, default=512)\n    parser.add_argument(\"--temporal_strategy\", type=str, default=\"uniform\",\n                        choices=[\"uniform\", \"last\"])\n    parser.add_argument(\"--num_neighbors\", type=list, default=[128, 128])\n    parser.add_argument(\"--channels\", type=int, default=128)\n    parser.add_argument(\"--aggr\", type=str, default=\"sum\")\n    parser.add_argument(\"--norm\", type=str, default=\"batch_norm\")\n    parser.add_argument(\"--epochs\", type=int, default=10)\n    parser.add_argument(\"--lr\", type=float, default=0.005)\n    args = parser.parse_args()\n\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    print(\"Using device:\", device)\n\n    print(\"Loading dataset and task...\")\n    assert args.task in get_task_names(args.dataset), (\n        f\"Invalid --task '{args.task}' for --dataset '{args.dataset}'. \"\n        f\"Available tasks: {get_task_names(args.dataset)}\")\n    dataset = get_dataset(name=args.dataset, download=True)\n    task = get_task(\n        dataset_name=args.dataset,\n        task_name=args.task,\n        download=True,\n    )\n    print(f\"Task type: {task.task_type}\")\n    print(f\"Target column: '{task.target_col}'\")\n    print(f\"Entity table: '{task.entity_table}'\")\n\n    print(\"Getting column to stype dictionary...\")\n    db = dataset.get_db()\n    col_to_stype_dict = get_stype_proposal(db)\n    print(\"Column to stype dictionary: \", col_to_stype_dict)\n\n    print(\"Defining text embedder...\")\n    text_embedder_cfg = TextEmbedderConfig(\n        text_embedder=GloveTextEmbedding(device=device),\n        batch_size=256,\n    )\n\n    # Transform the dataset into a HeteroData object with torch_frame features\n    # See also:\n    # https://github.com/snap-stanford/relbench/blob/v1.1.0/relbench/modeling/graph.py#L20-L111  # noqa: E501\n    print(\"Transforming dataset into HeteroData object...\")\n    data, col_stats_dict = make_pkey_fkey_graph(\n        db,\n        col_to_stype_dict=col_to_stype_dict,  # specified column types\n        text_embedder_cfg=text_embedder_cfg,  # our chosen text encoder\n        cache_dir=os.path.join(  # store materialized graph for convenience\n            \"./data\",\n            f\"{args.dataset}_{args.task}_materialized_cache\",\n        ),\n    )\n\n    print(\"Preparing data loaders...\")\n    loader_dict = {}\n    num_neighbors_dict = {\n        edge_type: args.num_neighbors\n        for edge_type in data.edge_types\n    }\n\n    for split in [\"train\", \"val\", \"test\"]:\n        table = task.get_table(split)\n        print(f\"Creating '{split}' dataloader with columns: \"\n              f\"{list(table.df.columns)}\")\n        table_input = get_train_table_input(split_table=table, task=task)\n        loader_dict[split] = NeighborLoader(\n            data=data,\n            num_neighbors=num_neighbors_dict,\n            input_nodes=table_input.nodes,\n            input_time=table_input.time,\n            time_attr=\"time\",\n            transform=table_input.transform,\n            batch_size=args.batch_size,\n            temporal_strategy=args.temporal_strategy,\n            shuffle=split == \"train\",\n            num_workers=4,\n            persistent_workers=True,\n        )\n\n    print(\"Getting task-specific parameters...\")\n    out_channels, loss_fn, tune_metric, higher_is_better = \\\n        get_task_type_params(task)\n    print(\"out_channels: \", out_channels)\n    print(\"loss_fn: \", loss_fn)\n    print(\"tune_metric: \", tune_metric)\n    print(\"higher_is_better: \", higher_is_better)\n\n    print(\"Initializing the model...\")\n    col_names_dict = {\n        node_type: data[node_type].tf.col_names_dict\n        for node_type in data.node_types\n    }\n    temporal_node_types = [\n        node_type for node_type in data.node_types if \"time\" in data[node_type]\n    ]\n    model = Model(\n        node_types=data.node_types,  # Include all node types\n        edge_types=data.edge_types,  # Include all edge types\n        col_names_dict=col_names_dict,\n        col_stats_dict=col_stats_dict,\n        temporal_node_types=temporal_node_types,\n        num_layers=len(args.num_neighbors),\n        channels=args.channels,\n        out_channels=out_channels,\n        aggr=args.aggr,\n        norm=args.norm,\n    ).to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)\n\n    print(\"Training the model...\")\n    best_val_metric = -math.inf if higher_is_better else math.inf\n    for epoch in range(1, args.epochs + 1):\n        train_loss = train(\n            model=model,\n            train_loader=loader_dict[\"train\"],\n            task=task,\n            optimizer=optimizer,\n            loss_fn=loss_fn,\n            device=device,\n        )\n        val_pred = test(\n            test_loader=loader_dict[\"val\"],\n            model=model,\n            task=task,\n            device=device,\n        )\n        val_metrics = task.evaluate(val_pred, task.get_table(\"val\"))\n        print(\n            f\"Epoch: {epoch:02d}, \"\n            f\"train_loss: {train_loss:.4f}, \"\n            f\"{', '.join([f'val_{k}: {v:.4f}' for k, v in val_metrics.items()])}\"  # noqa: E501\n        )\n\n        is_better_op = operator.gt if higher_is_better else operator.lt\n        if is_better_op(val_metrics[tune_metric], best_val_metric):\n            best_val_metric = val_metrics[tune_metric]\n            torch.save(model.state_dict(), \"best_model.pt\")\n\n    print(\"Testing the best model...\")\n    model.load_state_dict(torch.load(\"best_model.pt\"))\n    test_pred = test(\n        test_loader=loader_dict[\"test\"],\n        model=model,\n        task=task,\n        device=device,\n    )\n    test_metrics = task.evaluate(test_pred)\n    print(\n        f\"{', '.join([f'test_{k}: {v:.4f}' for k, v in test_metrics.items()])}\"\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "examples/rect.py",
    "content": "import argparse\nimport copy\nimport os.path as osp\n\nimport torch\nfrom sklearn.linear_model import LogisticRegression\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import RECT_L\n\n# RECT focuses on the zero-shot, i.e. completely-imbalanced label setting:\n# For this, we first remove \"unseen\" classes from the training set and train a\n# RECT (or more specifically its supervised part RECT-L) model in the zero-shot\n# label scenario. Lastly, we train a simple classifier to evaluate the final\n# performance of the embeddings based on the original labels.\n\n# Datasets              Citeseer             Cora          Pubmed\n# Unseen Classes  [1, 2, 5]  [3, 4]  [1, 2, 3]  [3, 4, 6]  [2]\n# RECT-L          66.30      68.20   74.60      71.20      75.30\n# GCN             51.80      55.70   55.80      57.10      59.80\n# NodeFeats       61.40      61.40   57.50      57.50      73.10\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, default='Cora',\n                    choices=['Cora', 'CiteSeer', 'PubMed'])\nparser.add_argument('--unseen-classes', type=int, nargs='*', default=[1, 2, 3])\nargs = parser.parse_args()\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '../data/Planetoid')\ntrain_mask_original = Planetoid(path, args.dataset)[0].train_mask.clone()\ntransform = T.Compose([\n    T.NormalizeFeatures(),\n    T.SVDFeatureReduction(200),\n    T.GDC(),\n])\ndataset = Planetoid(path, args.dataset, transform=transform)\ndata = dataset[0]\nzs_data = T.RemoveTrainingClasses(args.unseen_classes)(copy.copy(data))\n\nmodel = RECT_L(200, 200, normalize=False, dropout=0.0)\nzs_data.y = model.get_semantic_labels(zs_data.x, zs_data.y, zs_data.train_mask)\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nmodel, zs_data = model.to(device), zs_data.to(device)\n\ncriterion = torch.nn.MSELoss(reduction='sum')\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)\n\nmodel.train()\nfor epoch in range(1, 201):\n    optimizer.zero_grad()\n    out = model(zs_data.x, zs_data.edge_index, zs_data.edge_attr)\n    loss = criterion(out[zs_data.train_mask], zs_data.y)\n    loss.backward()\n    optimizer.step()\n    print(f'Epoch {epoch:03d}, Loss {loss:.4f}')\n\nmodel.eval()\nwith torch.no_grad():\n    h = model.embed(zs_data.x, zs_data.edge_index, zs_data.edge_attr).cpu()\n\nreg = LogisticRegression()\nreg.fit(h[data.train_mask].numpy(), data.y[data.train_mask].numpy())\ntest_acc = reg.score(h[data.test_mask].numpy(), data.y[data.test_mask].numpy())\nprint(f'Test Acc: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/reddit.py",
    "content": "import copy\nimport os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import Reddit\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import SAGEConv\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Reddit')\ndataset = Reddit(path)\n\n# Already send node features/labels to GPU for faster access during sampling:\ndata = dataset[0].to(device, 'x', 'y')\n\nkwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}\ntrain_loader = NeighborLoader(data, input_nodes=data.train_mask,\n                              num_neighbors=[25, 10], shuffle=True, **kwargs)\n\nsubgraph_loader = NeighborLoader(copy.copy(data), input_nodes=None,\n                                 num_neighbors=[-1], shuffle=False, **kwargs)\n\n# No need to maintain these features during evaluation:\ndel subgraph_loader.data.x, subgraph_loader.data.y\n# Add global node index information.\nsubgraph_loader.data.num_nodes = data.num_nodes\nsubgraph_loader.data.n_id = torch.arange(data.num_nodes)\n\n\nclass SAGE(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        self.convs = torch.nn.ModuleList()\n        self.convs.append(SAGEConv(in_channels, hidden_channels))\n        self.convs.append(SAGEConv(hidden_channels, out_channels))\n\n    def forward(self, x, edge_index):\n        for i, conv in enumerate(self.convs):\n            x = conv(x, edge_index)\n            if i < len(self.convs) - 1:\n                x = x.relu_()\n                x = F.dropout(x, p=0.5, training=self.training)\n        return x\n\n    @torch.no_grad()\n    def inference(self, x_all, subgraph_loader):\n        pbar = tqdm(total=len(subgraph_loader.dataset) * len(self.convs))\n        pbar.set_description('Evaluating')\n\n        # Compute representations of nodes layer by layer, using *all*\n        # available edges. This leads to faster computation in contrast to\n        # immediately computing the final representations of each batch:\n        for i, conv in enumerate(self.convs):\n            xs = []\n            for batch in subgraph_loader:\n                x = x_all[batch.n_id.to(x_all.device)].to(device)\n                x = conv(x, batch.edge_index.to(device))\n                if i < len(self.convs) - 1:\n                    x = x.relu_()\n                xs.append(x[:batch.batch_size].cpu())\n                pbar.update(batch.batch_size)\n            x_all = torch.cat(xs, dim=0)\n        pbar.close()\n        return x_all\n\n\nmodel = SAGE(dataset.num_features, 256, dataset.num_classes).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train(epoch):\n    model.train()\n\n    pbar = tqdm(total=int(len(train_loader.dataset)))\n    pbar.set_description(f'Epoch {epoch:02d}')\n\n    total_loss = total_correct = total_examples = 0\n    for batch in train_loader:\n        optimizer.zero_grad()\n        y = batch.y[:batch.batch_size]\n        y_hat = model(batch.x, batch.edge_index.to(device))[:batch.batch_size]\n        loss = F.cross_entropy(y_hat, y)\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss) * batch.batch_size\n        total_correct += int((y_hat.argmax(dim=-1) == y).sum())\n        total_examples += batch.batch_size\n        pbar.update(batch.batch_size)\n    pbar.close()\n\n    return total_loss / total_examples, total_correct / total_examples\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    y_hat = model.inference(data.x, subgraph_loader).argmax(dim=-1)\n    y = data.y.to(y_hat.device)\n\n    accs = []\n    for mask in [data.train_mask, data.val_mask, data.test_mask]:\n        accs.append(int((y_hat[mask] == y[mask]).sum()) / int(mask.sum()))\n    return accs\n\n\ntimes = []\nfor epoch in range(1, 11):\n    start = time.time()\n    loss, acc = train(epoch)\n    print(f'Epoch {epoch:02d}, Loss: {loss:.4f}, Approx. Train: {acc:.4f}')\n    train_acc, val_acc, test_acc = test()\n    print(f'Epoch: {epoch:02d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '\n          f'Test: {test_acc:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/renet.py",
    "content": "import argparse\nimport os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import GDELT, ICEWS18\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn.models.re_net import RENet\n\nparser = argparse.ArgumentParser()\nparser.add_argument(\n    '--dataset',\n    type=str,\n    default='GDELT',\n    choices=['ICEWS18', 'GDELT'],\n)\nparser.add_argument('--seq_len', type=int, default=10)\nargs = parser.parse_args()\n\n# Load the dataset and precompute history objects:\npath = osp.dirname(osp.realpath(__file__))\npath = osp.join(path, '..', 'data', args.dataset)\npre_transform = RENet.pre_transform(args.seq_len)\nif args.dataset == 'ICEWS18':\n    train_dataset = ICEWS18(path, pre_transform=pre_transform)\n    test_dataset = ICEWS18(path, split='test', pre_transform=pre_transform)\nelif args.dataset == 'GDELT':\n    train_dataset = GDELT(path, pre_transform=pre_transform)\n    test_dataset = GDELT(path, split='test', pre_transform=pre_transform)\n\n# Create dataloader for training and test dataset.\ntrain_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True,\n                          follow_batch=['h_sub', 'h_obj'], num_workers=6)\ntest_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False,\n                         follow_batch=['h_sub', 'h_obj'], num_workers=6)\n\n# Initialize model and optimizer.\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = RENet(\n    train_dataset.num_nodes,\n    train_dataset.num_rels,\n    hidden_channels=200,\n    seq_len=args.seq_len,\n    dropout=0.5,\n).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001,\n                             weight_decay=0.00001)\n\n\ndef train():\n    model.train()\n\n    # Train model via multi-class classification against the corresponding\n    # object and subject entities.\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        log_prob_obj, log_prob_sub = model(data)\n        loss_obj = F.nll_loss(log_prob_obj, data.obj)\n        loss_sub = F.nll_loss(log_prob_sub, data.sub)\n        loss = loss_obj + loss_sub\n        loss.backward()\n        optimizer.step()\n\n\ndef test(loader):\n    model.eval()\n\n    # Compute Mean Reciprocal Rank (MRR) and Hits@1/3/10.\n    result = torch.tensor([0, 0, 0, 0], dtype=torch.float)\n    for data in loader:\n        data = data.to(device)\n        with torch.no_grad():\n            log_prob_obj, log_prob_sub = model(data)\n        result += model.test(log_prob_obj, data.obj) * data.obj.size(0)\n        result += model.test(log_prob_sub, data.sub) * data.sub.size(0)\n    result = result / (2 * len(loader.dataset))\n    return result.tolist()\n\n\ntimes = []\nfor epoch in range(1, 21):\n    start = time.time()\n    train()\n    mrr, hits1, hits3, hits10 = test(test_loader)\n    print(f'Epoch: {epoch:02d}, MRR: {mrr:.4f}, Hits@1: {hits1:.4f}, '\n          f'Hits@3: {hits3:.4f}, Hits@10: {hits10:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/rev_gnn.py",
    "content": "# Peak GPU memory usage is around 1.57 G\n# | RevGNN Models           | Test Acc        | Val Acc         |\n# |-------------------------|-----------------|-----------------|\n# | 112 layers 160 channels | 0.8307 ± 0.0030 | 0.9290 ± 0.0007 |\n# | 7 layers 160 channels   | 0.8276 ± 0.0027 | 0.9272 ± 0.0006 |\n\nimport os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom ogb.nodeproppred import Evaluator, PygNodePropPredDataset\nfrom torch.nn import LayerNorm, Linear\nfrom tqdm import tqdm\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.loader import RandomNodeLoader\nfrom torch_geometric.nn import GroupAddRev, SAGEConv\nfrom torch_geometric.utils import index_to_mask\n\n\nclass GNNBlock(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.norm = LayerNorm(in_channels, elementwise_affine=True)\n        self.conv = SAGEConv(in_channels, out_channels)\n\n    def reset_parameters(self):\n        self.norm.reset_parameters()\n        self.conv.reset_parameters()\n\n    def forward(self, x, edge_index, dropout_mask=None):\n        x = self.norm(x).relu()\n        if self.training and dropout_mask is not None:\n            x = x * dropout_mask\n        return self.conv(x, edge_index)\n\n\nclass RevGNN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,\n                 dropout, num_groups=2):\n        super().__init__()\n\n        self.dropout = dropout\n\n        self.lin1 = Linear(in_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, out_channels)\n        self.norm = LayerNorm(hidden_channels, elementwise_affine=True)\n\n        assert hidden_channels % num_groups == 0\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            conv = GNNBlock(\n                hidden_channels // num_groups,\n                hidden_channels // num_groups,\n            )\n            self.convs.append(GroupAddRev(conv, num_groups=num_groups))\n\n    def reset_parameters(self):\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n        self.norm.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n\n    def forward(self, x, edge_index):\n        x = self.lin1(x)\n\n        # Generate a dropout mask which will be shared across GNN blocks:\n        mask = None\n        if self.training and self.dropout > 0:\n            mask = torch.zeros_like(x).bernoulli_(1 - self.dropout)\n            mask = mask.requires_grad_(False)\n            mask = mask / (1 - self.dropout)\n\n        for conv in self.convs:\n            x = conv(x, edge_index, mask)\n        x = self.norm(x).relu()\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        return self.lin2(x)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\ntransform = T.Compose([T.ToDevice(device), T.ToSparseTensor()])\nroot = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'products')\ndataset = PygNodePropPredDataset('ogbn-products', root,\n                                 transform=T.AddSelfLoops())\nevaluator = Evaluator(name='ogbn-products')\n\ndata = dataset[0]\nsplit_idx = dataset.get_idx_split()\nfor split in ['train', 'valid', 'test']:\n    data[f'{split}_mask'] = index_to_mask(split_idx[split], data.y.shape[0])\n\ntrain_loader = RandomNodeLoader(data, num_parts=10, shuffle=True,\n                                num_workers=5)\n# Increase the num_parts of the test loader if you cannot fit\n# the full batch graph into your GPU:\ntest_loader = RandomNodeLoader(data, num_parts=1, num_workers=5)\n\nmodel = RevGNN(\n    in_channels=dataset.num_features,\n    hidden_channels=160,\n    out_channels=dataset.num_classes,\n    num_layers=7,  # You can try 1000 layers for fun\n    dropout=0.5,\n    num_groups=2,\n).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.003)\n\n\ndef train(epoch):\n    model.train()\n\n    pbar = tqdm(total=len(train_loader))\n    pbar.set_description(f'Training epoch: {epoch:03d}')\n\n    total_loss = total_examples = 0\n    for data in train_loader:\n        optimizer.zero_grad()\n\n        # Memory-efficient aggregations:\n        data = transform(data)\n        out = model(data.x, data.adj_t)[data.train_mask]\n        loss = F.cross_entropy(out, data.y[data.train_mask].view(-1))\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss) * int(data.train_mask.sum())\n        total_examples += int(data.train_mask.sum())\n        pbar.update(1)\n\n    pbar.close()\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(epoch):\n    model.eval()\n\n    y_true = {\"train\": [], \"valid\": [], \"test\": []}\n    y_pred = {\"train\": [], \"valid\": [], \"test\": []}\n\n    pbar = tqdm(total=len(test_loader))\n    pbar.set_description(f'Evaluating epoch: {epoch:03d}')\n\n    for data in test_loader:\n        # Memory-efficient aggregations\n        data = transform(data)\n        out = model(data.x, data.adj_t).argmax(dim=-1, keepdim=True)\n\n        for split in ['train', 'valid', 'test']:\n            mask = data[f'{split}_mask']\n            y_true[split].append(data.y[mask].cpu())\n            y_pred[split].append(out[mask].cpu())\n\n        pbar.update(1)\n\n    pbar.close()\n\n    train_acc = evaluator.eval({\n        'y_true': torch.cat(y_true['train'], dim=0),\n        'y_pred': torch.cat(y_pred['train'], dim=0),\n    })['acc']\n\n    valid_acc = evaluator.eval({\n        'y_true': torch.cat(y_true['valid'], dim=0),\n        'y_pred': torch.cat(y_pred['valid'], dim=0),\n    })['acc']\n\n    test_acc = evaluator.eval({\n        'y_true': torch.cat(y_true['test'], dim=0),\n        'y_pred': torch.cat(y_pred['test'], dim=0),\n    })['acc']\n\n    return train_acc, valid_acc, test_acc\n\n\ntimes = []\nbest_val = 0.0\nfinal_train = 0.0\nfinal_test = 0.0\nfor epoch in range(1, 1001):\n    start = time.time()\n    loss = train(epoch)\n    train_acc, val_acc, test_acc = test(epoch)\n    if val_acc > best_val:\n        best_val = val_acc\n        final_train = train_acc\n        final_test = test_acc\n    print(f'Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '\n          f'Test: {test_acc:.4f}')\n    times.append(time.time() - start)\n\nprint(f'Final Train: {final_train:.4f}, Best Val: {best_val:.4f}, '\n      f'Final Test: {final_test:.4f}')\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/rgat.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Entities\nfrom torch_geometric.nn import RGATConv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities')\ndataset = Entities(path, 'AIFB')\ndata = dataset[0]\ndata.x = torch.randn(data.num_nodes, 16)\n\n\nclass RGAT(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels,\n                 num_relations):\n        super().__init__()\n        self.conv1 = RGATConv(in_channels, hidden_channels, num_relations)\n        self.conv2 = RGATConv(hidden_channels, hidden_channels, num_relations)\n        self.lin = torch.nn.Linear(hidden_channels, out_channels)\n\n    def forward(self, x, edge_index, edge_type):\n        x = self.conv1(x, edge_index, edge_type).relu()\n        x = self.conv2(x, edge_index, edge_type).relu()\n        x = self.lin(x)\n        return F.log_softmax(x, dim=-1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\ndata = data.to(device)\nmodel = RGAT(16, 16, dataset.num_classes, dataset.num_relations).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index, data.edge_type)\n    loss = F.nll_loss(out[data.train_idx], data.train_y)\n    loss.backward()\n    optimizer.step()\n    return float(loss.detach())\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.x, data.edge_index, data.edge_type).argmax(dim=-1)\n    train_acc = float((pred[data.train_idx] == data.train_y).float().mean())\n    test_acc = float((pred[data.test_idx] == data.test_y).float().mean())\n    return train_acc, test_acc\n\n\ntimes = []\nfor epoch in range(1, 51):\n    start = time.time()\n    loss = train()\n    train_acc, test_acc = test()\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} '\n          f'Test: {test_acc:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/rgcn.py",
    "content": "import argparse\nimport os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Entities\nfrom torch_geometric.nn import FastRGCNConv, RGCNConv\nfrom torch_geometric.utils import k_hop_subgraph\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, default='AIFB',\n                    choices=['AIFB', 'MUTAG', 'BGS', 'AM'])\nargs = parser.parse_args()\n\n# Trade memory consumption for faster computation.\nif args.dataset in ['AIFB', 'MUTAG']:\n    Conv = FastRGCNConv\nelse:\n    Conv = RGCNConv\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities')\ndataset = Entities(path, args.dataset)\ndata = dataset[0]\n\n# BGS and AM graphs are too big to process them in a full-batch fashion.\n# Since our model does only make use of a rather small receptive field, we\n# filter the graph to only contain the nodes that are at most 2-hop neighbors\n# away from any training/test node.\nnode_idx = torch.cat([data.train_idx, data.test_idx], dim=0)\nnode_idx, edge_index, mapping, edge_mask = k_hop_subgraph(\n    node_idx, 2, data.edge_index, relabel_nodes=True)\n\ndata.num_nodes = node_idx.size(0)\ndata.edge_index = edge_index\ndata.edge_type = data.edge_type[edge_mask]\ndata.train_idx = mapping[:data.train_idx.size(0)]\ndata.test_idx = mapping[data.train_idx.size(0):]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = Conv(data.num_nodes, 16, dataset.num_relations,\n                          num_bases=30)\n        self.conv2 = Conv(16, dataset.num_classes, dataset.num_relations,\n                          num_bases=30)\n\n    def forward(self, edge_index, edge_type):\n        x = F.relu(self.conv1(None, edge_index, edge_type))\n        x = self.conv2(x, edge_index, edge_type)\n        return F.log_softmax(x, dim=1)\n\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\ndevice = torch.device('cpu') if args.dataset == 'AM' else device\nmodel, data = Net().to(device), data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.edge_index, data.edge_type)\n    loss = F.nll_loss(out[data.train_idx], data.train_y)\n    loss.backward()\n    optimizer.step()\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    pred = model(data.edge_index, data.edge_type).argmax(dim=-1)\n    train_acc = float((pred[data.train_idx] == data.train_y).float().mean())\n    test_acc = float((pred[data.test_idx] == data.test_y).float().mean())\n    return train_acc, test_acc\n\n\ntimes = []\nfor epoch in range(1, 51):\n    start = time.time()\n    loss = train()\n    train_acc, test_acc = test()\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} '\n          f'Test: {test_acc:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/rgcn_link_pred.py",
    "content": "\"\"\"\"\nImplements the link prediction task on the FB15k237 datasets according to the\n`\"Modeling Relational Data with Graph Convolutional Networks\"\n<https://arxiv.org/abs/1703.06103>`_ paper.\n\nCaution: This script is executed in a full-batch fashion, and therefore needs\nto run on CPU (following the experimental setup in the official paper).\n\"\"\"\nimport os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Parameter\nfrom tqdm import tqdm\n\nfrom torch_geometric.datasets import RelLinkPredDataset\nfrom torch_geometric.nn import GAE, RGCNConv\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'RLPD')\ndataset = RelLinkPredDataset(path, 'FB15k-237')\ndata = dataset[0].to(device)\n\n\nclass RGCNEncoder(torch.nn.Module):\n    def __init__(self, num_nodes, hidden_channels, num_relations):\n        super().__init__()\n        self.node_emb = Parameter(torch.empty(num_nodes, hidden_channels))\n        self.conv1 = RGCNConv(hidden_channels, hidden_channels, num_relations,\n                              num_blocks=5)\n        self.conv2 = RGCNConv(hidden_channels, hidden_channels, num_relations,\n                              num_blocks=5)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.node_emb)\n        self.conv1.reset_parameters()\n        self.conv2.reset_parameters()\n\n    def forward(self, edge_index, edge_type):\n        x = self.node_emb\n        x = self.conv1(x, edge_index, edge_type).relu_()\n        x = F.dropout(x, p=0.2, training=self.training)\n        x = self.conv2(x, edge_index, edge_type)\n        return x\n\n\nclass DistMultDecoder(torch.nn.Module):\n    def __init__(self, num_relations, hidden_channels):\n        super().__init__()\n        self.rel_emb = Parameter(torch.empty(num_relations, hidden_channels))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.rel_emb)\n\n    def forward(self, z, edge_index, edge_type):\n        z_src, z_dst = z[edge_index[0]], z[edge_index[1]]\n        rel = self.rel_emb[edge_type]\n        return torch.sum(z_src * rel * z_dst, dim=1)\n\n\nmodel = GAE(\n    RGCNEncoder(data.num_nodes, 500, dataset.num_relations),\n    DistMultDecoder(dataset.num_relations // 2, 500),\n).to(device)\n\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef negative_sampling(edge_index, num_nodes):\n    # Sample edges by corrupting either the subject or the object of each edge.\n    mask_1 = torch.rand(edge_index.size(1)) < 0.5\n    mask_2 = ~mask_1\n\n    neg_edge_index = edge_index.clone()\n    neg_edge_index[0, mask_1] = torch.randint(num_nodes, (mask_1.sum(), ),\n                                              device=neg_edge_index.device)\n    neg_edge_index[1, mask_2] = torch.randint(num_nodes, (mask_2.sum(), ),\n                                              device=neg_edge_index.device)\n    return neg_edge_index\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n\n    z = model.encode(data.edge_index, data.edge_type)\n\n    pos_out = model.decode(z, data.train_edge_index, data.train_edge_type)\n\n    neg_edge_index = negative_sampling(data.train_edge_index, data.num_nodes)\n    neg_out = model.decode(z, neg_edge_index, data.train_edge_type)\n\n    out = torch.cat([pos_out, neg_out])\n    gt = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])\n    cross_entropy_loss = F.binary_cross_entropy_with_logits(out, gt)\n    reg_loss = z.pow(2).mean() + model.decoder.rel_emb.pow(2).mean()\n    loss = cross_entropy_loss + 1e-2 * reg_loss\n\n    loss.backward()\n    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)\n    optimizer.step()\n\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    z = model.encode(data.edge_index, data.edge_type)\n\n    valid_mrr = compute_mrr(z, data.valid_edge_index, data.valid_edge_type)\n    test_mrr = compute_mrr(z, data.test_edge_index, data.test_edge_type)\n\n    return valid_mrr, test_mrr\n\n\n@torch.no_grad()\ndef compute_rank(ranks):\n    # fair ranking prediction as the average\n    # of optimistic and pessimistic ranking\n    true = ranks[0]\n    optimistic = (ranks > true).sum() + 1\n    pessimistic = (ranks >= true).sum()\n    return (optimistic + pessimistic).float() * 0.5\n\n\n@torch.no_grad()\ndef compute_mrr(z, edge_index, edge_type):\n    ranks = []\n    for i in tqdm(range(edge_type.numel())):\n        (src, dst), rel = edge_index[:, i], edge_type[i]\n\n        # Try all nodes as tails, but delete true triplets:\n        tail_mask = torch.ones(data.num_nodes, dtype=torch.bool)\n        for (heads, tails), types in [\n            (data.train_edge_index, data.train_edge_type),\n            (data.valid_edge_index, data.valid_edge_type),\n            (data.test_edge_index, data.test_edge_type),\n        ]:\n            tail_mask[tails[(heads == src) & (types == rel)]] = False\n\n        tail = torch.arange(data.num_nodes)[tail_mask]\n        tail = torch.cat([torch.tensor([dst]), tail])\n        head = torch.full_like(tail, fill_value=src)\n        eval_edge_index = torch.stack([head, tail], dim=0)\n        eval_edge_type = torch.full_like(tail, fill_value=rel)\n\n        out = model.decode(z, eval_edge_index, eval_edge_type)\n        rank = compute_rank(out)\n        ranks.append(rank)\n\n        # Try all nodes as heads, but delete true triplets:\n        head_mask = torch.ones(data.num_nodes, dtype=torch.bool)\n        for (heads, tails), types in [\n            (data.train_edge_index, data.train_edge_type),\n            (data.valid_edge_index, data.valid_edge_type),\n            (data.test_edge_index, data.test_edge_type),\n        ]:\n            head_mask[heads[(tails == dst) & (types == rel)]] = False\n\n        head = torch.arange(data.num_nodes)[head_mask]\n        head = torch.cat([torch.tensor([src]), head])\n        tail = torch.full_like(head, fill_value=dst)\n        eval_edge_index = torch.stack([head, tail], dim=0)\n        eval_edge_type = torch.full_like(head, fill_value=rel)\n\n        out = model.decode(z, eval_edge_index, eval_edge_type)\n        rank = compute_rank(out)\n        ranks.append(rank)\n\n    return (1. / torch.tensor(ranks, dtype=torch.float)).mean()\n\n\ntimes = []\nfor epoch in range(1, 10001):\n    start = time.time()\n    loss = train()\n    print(f'Epoch: {epoch:05d}, Loss: {loss:.4f}')\n    if (epoch % 500) == 0:\n        valid_mrr, test_mrr = test()\n        print(f'Val MRR: {valid_mrr:.4f}, Test MRR: {test_mrr:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/seal_link_pred.py",
    "content": "import math\nimport os.path as osp\nimport time\nfrom itertools import chain\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom scipy.sparse.csgraph import shortest_path\nfrom sklearn.metrics import roc_auc_score\nfrom torch.nn import BCEWithLogitsLoss, Conv1d, MaxPool1d, ModuleList\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import MLP, GCNConv, SortAggregation\nfrom torch_geometric.transforms import RandomLinkSplit\nfrom torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix\n\n\nclass SEALDataset(InMemoryDataset):\n    def __init__(self, dataset, num_hops, split='train'):\n        self._data = dataset[0]\n        self.num_hops = num_hops\n        super().__init__(dataset.root)\n        index = ['train', 'val', 'test'].index(split)\n        self.load(self.processed_paths[index])\n\n    @property\n    def processed_file_names(self):\n        return ['SEAL_train_data.pt', 'SEAL_val_data.pt', 'SEAL_test_data.pt']\n\n    def process(self):\n        transform = RandomLinkSplit(num_val=0.05, num_test=0.1,\n                                    is_undirected=True, split_labels=True)\n        train_data, val_data, test_data = transform(self._data)\n\n        self._max_z = 0\n\n        # Collect a list of subgraphs for training, validation and testing:\n        train_pos_data_list = self.extract_enclosing_subgraphs(\n            train_data.edge_index, train_data.pos_edge_label_index, 1)\n        train_neg_data_list = self.extract_enclosing_subgraphs(\n            train_data.edge_index, train_data.neg_edge_label_index, 0)\n\n        val_pos_data_list = self.extract_enclosing_subgraphs(\n            val_data.edge_index, val_data.pos_edge_label_index, 1)\n        val_neg_data_list = self.extract_enclosing_subgraphs(\n            val_data.edge_index, val_data.neg_edge_label_index, 0)\n\n        test_pos_data_list = self.extract_enclosing_subgraphs(\n            test_data.edge_index, test_data.pos_edge_label_index, 1)\n        test_neg_data_list = self.extract_enclosing_subgraphs(\n            test_data.edge_index, test_data.neg_edge_label_index, 0)\n\n        # Convert node labeling to one-hot features.\n        for data in chain(train_pos_data_list, train_neg_data_list,\n                          val_pos_data_list, val_neg_data_list,\n                          test_pos_data_list, test_neg_data_list):\n            # We solely learn links from structure, dropping any node features:\n            data.x = F.one_hot(data.z, self._max_z + 1).to(torch.float)\n\n        train_data_list = train_pos_data_list + train_neg_data_list\n        self.save(train_data_list, self.processed_paths[0])\n        val_data_list = val_pos_data_list + val_neg_data_list\n        self.save(val_data_list, self.processed_paths[1])\n        test_data_list = test_pos_data_list + test_neg_data_list\n        self.save(test_data_list, self.processed_paths[2])\n\n    def extract_enclosing_subgraphs(self, edge_index, edge_label_index, y):\n        data_list = []\n        for src, dst in edge_label_index.t().tolist():\n            sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(\n                [src, dst], self.num_hops, edge_index, relabel_nodes=True)\n            src, dst = mapping.tolist()\n\n            # Remove target link from the subgraph.\n            mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)\n            mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)\n            sub_edge_index = sub_edge_index[:, mask1 & mask2]\n\n            # Calculate node labeling.\n            z = self.drnl_node_labeling(sub_edge_index, src, dst,\n                                        num_nodes=sub_nodes.size(0))\n\n            data = Data(x=self._data.x[sub_nodes], z=z,\n                        edge_index=sub_edge_index, y=y)\n            data_list.append(data)\n\n        return data_list\n\n    def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):\n        # Double-radius node labeling (DRNL).\n        src, dst = (dst, src) if src > dst else (src, dst)\n        adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()\n\n        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))\n        adj_wo_src = adj[idx, :][:, idx]\n\n        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))\n        adj_wo_dst = adj[idx, :][:, idx]\n\n        dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,\n                                 indices=src)\n        dist2src = np.insert(dist2src, dst, 0, axis=0)\n        dist2src = torch.from_numpy(dist2src)\n\n        dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,\n                                 indices=dst - 1)\n        dist2dst = np.insert(dist2dst, src, 0, axis=0)\n        dist2dst = torch.from_numpy(dist2dst)\n\n        dist = dist2src + dist2dst\n        dist_over_2, dist_mod_2 = dist // 2, dist % 2\n\n        z = 1 + torch.min(dist2src, dist2dst)\n        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)\n        z[src] = 1.\n        z[dst] = 1.\n        z[torch.isnan(z)] = 0.\n\n        self._max_z = max(int(z.max()), self._max_z)\n\n        return z.to(torch.long)\n\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')\ndataset = Planetoid(path, name='Cora')\n\ntrain_dataset = SEALDataset(dataset, num_hops=2, split='train')\nval_dataset = SEALDataset(dataset, num_hops=2, split='val')\ntest_dataset = SEALDataset(dataset, num_hops=2, split='test')\n\ntrain_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=32)\ntest_loader = DataLoader(test_dataset, batch_size=32)\n\n\nclass DGCNN(torch.nn.Module):\n    def __init__(self, hidden_channels, num_layers, GNN=GCNConv, k=0.6):\n        super().__init__()\n\n        if k < 1:  # Transform percentile to number.\n            num_nodes = sorted([data.num_nodes for data in train_dataset])\n            k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]\n            k = int(max(10, k))\n\n        self.convs = ModuleList()\n        self.convs.append(GNN(train_dataset.num_features, hidden_channels))\n        for _ in range(0, num_layers - 1):\n            self.convs.append(GNN(hidden_channels, hidden_channels))\n        self.convs.append(GNN(hidden_channels, 1))\n\n        conv1d_channels = [16, 32]\n        total_latent_dim = hidden_channels * num_layers + 1\n        conv1d_kws = [total_latent_dim, 5]\n        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0],\n                            conv1d_kws[0])\n        self.pool = SortAggregation(k)\n        self.maxpool1d = MaxPool1d(2, 2)\n        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1],\n                            conv1d_kws[1], 1)\n        dense_dim = int((k - 2) / 2 + 1)\n        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]\n        self.mlp = MLP([dense_dim, 128, 1], dropout=0.5, norm=None)\n\n    def forward(self, x, edge_index, batch):\n        xs = [x]\n        for conv in self.convs:\n            xs += [conv(xs[-1], edge_index).tanh()]\n        x = torch.cat(xs[1:], dim=-1)\n\n        # Global pooling.\n        x = self.pool(x, batch)\n        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]\n        x = self.conv1(x).relu()\n        x = self.maxpool1d(x)\n        x = self.conv2(x).relu()\n        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]\n\n        return self.mlp(x)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = DGCNN(hidden_channels=32, num_layers=3).to(device)\noptimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)\ncriterion = BCEWithLogitsLoss()\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.edge_index, data.batch)\n        loss = criterion(out.view(-1), data.y.to(torch.float))\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss) * data.num_graphs\n\n    return total_loss / len(train_dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    y_pred, y_true = [], []\n    for data in loader:\n        data = data.to(device)\n        logits = model(data.x, data.edge_index, data.batch)\n        y_pred.append(logits.view(-1).cpu())\n        y_true.append(data.y.view(-1).cpu().to(torch.float))\n\n    return roc_auc_score(torch.cat(y_true), torch.cat(y_pred))\n\n\ntimes = []\nbest_val_auc = test_auc = 0\nfor epoch in range(1, 51):\n    start = time.time()\n    loss = train()\n    val_auc = test(val_loader)\n    if val_auc > best_val_auc:\n        best_val_auc = val_auc\n        test_auc = test(test_loader)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '\n          f'Test: {test_auc:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/sgc.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import SGConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset)\ndata = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = SGConv(\n            in_channels=dataset.num_features,\n            out_channels=dataset.num_classes,\n            K=2,\n            cached=True,\n        )\n\n    def forward(self):\n        x, edge_index = data.x, data.edge_index\n        x = self.conv1(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nmodel, data = Net().to(device), data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=0.005)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()\n    optimizer.step()\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    out, accs = model(), []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        pred = out[mask].argmax(1)\n        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n        accs.append(acc)\n    return accs\n\n\nbest_val_acc = test_acc = 0\nfor epoch in range(1, 101):\n    train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '\n          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/shadow.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.datasets import Flickr\nfrom torch_geometric.loader import ShaDowKHopSampler\nfrom torch_geometric.nn import SAGEConv, global_mean_pool\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Flickr')\ndataset = Flickr(path)\ndata = dataset[0]\n\nkwargs = {'batch_size': 1024, 'num_workers': 6, 'persistent_workers': True}\ntrain_loader = ShaDowKHopSampler(data, depth=2, num_neighbors=5,\n                                 node_idx=data.train_mask, **kwargs)\nval_loader = ShaDowKHopSampler(data, depth=2, num_neighbors=5,\n                               node_idx=data.val_mask, **kwargs)\ntest_loader = ShaDowKHopSampler(data, depth=2, num_neighbors=5,\n                                node_idx=data.test_mask, **kwargs)\n\n\nclass GNN(torch.nn.Module):\n    def __init__(self, in_channels, hidden_channels, out_channels):\n        super().__init__()\n        self.conv1 = SAGEConv(in_channels, hidden_channels)\n        self.conv2 = SAGEConv(hidden_channels, hidden_channels)\n        self.conv3 = SAGEConv(hidden_channels, hidden_channels)\n        self.lin = torch.nn.Linear(2 * hidden_channels, out_channels)\n\n    def forward(self, x, edge_index, batch, root_n_id):\n        x = self.conv1(x, edge_index).relu()\n        x = F.dropout(x, p=0.3)\n        x = self.conv2(x, edge_index).relu()\n        x = F.dropout(x, p=0.3, training=self.training)\n        x = self.conv3(x, edge_index).relu()\n        x = F.dropout(x, p=0.3, training=self.training)\n\n        # We merge both central node embeddings and subgraph embeddings:\n        x = torch.cat([x[root_n_id], global_mean_pool(x, batch)], dim=-1)\n\n        x = self.lin(x)\n        return x\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = GNN(dataset.num_features, 256, dataset.num_classes).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train():\n    model.train()\n    total_loss = total_examples = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.edge_index, data.batch, data.root_n_id)\n        loss = F.cross_entropy(out, data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss) * data.num_graphs\n        total_examples += data.num_graphs\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n    total_correct = total_examples = 0\n    for data in loader:\n        data = data.to(device)\n        out = model(data.x, data.edge_index, data.batch, data.root_n_id)\n        total_correct += int((out.argmax(dim=-1) == data.y).sum())\n        total_examples += data.num_graphs\n    return total_correct / total_examples\n\n\nfor epoch in range(1, 51):\n    loss = train()\n    val_acc = test(val_loader)\n    test_acc = test(test_loader)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, ',\n          f'Val: {val_acc:.4f} Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/sign.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\nfrom torch.utils.data import DataLoader\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Flickr\n\nK = 2\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Flickr')\ntransform = T.Compose([T.NormalizeFeatures(), T.SIGN(K)])\ndataset = Flickr(path, transform=transform)\ndata = dataset[0]\n\ntrain_idx = data.train_mask.nonzero(as_tuple=False).view(-1)\nval_idx = data.val_mask.nonzero(as_tuple=False).view(-1)\ntest_idx = data.test_mask.nonzero(as_tuple=False).view(-1)\n\ntrain_loader = DataLoader(train_idx, batch_size=16 * 1024, shuffle=True)\nval_loader = DataLoader(val_idx, batch_size=32 * 1024)\ntest_loader = DataLoader(test_idx, batch_size=32 * 1024)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.lins = torch.nn.ModuleList()\n        for _ in range(K + 1):\n            self.lins.append(Linear(dataset.num_node_features, 1024))\n        self.lin = Linear((K + 1) * 1024, dataset.num_classes)\n\n    def forward(self, xs):\n        hs = []\n        for x, lin in zip(xs, self.lins):\n            h = lin(x).relu()\n            h = F.dropout(h, p=0.5, training=self.training)\n            hs.append(h)\n        h = torch.cat(hs, dim=-1)\n        h = self.lin(h)\n        return h.log_softmax(dim=-1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net().to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n\ndef train():\n    model.train()\n\n    total_loss = total_examples = 0\n    for idx in train_loader:\n        xs = [data.x[idx].to(device)]\n        xs += [data[f'x{i}'][idx].to(device) for i in range(1, K + 1)]\n        y = data.y[idx].to(device)\n\n        optimizer.zero_grad()\n        out = model(xs)\n        loss = F.nll_loss(out, y)\n        loss.backward()\n        optimizer.step()\n\n        total_loss += float(loss) * idx.numel()\n        total_examples += idx.numel()\n\n    return total_loss / total_examples\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_correct = total_examples = 0\n    for idx in loader:\n        xs = [data.x[idx].to(device)]\n        xs += [data[f'x{i}'][idx].to(device) for i in range(1, K + 1)]\n        y = data.y[idx].to(device)\n\n        out = model(xs)\n        total_correct += int((out.argmax(dim=-1) == y).sum())\n        total_examples += idx.numel()\n\n    return total_correct / total_examples\n\n\nfor epoch in range(1, 201):\n    loss = train()\n    train_acc = test(train_loader)\n    val_acc = test(val_loader)\n    test_acc = test(test_loader)\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/signed_gcn.py",
    "content": "import os.path as osp\n\nimport torch\n\nfrom torch_geometric.datasets import BitcoinOTC\nfrom torch_geometric.nn import SignedGCN\n\nname = 'BitcoinOTC-1'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)\ndataset = BitcoinOTC(path, edge_window_size=1)\n\n# Generate dataset.\npos_edge_indices, neg_edge_indices = [], []\nfor data in dataset:\n    pos_edge_indices.append(data.edge_index[:, data.edge_attr > 0])\n    neg_edge_indices.append(data.edge_index[:, data.edge_attr < 0])\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\npos_edge_index = torch.cat(pos_edge_indices, dim=1).to(device)\nneg_edge_index = torch.cat(neg_edge_indices, dim=1).to(device)\n\n# Build and train model.\nmodel = SignedGCN(64, 64, num_layers=2, lamb=5).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\ntrain_pos_edge_index, test_pos_edge_index = model.split_edges(pos_edge_index)\ntrain_neg_edge_index, test_neg_edge_index = model.split_edges(neg_edge_index)\nx = model.create_spectral_features(train_pos_edge_index, train_neg_edge_index)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    z = model(x, train_pos_edge_index, train_neg_edge_index)\n    loss = model.loss(z, train_pos_edge_index, train_neg_edge_index)\n    loss.backward()\n    optimizer.step()\n    return loss.item()\n\n\ndef test():\n    model.eval()\n    with torch.no_grad():\n        z = model(x, train_pos_edge_index, train_neg_edge_index)\n    return model.test(z, test_pos_edge_index, test_neg_edge_index)\n\n\nfor epoch in range(101):\n    loss = train()\n    auc, f1 = test()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, AUC: {auc:.4f}, '\n          f'F1: {f1:.4f}')\n"
  },
  {
    "path": "examples/super_gat.py",
    "content": "import os.path as osp\nimport time\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import SuperGATConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\ndata = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.conv1 = SuperGATConv(dataset.num_features, 8, heads=8,\n                                  dropout=0.6, attention_type='MX',\n                                  edge_sample_ratio=0.8, is_undirected=True)\n        self.conv2 = SuperGATConv(8 * 8, dataset.num_classes, heads=8,\n                                  concat=False, dropout=0.6,\n                                  attention_type='MX', edge_sample_ratio=0.8,\n                                  is_undirected=True)\n\n    def forward(self, x, edge_index):\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = F.elu(self.conv1(x, edge_index))\n        att_loss = self.conv1.get_attention_loss()\n        x = F.dropout(x, p=0.6, training=self.training)\n        x = self.conv2(x, data.edge_index)\n        att_loss += self.conv2.get_attention_loss()\n        return F.log_softmax(x, dim=-1), att_loss\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel, data = Net().to(device), data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)\n\n\ndef train(data):\n    model.train()\n    optimizer.zero_grad()\n    out, att_loss = model(data.x, data.edge_index)\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n    loss += 4.0 * att_loss\n    loss.backward()\n    optimizer.step()\n\n\n@torch.no_grad()\ndef test(data):\n    model.eval()\n    out, accs = model(data.x, data.edge_index)[0], []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        pred = out[mask].argmax(1)\n        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n        accs.append(acc)\n    return accs\n\n\ntimes = []\nfor epoch in range(1, 501):\n    start = time.time()\n    train(data)\n    train_acc, val_acc, test_acc = test(data)\n    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, '\n          f'Test: {test_acc:.4f}')\n    times.append(time.time() - start)\nprint(f\"Median time per epoch: {torch.tensor(times).median():.4f}s\")\n"
  },
  {
    "path": "examples/tagcn.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import TAGConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\ndata = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = TAGConv(dataset.num_features, 16)\n        self.conv2 = TAGConv(16, dataset.num_classes)\n\n    def forward(self):\n        x, edge_index = data.x, data.edge_index\n        x = F.relu(self.conv1(x, edge_index))\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nmodel, data = Net().to(device), data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()\n    optimizer.step()\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    out, accs = model(), []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        pred = out[mask].argmax(1)\n        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n        accs.append(acc)\n    return accs\n\n\nbest_val_acc = test_acc = 0\nfor epoch in range(1, 201):\n    train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '\n          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/tensorboard_logging.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.utils.tensorboard import SummaryWriter\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric.nn import GCNConv\n\ndataset = 'Cora'\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)\ndataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())\ndata = dataset[0]\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GCNConv(dataset.num_features, 16)\n        self.conv2 = GCNConv(16, dataset.num_classes)\n\n    def forward(self, x, edge_index):\n        x = F.relu(self.conv1(x, edge_index, None))\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index, None)\n        return F.log_softmax(x, dim=1)\n\n\nif torch.cuda.is_available():\n    device = torch.device('cuda')\nelif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n    device = torch.device('mps')\nelse:\n    device = torch.device('cpu')\n\nmodel, data = Net().to(device), data.to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)\n\n\ndef train():\n    model.train()\n    optimizer.zero_grad()\n    out = model(data.x, data.edge_index)\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])\n    loss.backward()\n    optimizer.step()\n    return loss.item()\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n    out, accs = model(data.x, data.edge_index), []\n    for _, mask in data('train_mask', 'val_mask', 'test_mask'):\n        pred = out[mask].argmax(1)\n        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()\n        accs.append(acc)\n    return accs\n\n\nmodel(data.x, data.edge_index)\nwriter = SummaryWriter()\nwriter.add_graph(model, [data.x, data.edge_index])\n\nbest_val_acc = test_acc = 0\nfor epoch in range(1, 201):\n    train_loss = train()\n    train_acc, val_acc, tmp_test_acc = test()\n    if val_acc > best_val_acc:\n        best_val_acc = val_acc\n        test_acc = tmp_test_acc\n    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '\n          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')\n\n    writer.add_scalar('Loss/train', train_loss, epoch)\n    writer.add_scalar('Accuracy/train', train_acc, epoch)\n    writer.add_scalar('Accuracy/val', val_acc, epoch)\n    writer.add_scalar('Accuracy/test', test_acc, epoch)\n"
  },
  {
    "path": "examples/tgn.py",
    "content": "# This code achieves a performance of around 96.60%. However, it is not\n# directly comparable to the results reported by the TGN paper since a\n# slightly different evaluation setup is used here.\n# In particular, predictions in the same batch are made in parallel, i.e.\n# predictions for interactions later in the batch have no access to any\n# information whatsoever about previous interactions in the same batch.\n# On the contrary, when sampling node neighborhoods for interactions later in\n# the batch, the TGN paper code has access to previous interactions in the\n# batch.\n# While both approaches are correct, together with the authors of the paper we\n# decided to present this version here as it is more realistic and a better\n# test bed for future methods.\n\nimport os.path as osp\n\nimport torch\nfrom sklearn.metrics import average_precision_score, roc_auc_score\nfrom torch.nn import Linear\n\nfrom torch_geometric.datasets import JODIEDataset\nfrom torch_geometric.loader import TemporalDataLoader\nfrom torch_geometric.nn import TGNMemory, TransformerConv\nfrom torch_geometric.nn.models.tgn import (\n    IdentityMessage,\n    LastAggregator,\n    LastNeighborLoader,\n)\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'JODIE')\ndataset = JODIEDataset(path, name='wikipedia')\ndata = dataset[0]\n\n# For small datasets, we can put the whole dataset on GPU and thus avoid\n# expensive memory transfer costs for mini-batches:\ndata = data.to(device)\n\ntrain_data, val_data, test_data = data.train_val_test_split(\n    val_ratio=0.15, test_ratio=0.15)\n\ntrain_loader = TemporalDataLoader(\n    train_data,\n    batch_size=200,\n    neg_sampling_ratio=1.0,\n)\nval_loader = TemporalDataLoader(\n    val_data,\n    batch_size=200,\n    neg_sampling_ratio=1.0,\n)\ntest_loader = TemporalDataLoader(\n    test_data,\n    batch_size=200,\n    neg_sampling_ratio=1.0,\n)\nneighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)\n\n\nclass GraphAttentionEmbedding(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, msg_dim, time_enc):\n        super().__init__()\n        self.time_enc = time_enc\n        edge_dim = msg_dim + time_enc.out_channels\n        self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,\n                                    dropout=0.1, edge_dim=edge_dim)\n\n    def forward(self, x, last_update, edge_index, t, msg):\n        rel_t = last_update[edge_index[0]] - t\n        rel_t_enc = self.time_enc(rel_t.to(x.dtype))\n        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)\n        return self.conv(x, edge_index, edge_attr)\n\n\nclass LinkPredictor(torch.nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n        self.lin_src = Linear(in_channels, in_channels)\n        self.lin_dst = Linear(in_channels, in_channels)\n        self.lin_final = Linear(in_channels, 1)\n\n    def forward(self, z_src, z_dst):\n        h = self.lin_src(z_src) + self.lin_dst(z_dst)\n        h = h.relu()\n        return self.lin_final(h)\n\n\nmemory_dim = time_dim = embedding_dim = 100\n\nmemory = TGNMemory(\n    data.num_nodes,\n    data.msg.size(-1),\n    memory_dim,\n    time_dim,\n    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),\n    aggregator_module=LastAggregator(),\n).to(device)\n\ngnn = GraphAttentionEmbedding(\n    in_channels=memory_dim,\n    out_channels=embedding_dim,\n    msg_dim=data.msg.size(-1),\n    time_enc=memory.time_enc,\n).to(device)\n\nlink_pred = LinkPredictor(in_channels=embedding_dim).to(device)\n\noptimizer = torch.optim.Adam(\n    set(memory.parameters()) | set(gnn.parameters())\n    | set(link_pred.parameters()), lr=0.0001)\ncriterion = torch.nn.BCEWithLogitsLoss()\n\n# Helper vector to map global node indices to local ones.\nassoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)\n\n\ndef train():\n    memory.train()\n    gnn.train()\n    link_pred.train()\n\n    memory.reset_state()  # Start with a fresh memory.\n    neighbor_loader.reset_state()  # Start with an empty graph.\n\n    total_loss = 0\n    for batch in train_loader:\n        optimizer.zero_grad()\n        batch = batch.to(device)\n\n        n_id, edge_index, e_id = neighbor_loader(batch.n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        # Get updated memory of all nodes involved in the computation.\n        z, last_update = memory(n_id)\n        z = gnn(z, last_update, edge_index, data.t[e_id].to(device),\n                data.msg[e_id].to(device))\n        pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])\n        neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])\n\n        loss = criterion(pos_out, torch.ones_like(pos_out))\n        loss += criterion(neg_out, torch.zeros_like(neg_out))\n\n        # Update memory and neighbor loader with ground-truth state.\n        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)\n        neighbor_loader.insert(batch.src, batch.dst)\n\n        loss.backward()\n        optimizer.step()\n        memory.detach()\n        total_loss += float(loss) * batch.num_events\n\n    return total_loss / train_data.num_events\n\n\n@torch.no_grad()\ndef test(loader):\n    memory.eval()\n    gnn.eval()\n    link_pred.eval()\n\n    torch.manual_seed(12345)  # Ensure deterministic sampling across epochs.\n\n    aps, aucs = [], []\n    for batch in loader:\n        batch = batch.to(device)\n\n        n_id, edge_index, e_id = neighbor_loader(batch.n_id)\n        assoc[n_id] = torch.arange(n_id.size(0), device=device)\n\n        z, last_update = memory(n_id)\n        z = gnn(z, last_update, edge_index, data.t[e_id].to(device),\n                data.msg[e_id].to(device))\n        pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])\n        neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])\n\n        y_pred = torch.cat([pos_out, neg_out], dim=0).sigmoid().cpu()\n        y_true = torch.cat(\n            [torch.ones(pos_out.size(0)),\n             torch.zeros(neg_out.size(0))], dim=0)\n\n        aps.append(average_precision_score(y_true, y_pred))\n        aucs.append(roc_auc_score(y_true, y_pred))\n\n        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)\n        neighbor_loader.insert(batch.src, batch.dst)\n    return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())\n\n\nfor epoch in range(1, 51):\n    loss = train()\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')\n    val_ap, val_auc = test(val_loader)\n    test_ap, test_auc = test(test_loader)\n    print(f'Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}')\n    print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}')\n"
  },
  {
    "path": "examples/triangles_sag_pool.py",
    "content": "import copy\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GCNConv, GINConv, SAGPooling, global_max_pool\nfrom torch_geometric.utils import scatter\n\n\nclass HandleNodeAttention:\n    def __call__(self, data):\n        data = copy.copy(data)\n        data.attn = torch.softmax(data.x, dim=0).flatten()\n        data.x = None\n        return data\n\n\ntransform = T.Compose([HandleNodeAttention(), T.OneHotDegree(max_degree=14)])\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TRIANGLES')\ndataset = TUDataset(path, name='TRIANGLES', use_node_attr=True,\n                    transform=transform)\n\ntrain_loader = DataLoader(dataset[:30000], batch_size=60, shuffle=True)\nval_loader = DataLoader(dataset[30000:35000], batch_size=60)\ntest_loader = DataLoader(dataset[35000:], batch_size=60)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels):\n        super().__init__()\n\n        self.conv1 = GINConv(Seq(Lin(in_channels, 64), ReLU(), Lin(64, 64)))\n        self.pool1 = SAGPooling(64, min_score=0.001, GNN=GCNConv)\n        self.conv2 = GINConv(Seq(Lin(64, 64), ReLU(), Lin(64, 64)))\n        self.pool2 = SAGPooling(64, min_score=0.001, GNN=GCNConv)\n        self.conv3 = GINConv(Seq(Lin(64, 64), ReLU(), Lin(64, 64)))\n\n        self.lin = torch.nn.Linear(64, 1)\n\n    def forward(self, data):\n        x, edge_index, batch = data.x, data.edge_index, data.batch\n\n        x = F.relu(self.conv1(x, edge_index))\n        x, edge_index, _, batch, perm, score = self.pool1(\n            x, edge_index, None, batch)\n        x = F.relu(self.conv2(x, edge_index))\n        x, edge_index, _, batch, perm, score = self.pool2(\n            x, edge_index, None, batch)\n        ratio = x.size(0) / data.x.size(0)\n\n        x = F.relu(self.conv3(x, edge_index))\n        x = global_max_pool(x, batch)\n        x = self.lin(x).view(-1)\n\n        attn_loss = F.kl_div(torch.log(score + 1e-14), data.attn[perm],\n                             reduction='none')\n        attn_loss = scatter(attn_loss, batch, reduce='mean')\n\n        return x, attn_loss, ratio\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(dataset.num_features).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out, attn_loss, _ = model(data)\n        loss = ((out - data.y).pow(2) + 100 * attn_loss).mean()\n        loss.backward()\n        total_loss += loss.item() * data.num_graphs\n        optimizer.step()\n\n    return total_loss / len(train_loader.dataset)\n\n\ndef test(loader):\n    model.eval()\n\n    corrects, total_ratio = [], 0\n    for data in loader:\n        data = data.to(device)\n        out, _, ratio = model(data)\n        pred = out.round().to(torch.long)\n        corrects.append(pred.eq(data.y.to(torch.long)))\n        total_ratio += ratio\n    return torch.cat(corrects, dim=0), total_ratio / len(loader)\n\n\nfor epoch in range(1, 301):\n    loss = train()\n    train_correct, train_ratio = test(train_loader)\n    val_correct, val_ratio = test(val_loader)\n    test_correct, test_ratio = test(test_loader)\n\n    train_acc = train_correct.sum().item() / train_correct.size(0)\n    val_acc = val_correct.sum().item() / val_correct.size(0)\n\n    test_acc1 = test_correct[:5000].sum().item() / 5000\n    test_acc2 = test_correct[5000:].sum().item() / 5000\n\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {train_acc:.3f}, '\n          f'Val: {val_acc:.3f}, Test Orig: {test_acc1:.3f}, '\n          f'Test Large: {test_acc2:.3f}, Train/Val/Test Ratio='\n          f'{train_ratio:.3f}/{val_ratio:.3f}/{test_ratio:.3f}')\n"
  },
  {
    "path": "examples/unimp_arxiv.py",
    "content": "import os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom ogb.nodeproppred import PygNodePropPredDataset\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.nn import MaskLabel, TransformerConv\nfrom torch_geometric.utils import index_to_mask\n\nroot = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'OGB')\ndataset = PygNodePropPredDataset('ogbn-arxiv', root, T.ToUndirected())\n\n\nclass UniMP(torch.nn.Module):\n    def __init__(self, in_channels, num_classes, hidden_channels, num_layers,\n                 heads, dropout=0.3):\n        super().__init__()\n\n        self.label_emb = MaskLabel(num_classes, in_channels)\n\n        self.convs = torch.nn.ModuleList()\n        self.norms = torch.nn.ModuleList()\n        for i in range(1, num_layers + 1):\n            if i < num_layers:\n                out_channels = hidden_channels // heads\n                concat = True\n            else:\n                out_channels = num_classes\n                concat = False\n            conv = TransformerConv(in_channels, out_channels, heads,\n                                   concat=concat, beta=True, dropout=dropout)\n            self.convs.append(conv)\n            in_channels = hidden_channels\n\n            if i < num_layers:\n                self.norms.append(torch.nn.LayerNorm(hidden_channels))\n\n    def forward(self, x, y, edge_index, label_mask):\n        x = self.label_emb(x, y, label_mask)\n        for conv, norm in zip(self.convs, self.norms):\n            x = norm(conv(x, edge_index)).relu()\n        return self.convs[-1](x, edge_index)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\ndata = dataset[0].to(device)\ndata.y = data.y.view(-1)\nmodel = UniMP(dataset.num_features, dataset.num_classes, hidden_channels=64,\n              num_layers=3, heads=2).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)\n\nsplit_idx = dataset.get_idx_split()\ntrain_mask = index_to_mask(split_idx['train'], size=data.num_nodes)\nval_mask = index_to_mask(split_idx['valid'], size=data.num_nodes)\ntest_mask = index_to_mask(split_idx['test'], size=data.num_nodes)\n\n\ndef train(label_rate=0.65):  # How many labels to use for propagation.\n    model.train()\n\n    propagation_mask = MaskLabel.ratio_mask(train_mask, ratio=label_rate)\n    supervision_mask = train_mask ^ propagation_mask\n\n    optimizer.zero_grad()\n    out = model(data.x, data.y, data.edge_index, propagation_mask)\n    loss = F.cross_entropy(out[supervision_mask], data.y[supervision_mask])\n    loss.backward()\n    optimizer.step()\n\n    return float(loss)\n\n\n@torch.no_grad()\ndef test():\n    model.eval()\n\n    propagation_mask = train_mask\n    out = model(data.x, data.y, data.edge_index, propagation_mask)\n    pred = out[val_mask].argmax(dim=-1)\n    val_acc = int((pred == data.y[val_mask]).sum()) / pred.size(0)\n\n    propagation_mask = train_mask | val_mask\n    out = model(data.x, data.y, data.edge_index, propagation_mask)\n    pred = out[test_mask].argmax(dim=-1)\n    test_acc = int((pred == data.y[test_mask]).sum()) / pred.size(0)\n\n    return val_acc, test_acc\n\n\nfor epoch in range(1, 501):\n    loss = train()\n    val_acc, test_acc = test()\n    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, '\n          f'Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/upfd.py",
    "content": "import argparse\nimport os.path as osp\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear\n\nfrom torch_geometric.datasets import UPFD\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import GATConv, GCNConv, SAGEConv, global_max_pool\nfrom torch_geometric.transforms import ToUndirected\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--dataset', type=str, default='politifact',\n                    choices=['politifact', 'gossipcop'])\nparser.add_argument('--feature', type=str, default='spacy',\n                    choices=['profile', 'spacy', 'bert', 'content'])\nparser.add_argument('--model', type=str, default='GCN',\n                    choices=['GCN', 'GAT', 'SAGE'])\nargs = parser.parse_args()\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'UPFD')\ntrain_dataset = UPFD(path, args.dataset, args.feature, 'train', ToUndirected())\nval_dataset = UPFD(path, args.dataset, args.feature, 'val', ToUndirected())\ntest_dataset = UPFD(path, args.dataset, args.feature, 'test', ToUndirected())\n\ntrain_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\nval_loader = DataLoader(val_dataset, batch_size=128, shuffle=False)\ntest_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, model, in_channels, hidden_channels, out_channels,\n                 concat=False):\n        super().__init__()\n        self.concat = concat\n\n        if model == 'GCN':\n            self.conv1 = GCNConv(in_channels, hidden_channels)\n        elif model == 'SAGE':\n            self.conv1 = SAGEConv(in_channels, hidden_channels)\n        elif model == 'GAT':\n            self.conv1 = GATConv(in_channels, hidden_channels)\n\n        if self.concat:\n            self.lin0 = Linear(in_channels, hidden_channels)\n            self.lin1 = Linear(2 * hidden_channels, hidden_channels)\n\n        self.lin2 = Linear(hidden_channels, out_channels)\n\n    def forward(self, x, edge_index, batch):\n        h = self.conv1(x, edge_index).relu()\n        h = global_max_pool(h, batch)\n\n        if self.concat:\n            # Get the root node (tweet) features of each graph:\n            root = (batch[1:] - batch[:-1]).nonzero(as_tuple=False).view(-1)\n            root = torch.cat([root.new_zeros(1), root + 1], dim=0)\n            news = x[root]\n\n            news = self.lin0(news).relu()\n            h = self.lin1(torch.cat([news, h], dim=-1)).relu()\n\n        h = self.lin2(h)\n        return h.log_softmax(dim=-1)\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nmodel = Net(args.model, train_dataset.num_features, 128,\n            train_dataset.num_classes, concat=True).to(device)\noptimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.01)\n\n\ndef train():\n    model.train()\n\n    total_loss = 0\n    for data in train_loader:\n        data = data.to(device)\n        optimizer.zero_grad()\n        out = model(data.x, data.edge_index, data.batch)\n        loss = F.nll_loss(out, data.y)\n        loss.backward()\n        optimizer.step()\n        total_loss += float(loss) * data.num_graphs\n\n    return total_loss / len(train_loader.dataset)\n\n\n@torch.no_grad()\ndef test(loader):\n    model.eval()\n\n    total_correct = total_examples = 0\n    for data in loader:\n        data = data.to(device)\n        pred = model(data.x, data.edge_index, data.batch).argmax(dim=-1)\n        total_correct += int((pred == data.y).sum())\n        total_examples += data.num_graphs\n\n    return total_correct / total_examples\n\n\nfor epoch in range(1, 61):\n    loss = train()\n    train_acc = test(train_loader)\n    val_acc = test(val_loader)\n    test_acc = test(test_loader)\n    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, '\n          f'Val: {val_acc:.4f}, Test: {test_acc:.4f}')\n"
  },
  {
    "path": "examples/wl_kernel.py",
    "content": "import argparse\nimport os.path as osp\nimport warnings\n\nimport torch\nfrom sklearn.exceptions import ConvergenceWarning\nfrom sklearn.metrics import accuracy_score\nfrom sklearn.svm import LinearSVC\n\nfrom torch_geometric.data import Batch\nfrom torch_geometric.datasets import TUDataset\nfrom torch_geometric.nn import WLConv\n\nwarnings.filterwarnings('ignore', category=ConvergenceWarning)\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--runs', type=int, default=10)\nargs = parser.parse_args()\n\ntorch.manual_seed(42)\n\npath = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'TU')\ndataset = TUDataset(path, name='ENZYMES')\ndata = Batch.from_data_list(dataset)\n\n\nclass WL(torch.nn.Module):\n    def __init__(self, num_layers):\n        super().__init__()\n        self.convs = torch.nn.ModuleList([WLConv() for _ in range(num_layers)])\n\n    def forward(self, x, edge_index, batch=None):\n        hists = []\n        for conv in self.convs:\n            x = conv(x, edge_index)\n            hists.append(conv.histogram(x, batch, norm=True))\n        return hists\n\n\nwl = WL(num_layers=5)\nhists = wl(data.x, data.edge_index, data.batch)\n\ntest_accs = torch.empty(args.runs, dtype=torch.float)\n\nfor run in range(1, args.runs + 1):\n    perm = torch.randperm(data.num_graphs)\n    val_index = perm[:data.num_graphs // 10]\n    test_index = perm[data.num_graphs // 10:data.num_graphs // 5]\n    train_index = perm[data.num_graphs // 5:]\n\n    best_val_acc = 0\n\n    for hist in hists:\n        train_hist, train_y = hist[train_index], data.y[train_index]\n        val_hist, val_y = hist[val_index], data.y[val_index]\n        test_hist, test_y = hist[test_index], data.y[test_index]\n\n        for C in [10**3, 10**2, 10**1, 10**0, 10**-1, 10**-2, 10**-3]:\n            model = LinearSVC(C=C, tol=0.01, dual=True)\n            model.fit(train_hist, train_y)\n            val_acc = accuracy_score(val_y, model.predict(val_hist))\n            if val_acc > best_val_acc:\n                best_val_acc = val_acc\n                test_acc = accuracy_score(test_y, model.predict(test_hist))\n                test_accs[run - 1] = test_acc\n\n    print(f'Run: {run:02d}, Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')\n\nprint(f'Final Test Performance: {test_accs.mean():.4f}±{test_accs.std():.4f}')\n"
  },
  {
    "path": "graphgym/agg_batch.py",
    "content": "import argparse\n\nfrom torch_geometric.graphgym.utils.agg_runs import agg_batch\n\n\ndef parse_args():\n    \"\"\"Parses the arguments.\"\"\"\n    parser = argparse.ArgumentParser(\n        description='Train a classification model')\n    parser.add_argument('--dir', dest='dir', help='Dir for batch of results',\n                        required=True, type=str)\n    parser.add_argument('--metric', dest='metric',\n                        help='metric to select best epoch', required=False,\n                        type=str, default='auto')\n    return parser.parse_args()\n\n\nargs = parse_args()\nagg_batch(args.dir, args.metric)\n"
  },
  {
    "path": "graphgym/configs/example.yaml",
    "content": "# The recommended basic settings for GNN\nout_dir: results\ndataset:\n  format: PyG\n  name: Cora\n  task: node\n  task_type: classification\n  transductive: true\n  split: [0.8, 0.2]\n  transform: none\ntrain:\n  batch_size: 32\n  eval_period: 20\n  ckpt_period: 100\nmodel:\n  type: gnn\n  loss_fun: cross_entropy\n  edge_decoding: dot\n  graph_pooling: add\ngnn:\n  layers_pre_mp: 1\n  layers_mp: 2\n  layers_post_mp: 1\n  dim_inner: 256\n  layer_type: generalconv\n  stage_type: stack\n  batchnorm: true\n  act: prelu\n  dropout: 0.0\n  agg: add\n  normalize_adj: false\noptim:\n  optimizer: adam\n  base_lr: 0.01\n  max_epoch: 400\n"
  },
  {
    "path": "graphgym/configs/pyg/example_graph.yaml",
    "content": "out_dir: results\ndataset:\n  format: OGB\n  name: ogbg-molhiv\n  task: graph\n  task_type: classification\n  node_encoder: true\n  node_encoder_name: Atom\n  edge_encoder: true\n  edge_encoder_name: Bond\ntrain:\n  batch_size: 128\n  eval_period: 1\n  ckpt_period: 100\n  sampler: full_batch\nmodel:\n  type: gnn\n  loss_fun: cross_entropy\n  edge_decoding: dot\n  graph_pooling: add\ngnn:\n  layers_pre_mp: 1\n  layers_mp: 2\n  layers_post_mp: 1\n  dim_inner: 300\n  layer_type: generalconv\n  stage_type: stack\n  batchnorm: true\n  act: prelu\n  dropout: 0.0\n  agg: mean\n  normalize_adj: false\noptim:\n  optimizer: adam\n  base_lr: 0.01\n  max_epoch: 100\n"
  },
  {
    "path": "graphgym/configs/pyg/example_link.yaml",
    "content": "out_dir: results\ndataset:\n  format: OGB\n  name: ogbl-collab\n  task: link_pred\n  task_type: classification\n  node_encoder: false\n  node_encoder_name: Atom\n  edge_encoder: false\n  edge_encoder_name: Bond\ntrain:\n  batch_size: 128\n  eval_period: 1\n  ckpt_period: 100\n  sampler: full_batch\nmodel:\n  type: gnn\n  loss_fun: cross_entropy\n  edge_decoding: dot\n  graph_pooling: add\ngnn:\n  layers_pre_mp: 1\n  layers_mp: 2\n  layers_post_mp: 1\n  dim_inner: 300\n  layer_type: gcnconv\n  stage_type: stack\n  batchnorm: true\n  act: prelu\n  dropout: 0.0\n  agg: mean\n  normalize_adj: false\noptim:\n  optimizer: adam\n  base_lr: 0.01\n  max_epoch: 100\n"
  },
  {
    "path": "graphgym/configs/pyg/example_node.yaml",
    "content": "out_dir: results\ndataset:\n  format: PyG\n  name: Cora\n  task: node\n  task_type: classification\n  node_encoder: false\n  node_encoder_name: Atom\n  edge_encoder: false\n  edge_encoder_name: Bond\ntrain:\n  batch_size: 128\n  eval_period: 1\n  ckpt_period: 100\n  sampler: full_batch\nmodel:\n  type: gnn\n  loss_fun: cross_entropy\n  edge_decoding: dot\n  graph_pooling: add\ngnn:\n  layers_pre_mp: 0\n  layers_mp: 2\n  layers_post_mp: 1\n  dim_inner: 16\n  layer_type: gcnconv\n  stage_type: stack\n  batchnorm: false\n  act: prelu\n  dropout: 0.1\n  agg: mean\n  normalize_adj: false\noptim:\n  optimizer: adam\n  base_lr: 0.01\n  max_epoch: 200\n"
  },
  {
    "path": "graphgym/configs_gen.py",
    "content": "import argparse\nimport copy\nimport csv\nimport os.path as osp\nimport random\n\nimport numpy as np\nimport yaml\n\nfrom torch_geometric.graphgym.utils.comp_budget import match_baseline_cfg\nfrom torch_geometric.graphgym.utils.io import (\n    makedirs_rm_exist,\n    string_to_python,\n)\n\nrandom.seed(123)\n\n\ndef parse_args():\n    \"\"\"Parses the arguments.\"\"\"\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--config', dest='config',\n                        help='the base configuration file used for edit',\n                        default=None, type=str)\n    parser.add_argument('--grid', dest='grid',\n                        help='configuration file for grid search',\n                        required=True, type=str)\n    parser.add_argument('--sample_alias', dest='sample_alias',\n                        help='configuration file for sample alias',\n                        default=None, required=False, type=str)\n    parser.add_argument('--sample_num', dest='sample_num',\n                        help='Number of random samples in the space',\n                        default=10, type=int)\n    parser.add_argument('--out_dir', dest='out_dir',\n                        help='output directory for generated config files',\n                        default='configs', type=str)\n    parser.add_argument(\n        '--config_budget', dest='config_budget',\n        help='the base configuration file used for matching computation',\n        default=None, type=str)\n    return parser.parse_args()\n\n\ndef get_fname(string):\n    if string is not None:\n        return string.split('/')[-1].split('.')[0]\n    else:\n        return 'default'\n\n\ndef grid2list(grid):\n    list_in = [[]]\n    for grid_temp in grid:\n        list_out = []\n        for val in grid_temp:\n            for list_temp in list_in:\n                list_out.append(list_temp + [val])\n        list_in = list_out\n    return list_in\n\n\ndef lists_distance(l1, l2):\n    assert len(l1) == len(l2)\n    dist = 0\n    for i in range(len(l1)):\n        if l1[i] != l2[i]:\n            dist += 1\n    return dist\n\n\ndef grid2list_sample(grid, sample=10):\n    configs = []\n    while len(configs) < sample:\n        config = []\n        for grid_temp in grid:\n            config.append(random.choice(grid_temp))\n        if config not in configs:\n            configs.append(config)\n    return configs\n\n\ndef load_config(fname):\n    if fname is not None:\n        with open(fname) as f:\n            return yaml.load(f, Loader=yaml.FullLoader)\n    else:\n        return {}\n\n\ndef load_search_file(fname):\n    with open(fname) as f:\n        out_raw = csv.reader(f, delimiter=' ')\n        outs = []\n        out = []\n        for row in out_raw:\n            if '#' in row:\n                continue\n            elif len(row) > 0:\n                assert len(row) == 3, \\\n                    'Exact 1 space between each grid argument file' \\\n                    'And no spaces within each argument is allowed'\n                out.append(row)\n            else:\n                if len(out) > 0:\n                    outs.append(out)\n                out = []\n        if len(out) > 0:\n            outs.append(out)\n    return outs\n\n\ndef load_alias_file(fname):\n    with open(fname) as f:\n        file = csv.reader(f, delimiter=' ')\n        for line in file:  # noqa: B007\n            break\n    return line\n\n\ndef exclude_list_id(list, id):\n    return [list[i] for i in range(len(list)) if i != id]\n\n\ndef gen_grid(args, config, config_budget=None):\n    if config_budget is None:\n        config_budget = {}\n    task_name = f'{get_fname(args.config)}_grid_{get_fname(args.grid)}'\n    fname_start = get_fname(args.config)\n    out_dir = f'{args.out_dir}/{task_name}'\n    makedirs_rm_exist(out_dir)\n    config['out_dir'] = osp.join(config['out_dir'], task_name)\n\n    outs = load_search_file(args.grid)\n    for i, out in enumerate(outs):\n        vars_label = [row[0].split('.') for row in out]\n        vars_alias = [row[1] for row in out]\n        vars_value = grid2list([string_to_python(row[2]) for row in out])\n        if i == 0:\n            print(f'Variable label: {vars_label}')\n            print(f'Variable alias: {vars_alias}')\n\n        for vars in vars_value:\n            config_out = config.copy()\n            fname_out = fname_start\n            for id, var in enumerate(vars):\n                if len(vars_label[id]) == 1:\n                    config_out[vars_label[id][0]] = var\n                elif len(vars_label[id]) == 2:\n                    if vars_label[id][0] in config_out:  # if key1 exist\n                        config_out[vars_label[id][0]][vars_label[id][1]] = var\n                    else:\n                        config_out[vars_label[id][0]] = {\n                            vars_label[id][1]: var\n                        }\n                else:\n                    raise ValueError('Only 2-level config files are supported')\n                var_repr = str(var).strip(\"[]\").strip(\"''\")  # noqa: B005\n                fname_out += f'-{vars_alias[id]}={var_repr}'\n            if len(config_budget) > 0:\n                config_out = match_baseline_cfg(config_out, config_budget)\n            with open(f'{out_dir}/{fname_out}.yaml', 'w') as f:\n                yaml.dump(config_out, f, default_flow_style=False)\n        print(f'{len(vars_value)} configurations saved to: {out_dir}')\n\n\ndef gen_grid_sample(args, config, config_budget=None, compare_alias_list=None):\n    if config_budget is None:\n        config_budget = {}\n    if compare_alias_list is None:\n        compare_alias_list = []\n    task_name = f'{get_fname(args.config)}_grid_{get_fname(args.grid)}'\n    fname_start = get_fname(args.config)\n    out_dir = f'{args.out_dir}/{task_name}'\n    makedirs_rm_exist(out_dir)\n    config['out_dir'] = osp.join(config['out_dir'], task_name)\n    outs = load_search_file(args.grid)\n\n    counts = []\n    for out in outs:\n        vars_grid = [string_to_python(row[2]) for row in out]\n        count = 1\n        for var in vars_grid:\n            count *= len(var)\n        counts.append(count)\n    counts = np.array(counts)\n    print('Total size of each chunk of experiment space:', counts)\n    counts = counts / np.sum(counts)\n    counts = np.round(counts * args.sample_num)\n    counts[0] += args.sample_num - np.sum(counts)\n    print('Total sample size of each chunk of experiment space:', counts)\n\n    for i, out in enumerate(outs):\n        vars_label = [row[0].split('.') for row in out]\n        vars_alias = [row[1] for row in out]\n        if i == 0:\n            print(f'Variable label: {vars_label}')\n            print(f'Variable alias: {vars_alias}')\n        vars_grid = [string_to_python(row[2]) for row in out]\n        for alias in compare_alias_list:\n            alias_id = vars_alias.index(alias)\n            vars_grid_select = copy.deepcopy(vars_grid[alias_id])\n            vars_grid[alias_id] = [vars_grid[alias_id][0]]\n            vars_value = grid2list_sample(vars_grid, counts[i])\n\n            vars_value_new = []\n            for vars in vars_value:\n                for grid in vars_grid_select:\n                    vars[alias_id] = grid\n                    vars_value_new.append(copy.deepcopy(vars))\n            vars_value = vars_value_new\n\n            vars_grid[alias_id] = vars_grid_select\n            for vars in vars_value:\n                config_out = config.copy()\n                fname_out = fname_start + f'-sample={vars_alias[alias_id]}'\n                for id, var in enumerate(vars):\n                    if len(vars_label[id]) == 1:\n                        config_out[vars_label[id][0]] = var\n                    elif len(vars_label[id]) == 2:\n                        if vars_label[id][0] in config_out:  # if key1 exist\n                            config_out[vars_label[id][0]][vars_label[id]\n                                                          [1]] = var\n                        else:\n                            config_out[vars_label[id][0]] = {\n                                vars_label[id][1]: var\n                            }\n                    else:\n                        raise ValueError(\n                            'Only 2-level config files are supported')\n                    var_repr = str(var).strip(\"[]\").strip(\"''\")  # noqa: B005\n                    fname_out += f'-{vars_alias[id]}={var_repr}'\n                if len(config_budget) > 0:\n                    config_out = match_baseline_cfg(config_out, config_budget,\n                                                    verbose=False)\n                with open(f'{out_dir}/{fname_out}.yaml', \"w\") as f:\n                    yaml.dump(config_out, f, default_flow_style=False)\n            print(f'Chunk {i + 1}/{len(outs)}: '\n                  f'Perturbing design dimension {alias}, '\n                  f'{len(vars_value)} configurations saved to: {out_dir}')\n\n\nargs = parse_args()\nconfig = load_config(args.config)\nconfig_budget = load_config(args.config_budget)\nif args.sample_alias is None:\n    gen_grid(args, config, config_budget)\nelse:\n    alias_list = load_alias_file(args.sample_alias)\n    gen_grid_sample(args, config, config_budget, alias_list)\n"
  },
  {
    "path": "graphgym/custom_graphgym/__init__.py",
    "content": "from .act import *  # noqa\nfrom .config import *  # noqa\nfrom .encoder import *  # noqa\nfrom .head import *  # noqa\nfrom .layer import *  # noqa\nfrom .loader import *  # noqa\nfrom .loss import *  # noqa\nfrom .network import *  # noqa\nfrom .optimizer import *  # noqa\nfrom .pooling import *  # noqa\nfrom .stage import *  # noqa\nfrom .train import *  # noqa\nfrom .transform import *  # noqa\n"
  },
  {
    "path": "graphgym/custom_graphgym/act/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/act/example.py",
    "content": "from functools import partial\n\nimport torch\nimport torch.nn as nn\n\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.register import register_act\n\n\nclass SWISH(nn.Module):\n    def __init__(self, inplace=False):\n        super().__init__()\n        self.inplace = inplace\n\n    def forward(self, x):\n        if self.inplace:\n            x.mul_(torch.sigmoid(x))\n            return x\n        else:\n            return x * torch.sigmoid(x)\n\n\nregister_act('swish', partial(SWISH, inplace=cfg.mem.inplace))\nregister_act('lrelu_03', partial(nn.LeakyReLU, 0.3, inplace=cfg.mem.inplace))\n"
  },
  {
    "path": "graphgym/custom_graphgym/config/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/config/example.py",
    "content": "from yacs.config import CfgNode as CN\n\nfrom torch_geometric.graphgym.register import register_config\n\n\n@register_config('example')\ndef set_cfg_example(cfg):\n    r\"\"\"This function sets the default config value for customized options\n    :return: customized configuration use by the experiment.\n    \"\"\"\n    # ----------------------------------------------------------------------- #\n    # Customized options\n    # ----------------------------------------------------------------------- #\n\n    # example argument\n    cfg.example_arg = 'example'\n\n    # example argument group\n    cfg.example_group = CN()\n\n    # then argument can be specified within the group\n    cfg.example_group.example_arg = 'example'\n"
  },
  {
    "path": "graphgym/custom_graphgym/encoder/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/encoder/example.py",
    "content": "import torch\nfrom ogb.utils.features import get_bond_feature_dims\n\nfrom torch_geometric.graphgym.register import (\n    register_edge_encoder,\n    register_node_encoder,\n)\n\n\n@register_node_encoder('example')\nclass ExampleNodeEncoder(torch.nn.Module):\n    \"\"\"Provides an encoder for integer node features.\n\n    Args:\n        num_classes (int): The number of classes for the embedding mapping to\n            learn.\n    \"\"\"\n    def __init__(self, emb_dim, num_classes=None):\n        super().__init__()\n\n        self.encoder = torch.nn.Embedding(num_classes, emb_dim)\n        torch.nn.init.xavier_uniform_(self.encoder.weight.data)\n\n    def forward(self, batch):\n        # Encode just the first dimension if more exist\n        batch.x = self.encoder(batch.x[:, 0])\n\n        return batch\n\n\n@register_edge_encoder('example')\nclass ExampleEdgeEncoder(torch.nn.Module):\n    def __init__(self, emb_dim):\n        super().__init__()\n\n        self.bond_embedding_list = torch.nn.ModuleList()\n        full_bond_feature_dims = get_bond_feature_dims()\n\n        for dim in full_bond_feature_dims:\n            emb = torch.nn.Embedding(dim, emb_dim)\n            torch.nn.init.xavier_uniform_(emb.weight.data)\n            self.bond_embedding_list.append(emb)\n\n    def forward(self, batch):\n        bond_embedding = 0\n        for i in range(batch.edge_feature.shape[1]):\n            bond_embedding += \\\n                self.bond_embedding_list[i](batch.edge_attr[:, i])\n\n        batch.edge_attr = bond_embedding\n        return batch\n"
  },
  {
    "path": "graphgym/custom_graphgym/head/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/head/example.py",
    "content": "import torch.nn as nn\n\nfrom torch_geometric.graphgym.register import register_head\n\n\n@register_head('head')\nclass ExampleNodeHead(nn.Module):\n    r\"\"\"Head of GNN for node prediction.\"\"\"\n    def __init__(self, dim_in, dim_out):\n        super().__init__()\n        self.layer_post_mp = nn.Linear(dim_in, dim_out, bias=True)\n\n    def _apply_index(self, batch):\n        if batch.node_label_index.shape[0] == batch.node_label.shape[0]:\n            return batch.x[batch.node_label_index], batch.node_label\n        else:\n            return batch.x[batch.node_label_index], \\\n                batch.node_label[batch.node_label_index]\n\n    def forward(self, batch):\n        batch = self.layer_post_mp(batch)\n        pred, label = self._apply_index(batch)\n        return pred, label\n"
  },
  {
    "path": "graphgym/custom_graphgym/layer/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/layer/example.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import Parameter\n\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.register import register_layer\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.inits import glorot, zeros\n\n# Note: A registered GNN layer should take 'batch' as input\n# and 'batch' as output\n\n\n# Example 1: Directly define a GraphGym format Conv\n# take 'batch' as input and 'batch' as output\n@register_layer('exampleconv1')\nclass ExampleConv1(MessagePassing):\n    r\"\"\"Example GNN layer.\"\"\"\n    def __init__(self, in_channels, out_channels, bias=True, **kwargs):\n        super().__init__(aggr=cfg.gnn.agg, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        self.weight = Parameter(torch.empty(in_channels, out_channels))\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot(self.weight)\n        zeros(self.bias)\n\n    def forward(self, batch):\n        x, edge_index = batch.x, batch.edge_index\n        x = torch.matmul(x, self.weight)\n\n        batch.x = self.propagate(edge_index, x=x)\n\n        return batch\n\n    def message(self, x_j):\n        return x_j\n\n    def update(self, aggr_out):\n        if self.bias is not None:\n            aggr_out = aggr_out + self.bias\n        return aggr_out\n\n\n# Example 2: First define a PyG format Conv layer\n# Then wrap it to become GraphGym format\nclass ExampleConv2Layer(MessagePassing):\n    r\"\"\"Example GNN layer.\"\"\"\n    def __init__(self, in_channels, out_channels, bias=True, **kwargs):\n        super().__init__(aggr=cfg.gnn.agg, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        self.weight = Parameter(torch.empty(in_channels, out_channels))\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot(self.weight)\n        zeros(self.bias)\n\n    def forward(self, x, edge_index):\n        x = torch.matmul(x, self.weight)\n\n        return self.propagate(edge_index, x=x)\n\n    def message(self, x_j):\n        return x_j\n\n    def update(self, aggr_out):\n        if self.bias is not None:\n            aggr_out = aggr_out + self.bias\n        return aggr_out\n\n\n@register_layer('exampleconv2')\nclass ExampleConv2(nn.Module):\n    def __init__(self, dim_in, dim_out, bias=False, **kwargs):\n        super().__init__()\n        self.model = ExampleConv2Layer(dim_in, dim_out, bias=bias)\n\n    def forward(self, batch):\n        batch.x = self.model(batch.x, batch.edge_index)\n        return batch\n"
  },
  {
    "path": "graphgym/custom_graphgym/loader/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/loader/example.py",
    "content": "from torch_geometric.datasets import QM7b\nfrom torch_geometric.graphgym.register import register_loader\n\n\n@register_loader('example')\ndef load_dataset_example(format, name, dataset_dir):\n    dataset_dir = f'{dataset_dir}/{name}'\n    if format == 'PyG':\n        if name == 'QM7b':\n            dataset_raw = QM7b(dataset_dir)\n            return dataset_raw\n"
  },
  {
    "path": "graphgym/custom_graphgym/loss/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/loss/example.py",
    "content": "import torch\n\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.register import register_loss\n\n\n@register_loss('smoothl1')\ndef loss_example(pred, true):\n    if cfg.model.loss_fun == 'smoothl1':\n        l1_loss = torch.nn.SmoothL1Loss()\n        loss = l1_loss(pred, true)\n        return loss, pred\n"
  },
  {
    "path": "graphgym/custom_graphgym/network/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/network/example.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport torch_geometric.graphgym.models.head  # noqa, register module\nimport torch_geometric.graphgym.register as register\nimport torch_geometric.nn as pyg_nn\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.register import register_network\n\n\n@register_network('example')\nclass ExampleGNN(torch.nn.Module):\n    def __init__(self, dim_in, dim_out, num_layers=2, model_type='GCN'):\n        super().__init__()\n        conv_model = self.build_conv_model(model_type)\n        self.convs = nn.ModuleList()\n        self.convs.append(conv_model(dim_in, dim_in))\n\n        for _ in range(num_layers - 1):\n            self.convs.append(conv_model(dim_in, dim_in))\n\n        GNNHead = register.head_dict[cfg.dataset.task]\n        self.post_mp = GNNHead(dim_in=dim_in, dim_out=dim_out)\n\n    def build_conv_model(self, model_type):\n        if model_type == 'GCN':\n            return pyg_nn.GCNConv\n        elif model_type == 'GAT':\n            return pyg_nn.GATConv\n        elif model_type == \"GraphSage\":\n            return pyg_nn.SAGEConv\n        else:\n            raise ValueError(f'Model {model_type} unavailable')\n\n    def forward(self, batch):\n        x, edge_index = batch.x, batch.edge_index\n\n        for i in range(len(self.convs)):\n            x = self.convs[i](x, edge_index)\n            x = F.relu(x)\n            x = F.dropout(x, p=0.1, training=self.training)\n\n        batch.x = x\n        batch = self.post_mp(batch)\n\n        return batch\n"
  },
  {
    "path": "graphgym/custom_graphgym/optimizer/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/optimizer/example.py",
    "content": "from typing import Iterator\n\nfrom torch.nn import Parameter\nfrom torch.optim import Adagrad, Optimizer\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\n\nimport torch_geometric.graphgym.register as register\n\n\n@register.register_optimizer('adagrad')\ndef adagrad_optimizer(params: Iterator[Parameter], base_lr: float,\n                      weight_decay: float) -> Adagrad:\n    return Adagrad(params, lr=base_lr, weight_decay=weight_decay)\n\n\n@register.register_scheduler('pleateau')\ndef plateau_scheduler(optimizer: Optimizer, patience: int,\n                      lr_decay: float) -> ReduceLROnPlateau:\n    return ReduceLROnPlateau(optimizer, patience=patience, factor=lr_decay)\n"
  },
  {
    "path": "graphgym/custom_graphgym/pooling/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/pooling/example.py",
    "content": "from torch_geometric.graphgym.register import register_pooling\nfrom torch_geometric.utils import scatter\n\n\n@register_pooling('example')\ndef global_example_pool(x, batch, size=None):\n    size = batch.max().item() + 1 if size is None else size\n    return scatter(x, batch, dim=0, dim_size=size, reduce='sum')\n"
  },
  {
    "path": "graphgym/custom_graphgym/stage/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/stage/example.py",
    "content": "import torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.models.layer import GeneralLayer\nfrom torch_geometric.graphgym.register import register_stage\n\n\ndef GNNLayer(dim_in, dim_out, has_act=True):\n    return GeneralLayer(cfg.gnn.layer_type, dim_in, dim_out, has_act)\n\n\n@register_stage('example')\nclass GNNStackStage(nn.Module):\n    r\"\"\"Simple stage that stacks GNN layers.\"\"\"\n    def __init__(self, dim_in, dim_out, num_layers):\n        super().__init__()\n        for i in range(num_layers):\n            d_in = dim_in if i == 0 else dim_out\n            layer = GNNLayer(d_in, dim_out)\n            self.add_module(f'layer{i}', layer)\n        self.dim_out = dim_out\n\n    def forward(self, batch):\n        for layer in self.children():\n            batch = layer(batch)\n        if cfg.gnn.l2norm:\n            batch.x = F.normalize(batch.x, p=2, dim=-1)\n        return batch\n"
  },
  {
    "path": "graphgym/custom_graphgym/train/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/custom_graphgym/train/example.py",
    "content": "import logging\nimport time\n\nimport torch\n\nfrom torch_geometric.graphgym.checkpoint import (\n    clean_ckpt,\n    load_ckpt,\n    save_ckpt,\n)\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.loss import compute_loss\nfrom torch_geometric.graphgym.register import register_train\nfrom torch_geometric.graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch\n\n\ndef train_epoch(logger, loader, model, optimizer, scheduler):\n    model.train()\n    time_start = time.time()\n    for batch in loader:\n        optimizer.zero_grad()\n        batch.to(torch.device(cfg.device))\n        pred, true = model(batch)\n        loss, pred_score = compute_loss(pred, true)\n        loss.backward()\n        optimizer.step()\n        logger.update_stats(true=true.detach().cpu(),\n                            pred=pred_score.detach().cpu(), loss=loss.item(),\n                            lr=scheduler.get_last_lr()[0],\n                            time_used=time.time() - time_start,\n                            params=cfg.params)\n        time_start = time.time()\n    scheduler.step()\n\n\ndef eval_epoch(logger, loader, model):\n    model.eval()\n    time_start = time.time()\n    for batch in loader:\n        batch.to(torch.device(cfg.device))\n        pred, true = model(batch)\n        loss, pred_score = compute_loss(pred, true)\n        logger.update_stats(true=true.detach().cpu(),\n                            pred=pred_score.detach().cpu(), loss=loss.item(),\n                            lr=0, time_used=time.time() - time_start,\n                            params=cfg.params)\n        time_start = time.time()\n\n\n@register_train('example')\ndef train_example(loggers, loaders, model, optimizer, scheduler):\n    start_epoch = 0\n    if cfg.train.auto_resume:\n        start_epoch = load_ckpt(model, optimizer, scheduler,\n                                cfg.train.epoch_resume)\n    if start_epoch == cfg.optim.max_epoch:\n        logging.info('Checkpoint found, Task already done')\n    else:\n        logging.info('Start from epoch %s', start_epoch)\n\n    num_splits = len(loggers)\n    for cur_epoch in range(start_epoch, cfg.optim.max_epoch):\n        train_epoch(loggers[0], loaders[0], model, optimizer, scheduler)\n        loggers[0].write_epoch(cur_epoch)\n        if is_eval_epoch(cur_epoch):\n            for i in range(1, num_splits):\n                eval_epoch(loggers[i], loaders[i], model)\n                loggers[i].write_epoch(cur_epoch)\n        if is_ckpt_epoch(cur_epoch):\n            save_ckpt(model, optimizer, scheduler, cur_epoch)\n    for logger in loggers:\n        logger.close()\n    if cfg.train.ckpt_clean:\n        clean_ckpt()\n\n    logging.info('Task done, results saved in %s', cfg.run_dir)\n"
  },
  {
    "path": "graphgym/custom_graphgym/transform/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "graphgym/grids/example.txt",
    "content": "# Format for each row: name in config.py; alias; range to search\n# No spaces, except between these 3 fields\n# Line breaks are used to union different grid search spaces\n# Feel free to add '#' to add comments\n\n\n# (1) dataset configurations\ndataset.format format ['PyG']\ndataset.name dataset ['TU_ENZYMES','TU_PROTEINS']\ndataset.task task ['graph']\ndataset.transductive trans [False]\n# (2) The recommended GNN design space, 96 models in total\ngnn.layers_pre_mp l_pre [1,2]\ngnn.layers_mp l_mp [2,4,6,8]\ngnn.layers_post_mp l_post [2,3]\ngnn.stage_type stage ['skipsum','skipconcat']\ngnn.agg agg ['add','mean','max']\n"
  },
  {
    "path": "graphgym/grids/pyg/example.txt",
    "content": "# Format for each row: name in config.py; alias; range to search\n# No spaces, except between these 3 fields\n# Line breaks are used to union different grid search spaces\n# Feel free to add '#' to add comments\n\n\ngnn.layers_pre_mp l_pre [1,2]\ngnn.layers_mp l_mp [2,4,6]\ngnn.layers_post_mp l_post [1,2]\ngnn.stage_type stage ['stack','skipsum','skipconcat']\ngnn.dim_inner dim [64]\noptim.base_lr lr [0.01]\noptim.max_epoch epoch [200]\n"
  },
  {
    "path": "graphgym/main.py",
    "content": "import logging\nimport os\n\nimport custom_graphgym  # noqa, register custom modules\nimport torch\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.graphgym.cmd_args import parse_args\nfrom torch_geometric.graphgym.config import (\n    cfg,\n    dump_cfg,\n    load_cfg,\n    set_out_dir,\n    set_run_dir,\n)\nfrom torch_geometric.graphgym.logger import set_printing\nfrom torch_geometric.graphgym.model_builder import create_model\nfrom torch_geometric.graphgym.train import GraphGymDataModule, train\nfrom torch_geometric.graphgym.utils.agg_runs import agg_runs\nfrom torch_geometric.graphgym.utils.comp_budget import params_count\nfrom torch_geometric.graphgym.utils.device import auto_select_device\n\nif __name__ == '__main__':\n    # Load cmd line args\n    args = parse_args()\n    # Load config file\n    load_cfg(cfg, args)\n    set_out_dir(cfg.out_dir, args.cfg_file)\n    # Set Pytorch environment\n    torch.set_num_threads(cfg.num_threads)\n    dump_cfg(cfg)\n    # Repeat for different random seeds\n    for _ in range(args.repeat):\n        set_run_dir(cfg.out_dir)\n        set_printing()\n        # Set configurations for each run\n        cfg.seed = cfg.seed + 1\n        seed_everything(cfg.seed)\n        auto_select_device()\n        # Set machine learning pipeline\n        datamodule = GraphGymDataModule()\n        model = create_model()\n        # Print model info\n        logging.info(model)\n        logging.info(cfg)\n        cfg.params = params_count(model)\n        logging.info('Num parameters: %s', cfg.params)\n        train(model, datamodule, logger=True)\n\n    # Aggregate results from different seeds\n    agg_runs(cfg.out_dir, cfg.metric_best)\n    # When being launched in batch mode, mark a yaml as done\n    if args.mark_done:\n        os.rename(args.cfg_file, f'{args.cfg_file}_done')\n"
  },
  {
    "path": "graphgym/parallel.sh",
    "content": "CONFIG_DIR=$1\nREPEAT=$2\nMAX_JOBS=${3:-2}\nSLEEP=${4:-1}\nMAIN=${5:-main}\n\n(\n  trap 'kill 0' SIGINT\n  CUR_JOBS=0\n  for CONFIG in \"$CONFIG_DIR\"/*.yaml; do\n    if [ \"$CONFIG\" != \"$CONFIG_DIR/*.yaml\" ]; then\n      ((CUR_JOBS >= MAX_JOBS)) && wait -n\n      python $MAIN.py --cfg $CONFIG --repeat $REPEAT --mark_done &\n      echo $CONFIG\n      sleep $SLEEP\n      ((++CUR_JOBS))\n    fi\n  done\n\n  wait\n)\n"
  },
  {
    "path": "graphgym/run_batch.sh",
    "content": "#!/usr/bin/env bash\n\nCONFIG=${CONFIG:-example_node}\nGRID=${GRID:-example}\nREPEAT=${REPEAT:-3}\nMAX_JOBS=${MAX_JOBS:-8}\nSLEEP=${SLEEP:-1}\nMAIN=${MAIN:-main}\n\n# generate configs (after controlling computational budget)\n# please remove --config_budget, if don't control computational budget\npython configs_gen.py --config configs/pyg/${CONFIG}.yaml \\\n  --grid grids/pyg/${GRID}.txt \\\n  --out_dir configs\n#python configs_gen.py --config configs/ChemKG/${CONFIG}.yaml --config_budget configs/ChemKG/${CONFIG}.yaml --grid grids/ChemKG/${GRID}.txt --out_dir configs\n# run batch of configs\n# Args: config_dir, num of repeats, max jobs running, sleep time\nbash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $SLEEP $MAIN\n# rerun missed / stopped experiments\nbash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $SLEEP $MAIN\n# rerun missed / stopped experiments\nbash parallel.sh configs/${CONFIG}_grid_${GRID} $REPEAT $MAX_JOBS $SLEEP $MAIN\n\n# aggregate results for the batch\npython agg_batch.py --dir results/${CONFIG}_grid_${GRID}\n"
  },
  {
    "path": "graphgym/run_single.sh",
    "content": "#!/usr/bin/env bash\n\n# Test for running a single experiment. --repeat means run how many different random seeds.\npython main.py --cfg configs/pyg/example_node.yaml --repeat 3 # node classification\npython main.py --cfg configs/pyg/example_link.yaml --repeat 3 # link prediction\npython main.py --cfg configs/pyg/example_graph.yaml --repeat 3 # graph classification\n"
  },
  {
    "path": "graphgym/sample/dimensions.txt",
    "content": "act bn drop agg l_mp l_pre l_post stage batch lr optim epoch\n"
  },
  {
    "path": "graphgym/sample/dimensionsatt.txt",
    "content": "l_tw\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires=[\"flit_core >=3.12,<4\"]\nbuild-backend=\"flit_core.buildapi\"\n\n[project]\nname=\"torch-geometric\"\nversion=\"2.8.0\"\nauthors=[\n    {name=\"Matthias Fey\", email=\"matthias@pyg.org\"},\n]\ndescription=\"Graph Neural Network Library for PyTorch\"\nreadme=\"README.md\"\nrequires-python=\">=3.10\"\nkeywords=[\n    \"deep-learning\",\n    \"pytorch\",\n    \"geometric-deep-learning\",\n    \"graph-neural-networks\",\n    \"graph-convolutional-networks\",\n]\nlicense = \"MIT\"\nlicense-files = [\"LICENSE\"]\nclassifiers=[\n    \"Development Status :: 5 - Production/Stable\",\n    \"Programming Language :: Python\",\n    \"Programming Language :: Python :: 3.10\",\n    \"Programming Language :: Python :: 3.11\",\n    \"Programming Language :: Python :: 3.12\",\n    \"Programming Language :: Python :: 3.13\",\n    \"Programming Language :: Python :: 3.14\",\n    \"Programming Language :: Python :: 3 :: Only\",\n]\ndependencies=[\n    \"aiohttp\",\n    \"fsspec\",\n    \"jinja2\",\n    \"numpy\",\n    \"psutil>=5.8.0\",\n    \"pyparsing\",\n    \"requests\",\n    \"tqdm\",\n    \"xxhash\",\n]\n\n[project.optional-dependencies]\ngraphgym=[\n    \"protobuf<4.21\",\n    \"pytorch-lightning\",\n    \"yacs\",\n]\nmodelhub=[\n    \"huggingface_hub\"\n]\nbenchmark=[\n    \"matplotlib\",\n    \"networkx\",\n    \"pandas\",\n    \"protobuf<4.21\",\n    \"wandb\",\n]\nrag=[\n    \"pcst_fast\",\n    \"datasets\",\n    \"transformers\",\n    \"pandas\",\n    \"sentencepiece\",\n    \"accelerate\",\n    \"torchmetrics\",\n    \"peft\",\n]\ntest=[\n    \"onnx\",\n    \"onnxruntime\",\n    \"onnxscript\",\n    \"pytest\",\n    \"pytest-cov\",\n]\ndev=[\n    \"ipython\",\n    \"matplotlib-inline\",\n    \"pre-commit\",\n    \"torch_geometric[test]\",\n]\nfull = [\n    \"scipy\",\n    \"scikit-learn\",\n    \"ase\",\n    \"captum<0.7.0\",\n    \"graphviz\",\n    \"h5py\",\n    \"matplotlib\",\n    \"networkx\",\n    \"numba<0.60.0\",\n    \"opt_einsum\",\n    \"pandas\",\n    # See https://github.com/pgmpy/pgmpy/issues/2360.\n    # \"pgmpy\",\n    \"pynndescent\",\n    \"pytorch-memlab\",\n    \"rdflib\",\n    \"rdkit\",\n    \"scikit-image\",\n    \"statsmodels\",\n    \"sympy\",\n    \"tabulate\",\n    \"torch_geometric[graphgym, modelhub]\",\n    \"torchmetrics\",\n    \"trimesh\",\n]\n\n[project.urls]\nhomepage=\"https://pyg.org\"\ndocumentation=\"https://pytorch-geometric.readthedocs.io\"\nrepository=\"https://github.com/pyg-team/pytorch_geometric.git\"\nchangelog=\"https://github.com/pyg-team/pytorch_geometric/blob/master/CHANGELOG.md\"\n\n[tool.flit.module]\nname=\"torch_geometric\"\n\n[tool.yapf]\nbased_on_style = \"pep8\"\nsplit_before_named_assigns = false\nblank_line_before_nested_class_or_def = false\n\n[tool.mypy]\nfiles = [\"torch_geometric\"]\ninstall_types = true\nnon_interactive = true\nignore_missing_imports = true\nshow_error_codes = true\nwarn_redundant_casts = true\nwarn_unused_configs = true\nwarn_unused_ignores = true\ndisallow_untyped_defs = true\ndisallow_incomplete_defs = true\n\n[[tool.mypy.overrides]]\nignore_errors = true\nmodule = [\n    \"torch_geometric.data.*\",\n    \"torch_geometric.sampler.*\",\n    \"torch_geometric.loader.*\",\n    \"torch_geometric.nn.*\",\n    \"torch_geometric.explain.*\",\n    \"torch_geometric.profile.*\",\n    \"torch_geometric.contrib.*\",\n    \"torch_geometric.graphgym.*\",\n    \"torch_geometric.distributed.*\",\n    \"torch_geometric.llm.*\",\n]\n\n[tool.isort]\nmulti_line_output = 3\ninclude_trailing_comma = true\nskip = [\".gitignore\", \"__init__.py\"]\n\n[tool.ruff]  # https://docs.astral.sh/ruff/rules\nsrc = [\"torch_geometric\"]\nline-length = 80\nindent-width = 4\ntarget-version = \"py310\"\n\n[tool.ruff.lint]\nselect = [\n    \"B\",  # flake8-bugbear\n    \"D\",  # pydocstyle\n]\nignore = [\n    \"B905\",  # TODO Don't ignore \"zip with strict=False\"\n    \"D100\",  # TODO Don't ignore \"Missing docstring in public module\"\n    \"D101\",  # TODO Don't ignore \"Missing docstring in public class\"\n    \"D102\",  # TODO Don't ignore \"Missing docstring in public method\"\n    \"D103\",  # TODO Don't ignore \"Missing docstring in public function\"\n    \"D104\",  # TODO Don't ignore \"Missing docstring in public package\"\n    \"D105\",  # Ignore \"Missing docstring in magic method\"\n    \"D107\",  # Ignore \"Missing docstring in __init__\"\n    \"D205\",  # Ignore \"blank line required between summary line and description\"\n]\n\n[tool.ruff.format]\nquote-style = \"single\"\n\n[tool.ruff.lint.pydocstyle]\nconvention = \"google\"\n\n[tool.pytest.ini_options]\naddopts = [\n    \"--capture=no\",\n    \"--color=yes\",\n    \"-vv\",\n]\nfilterwarnings = [\n    \"ignore:distutils:DeprecationWarning\",\n    \"ignore:'torch_geometric.contrib' contains experimental code:UserWarning\",\n    # Filter `torch` warnings:\n    \"ignore:The PyTorch API of nested tensors is in prototype stage:UserWarning\",\n    \"ignore:scatter_reduce():UserWarning\",\n    \"ignore:Sparse CSR tensor support is in beta state:UserWarning\",\n    \"ignore:Sparse CSC tensor support is in beta state:UserWarning\",\n    \"ignore:torch.distributed._sharded_tensor will be deprecated:DeprecationWarning\",\n    # Filter `torch.compile` warnings:\n    \"ignore:pkg_resources is deprecated as an API\",\n    \"ignore:Deprecated call to `pkg_resources.declare_namespace\",\n    # Filter `captum` warnings:\n    \"ignore:Setting backward hooks on ReLU activations:UserWarning\",\n    \"ignore:.*did not already require gradients, required_grads has been set automatically:UserWarning\",\n    # Filter `pytorch_lightning` warnings:\n    \"ignore:GPU available but not used:UserWarning\",\n    \"error:.*torch_geometric.*:DeprecationWarning\",\n    # TODO(rishipuri98): Remove usage of `torch_geometric.distributed` from `torch_geometric.llm`\n    \"ignore:.*torch_geometric.distributed.*:DeprecationWarning\",\n    # Filter `torch.jit.*` deprication warnings:\n    \"ignore:.*torch.jit.*:DeprecationWarning\",\n]\nmarkers = [\n    \"rag: mark test as RAG test\",\n]\n\n[tool.coverage.run]\nsource = [\"torch_geometric\"]\nomit = [\n    \"torch_geometric/distributed/*\",\n    \"torch_geometric/datasets/*\",\n    \"torch_geometric/data/extract.py\",\n    \"torch_geometric/nn/data_parallel.py\",\n]\n\n[tool.coverage.report]\nexclude_lines = [\n    \"pragma: no cover\",\n    \"pass\",\n    \"raise NotImplementedError\",\n    \"register_parameter\",\n    \"torch.cuda.is_available\",\n]\n\n[tool.setuptools]\npy-modules = []\n"
  },
  {
    "path": "readthedocs.yml",
    "content": "version: 2\n\nsphinx:\n   configuration: docs/source/conf.py\n\nbuild:\n   os: ubuntu-24.04\n   tools:\n      python: \"3.10\"\n\npython:\n   install:\n      - requirements: docs/requirements.txt\n      - method: pip\n        path: .\n\nformats: []\n"
  },
  {
    "path": "test/conftest.py",
    "content": "import functools\nimport logging\nimport os.path as osp\nfrom typing import Callable\n\nimport pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.data import Dataset\nfrom torch_geometric.io import fs\n\n\ndef load_dataset(root: str, name: str, *args, **kwargs) -> Dataset:\n    r\"\"\"Returns a variety of datasets according to :obj:`name`.\"\"\"\n    if 'karate' in name.lower():\n        from torch_geometric.datasets import KarateClub\n        return KarateClub(*args, **kwargs)\n    if name.lower() in ['cora', 'citeseer', 'pubmed']:\n        from torch_geometric.datasets import Planetoid\n        path = osp.join(root, 'Planetoid', name)\n        return Planetoid(path, name, *args, **kwargs)\n    if name in ['BZR', 'ENZYMES', 'IMDB-BINARY', 'MUTAG']:\n        from torch_geometric.datasets import TUDataset\n        path = osp.join(root, 'TUDataset')\n        return TUDataset(path, name, *args, **kwargs)\n    if name in ['ego-facebook', 'soc-Slashdot0811', 'wiki-vote']:\n        from torch_geometric.datasets import SNAPDataset\n        path = osp.join(root, 'SNAPDataset')\n        return SNAPDataset(path, name, *args, **kwargs)\n    if name.lower() in ['bashapes']:\n        from torch_geometric.datasets import BAShapes\n        return BAShapes(*args, **kwargs)\n    if name in ['citationCiteseer', 'illc1850']:\n        from torch_geometric.datasets import SuiteSparseMatrixCollection\n        path = osp.join(root, 'SuiteSparseMatrixCollection')\n        return SuiteSparseMatrixCollection(path, *args, name=name, **kwargs)\n    if 'elliptic' in name.lower():\n        from torch_geometric.datasets import EllipticBitcoinDataset\n        path = osp.join(root, 'EllipticBitcoinDataset')\n        return EllipticBitcoinDataset(path, *args, **kwargs)\n    if name.lower() in ['hetero']:\n        from torch_geometric.testing import FakeHeteroDataset\n        return FakeHeteroDataset(*args, **kwargs)\n\n    raise ValueError(f\"Cannot load dataset with name '{name}'\")\n\n\n@pytest.fixture(scope='session')\ndef get_dataset() -> Callable:\n    # TODO Support memory filesystem on Windows.\n    if torch_geometric.typing.WITH_WINDOWS:\n        root = osp.join('/', 'tmp', 'pyg_test_datasets')\n    else:\n        root = 'memory://pyg_test_datasets'\n\n    yield functools.partial(load_dataset, root)\n\n    if fs.exists(root):\n        fs.rm(root)\n\n\n@pytest.fixture\ndef enable_extensions():  # Nothing to do.\n    yield\n\n\n@pytest.fixture\ndef disable_extensions():\n    def is_setting(name: str) -> bool:\n        if not name.startswith('WITH_'):\n            return False\n        if name.startswith('WITH_PT') or name.startswith('WITH_WINDOWS'):\n            return False\n        return True\n\n    settings = dir(torch_geometric.typing)\n    settings = [key for key in settings if is_setting(key)]\n    state = {key: getattr(torch_geometric.typing, key) for key in settings}\n\n    for key in state.keys():\n        setattr(torch_geometric.typing, key, False)\n    yield\n    for key, value in state.items():\n        setattr(torch_geometric.typing, key, value)\n\n\n@pytest.fixture\ndef without_extensions(request):\n    request.getfixturevalue(request.param)\n    return request.param == 'disable_extensions'\n\n\n@pytest.fixture(scope='function')\ndef spawn_context():\n    torch.multiprocessing.set_start_method('spawn', force=True)\n    logging.info(\"Setting torch.multiprocessing context to 'spawn'\")\n"
  },
  {
    "path": "test/contrib/explain/test_pgm_explainer.py",
    "content": "import numpy as np\nimport pytest\nimport torch\n\nfrom torch_geometric.contrib.explain import PGMExplainer\nfrom torch_geometric.explain import Explainer\nfrom torch_geometric.explain.config import ModelConfig\nfrom torch_geometric.nn import GCNConv, global_add_pool\nfrom torch_geometric.testing import onlyLinux, withPackage\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, model_config: ModelConfig):\n        super().__init__()\n        self.model_config = model_config\n\n        if model_config.mode.value == 'multiclass_classification':\n            out_channels = 7\n        else:\n            out_channels = 1\n\n        self.conv1 = GCNConv(3, 16)\n        self.conv2 = GCNConv(16, out_channels)\n\n    def forward(self, x, edge_index, edge_weight=None, batch=None, **kwargs):\n        x = self.conv1(x, edge_index, edge_weight).relu()\n        x = self.conv2(x, edge_index, edge_weight).relu()\n\n        if self.model_config.task_level.value == 'graph':\n            x = global_add_pool(x, batch)\n\n        if self.model_config.mode.value == 'binary_classification':\n            x = x.sigmoid()\n        elif self.model_config.mode.value == 'multiclass_classification':\n            x = x.log_softmax(dim=-1)\n\n        return x\n\n\nx = torch.randn(8, 3)\nedge_index = torch.tensor([\n    [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],\n    [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],\n])\ntarget = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2])\nedge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]])\n\n\n@onlyLinux\n@withPackage('pgmpy', 'pandas')\n@pytest.mark.parametrize('node_idx', [2, 6])\n@pytest.mark.parametrize('task_level, perturbation_mode', [\n    ('node', 'randint'),\n    ('graph', 'mean'),\n    ('graph', 'max'),\n    ('graph', 'min'),\n    ('graph', 'zero'),\n])\ndef test_pgm_explainer_classification(node_idx, task_level, perturbation_mode):\n    model_config = ModelConfig(\n        mode='multiclass_classification',\n        task_level=task_level,\n        return_type='raw',\n    )\n\n    model = GCN(model_config)\n    logits = model(x, edge_index)\n    target = logits.argmax(dim=1)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=PGMExplainer(feature_index=[0],\n                               perturbation_mode=perturbation_mode),\n        explanation_type='phenomenon',\n        node_mask_type='object',\n        model_config=model_config,\n    )\n\n    explanation = explainer(\n        x=x,\n        edge_index=edge_index,\n        index=node_idx,\n        target=target,\n    )\n\n    assert 'node_mask' in explanation\n    assert 'pgm_stats' in explanation\n    assert explanation.node_mask.size(0) == explanation.num_nodes\n    assert explanation.node_mask.min() >= 0\n    assert explanation.node_mask.max() <= 1\n\n\nclass DummyModel(torch.nn.Module):\n    def forward(self, x, edge_index, **kwargs):\n        return torch.tensor([[0.2, 0.8]], requires_grad=True)\n\n\ndef test_batch_perturb_features_on_node():\n    model = DummyModel()\n    explainer = PGMExplainer(num_samples=1)  # just one sample for testing\n\n    # Minimal graph data with 1 node and 2 features\n    x = torch.randn(1, 2)\n    edge_index = torch.tensor([[0], [0]])  # dummy self-loop\n    indices_to_perturb = np.array([0])  # only node 0 can be perturbed\n\n    # Simulate kwargs that would include prediction details\n    kwargs = {\n        \"soft_pred\": torch.tensor([0.4,\n                                   0.6]),  # soft prediction of original input\n        \"pred_label\": 1,\n        \"num_nodes\": 1,\n    }\n\n    samples = explainer._batch_perturb_features_on_node(\n        model=model, x=x, edge_index=edge_index,\n        indices_to_perturb=indices_to_perturb, **kwargs)\n\n    assert isinstance(samples, torch.Tensor)\n    assert samples.shape == (1, 2)\n    assert torch.all(samples[0] >= 0)  # pred_change should be non-negative\n"
  },
  {
    "path": "test/contrib/nn/models/test_rbcd_attack.py",
    "content": "import pytest\nimport torch\nfrom torch.nn import Linear\n\nfrom torch_geometric.contrib.nn import GRBCDAttack, PRBCDAttack\nfrom torch_geometric.nn import GCNConv, global_add_pool\nfrom torch_geometric.utils import to_undirected\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GCNConv(3, 16)\n        self.conv2 = GCNConv(16, 7)\n\n    def forward(self, x, edge_index, edge_weight):\n        x = self.conv1(x, edge_index, edge_weight).relu()\n        x = self.conv2(x, edge_index, edge_weight)\n        return x\n\n\nclass GNN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GCNConv(3, 16)\n        self.conv2 = GCNConv(16, 16)\n        self.lin = Linear(16, 7)\n\n    def forward(self, x, edge_index, edge_weight, batch=None):\n        x = self.conv1(x, edge_index, edge_weight).relu()\n        x = self.conv2(x, edge_index, edge_weight).relu()\n        x = global_add_pool(x, batch)\n        x = self.lin(x)\n        return x\n\n\n@pytest.mark.parametrize('model', [GCN, GNN])\n@pytest.mark.parametrize('budget', [1])\n@pytest.mark.parametrize('loss',\n                         ['masked', 'margin', 'prob_margin', 'tanh_margin'])\n@pytest.mark.parametrize('is_undirected', [False, True])\n@pytest.mark.parametrize('with_early_stopping', [False, True])\ndef test_prbcd_attack(model, budget, loss, is_undirected, with_early_stopping):\n    attack = PRBCDAttack(model(), block_size=10_000, epochs=4,\n                         epochs_resampling=2, loss=loss, max_final_samples=2,\n                         log=False, is_undirected=is_undirected,\n                         with_early_stopping=with_early_stopping)\n    assert str(attack) == 'PRBCDAttack()'\n\n    x = torch.randn(8, 3)\n    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 7]])\n    if is_undirected:\n        edge_index = to_undirected(edge_index)\n\n    if model == GNN:\n        y = torch.tensor([0])\n        # All nodes belong to same graph:\n        kwargs = dict(batch=edge_index.new_zeros(x.size(0)))\n    else:\n        y = torch.tensor([0, 1, 1, 0, 1, 0, 1, 0])\n        kwargs = {}\n\n    pert_edge_index, pert = attack.attack(x, edge_index, y, budget, **kwargs)\n\n    m = edge_index.size(1)\n    if budget == 1:\n        assert pert.size() in [(2, 0), (2, 1)]\n\n        if pert.size(1):\n            if is_undirected:\n                possible_m = [m - 2, m + 2]\n            else:\n                possible_m = [m - 1, m + 1]\n        else:\n            possible_m = [m]\n        assert pert_edge_index.size(1) in possible_m\n\n\n@pytest.mark.parametrize('model', [GCN, GNN])\n@pytest.mark.parametrize('budget', [1])\n@pytest.mark.parametrize('is_undirected', [False, True])\ndef test_grbcd_attack(model, budget, is_undirected):\n    attack = GRBCDAttack(model(), block_size=10_000, epochs=4, log=False,\n                         is_undirected=is_undirected)\n    assert str(attack) == 'GRBCDAttack()'\n\n    x = torch.randn(8, 3)\n    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 7]])\n    if is_undirected:\n        edge_index = to_undirected(edge_index)\n\n    if model == GNN:\n        y = torch.tensor([0])\n        # All nodes belong to same graph:\n        kwargs = dict(batch=edge_index.new_zeros(x.size(0)))\n    else:\n        y = torch.tensor([0, 1, 1, 0, 1, 0, 1, 0])\n        kwargs = {}\n\n    pert_edge_index, pert = attack.attack(x, edge_index, y, budget, **kwargs)\n\n    m = edge_index.size(1)\n    if budget == 1:\n        assert pert.size() == (2, 1)\n\n        if is_undirected:\n            possible_m = [m - 2, m + 2]\n        else:\n            possible_m = [m - 1, m + 1]\n        assert pert_edge_index.size(1) in possible_m\n"
  },
  {
    "path": "test/data/lightning/test_datamodule.py",
    "content": "import math\nfrom contextlib import contextmanager\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.lightning import (\n    LightningDataset,\n    LightningLinkData,\n    LightningNodeData,\n)\nfrom torch_geometric.nn import global_mean_pool\nfrom torch_geometric.sampler import BaseSampler, NeighborSampler\nfrom torch_geometric.testing import (\n    MyFeatureStore,\n    MyGraphStore,\n    get_random_edge_index,\n    has_package,\n    onlyCUDA,\n    onlyFullTest,\n    onlyNeighborSampler,\n    onlyOnline,\n    withPackage,\n)\n\ntry:\n    from pytorch_lightning import LightningModule\nexcept ImportError:\n    LightningModule = torch.nn.Module\n\n\nclass LinearGraphModule(LightningModule):\n    def __init__(self, in_channels: int, hidden_channels: int,\n                 out_channels: int):\n        super().__init__()\n        from torchmetrics import Accuracy\n\n        self.lin1 = torch.nn.Linear(in_channels, hidden_channels)\n        self.lin2 = torch.nn.Linear(hidden_channels, out_channels)\n\n        self.train_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.val_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.test_acc = Accuracy(task='multiclass', num_classes=out_channels)\n\n    def forward(self, x: Tensor, batch: Data) -> Tensor:\n        # Basic test to ensure that the dataset is not replicated:\n        self.trainer.datamodule.train_dataset._data.x.add_(1)\n\n        x = self.lin1(x).relu()\n        x = global_mean_pool(x, batch)\n        x = self.lin2(x)\n        return x\n\n    def training_step(self, data: Data, batch_idx: int):\n        y_hat = self(data.x, data.batch)\n        loss = F.cross_entropy(y_hat, data.y)\n        self.train_acc(y_hat.softmax(dim=-1), data.y)\n        self.log('loss', loss, batch_size=data.num_graphs)\n        self.log('train_acc', self.train_acc, batch_size=data.num_graphs)\n        return loss\n\n    def validation_step(self, data: Data, batch_idx: int):\n        y_hat = self(data.x, data.batch)\n        self.val_acc(y_hat.softmax(dim=-1), data.y)\n        self.log('val_acc', self.val_acc, batch_size=data.num_graphs)\n\n    def test_step(self, data: Data, batch_idx: int):\n        y_hat = self(data.x, data.batch)\n        self.test_acc(y_hat.softmax(dim=-1), data.y)\n        self.log('test_acc', self.test_acc, batch_size=data.num_graphs)\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=0.01)\n\n\n@onlyCUDA\n@onlyOnline\n@onlyFullTest\n@withPackage('pytorch_lightning>=2.0.0', 'torchmetrics>=0.11.0')\n@pytest.mark.parametrize('strategy_type', [None, 'ddp'])\ndef test_lightning_dataset(get_dataset, strategy_type):\n    import pytorch_lightning as pl\n    from pytorch_lightning.utilities import rank_zero_only\n\n    @contextmanager\n    def expect_rank_zero_user_warning(match: str):\n        if rank_zero_only.rank == 0:\n            with pytest.warns(UserWarning, match=match):\n                yield\n        else:\n            yield\n\n    dataset = get_dataset(name='MUTAG').shuffle()\n    train_dataset = dataset[:50]\n    val_dataset = dataset[50:80]\n    test_dataset = dataset[80:90]\n    pred_dataset = dataset[90:]\n\n    devices = 1 if strategy_type is None else torch.cuda.device_count()\n    if strategy_type == 'ddp':\n        strategy = pl.strategies.DDPStrategy(accelerator='gpu')\n    else:\n        strategy = pl.strategies.SingleDeviceStrategy(device='cuda:0')\n\n    model = LinearGraphModule(dataset.num_features, 64, dataset.num_classes)\n\n    trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=1,\n                         log_every_n_steps=1)\n    with pytest.warns(UserWarning, match=\"'shuffle=True' option is ignored\"):\n        datamodule = LightningDataset(train_dataset, val_dataset, test_dataset,\n                                      pred_dataset, batch_size=5,\n                                      num_workers=3, shuffle=True)\n        assert 'shuffle' not in datamodule.kwargs\n    old_x = train_dataset._data.x.clone()\n    if has_package('pytorch_lightning>=2.5.0'):\n        datamodule_repr = ('{Train dataloader: size=50}\\n'\n                           '{Validation dataloader: size=30}\\n'\n                           '{Test dataloader: size=10}\\n'\n                           '{Predict dataloader: size=98}')\n    else:\n        datamodule_repr = ('LightningDataset(train_dataset=MUTAG(50), '\n                           'val_dataset=MUTAG(30), '\n                           'test_dataset=MUTAG(10), '\n                           'pred_dataset=MUTAG(98), batch_size=5, '\n                           'num_workers=3, pin_memory=True, '\n                           'persistent_workers=True)')\n    assert str(datamodule) == datamodule_repr\n\n    trainer.fit(model, datamodule)\n    trainer.test(model, datamodule)\n    new_x = train_dataset._data.x\n    assert torch.all(new_x > old_x)  # Ensure shared data.\n    assert trainer.validate_loop._data_source.is_defined()\n    assert trainer.test_loop._data_source.is_defined()\n\n    # Test with `val_dataset=None` and `test_dataset=None`:\n    if strategy_type is None:\n        trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=1,\n                             log_every_n_steps=1)\n\n        datamodule = LightningDataset(train_dataset, batch_size=5)\n        if has_package('pytorch_lightning>=2.5.0'):\n            datamodule_repr = ('{Train dataloader: size=50}\\n'\n                               '{Validation dataloader: None}\\n'\n                               '{Test dataloader: None}\\n{'\n                               'Predict dataloader: None}')\n        else:\n            datamodule_repr = ('LightningDataset(train_dataset=MUTAG(50), '\n                               'batch_size=5, num_workers=0, '\n                               'pin_memory=True, '\n                               'persistent_workers=False)')\n        assert str(datamodule) == datamodule_repr\n\n        with expect_rank_zero_user_warning(\"defined a `validation_step`\"):\n            trainer.fit(model, datamodule)\n\n        assert not trainer.validate_loop._data_source.is_defined()\n        assert not trainer.test_loop._data_source.is_defined()\n\n\nclass LinearNodeModule(LightningModule):\n    def __init__(self, in_channels: int, out_channels: int):\n        super().__init__()\n        from torchmetrics import Accuracy\n\n        self.lin = torch.nn.Linear(in_channels, out_channels)\n\n        self.train_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.val_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.test_acc = Accuracy(task='multiclass', num_classes=out_channels)\n\n    def forward(self, x: Tensor) -> Tensor:\n        # Basic test to ensure that the dataset is not replicated:\n        self.trainer.datamodule.data.x.add_(1)\n\n        return self.lin(x)\n\n    def training_step(self, data: Data, batch_idx: int):\n        y_hat = self(data.x)[data.train_mask]\n        y = data.y[data.train_mask]\n        loss = F.cross_entropy(y_hat, y)\n        self.train_acc(y_hat.softmax(dim=-1), y)\n        self.log('loss', loss, batch_size=y.size(0))\n        self.log('train_acc', self.train_acc, batch_size=y.size(0))\n        return loss\n\n    def validation_step(self, data: Data, batch_idx: int):\n        y_hat = self(data.x)[data.val_mask]\n        y = data.y[data.val_mask]\n        self.val_acc(y_hat.softmax(dim=-1), y)\n        self.log('val_acc', self.val_acc, batch_size=y.size(0))\n\n    def test_step(self, data: Data, batch_idx: int):\n        y_hat = self(data.x)[data.test_mask]\n        y = data.y[data.test_mask]\n        self.test_acc(y_hat.softmax(dim=-1), y)\n        self.log('test_acc', self.test_acc, batch_size=y.size(0))\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=0.01)\n\n\n@onlyCUDA\n@onlyOnline\n@onlyFullTest\n@onlyNeighborSampler\n@withPackage('pytorch_lightning>=2.0.0', 'torchmetrics>=0.11.0', 'scipy')\n@pytest.mark.parametrize('loader', ['full', 'neighbor'])\n@pytest.mark.parametrize('strategy_type', [None, 'ddp'])\ndef test_lightning_node_data(get_dataset, strategy_type, loader):\n    import pytorch_lightning as pl\n\n    dataset = get_dataset(name='Cora')\n    data = dataset[0]\n    data_repr = ('Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], '\n                 'train_mask=[2708], val_mask=[2708], test_mask=[2708])')\n\n    model = LinearNodeModule(dataset.num_features, dataset.num_classes)\n\n    if strategy_type is None or loader == 'full':\n        devices = 1\n    else:\n        devices = torch.cuda.device_count()\n\n    if strategy_type == 'ddp':\n        strategy = pl.strategies.DDPStrategy(accelerator='gpu')\n    else:\n        strategy = pl.strategies.SingleDeviceStrategy(device='cuda:0')\n\n    if loader == 'full':  # Set reasonable defaults for full-batch training:\n        batch_size = 1\n        num_workers = 0\n    else:\n        batch_size = 32\n        num_workers = 3\n    kwargs, kwargs_repr = {}, ''\n    if loader == 'neighbor':\n        kwargs['num_neighbors'] = [5]\n        kwargs_repr += 'num_neighbors=[5], '\n\n    trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=5,\n                         log_every_n_steps=1)\n    datamodule = LightningNodeData(data, loader=loader, batch_size=batch_size,\n                                   num_workers=num_workers, **kwargs)\n\n    old_x = data.x.clone().cpu()\n    flag = loader != 'full'\n    if has_package('pytorch_lightning>=2.5.0'):\n        datamodule_repr = (\n            '{Train dataloader: ' + f'size={140 if flag else 1}' + '}\\n'\n            '{Validation dataloader: ' + f'size={500 if flag else 1}' + '}\\n'\n            '{Test dataloader: ' + f'size={1000 if flag else 1}' + '}\\n'\n            '{Predict dataloader: ' + f'size={2708 if flag else 1}' + '}')\n    else:\n        datamodule_repr = (f'LightningNodeData(data={data_repr}, '\n                           f'loader={loader}, batch_size={batch_size}, '\n                           f'num_workers={num_workers}, {kwargs_repr}'\n                           f'pin_memory={flag}, '\n                           f'persistent_workers={flag})')\n    assert str(datamodule) == datamodule_repr\n\n    trainer.fit(model, datamodule)\n    trainer.test(model, datamodule)\n    new_x = data.x.cpu()\n    assert torch.all(new_x > old_x)  # Ensure shared data.\n    assert trainer.validate_loop._data_source.is_defined()\n    assert trainer.test_loop._data_source.is_defined()\n\n\nclass LinearHeteroNodeModule(LightningModule):\n    def __init__(self, in_channels: int, out_channels: int):\n        super().__init__()\n        from torchmetrics import Accuracy\n\n        self.lin = torch.nn.Linear(in_channels, out_channels)\n\n        self.train_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.val_acc = Accuracy(task='multiclass', num_classes=out_channels)\n        self.test_acc = Accuracy(task='multiclass', num_classes=out_channels)\n\n    def forward(self, x: Tensor) -> Tensor:\n        # Basic test to ensure that the dataset is not replicated:\n        self.trainer.datamodule.data['paper'].x.add_(1)\n\n        return self.lin(x)\n\n    def training_step(self, data: HeteroData, batch_idx: int):\n        y_hat = self(data['paper'].x)[data['paper'].train_mask]\n        y = data['paper'].y[data['paper'].train_mask]\n        loss = F.cross_entropy(y_hat, y)\n        self.train_acc(y_hat.softmax(dim=-1), y)\n        self.log('loss', loss, batch_size=y.size(0))\n        self.log('train_acc', self.train_acc, batch_size=y.size(0))\n        return loss\n\n    def validation_step(self, data: HeteroData, batch_idx: int):\n        y_hat = self(data['paper'].x)[data['paper'].val_mask]\n        y = data['paper'].y[data['paper'].val_mask]\n        self.val_acc(y_hat.softmax(dim=-1), y)\n        self.log('val_acc', self.val_acc, batch_size=y.size(0))\n\n    def test_step(self, data: HeteroData, batch_idx: int):\n        y_hat = self(data['paper'].x)[data['paper'].test_mask]\n        y = data['paper'].y[data['paper'].test_mask]\n        self.test_acc(y_hat.softmax(dim=-1), y)\n        self.log('test_acc', self.test_acc, batch_size=y.size(0))\n\n    def configure_optimizers(self):\n        return torch.optim.Adam(self.parameters(), lr=0.01)\n\n\n@pytest.fixture\ndef preserve_context():\n    num_threads = torch.get_num_threads()\n    yield\n    if torch.distributed.is_initialized():\n        torch.distributed.destroy_process_group()\n    torch.set_num_threads(num_threads)\n\n\n@onlyCUDA\n@onlyFullTest\n@onlyNeighborSampler\n@withPackage('pytorch_lightning>=2.0.0', 'torchmetrics>=0.11.0')\ndef test_lightning_hetero_node_data(preserve_context, get_dataset):\n    import pytorch_lightning as pl\n\n    data = get_dataset(name='hetero')[0]\n\n    model = LinearHeteroNodeModule(data['paper'].num_features,\n                                   int(data['paper'].y.max()) + 1)\n\n    devices = torch.cuda.device_count()\n    strategy = pl.strategies.DDPStrategy(accelerator='gpu')\n\n    trainer = pl.Trainer(strategy=strategy, devices=devices, max_epochs=5,\n                         log_every_n_steps=1)\n    datamodule = LightningNodeData(data, loader='neighbor', num_neighbors=[5],\n                                   batch_size=32, num_workers=3)\n    assert isinstance(datamodule.graph_sampler, NeighborSampler)\n    original_x = data['paper'].x.clone()\n    trainer.fit(model, datamodule)\n    trainer.test(model, datamodule)\n    assert torch.all(data['paper'].x > original_x)  # Ensure shared data.\n    assert trainer.validate_loop._data_source.is_defined()\n    assert trainer.test_loop._data_source.is_defined()\n\n\n@withPackage('pytorch_lightning')\ndef test_lightning_data_custom_sampler():\n    class DummySampler(BaseSampler):\n        def sample_from_edges(self, *args, **kwargs):\n            pass\n\n        def sample_from_nodes(self, *args, **kwargs):\n            pass\n\n    data = Data(num_nodes=2, edge_index=torch.tensor([[0, 1], [1, 0]]))\n\n    datamodule = LightningNodeData(data, node_sampler=DummySampler(),\n                                   input_train_nodes=torch.arange(2))\n    assert isinstance(datamodule.graph_sampler, DummySampler)\n\n    datamodule = LightningLinkData(\n        data, link_sampler=DummySampler(),\n        input_train_edges=torch.tensor([[0, 1], [0, 1]]))\n    assert isinstance(datamodule.graph_sampler, DummySampler)\n\n\n@onlyCUDA\n@onlyFullTest\n@onlyNeighborSampler\n@withPackage('pytorch_lightning')\ndef test_lightning_hetero_link_data():\n    torch.manual_seed(12345)\n\n    data = HeteroData()\n\n    data['paper'].x = torch.arange(10)\n    data['author'].x = torch.arange(10)\n    data['term'].x = torch.arange(10)\n\n    data['paper', 'author'].edge_index = get_random_edge_index(10, 10, 10)\n    data['author', 'paper'].edge_index = get_random_edge_index(10, 10, 10)\n    data['paper', 'term'].edge_index = get_random_edge_index(10, 10, 10)\n    data['author', 'term'].edge_index = get_random_edge_index(10, 10, 10)\n\n    datamodule = LightningLinkData(\n        data,\n        input_train_edges=('author', 'paper'),\n        input_val_edges=('paper', 'author'),\n        input_test_edges=('paper', 'term'),\n        input_pred_edges=('author', 'term'),\n        loader='neighbor',\n        num_neighbors=[5],\n        batch_size=32,\n        num_workers=0,\n    )\n\n    assert isinstance(datamodule.graph_sampler, NeighborSampler)\n    assert isinstance(datamodule.eval_graph_sampler, NeighborSampler)\n\n    for batch in datamodule.train_dataloader():\n        assert 'edge_label_index' in batch['author', 'paper']\n    for batch in datamodule.val_dataloader():\n        assert 'edge_label_index' in batch['paper', 'author']\n    for batch in datamodule.test_dataloader():\n        assert 'edge_label_index' in batch['paper', 'term']\n    for batch in datamodule.predict_dataloader():\n        assert 'edge_label_index' in batch['author', 'term']\n\n    data['author'].time = torch.arange(data['author'].num_nodes)\n    data['paper'].time = torch.arange(data['paper'].num_nodes)\n    data['term'].time = torch.arange(data['term'].num_nodes)\n\n    datamodule = LightningLinkData(\n        data,\n        input_train_edges=('author', 'paper'),\n        input_train_time=torch.arange(data['author', 'paper'].num_edges),\n        loader='neighbor',\n        num_neighbors=[5],\n        batch_size=32,\n        num_workers=0,\n        time_attr='time',\n    )\n\n    for batch in datamodule.train_dataloader():\n        assert 'edge_label_index' in batch['author', 'paper']\n        assert 'edge_label_time' in batch['author', 'paper']\n\n\n@onlyNeighborSampler\n@withPackage('pytorch_lightning')\ndef test_lightning_hetero_link_data_custom_store():\n    torch.manual_seed(12345)\n\n    feature_store = MyFeatureStore()\n    graph_store = MyGraphStore()\n\n    x = torch.arange(10)\n    feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None)\n    feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)\n    feature_store.put_tensor(x, group_name='term', attr_name='x', index=None)\n\n    edge_index = get_random_edge_index(10, 10, 10)\n    graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),\n                               edge_type=('paper', 'to', 'author'),\n                               layout='coo', size=(10, 10))\n    graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),\n                               edge_type=('author', 'to', 'paper'),\n                               layout='coo', size=(10, 10))\n    graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),\n                               edge_type=('paper', 'to', 'term'), layout='coo',\n                               size=(10, 10))\n\n    datamodule = LightningLinkData(\n        (feature_store, graph_store),\n        input_train_edges=('author', 'to', 'paper'),\n        loader='neighbor',\n        num_neighbors=[5],\n        batch_size=32,\n        num_workers=0,\n    )\n\n    batch = next(iter(datamodule.train_dataloader()))\n    assert 'edge_label_index' in batch['author', 'paper']\n\n\n@onlyOnline\n@onlyNeighborSampler\n@withPackage('pytorch_lightning', 'scipy')\ndef test_eval_loader_kwargs(get_dataset):\n    data = get_dataset(name='Cora')[0]\n\n    node_sampler = NeighborSampler(data, num_neighbors=[5])\n\n    datamodule = LightningNodeData(\n        data,\n        node_sampler=node_sampler,\n        batch_size=32,\n        eval_loader_kwargs=dict(num_neighbors=[-1], batch_size=64),\n    )\n\n    assert datamodule.loader_kwargs['batch_size'] == 32\n    assert datamodule.graph_sampler.num_neighbors.values == [5]\n    assert datamodule.eval_loader_kwargs['batch_size'] == 64\n    assert datamodule.eval_graph_sampler.num_neighbors.values == [-1]\n\n    train_loader = datamodule.train_dataloader()\n    assert math.ceil(int(data.train_mask.sum()) / 32) == len(train_loader)\n\n    val_loader = datamodule.val_dataloader()\n    assert math.ceil(int(data.val_mask.sum()) / 64) == len(val_loader)\n\n    test_loader = datamodule.test_dataloader()\n    assert math.ceil(int(data.test_mask.sum()) / 64) == len(test_loader)\n\n    pred_loader = datamodule.predict_dataloader()\n    assert math.ceil(data.num_nodes / 64) == len(pred_loader)\n"
  },
  {
    "path": "test/data/test_batch.py",
    "content": "import os.path as osp\n\nimport numpy as np\nimport pytest\nimport torch\n\nimport torch_geometric\nfrom torch_geometric import EdgeIndex, Index\nfrom torch_geometric.data import Batch, Data, HeteroData\nfrom torch_geometric.testing import get_random_edge_index, withPackage\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_edge_index, to_torch_sparse_tensor\n\n\ndef test_batch_basic():\n    torch_geometric.set_debug(True)\n\n    x = torch.tensor([1.0, 2.0, 3.0])\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    data1 = Data(x=x, y=1, edge_index=edge_index, string='1', array=['1', '2'],\n                 num_nodes=3)\n\n    x = torch.tensor([1.0, 2.0])\n    edge_index = torch.tensor([[0, 1], [1, 0]])\n    data2 = Data(x=x, y=2, edge_index=edge_index, string='2',\n                 array=['3', '4', '5'], num_nodes=2)\n\n    x = torch.tensor([1.0, 2.0, 3.0, 4.0])\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])\n    data3 = Data(x=x, y=3, edge_index=edge_index, string='3',\n                 array=['6', '7', '8', '9'], num_nodes=4)\n\n    batch = Batch.from_data_list([data1])\n    assert str(batch) == ('DataBatch(x=[3], edge_index=[2, 4], y=[1], '\n                          'string=[1], array=[1], num_nodes=3, batch=[3], '\n                          'ptr=[2])')\n    assert batch.num_graphs == len(batch) == 1\n    assert batch.x.tolist() == [1, 2, 3]\n    assert batch.y.tolist() == [1]\n    assert batch.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert batch.string == ['1']\n    assert batch.array == [['1', '2']]\n    assert batch.num_nodes == 3\n    assert batch.batch.tolist() == [0, 0, 0]\n    assert batch.ptr.tolist() == [0, 3]\n\n    batch = Batch.from_data_list([data1, data2, data3],\n                                 follow_batch=['string'])\n\n    assert str(batch) == ('DataBatch(x=[9], edge_index=[2, 12], y=[3], '\n                          'string=[3], string_batch=[3], string_ptr=[4], '\n                          'array=[3], num_nodes=9, batch=[9], ptr=[4])')\n    assert batch.num_graphs == len(batch) == 3\n    assert batch.x.tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4]\n    assert batch.y.tolist() == [1, 2, 3]\n    assert batch.edge_index.tolist() == [[0, 1, 1, 2, 3, 4, 5, 6, 6, 7, 7, 8],\n                                         [1, 0, 2, 1, 4, 3, 6, 5, 7, 6, 8, 7]]\n    assert batch.string == ['1', '2', '3']\n    assert batch.string_batch.tolist() == [0, 1, 2]\n    assert batch.string_ptr.tolist() == [0, 1, 2, 3]\n    assert batch.array == [['1', '2'], ['3', '4', '5'], ['6', '7', '8', '9']]\n    assert batch.num_nodes == 9\n    assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2]\n    assert batch.ptr.tolist() == [0, 3, 5, 9]\n\n    assert str(batch[0]) == (\"Data(x=[3], edge_index=[2, 4], y=[1], \"\n                             \"string='1', array=[2], num_nodes=3)\")\n    assert str(batch[1]) == (\"Data(x=[2], edge_index=[2, 2], y=[1], \"\n                             \"string='2', array=[3], num_nodes=2)\")\n    assert str(batch[2]) == (\"Data(x=[4], edge_index=[2, 6], y=[1], \"\n                             \"string='3', array=[4], num_nodes=4)\")\n\n    assert len(batch.index_select([1, 0])) == 2\n    assert len(batch.index_select(torch.tensor([1, 0]))) == 2\n    assert len(batch.index_select(torch.tensor([True, False]))) == 1\n    assert len(batch.index_select(np.array([1, 0], dtype=np.int64))) == 2\n    assert len(batch.index_select(np.array([True, False]))) == 1\n    assert len(batch[:2]) == 2\n\n    data_list = batch.to_data_list()\n    assert len(data_list) == 3\n\n    assert len(data_list[0]) == 6\n    assert data_list[0].x.tolist() == [1, 2, 3]\n    assert data_list[0].y.tolist() == [1]\n    assert data_list[0].edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert data_list[0].string == '1'\n    assert data_list[0].array == ['1', '2']\n    assert data_list[0].num_nodes == 3\n\n    assert len(data_list[1]) == 6\n    assert data_list[1].x.tolist() == [1, 2]\n    assert data_list[1].y.tolist() == [2]\n    assert data_list[1].edge_index.tolist() == [[0, 1], [1, 0]]\n    assert data_list[1].string == '2'\n    assert data_list[1].array == ['3', '4', '5']\n    assert data_list[1].num_nodes == 2\n\n    assert len(data_list[2]) == 6\n    assert data_list[2].x.tolist() == [1, 2, 3, 4]\n    assert data_list[2].y.tolist() == [3]\n    assert data_list[2].edge_index.tolist() == [[0, 1, 1, 2, 2, 3],\n                                                [1, 0, 2, 1, 3, 2]]\n    assert data_list[2].string == '3'\n    assert data_list[2].array == ['6', '7', '8', '9']\n    assert data_list[2].num_nodes == 4\n\n    torch_geometric.set_debug(True)\n\n\ndef test_index():\n    index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)\n    index2 = Index([0, 1, 1, 2, 2, 3], dim_size=4, is_sorted=True)\n\n    data1 = Data(index=index1, num_nodes=3)\n    data2 = Data(index=index2, num_nodes=4)\n\n    batch = Batch.from_data_list([data1, data2])\n\n    assert len(batch) == 2\n    assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1]))\n    assert batch.ptr.equal(torch.tensor([0, 3, 7]))\n    assert isinstance(batch.index, Index)\n    assert batch.index.equal(torch.tensor([0, 1, 1, 2, 3, 4, 4, 5, 5, 6]))\n    assert batch.index.dim_size == 7\n    assert batch.index.is_sorted\n\n    for i, index in enumerate([index1, index2]):\n        data = batch[i]\n        assert isinstance(data.index, Index)\n        assert data.index.equal(index)\n        assert data.index.dim_size == index.dim_size\n        assert data.index.is_sorted == index.is_sorted\n\n\ndef test_edge_index():\n    edge_index1 = EdgeIndex(\n        [[0, 1, 1, 2], [1, 0, 2, 1]],\n        sparse_size=(3, 3),\n        sort_order='row',\n        is_undirected=True,\n    )\n    edge_index2 = EdgeIndex(\n        [[1, 0, 2, 1, 3, 2], [0, 1, 1, 2, 2, 3]],\n        sparse_size=(4, 4),\n        sort_order='col',\n    )\n\n    data1 = Data(edge_index=edge_index1)\n    data2 = Data(edge_index=edge_index2)\n\n    batch = Batch.from_data_list([data1, data2])\n\n    assert len(batch) == 2\n    assert batch.batch.equal(torch.tensor([0, 0, 0, 1, 1, 1, 1]))\n    assert batch.ptr.equal(torch.tensor([0, 3, 7]))\n    assert isinstance(batch.edge_index, EdgeIndex)\n    assert batch.edge_index.equal(\n        torch.tensor([\n            [0, 1, 1, 2, 4, 3, 5, 4, 6, 5],\n            [1, 0, 2, 1, 3, 4, 4, 5, 5, 6],\n        ]))\n    assert batch.edge_index.sparse_size() == (7, 7)\n    assert batch.edge_index.sort_order is None\n    assert not batch.edge_index.is_undirected\n\n    for i, edge_index in enumerate([edge_index1, edge_index2]):\n        data = batch[i]\n        assert isinstance(data.edge_index, EdgeIndex)\n        assert data.edge_index.equal(edge_index)\n        assert data.edge_index.sparse_size() == edge_index.sparse_size()\n        assert data.edge_index.sort_order == edge_index.sort_order\n        assert data.edge_index.is_undirected == edge_index.is_undirected\n\n\n@withPackage('torch_sparse')\ndef test_batch_with_sparse_tensor():\n    x = SparseTensor.from_dense(torch.tensor([[1.0], [2.0], [3.0]]))\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    adj = SparseTensor.from_edge_index(edge_index)\n    data1 = Data(x=x, adj=adj)\n\n    x = SparseTensor.from_dense(torch.tensor([[1.0], [2.0]]))\n    edge_index = torch.tensor([[0, 1], [1, 0]])\n    adj = SparseTensor.from_edge_index(edge_index)\n    data2 = Data(x=x, adj=adj)\n\n    x = SparseTensor.from_dense(torch.tensor([[1.0], [2.0], [3.0], [4.0]]))\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])\n    adj = SparseTensor.from_edge_index(edge_index)\n    data3 = Data(x=x, adj=adj)\n\n    batch = Batch.from_data_list([data1])\n    assert str(batch) == ('DataBatch(x=[3, 1, nnz=3], adj=[3, 3, nnz=4], '\n                          'batch=[3], ptr=[2])')\n    assert batch.num_graphs == len(batch) == 1\n    assert batch.x.to_dense().tolist() == [[1], [2], [3]]\n    assert batch.adj.coo()[0].tolist() == [0, 1, 1, 2]\n    assert batch.adj.coo()[1].tolist() == [1, 0, 2, 1]\n    assert batch.batch.tolist() == [0, 0, 0]\n    assert batch.ptr.tolist() == [0, 3]\n\n    batch = Batch.from_data_list([data1, data2, data3])\n\n    assert str(batch) == ('DataBatch(x=[9, 1, nnz=9], adj=[9, 9, nnz=12], '\n                          'batch=[9], ptr=[4])')\n    assert batch.num_graphs == len(batch) == 3\n    assert batch.x.to_dense().view(-1).tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4]\n    assert batch.adj.coo()[0].tolist() == [0, 1, 1, 2, 3, 4, 5, 6, 6, 7, 7, 8]\n    assert batch.adj.coo()[1].tolist() == [1, 0, 2, 1, 4, 3, 6, 5, 7, 6, 8, 7]\n    assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2]\n    assert batch.ptr.tolist() == [0, 3, 5, 9]\n\n    assert str(batch[0]) == (\"Data(x=[3, 1, nnz=3], adj=[3, 3, nnz=4])\")\n    assert str(batch[1]) == (\"Data(x=[2, 1, nnz=2], adj=[2, 2, nnz=2])\")\n    assert str(batch[2]) == (\"Data(x=[4, 1, nnz=4], adj=[4, 4, nnz=6])\")\n\n    data_list = batch.to_data_list()\n    assert len(data_list) == 3\n\n    assert len(data_list[0]) == 2\n    assert data_list[0].x.to_dense().tolist() == [[1], [2], [3]]\n    assert data_list[0].adj.coo()[0].tolist() == [0, 1, 1, 2]\n    assert data_list[0].adj.coo()[1].tolist() == [1, 0, 2, 1]\n\n    assert len(data_list[1]) == 2\n    assert data_list[1].x.to_dense().tolist() == [[1], [2]]\n    assert data_list[1].adj.coo()[0].tolist() == [0, 1]\n    assert data_list[1].adj.coo()[1].tolist() == [1, 0]\n\n    assert len(data_list[2]) == 2\n    assert data_list[2].x.to_dense().tolist() == [[1], [2], [3], [4]]\n    assert data_list[2].adj.coo()[0].tolist() == [0, 1, 1, 2, 2, 3]\n    assert data_list[2].adj.coo()[1].tolist() == [1, 0, 2, 1, 3, 2]\n\n\ndef test_batch_with_torch_coo_tensor():\n    x = torch.tensor([[1.0], [2.0], [3.0]]).to_sparse_coo()\n    data1 = Data(x=x)\n\n    x = torch.tensor([[1.0], [2.0]]).to_sparse_coo()\n    data2 = Data(x=x)\n\n    x = torch.tensor([[1.0], [2.0], [3.0], [4.0]]).to_sparse_coo()\n    data3 = Data(x=x)\n\n    batch = Batch.from_data_list([data1])\n    assert str(batch) == ('DataBatch(x=[3, 1], batch=[3], ptr=[2])')\n    assert batch.num_graphs == len(batch) == 1\n    assert batch.x.to_dense().tolist() == [[1], [2], [3]]\n    assert batch.batch.tolist() == [0, 0, 0]\n    assert batch.ptr.tolist() == [0, 3]\n\n    batch = Batch.from_data_list([data1, data2, data3])\n\n    assert str(batch) == ('DataBatch(x=[9, 1], batch=[9], ptr=[4])')\n    assert batch.num_graphs == len(batch) == 3\n    assert batch.x.to_dense().view(-1).tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4]\n    assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2]\n    assert batch.ptr.tolist() == [0, 3, 5, 9]\n\n    assert str(batch[0]) == (\"Data(x=[3, 1])\")\n    assert str(batch[1]) == (\"Data(x=[2, 1])\")\n    assert str(batch[2]) == (\"Data(x=[4, 1])\")\n\n    data_list = batch.to_data_list()\n    assert len(data_list) == 3\n\n    assert len(data_list[0]) == 1\n    assert data_list[0].x.to_dense().tolist() == [[1], [2], [3]]\n\n    assert len(data_list[1]) == 1\n    assert data_list[1].x.to_dense().tolist() == [[1], [2]]\n\n    assert len(data_list[2]) == 1\n    assert data_list[2].x.to_dense().tolist() == [[1], [2], [3], [4]]\n\n\ndef test_batching_with_new_dimension():\n    torch_geometric.set_debug(True)\n\n    class MyData(Data):\n        def __cat_dim__(self, key, value, *args, **kwargs):\n            if key == 'foo':\n                return None\n            else:\n                return super().__cat_dim__(key, value, *args, **kwargs)\n\n    x1 = torch.tensor([1, 2, 3], dtype=torch.float)\n    foo1 = torch.randn(4)\n    y1 = torch.tensor(1)\n\n    x2 = torch.tensor([1, 2], dtype=torch.float)\n    foo2 = torch.randn(4)\n    y2 = torch.tensor(2)\n\n    batch = Batch.from_data_list(\n        [MyData(x=x1, foo=foo1, y=y1),\n         MyData(x=x2, foo=foo2, y=y2)])\n\n    assert str(batch) == ('MyDataBatch(x=[5], y=[2], foo=[2, 4], batch=[5], '\n                          'ptr=[3])')\n    assert batch.num_graphs == len(batch) == 2\n    assert batch.x.tolist() == [1, 2, 3, 1, 2]\n    assert batch.foo.size() == (2, 4)\n    assert batch.foo[0].tolist() == foo1.tolist()\n    assert batch.foo[1].tolist() == foo2.tolist()\n    assert batch.y.tolist() == [1, 2]\n    assert batch.batch.tolist() == [0, 0, 0, 1, 1]\n    assert batch.ptr.tolist() == [0, 3, 5]\n    assert batch.num_graphs == 2\n\n    data = batch[0]\n    assert str(data) == ('MyData(x=[3], y=[1], foo=[4])')\n    data = batch[1]\n    assert str(data) == ('MyData(x=[2], y=[1], foo=[4])')\n\n    torch_geometric.set_debug(True)\n\n\ndef test_pickling(tmp_path):\n    data = Data(x=torch.randn(5, 16))\n    batch = Batch.from_data_list([data, data, data, data])\n    assert id(batch._store._parent()) == id(batch)\n    assert batch.num_nodes == 20\n\n    # filename = f'{random.randrange(sys.maxsize)}.pt'\n    path = osp.join(tmp_path, 'batch.pt')\n    torch.save(batch, path)\n    assert id(batch._store._parent()) == id(batch)\n    assert batch.num_nodes == 20\n\n    batch = torch.load(path, weights_only=False)\n    assert id(batch._store._parent()) == id(batch)\n    assert batch.num_nodes == 20\n\n    assert batch.__class__.__name__ == 'DataBatch'\n    assert batch.num_graphs == len(batch) == 4\n\n\ndef test_recursive_batch():\n    data1 = Data(\n        x={\n            '1': torch.randn(10, 32),\n            '2': torch.randn(20, 48)\n        },\n        edge_index=[\n            get_random_edge_index(30, 30, 50),\n            get_random_edge_index(30, 30, 70)\n        ],\n        num_nodes=30,\n    )\n\n    data2 = Data(\n        x={\n            '1': torch.randn(20, 32),\n            '2': torch.randn(40, 48)\n        },\n        edge_index=[\n            get_random_edge_index(60, 60, 80),\n            get_random_edge_index(60, 60, 90)\n        ],\n        num_nodes=60,\n    )\n\n    batch = Batch.from_data_list([data1, data2])\n\n    assert batch.num_graphs == len(batch) == 2\n    assert batch.num_nodes == 90\n\n    assert torch.allclose(batch.x['1'],\n                          torch.cat([data1.x['1'], data2.x['1']], dim=0))\n    assert torch.allclose(batch.x['2'],\n                          torch.cat([data1.x['2'], data2.x['2']], dim=0))\n    assert (batch.edge_index[0].tolist() == torch.cat(\n        [data1.edge_index[0], data2.edge_index[0] + 30], dim=1).tolist())\n    assert (batch.edge_index[1].tolist() == torch.cat(\n        [data1.edge_index[1], data2.edge_index[1] + 30], dim=1).tolist())\n    assert batch.batch.size() == (90, )\n    assert batch.ptr.size() == (3, )\n\n    out1 = batch[0]\n    assert len(out1) == 3\n    assert out1.num_nodes == 30\n    assert torch.allclose(out1.x['1'], data1.x['1'])\n    assert torch.allclose(out1.x['2'], data1.x['2'])\n    assert out1.edge_index[0].tolist(), data1.edge_index[0].tolist()\n    assert out1.edge_index[1].tolist(), data1.edge_index[1].tolist()\n\n    out2 = batch[1]\n    assert len(out2) == 3\n    assert out2.num_nodes == 60\n    assert torch.allclose(out2.x['1'], data2.x['1'])\n    assert torch.allclose(out2.x['2'], data2.x['2'])\n    assert out2.edge_index[0].tolist(), data2.edge_index[0].tolist()\n    assert out2.edge_index[1].tolist(), data2.edge_index[1].tolist()\n\n\ndef test_batching_of_batches():\n    data = Data(x=torch.randn(2, 16))\n    batch = Batch.from_data_list([data, data])\n\n    batch = Batch.from_data_list([batch, batch])\n    assert batch.num_graphs == len(batch) == 2\n    assert batch.x[0:2].tolist() == data.x.tolist()\n    assert batch.x[2:4].tolist() == data.x.tolist()\n    assert batch.x[4:6].tolist() == data.x.tolist()\n    assert batch.x[6:8].tolist() == data.x.tolist()\n    assert batch.batch.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]\n\n\ndef test_hetero_batch():\n    e1 = ('p', 'a')\n    e2 = ('a', 'p')\n    data1 = HeteroData()\n    data1['p'].x = torch.randn(100, 128)\n    data1['a'].x = torch.randn(200, 128)\n    data1[e1].edge_index = get_random_edge_index(100, 200, 500)\n    data1[e1].edge_attr = torch.randn(500, 32)\n    data1[e2].edge_index = get_random_edge_index(200, 100, 400)\n    data1[e2].edge_attr = torch.randn(400, 32)\n\n    data2 = HeteroData()\n    data2['p'].x = torch.randn(50, 128)\n    data2['a'].x = torch.randn(100, 128)\n    data2[e1].edge_index = get_random_edge_index(50, 100, 300)\n    data2[e1].edge_attr = torch.randn(300, 32)\n    data2[e2].edge_index = get_random_edge_index(100, 50, 200)\n    data2[e2].edge_attr = torch.randn(200, 32)\n\n    batch = Batch.from_data_list([data1, data2])\n\n    assert batch.num_graphs == len(batch) == 2\n    assert batch.num_nodes == 450\n\n    assert torch.allclose(batch['p'].x[:100], data1['p'].x)\n    assert torch.allclose(batch['a'].x[:200], data1['a'].x)\n    assert torch.allclose(batch['p'].x[100:], data2['p'].x)\n    assert torch.allclose(batch['a'].x[200:], data2['a'].x)\n    assert (batch[e1].edge_index.tolist() == torch.cat([\n        data1[e1].edge_index,\n        data2[e1].edge_index + torch.tensor([[100], [200]])\n    ], 1).tolist())\n    assert torch.allclose(\n        batch[e1].edge_attr,\n        torch.cat([data1[e1].edge_attr, data2[e1].edge_attr], 0))\n    assert (batch[e2].edge_index.tolist() == torch.cat([\n        data1[e2].edge_index,\n        data2[e2].edge_index + torch.tensor([[200], [100]])\n    ], 1).tolist())\n    assert torch.allclose(\n        batch[e2].edge_attr,\n        torch.cat([data1[e2].edge_attr, data2[e2].edge_attr], 0))\n    assert batch['p'].batch.size() == (150, )\n    assert batch['p'].ptr.size() == (3, )\n    assert batch['a'].batch.size() == (300, )\n    assert batch['a'].ptr.size() == (3, )\n\n    out1 = batch[0]\n    assert len(out1) == 3\n    assert out1.num_nodes == 300\n    assert torch.allclose(out1['p'].x, data1['p'].x)\n    assert torch.allclose(out1['a'].x, data1['a'].x)\n    assert out1[e1].edge_index.tolist() == data1[e1].edge_index.tolist()\n    assert torch.allclose(out1[e1].edge_attr, data1[e1].edge_attr)\n    assert out1[e2].edge_index.tolist() == data1[e2].edge_index.tolist()\n    assert torch.allclose(out1[e2].edge_attr, data1[e2].edge_attr)\n\n    out2 = batch[1]\n    assert len(out2) == 3\n    assert out2.num_nodes == 150\n    assert torch.allclose(out2['p'].x, data2['p'].x)\n    assert torch.allclose(out2['a'].x, data2['a'].x)\n    assert out2[e1].edge_index.tolist() == data2[e1].edge_index.tolist()\n    assert torch.allclose(out2[e1].edge_attr, data2[e1].edge_attr)\n    assert out2[e2].edge_index.tolist() == data2[e2].edge_index.tolist()\n    assert torch.allclose(out2[e2].edge_attr, data2[e2].edge_attr)\n\n\ndef test_pair_data_batching():\n    class PairData(Data):\n        def __inc__(self, key, value, *args, **kwargs):\n            if key == 'edge_index_s':\n                return self.x_s.size(0)\n            if key == 'edge_index_t':\n                return self.x_t.size(0)\n            return super().__inc__(key, value, *args, **kwargs)\n\n    x_s = torch.randn(5, 16)\n    edge_index_s = torch.tensor([\n        [0, 0, 0, 0],\n        [1, 2, 3, 4],\n    ])\n    x_t = torch.randn(4, 16)\n    edge_index_t = torch.tensor([\n        [0, 0, 0],\n        [1, 2, 3],\n    ])\n\n    data = PairData(x_s=x_s, edge_index_s=edge_index_s, x_t=x_t,\n                    edge_index_t=edge_index_t)\n    batch = Batch.from_data_list([data, data])\n\n    assert torch.allclose(batch.x_s, torch.cat([x_s, x_s], dim=0))\n    assert batch.edge_index_s.tolist() == [[0, 0, 0, 0, 5, 5, 5, 5],\n                                           [1, 2, 3, 4, 6, 7, 8, 9]]\n\n    assert torch.allclose(batch.x_t, torch.cat([x_t, x_t], dim=0))\n    assert batch.edge_index_t.tolist() == [[0, 0, 0, 4, 4, 4],\n                                           [1, 2, 3, 5, 6, 7]]\n\n\ndef test_batch_with_empty_list():\n    x = torch.randn(4, 1)\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])\n    data = Data(x=x, edge_index=edge_index, nontensor=[])\n\n    batch = Batch.from_data_list([data, data])\n    assert batch.nontensor == [[], []]\n    assert batch[0].nontensor == []\n    assert batch[1].nontensor == []\n\n\ndef test_nested_follow_batch():\n    def tr(n, m):\n        return torch.rand((n, m))\n\n    d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)], a={\"aa\": tr(11, 3)},\n              x=tr(10, 5))\n    d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)], a={\"aa\": tr(2, 3)},\n              x=tr(11, 5))\n    d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)], a={\"aa\": tr(4, 3)},\n              x=tr(9, 5))\n    d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)], a={\"aa\": tr(8, 3)},\n              x=tr(8, 5))\n\n    data_list = [d1, d2, d3, d4]\n\n    batch = Batch.from_data_list(data_list, follow_batch=['xs', 'a'])\n\n    assert batch.xs[0].shape == (19, 3)\n    assert batch.xs[1].shape == (56, 4)\n    assert batch.xs[2].shape == (7, 2)\n    assert batch.a['aa'].shape == (25, 3)\n\n    assert len(batch.xs_batch) == 3\n    assert len(batch.a_batch) == 1\n\n    assert batch.xs_batch[0].tolist() == \\\n           [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3]\n    assert batch.xs_batch[1].tolist() == \\\n           [0] * 11 + [1] * 14 + [2] * 15 + [3] * 16\n    assert batch.xs_batch[2].tolist() == \\\n           [0] * 1 + [1] * 3 + [2] * 2 + [3] * 1\n\n    assert batch.a_batch['aa'].tolist() == \\\n           [0] * 11 + [1] * 2 + [2] * 4 + [3] * 8\n\n\n@withPackage('torch>=2.0.0')\n@pytest.mark.parametrize('layout', [\n    torch.sparse_coo,\n    torch.sparse_csr,\n    torch.sparse_csc,\n])\ndef test_torch_sparse_batch(layout):\n    x_dense = torch.randn(3, 4)\n    x = x_dense.to_sparse(layout=layout)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_attr = torch.rand(4)\n    adj = to_torch_sparse_tensor(edge_index, edge_attr, layout=layout)\n\n    data = Data(x=x, adj=adj)\n\n    batch = Batch.from_data_list([data, data])\n\n    assert batch.x.size() == (6, 4)\n    assert batch.x.layout in {torch.sparse_coo, torch.sparse_csr}\n    assert torch.equal(batch.x.to_dense(), torch.cat([x_dense, x_dense], 0))\n\n    assert batch.adj.size() == (6, 6)\n    assert batch.adj.layout == layout\n    out = to_edge_index(batch.adj.to_sparse(layout=torch.sparse_csr))\n    assert torch.equal(out[0], torch.cat([edge_index, edge_index + 3], 1))\n    assert torch.equal(out[1], torch.cat([edge_attr, edge_attr], 0))\n\n\ndef test_torch_nested_batch():\n    from torch.nested import nested_tensor\n\n    class MyData(Data):\n        def __inc__(self, key, value, *args, **kwargs) -> int:\n            return 2\n\n    x1 = nested_tensor([torch.randn(3), torch.randn(4)])\n    data1 = MyData(x=x1)\n    assert str(data1) == 'MyData(x=[2, 4])'\n\n    x2 = nested_tensor([torch.randn(3), torch.randn(4), torch.randn(5)])\n    data2 = MyData(x=x2)\n    assert str(data2) == 'MyData(x=[3, 5])'\n\n    batch = Batch.from_data_list([data1, data2])\n    assert str(batch) == 'MyDataBatch(x=[5, 5], batch=[5], ptr=[3])'\n\n    expected = nested_tensor(list(x1.unbind() + (x2 + 2).unbind()))\n    assert torch.equal(\n        batch.x.to_padded_tensor(0.0),\n        expected.to_padded_tensor(0.0),\n    )\n"
  },
  {
    "path": "test/data/test_data.py",
    "content": "import copy\n\nimport pytest\nimport torch\nimport torch.multiprocessing as mp\n\nimport torch_geometric\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.storage import AttrType\nfrom torch_geometric.testing import get_random_tensor_frame, withPackage\n\n\ndef test_data():\n    torch_geometric.set_debug(True)\n\n    x = torch.tensor([[1, 3, 5], [2, 4, 6]], dtype=torch.float).t()\n    edge_index = torch.tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]])\n    data = Data(x=x, edge_index=edge_index).to(torch.device('cpu'))\n    data.validate(raise_on_error=True)\n\n    N = data.num_nodes\n    assert N == 3\n\n    assert data.node_attrs() == ['x']\n    assert data.edge_attrs() == ['edge_index']\n\n    assert data.x.tolist() == x.tolist()\n    assert data['x'].tolist() == x.tolist()\n    assert data.get('x').tolist() == x.tolist()\n    assert data.get('y', 2) == 2\n    assert data.get('y', None) is None\n    assert data.num_edge_types == 1\n    assert data.num_node_types == 1\n    assert next(data('x')) == ('x', x)\n\n    assert sorted(data.keys()) == ['edge_index', 'x']\n    assert len(data) == 2\n    assert 'x' in data and 'edge_index' in data and 'pos' not in data\n\n    data.apply_(lambda x: x.mul_(2), 'x')\n    assert torch.allclose(data.x, x)\n\n    data.requires_grad_('x')\n    assert data.x.requires_grad is True\n\n    D = data.to_dict()\n    assert len(D) == 2\n    assert 'x' in D and 'edge_index' in D\n\n    D = data.to_namedtuple()\n    assert len(D) == 2\n    assert D.x is not None and D.edge_index is not None\n\n    assert data.__cat_dim__('x', data.x) == 0\n    assert data.__cat_dim__('edge_index', data.edge_index) == -1\n    assert data.__inc__('x', data.x) == 0\n    assert data.__inc__('edge_index', data.edge_index) == data.num_nodes\n\n    assert not data.x.is_contiguous()\n    data.contiguous()\n    assert data.x.is_contiguous()\n\n    assert not data.is_coalesced()\n    data = data.coalesce()\n    assert data.is_coalesced()\n\n    clone = data.clone()\n    assert clone != data\n    assert len(clone) == len(data)\n    assert clone.x.data_ptr() != data.x.data_ptr()\n    assert clone.x.tolist() == data.x.tolist()\n    assert clone.edge_index.data_ptr() != data.edge_index.data_ptr()\n    assert clone.edge_index.tolist() == data.edge_index.tolist()\n\n    # Test `data.to_heterogeneous()`:\n    out = data.to_heterogeneous()\n    assert torch.allclose(data.x, out['0'].x)\n    assert torch.allclose(data.edge_index, out['0', '0'].edge_index)\n\n    data.edge_type = torch.tensor([0, 0, 1, 0])\n    out = data.to_heterogeneous()\n    assert torch.allclose(data.x, out['0'].x)\n    assert [store.num_edges for store in out.edge_stores] == [3, 1]\n    data.edge_type = None\n\n    data['x'] = x + 1\n    assert data.x.tolist() == (x + 1).tolist()\n\n    assert str(data) == 'Data(x=[3, 2], edge_index=[2, 4])'\n\n    dictionary = {'x': data.x, 'edge_index': data.edge_index}\n    data = Data.from_dict(dictionary)\n    assert sorted(data.keys()) == ['edge_index', 'x']\n\n    assert not data.has_isolated_nodes()\n    assert not data.has_self_loops()\n    assert data.is_undirected()\n    assert not data.is_directed()\n\n    assert data.num_nodes == 3\n    assert data.num_edges == 4\n    with pytest.warns(UserWarning, match='deprecated'):\n        assert data.num_faces is None\n    assert data.num_node_features == 2\n    assert data.num_features == 2\n\n    data.edge_attr = torch.randn(data.num_edges, 2)\n    assert data.num_edge_features == 2\n    data.edge_attr = None\n\n    data.x = None\n    with pytest.warns(UserWarning, match='Unable to accurately infer'):\n        assert data.num_nodes == 3\n\n    data.edge_index = None\n    with pytest.warns(UserWarning, match='Unable to accurately infer'):\n        assert data.num_nodes is None\n    assert data.num_edges == 0\n\n    data.num_nodes = 4\n    assert data.num_nodes == 4\n\n    data = Data(x=x, attribute=x)\n    assert len(data) == 2\n    assert data.x.tolist() == x.tolist()\n    assert data.attribute.tolist() == x.tolist()\n\n    face = torch.tensor([[0, 1], [1, 2], [2, 3]])\n    data = Data(num_nodes=4, face=face)\n    with pytest.warns(UserWarning, match='deprecated'):\n        assert data.num_faces == 2\n    assert data.num_nodes == 4\n\n    data = Data(title='test')\n    assert str(data) == \"Data(title='test')\"\n    assert data.num_node_features == 0\n    assert data.num_edge_features == 0\n\n    key = value = 'test_value'\n    data[key] = value\n    assert data[key] == value\n    del data[value]\n    del data[value]  # Deleting unset attributes should work as well.\n\n    assert data.get(key) is None\n    assert data.get('title') == 'test'\n\n    torch_geometric.set_debug(False)\n\n\ndef test_data_attr_cache():\n    x = torch.randn(3, 16)\n    edge_index = torch.tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]])\n    edge_attr = torch.randn(5, 4)\n    y = torch.tensor([0])\n\n    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)\n\n    assert data.is_node_attr('x')\n    assert 'x' in data._store._cached_attr[AttrType.NODE]\n    assert 'x' not in data._store._cached_attr[AttrType.EDGE]\n    assert 'x' not in data._store._cached_attr[AttrType.OTHER]\n\n    assert not data.is_node_attr('edge_index')\n    assert 'edge_index' not in data._store._cached_attr[AttrType.NODE]\n    assert 'edge_index' in data._store._cached_attr[AttrType.EDGE]\n    assert 'edge_index' not in data._store._cached_attr[AttrType.OTHER]\n\n    assert data.is_edge_attr('edge_attr')\n    assert 'edge_attr' not in data._store._cached_attr[AttrType.NODE]\n    assert 'edge_attr' in data._store._cached_attr[AttrType.EDGE]\n    assert 'edge_attr' not in data._store._cached_attr[AttrType.OTHER]\n\n    assert not data.is_edge_attr('y')\n    assert 'y' not in data._store._cached_attr[AttrType.NODE]\n    assert 'y' not in data._store._cached_attr[AttrType.EDGE]\n    assert 'y' in data._store._cached_attr[AttrType.OTHER]\n\n\ndef test_data_attr_cache_not_shared():\n    x = torch.rand((4, 4))\n    edge_index = torch.tensor([[0, 1, 2, 3, 0, 1], [0, 1, 2, 3, 0, 1]])\n    time = torch.arange(edge_index.size(1))\n    data = Data(x=x, edge_index=edge_index, time=time)\n    assert data.is_node_attr('x')\n\n    out = data.up_to(3.5)\n    # This is expected behavior due to the ambiguity of between node-level and\n    # edge-level tensors when they share the same number of nodes/edges.\n    assert out.is_node_attr('time')\n    assert not data.is_node_attr('time')\n\n\ndef test_to_heterogeneous_empty_edge_index():\n    data = Data(\n        x=torch.randn(5, 10),\n        edge_index=torch.empty(2, 0, dtype=torch.long),\n    )\n    hetero_data = data.to_heterogeneous()\n    assert hetero_data.node_types == ['0']\n    assert hetero_data.edge_types == []\n    assert len(hetero_data) == 1\n    assert torch.equal(hetero_data['0'].x, data.x)\n\n    hetero_data = data.to_heterogeneous(\n        node_type_names=['0'],\n        edge_type_names=[('0', 'to', '0')],\n    )\n    assert hetero_data.node_types == ['0']\n    assert hetero_data.edge_types == [('0', 'to', '0')]\n    assert len(hetero_data) == 2\n    assert torch.equal(hetero_data['0'].x, data.x)\n    assert torch.equal(hetero_data['0', 'to', '0'].edge_index, data.edge_index)\n\n\ndef test_data_subgraph():\n    x = torch.arange(5)\n    y = torch.tensor([0.])\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4],\n                               [1, 0, 2, 1, 3, 2, 4, 3]])\n    edge_weight = torch.arange(edge_index.size(1))\n\n    data = Data(x=x, y=y, edge_index=edge_index, edge_weight=edge_weight,\n                num_nodes=5)\n\n    out = data.subgraph(torch.tensor([1, 2, 3]))\n    assert len(out) == 5\n    assert torch.equal(out.x, torch.arange(1, 4))\n    assert torch.equal(out.y, data.y)\n    assert out.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert torch.equal(out.edge_weight, edge_weight[torch.arange(2, 6)])\n    assert out.num_nodes == 3\n\n    # Test unordered selection:\n    out = data.subgraph(torch.tensor([3, 1, 2]))\n    assert len(out) == 5\n    assert torch.equal(out.x, torch.tensor([3, 1, 2]))\n    assert torch.equal(out.y, data.y)\n    assert out.edge_index.tolist() == [[1, 2, 2, 0], [2, 1, 0, 2]]\n    assert torch.equal(out.edge_weight, edge_weight[torch.arange(2, 6)])\n    assert out.num_nodes == 3\n\n    out = data.subgraph(torch.tensor([False, False, False, True, True]))\n    assert len(out) == 5\n    assert torch.equal(out.x, torch.arange(3, 5))\n    assert torch.equal(out.y, data.y)\n    assert out.edge_index.tolist() == [[0, 1], [1, 0]]\n    assert torch.equal(out.edge_weight, edge_weight[torch.arange(6, 8)])\n    assert out.num_nodes == 2\n\n    out = data.edge_subgraph(torch.tensor([1, 2, 3]))\n    assert len(out) == 5\n    assert out.num_nodes == data.num_nodes\n    assert torch.equal(out.x, data.x)\n    assert torch.equal(out.y, data.y)\n    assert out.edge_index.tolist() == [[1, 1, 2], [0, 2, 1]]\n    assert torch.equal(out.edge_weight, edge_weight[torch.tensor([1, 2, 3])])\n\n    out = data.edge_subgraph(\n        torch.tensor([False, True, True, True, False, False, False, False]))\n    assert len(out) == 5\n    assert out.num_nodes == data.num_nodes\n    assert torch.equal(out.x, data.x)\n    assert torch.equal(out.y, data.y)\n    assert out.edge_index.tolist() == [[1, 1, 2], [0, 2, 1]]\n    assert torch.equal(out.edge_weight, edge_weight[torch.tensor([1, 2, 3])])\n\n\ndef test_data_subgraph_with_list_field():\n    x = torch.arange(5)\n    y = list(range(5))\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4],\n                               [1, 0, 2, 1, 3, 2, 4, 3]])\n    data = Data(x=x, y=y, edge_index=edge_index)\n\n    out = data.subgraph(torch.tensor([1, 2, 3]))\n    assert len(out) == 3\n    assert out.x.tolist() == out.y == [1, 2, 3]\n\n    out = data.subgraph(torch.tensor([False, True, True, True, False]))\n    assert len(out) == 3\n    assert out.x.tolist() == out.y == [1, 2, 3]\n\n\ndef test_data_empty_subgraph():\n    data = Data(x=torch.arange(5), y=torch.tensor(0.0))\n\n    out = data.subgraph(torch.tensor([1, 2, 3]))\n    assert 'edge_index' not in out\n    assert torch.equal(out.x, torch.arange(1, 4))\n    assert torch.equal(out.y, data.y)\n    assert out.num_nodes == 3\n\n\ndef test_copy_data():\n    data = Data(x=torch.randn(20, 5))\n\n    out = copy.copy(data)\n    assert id(data) != id(out)\n    assert id(data._store) != id(out._store)\n    assert len(data.stores) == len(out.stores)\n    for store1, store2 in zip(data.stores, out.stores):\n        assert id(store1) != id(store2)\n        assert id(data) == id(store1._parent())\n        assert id(out) == id(store2._parent())\n    assert data.x.data_ptr() == out.x.data_ptr()\n\n    out = copy.deepcopy(data)\n    assert id(data) != id(out)\n    assert id(data._store) != id(out._store)\n    assert len(data.stores) == len(out.stores)\n    for store1, store2 in zip(data.stores, out.stores):\n        assert id(store1) != id(store2)\n        assert id(data) == id(store1._parent())\n        assert id(out) == id(store2._parent())\n    assert data.x.data_ptr() != out.x.data_ptr()\n    assert data.x.tolist() == out.x.tolist()\n\n\ndef test_data_sort():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 2, 1, 3], [1, 2, 3, 0, 0, 0]])\n    edge_attr = torch.randn(6, 8)\n\n    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)\n    assert not data.is_sorted(sort_by_row=True)\n    assert not data.is_sorted(sort_by_row=False)\n\n    out = data.sort(sort_by_row=True)\n    assert out.is_sorted(sort_by_row=True)\n    assert not out.is_sorted(sort_by_row=False)\n    assert torch.equal(out.x, data.x)\n    assert out.edge_index.tolist() == [[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]\n    assert torch.equal(\n        out.edge_attr,\n        data.edge_attr[torch.tensor([0, 1, 2, 4, 3, 5])],\n    )\n\n    out = data.sort(sort_by_row=False)\n    assert not out.is_sorted(sort_by_row=True)\n    assert out.is_sorted(sort_by_row=False)\n    assert torch.equal(out.x, data.x)\n    assert out.edge_index.tolist() == [[1, 2, 3, 0, 0, 0], [0, 0, 0, 1, 2, 3]]\n    assert torch.equal(\n        out.edge_attr,\n        data.edge_attr[torch.tensor([4, 3, 5, 0, 1, 2])],\n    )\n\n\ndef test_debug_data():\n    torch_geometric.set_debug(True)\n\n    Data()\n    Data(edge_index=torch.zeros((2, 0), dtype=torch.long), num_nodes=10)\n    Data(face=torch.zeros((3, 0), dtype=torch.long), num_nodes=10)\n    Data(edge_index=torch.tensor([[0, 1], [1, 0]]), edge_attr=torch.randn(2))\n    Data(x=torch.torch.randn(5, 3), num_nodes=5)\n    Data(pos=torch.torch.randn(5, 3), num_nodes=5)\n    Data(norm=torch.torch.randn(5, 3), num_nodes=5)\n\n    torch_geometric.set_debug(False)\n\n\ndef run(rank, data_list):\n    for data in data_list:\n        assert data.x.is_shared()\n        data.x.add_(1)\n\n\ndef test_data_share_memory():\n    data_list = [Data(x=torch.zeros(8)) for _ in range(10)]\n\n    for data in data_list:\n        assert not data.x.is_shared()\n        assert torch.all(data.x == 0.0)\n\n    mp.spawn(run, args=(data_list, ), nprocs=4, join=True)\n\n    for data in data_list:\n        assert data.x.is_shared()\n        assert torch.all(data.x > 0.0)\n\n\ndef test_data_setter_properties():\n    class MyData(Data):\n        def __init__(self):\n            super().__init__()\n            self.my_attr1 = 1\n            self.my_attr2 = 2\n\n        @property\n        def my_attr1(self):\n            return self._my_attr1\n\n        @my_attr1.setter\n        def my_attr1(self, value):\n            self._my_attr1 = value\n\n    data = MyData()\n    assert data.my_attr2 == 2\n\n    assert 'my_attr1' not in data._store\n    assert data.my_attr1 == 1\n\n    data.my_attr1 = 2\n    assert 'my_attr1' not in data._store\n    assert data.my_attr1 == 2\n\n\ndef test_data_update():\n    data = Data(x=torch.arange(0, 5), y=torch.arange(5, 10))\n    other = Data(z=torch.arange(10, 15), x=torch.arange(15, 20))\n    data.update(other)\n\n    assert len(data) == 3\n    assert torch.equal(data.x, torch.arange(15, 20))\n    assert torch.equal(data.y, torch.arange(5, 10))\n    assert torch.equal(data.z, torch.arange(10, 15))\n\n\ndef test_data_connected_components():\n    data = Data()\n    data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0]])\n    data.y = torch.tensor([[1.1, 1.2], [2.1, 2.2], [3.1, 3.2], [4.1, 4.2],\n                           [5.1, 5.2]])\n    data.edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]],\n                                   dtype=torch.long)\n\n    split_data = data.connected_components()\n    assert isinstance(split_data, list)\n    assert len(split_data) == 3\n\n    assert torch.equal(split_data[0].x, torch.tensor([[1.0], [2.0]]))\n    assert torch.equal(split_data[0].y, torch.tensor([[1.1, 1.2], [2.1, 2.2]]))\n    assert torch.equal(split_data[0].edge_index, torch.tensor([[0, 1], [1,\n                                                                        0]]))\n\n    assert torch.equal(split_data[1].x, torch.tensor([[3.0], [4.0]]))\n    assert torch.equal(split_data[1].y, torch.tensor([[3.1, 3.2], [4.1, 4.2]]))\n    assert torch.equal(split_data[1].edge_index, torch.tensor([[0, 1], [1,\n                                                                        0]]))\n\n    assert torch.equal(split_data[2].x, torch.tensor([[5.0]]))\n    assert torch.equal(split_data[2].y, torch.tensor([[5.1, 5.2]]))\n    assert torch.equal(split_data[2].edge_index,\n                       torch.tensor([[], []], dtype=torch.long))\n\n\ndef test_data_find_parent():\n\n    # Case 1: Parent does not exist\n    data = Data()\n    data._parents = {}\n    data._ranks = {}\n    node = 1\n    assert data._find_parent(node) == node\n    assert data._parents == {1: 1}\n    assert data._ranks == {1: 0}\n\n    # Case 2: Parent exists\n    data._parents[node] = 0\n    assert data._find_parent(node) == 0\n\n\ndef test_data_union():\n\n    # Setup: two nodes in different sets\n    data = Data()\n    data._parents = {}\n    data._ranks = {}\n    node1 = 1\n    node2 = 2\n\n    # Initially, both nodes are their own parents with rank 0\n    assert data._find_parent(node1) == node1\n    assert data._find_parent(node2) == node2\n    data._ranks[node1] = 0\n    data._ranks[node2] = 0\n\n    # Union them: node2 should now point to node1, and node1's rank increases\n    data._union(node1, node2)\n    assert data._find_parent(node1) == node1\n    assert data._find_parent(node2) == node1\n    assert data._ranks[node1] == 1\n\n    # Add a third node with higher rank and union with node1\n    node3 = 3\n    data._parents[node3] = node3\n    data._ranks[node3] = 2\n    data._union(node1, node3)\n    # node1's parent should now be node3, since node3 has higher rank\n    assert data._find_parent(node1) == node3\n    assert data._find_parent(node3) == node3\n\n    # Add a fourth node with lower rank and union with node3\n    node4 = 4\n    data._parents[node4] = node4\n    data._ranks[node4] = 0\n    data._union(node3, node4)\n    assert data._find_parent(node4) == node3\n    assert data._find_parent(node3) == node3\n\n    # Union of already connected nodes should not change anything\n    prev_ranks = data._ranks.copy()\n    prev_parents = data._parents.copy()\n    data._union(node1, node3)\n    assert data._ranks == prev_ranks\n    assert data._parents == prev_parents\n\n\n# Feature Store ###############################################################\n\n\ndef test_basic_feature_store():\n    data = Data()\n    x = torch.randn(20, 20)\n    data.not_a_tensor_attr = 10  # don't include, not a tensor attr\n    data.bad_attr = torch.randn(10, 20)  # don't include, bad cat_dim\n\n    # Put tensor:\n    assert data.put_tensor(copy.deepcopy(x), attr_name='x', index=None)\n    assert torch.equal(data.x, x)\n\n    # Put (modify) tensor slice:\n    x[15:] = 0\n    data.put_tensor(0, attr_name='x', index=slice(15, None, None))\n\n    # Get tensor:\n    out = data.get_tensor(attr_name='x', index=None)\n    assert torch.equal(x, out)\n\n    # Get tensor size:\n    assert data.get_tensor_size(attr_name='x') == (20, 20)\n\n    # Get tensor attrs:\n    tensor_attrs = data.get_all_tensor_attrs()\n    assert len(tensor_attrs) == 1\n    assert tensor_attrs[0].attr_name == 'x'\n\n    # Remove tensor:\n    assert 'x' in data.__dict__['_store']\n    data.remove_tensor(attr_name='x', index=None)\n    assert 'x' not in data.__dict__['_store']\n\n\n# Graph Store #################################################################\n\n\n@withPackage('torch_sparse')\ndef test_basic_graph_store():\n    r\"\"\"Test the core graph store API.\"\"\"\n    data = Data()\n\n    def assert_equal_tensor_tuple(expected, actual):\n        assert len(expected) == len(actual)\n        for i in range(len(expected)):\n            assert torch.equal(expected[i], actual[i])\n\n    # We put all three tensor types: COO, CSR, and CSC, and we get them back\n    # to confirm that `GraphStore` works as intended.\n    coo = (torch.tensor([0, 1]), torch.tensor([1, 2]))\n    csr = (torch.tensor([0, 1, 2, 2]), torch.tensor([1, 2]))\n    csc = (torch.tensor([0, 1]), torch.tensor([0, 0, 1, 2]))\n\n    # Put:\n    data.put_edge_index(coo, layout='coo', size=(3, 3))\n    data.put_edge_index(csr, layout='csr')\n    data.put_edge_index(csc, layout='csc')\n\n    # Get:\n    assert_equal_tensor_tuple(coo, data.get_edge_index('coo'))\n    assert_equal_tensor_tuple(csr, data.get_edge_index('csr'))\n    assert_equal_tensor_tuple(csc, data.get_edge_index('csc'))\n\n    # Get attrs:\n    edge_attrs = data.get_all_edge_attrs()\n    assert len(edge_attrs) == 3\n\n    # Remove:\n    coo, csr, csc = edge_attrs\n    data.remove_edge_index(coo)\n    data.remove_edge_index(csr)\n    data.remove_edge_index(csc)\n    assert len(data.get_all_edge_attrs()) == 0\n\n\ndef test_data_generate_ids():\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]])\n\n    data = Data(x=x, edge_index=edge_index)\n    assert len(data) == 2\n\n    data.generate_ids()\n    assert len(data) == 4\n    assert data.n_id.tolist() == [0, 1, 2]\n    assert data.e_id.tolist() == [0, 1, 2, 3, 4]\n\n\n@withPackage('torch_frame')\ndef test_data_with_tensor_frame():\n    tf = get_random_tensor_frame(num_rows=10)\n    data = Data(tf=tf, edge_index=torch.randint(0, 10, size=(2, 20)))\n\n    # Test basic attributes:\n    assert data.is_node_attr('tf')\n    assert data.num_nodes == tf.num_rows\n    assert data.num_edges == 20\n    assert data.num_node_features == tf.num_cols\n\n    # Test subgraph:\n    index = torch.tensor([1, 2, 3])\n    sub_data = data.subgraph(index)\n    assert sub_data.num_nodes == 3\n    for key, value in sub_data.tf.feat_dict.items():\n        assert torch.allclose(value, tf.feat_dict[key][index])\n\n    mask = torch.tensor(\n        [False, True, True, True, False, False, False, False, False, False])\n    data_sub = data.subgraph(mask)\n    assert data_sub.num_nodes == 3\n    for key, value in sub_data.tf.feat_dict.items():\n        assert torch.allclose(value, tf.feat_dict[key][mask])\n\n\n@pytest.mark.parametrize('num_nodes', [4])\n@pytest.mark.parametrize('num_edges', [8])\ndef test_data_time_handling(num_nodes, num_edges):\n    data = Data(\n        x=torch.randn(num_nodes, 12),\n        edge_index=torch.randint(0, num_nodes, (2, num_edges)),\n        edge_attr=torch.rand((num_edges, 16)),\n        time=torch.arange(num_edges),\n        num_nodes=num_nodes,\n    )\n\n    assert data.is_edge_attr('time')\n    assert not data.is_node_attr('time')\n    assert data.is_sorted_by_time()\n\n    out = data.up_to(5)\n    assert out.num_edges == 6\n    assert torch.allclose(out.x, data.x)\n    assert torch.equal(out.edge_index, data.edge_index[:, :6])\n    assert torch.allclose(out.edge_attr, data.edge_attr[:6])\n    assert torch.equal(out.time, data.time[:6])\n\n    out = data.snapshot(2, 5)\n    assert out.num_edges == 4\n    assert torch.allclose(out.x, data.x)\n    assert torch.equal(out.edge_index, data.edge_index[:, 2:6])\n    assert torch.allclose(out.edge_attr, data.edge_attr[2:6, :])\n    assert torch.equal(out.time, data.time[2:6])\n\n    out = data.sort_by_time()\n    assert data.is_sorted_by_time()\n\n    out = data.concat(data)\n    assert out.num_nodes == 8\n    assert not out.is_sorted_by_time()\n\n    assert torch.allclose(out.x, torch.cat([data.x, data.x], dim=0))\n    assert torch.equal(\n        out.edge_index,\n        torch.cat([data.edge_index, data.edge_index], dim=1),\n    )\n    assert torch.allclose(\n        out.edge_attr,\n        torch.cat([data.edge_attr, data.edge_attr], dim=0),\n    )\n    assert torch.allclose(out.time, torch.cat([data.time, data.time], dim=0))\n\n    out = out.sort_by_time()\n    assert torch.equal(out.time, data.time.repeat_interleave(2))\n\n\ndef test_data_inc():\n    data = Data(edge_index=torch.tensor([[0, 1], [1, 0]]))\n    with pytest.warns(UserWarning, match=\"Unable to accurately infer\"):\n        assert data.__inc__('edge_index', data.edge_index) == 2\n\n        data = Data(index=torch.empty(2, 0, dtype=torch.long))\n    with pytest.raises(RuntimeError, match=\"Unable to infer\"):\n        with pytest.warns(UserWarning, match=\"Unable to accurately infer\"):\n            data.__inc__('index', data.edge_index)\n"
  },
  {
    "path": "test/data/test_database.py",
    "content": "import math\nimport os.path as osp\n\nimport pytest\nimport torch\n\nfrom torch_geometric import EdgeIndex, Index\nfrom torch_geometric.data import Data, RocksDatabase, SQLiteDatabase\nfrom torch_geometric.data.database import TensorInfo\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import has_package, withPackage\n\nAVAILABLE_DATABASES = []\nif has_package('sqlite3'):\n    AVAILABLE_DATABASES.append(SQLiteDatabase)\nif has_package('rocksdict'):\n    AVAILABLE_DATABASES.append(RocksDatabase)\n\n\n@pytest.mark.parametrize('Database', AVAILABLE_DATABASES)\n@pytest.mark.parametrize('batch_size', [None, 1])\ndef test_database_single_tensor(tmp_path, Database, batch_size):\n    kwargs = dict(path=osp.join(tmp_path, 'storage.db'))\n    if Database == SQLiteDatabase:\n        kwargs['name'] = 'test_table'\n\n    db = Database(**kwargs)\n    assert db.schema == {0: object}\n\n    try:\n        assert len(db) == 0\n        assert str(db) == f'{Database.__name__}(0)'\n    except NotImplementedError:\n        assert str(db) == f'{Database.__name__}()'\n\n    data = torch.randn(5)\n    db.insert(0, data)\n    try:\n        assert len(db) == 1\n    except NotImplementedError:\n        pass\n    assert torch.equal(db.get(0), data)\n\n    indices = torch.tensor([1, 2])\n    data_list = torch.randn(2, 5)\n    db.multi_insert(indices, data_list, batch_size=batch_size)\n    try:\n        assert len(db) == 3\n    except NotImplementedError:\n        pass\n    out_list = db.multi_get(indices, batch_size=batch_size)\n    assert isinstance(out_list, list)\n    assert len(out_list) == 2\n    assert torch.equal(out_list[0], data_list[0])\n    assert torch.equal(out_list[1], data_list[1])\n\n    db.close()\n\n\n@pytest.mark.parametrize('Database', AVAILABLE_DATABASES)\ndef test_database_schema(tmp_path, Database):\n    kwargs = dict(name='test_table') if Database == SQLiteDatabase else {}\n\n    path = osp.join(tmp_path, 'tuple_storage.db')\n    schema = (int, float, str, dict(dtype=torch.float, size=(2, -1)), object)\n    db = Database(path, schema=schema, **kwargs)\n    assert db.schema == {\n        0: int,\n        1: float,\n        2: str,\n        3: TensorInfo(dtype=torch.float, size=(2, -1)),\n        4: object,\n    }\n\n    data1 = (1, 0.1, 'a', torch.randn(2, 8), Data(x=torch.randn(8)))\n    data2 = (2, float('inf'), 'b', torch.randn(2, 16), Data(x=torch.randn(8)))\n    data3 = (3, float('NaN'), 'c', torch.randn(2, 32), Data(x=torch.randn(8)))\n    db.insert(0, data1)\n    db.multi_insert([1, 2], [data2, data3])\n\n    out1 = db.get(0)\n    out2, out3 = db.multi_get([1, 2])\n\n    for out, data in zip([out1, out2, out3], [data1, data2, data3]):\n        assert out[0] == data[0]\n        if math.isnan(data[1]):\n            assert math.isnan(out[1])\n        else:\n            assert out[1] == data[1]\n        assert out[2] == data[2]\n        assert torch.equal(out[3], data[3])\n        assert isinstance(out[4], Data) and len(out[4]) == 1\n        assert torch.equal(out[4].x, data[4].x)\n\n    db.close()\n\n    path = osp.join(tmp_path, 'dict_storage.db')\n    schema = {\n        'int': int,\n        'float': float,\n        'str': str,\n        'tensor': dict(dtype=torch.float, size=(2, -1)),\n        'data': object\n    }\n    db = Database(path, schema=schema, **kwargs)\n    assert db.schema == {\n        'int': int,\n        'float': float,\n        'str': str,\n        'tensor': TensorInfo(dtype=torch.float, size=(2, -1)),\n        'data': object,\n    }\n\n    data1 = {\n        'int': 1,\n        'float': 0.1,\n        'str': 'a',\n        'tensor': torch.randn(2, 8),\n        'data': Data(x=torch.randn(1, 8)),\n    }\n    data2 = {\n        'int': 2,\n        'float': 0.2,\n        'str': 'b',\n        'tensor': torch.randn(2, 16),\n        'data': Data(x=torch.randn(2, 8)),\n    }\n    data3 = {\n        'int': 3,\n        'float': 0.3,\n        'str': 'c',\n        'tensor': torch.randn(2, 32),\n        'data': Data(x=torch.randn(3, 8)),\n    }\n    db.insert(0, data1)\n    db.multi_insert([1, 2], [data2, data3])\n\n    out1 = db.get(0)\n    out2, out3 = db.multi_get([1, 2])\n\n    for out, data in zip([out1, out2, out3], [data1, data2, data3]):\n        assert out['int'] == data['int']\n        assert out['float'] == data['float']\n        assert out['str'] == data['str']\n        assert torch.equal(out['tensor'], data['tensor'])\n        assert isinstance(out['data'], Data) and len(out['data']) == 1\n        assert torch.equal(out['data'].x, data['data'].x)\n\n    db.close()\n\n\n@pytest.mark.parametrize('Database', AVAILABLE_DATABASES)\ndef test_index(tmp_path, Database):\n    kwargs = dict(name='test_table') if Database == SQLiteDatabase else {}\n\n    path = osp.join(tmp_path, 'tuple_storage.db')\n    schema = dict(dtype=torch.long, is_index=True)\n    db = Database(path, schema=schema, **kwargs)\n    assert db.schema == {\n        0: TensorInfo(dtype=torch.long, is_index=True),\n    }\n\n    index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)\n    index2 = Index([0, 1, 1, 2, 2, 3], dim_size=None, is_sorted=True)\n    index3 = Index([], dtype=torch.long)\n\n    db.insert(0, index1)\n    db.multi_insert([1, 2], [index2, index3])\n\n    out1 = db.get(0)\n    out2, out3 = db.multi_get([1, 2])\n\n    for out, index in zip([out1, out2, out3], [index1, index2, index3]):\n        assert index.equal(out)\n        assert index.dtype == out.dtype\n        assert index.dim_size == out.dim_size\n        assert index.is_sorted == out.is_sorted\n\n    db.close()\n\n\n@pytest.mark.parametrize('Database', AVAILABLE_DATABASES)\ndef test_edge_index(tmp_path, Database):\n    kwargs = dict(name='test_table') if Database == SQLiteDatabase else {}\n\n    path = osp.join(tmp_path, 'tuple_storage.db')\n    schema = dict(dtype=torch.long, is_edge_index=True)\n    db = Database(path, schema=schema, **kwargs)\n    assert db.schema == {\n        0: TensorInfo(dtype=torch.long, size=(2, -1), is_edge_index=True),\n    }\n\n    adj1 = EdgeIndex(\n        [[0, 1, 1, 2], [1, 0, 2, 1]],\n        sparse_size=(3, 3),\n        sort_order='row',\n        is_undirected=True,\n    )\n    adj2 = EdgeIndex(\n        [[1, 0, 2, 1, 3, 2], [0, 1, 1, 2, 2, 3]],\n        sparse_size=(4, 4),\n        sort_order='col',\n    )\n    adj3 = EdgeIndex([[], []], dtype=torch.long)\n\n    db.insert(0, adj1)\n    db.multi_insert([1, 2], [adj2, adj3])\n\n    out1 = db.get(0)\n    out2, out3 = db.multi_get([1, 2])\n\n    for out, adj in zip([out1, out2, out3], [adj1, adj2, adj3]):\n        assert adj.equal(out)\n        assert adj.dtype == out.dtype\n        assert adj.sparse_size() == out.sparse_size()\n        assert adj.sort_order == out.sort_order\n        assert adj.is_undirected == out.is_undirected\n\n    db.close()\n\n\n@withPackage('sqlite3')\ndef test_database_syntactic_sugar(tmp_path):\n    path = osp.join(tmp_path, 'storage.db')\n    db = SQLiteDatabase(path, name='test_table')\n\n    data = torch.randn(5, 16)\n    db[0] = data[0]\n    db[1:3] = data[1:3]\n    db[torch.tensor([3, 4])] = data[torch.tensor([3, 4])]\n    assert len(db) == 5\n\n    assert torch.equal(db[0], data[0])\n    assert torch.equal(torch.stack(db[:3], dim=0), data[:3])\n    assert torch.equal(torch.stack(db[3:], dim=0), data[3:])\n    assert torch.equal(torch.stack(db[1::2], dim=0), data[1::2])\n    assert torch.equal(torch.stack(db[[4, 3]], dim=0), data[[4, 3]])\n    assert torch.equal(\n        torch.stack(db[torch.tensor([4, 3])], dim=0),\n        data[torch.tensor([4, 3])],\n    )\n    assert torch.equal(\n        torch.stack(db[torch.tensor([4, 4])], dim=0),\n        data[torch.tensor([4, 4])],\n    )\n\n\nif __name__ == '__main__':\n    import argparse\n    import tempfile\n    import time\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--numel', type=int, default=100_000)\n    parser.add_argument('--batch_size', type=int, default=256)\n    args = parser.parse_args()\n\n    data = torch.randn(args.numel, 128)\n    tmp_dir = tempfile.TemporaryDirectory()\n\n    path = osp.join(tmp_dir.name, 'sqlite.db')\n    sqlite_db = SQLiteDatabase(path, name='test_table')\n    t = time.perf_counter()\n    sqlite_db.multi_insert(range(args.numel), data, batch_size=100, log=True)\n    print(f'Initialized SQLiteDB in {time.perf_counter() - t:.2f} seconds')\n\n    path = osp.join(tmp_dir.name, 'rocks.db')\n    rocks_db = RocksDatabase(path)\n    t = time.perf_counter()\n    rocks_db.multi_insert(range(args.numel), data, batch_size=100, log=True)\n    print(f'Initialized RocksDB in {time.perf_counter() - t:.2f} seconds')\n\n    def in_memory_get(data):\n        index = torch.randint(0, args.numel, (args.batch_size, ))\n        return data[index]\n\n    def db_get(db):\n        index = torch.randint(0, args.numel, (args.batch_size, ))\n        return db[index]\n\n    benchmark(\n        funcs=[in_memory_get, db_get, db_get],\n        func_names=['In-Memory', 'SQLite', 'RocksDB'],\n        args=[(data, ), (sqlite_db, ), (rocks_db, )],\n        num_steps=50,\n        num_warmups=5,\n    )\n\n    tmp_dir.cleanup()\n"
  },
  {
    "path": "test/data/test_datapipes.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import DatasetAdapter\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.utils import to_smiles\n\n\n@pytest.fixture()\ndef dataset_adapter() -> DatasetAdapter:\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1], [1, 0, 2]])\n    data = Data(x=x, edge_index=edge_index)\n    return DatasetAdapter([data, data, data, data])\n\n\ndef test_dataset_adapter(dataset_adapter):\n    loader = DataLoader(dataset_adapter, batch_size=2)\n    batch = next(iter(loader))\n    assert batch.x.shape == (6, 8)\n    assert len(loader) == 2\n\n    # Test sharding:\n    dataset_adapter.apply_sharding(2, 0)\n    assert len([data for data in dataset_adapter]) == 2\n\n    assert dataset_adapter.is_shardable()\n\n\ndef test_datapipe_batch_graphs(dataset_adapter):\n    dp = dataset_adapter.batch_graphs(batch_size=2)\n    assert len(dp) == 2\n    batch = next(iter(dp))\n    assert batch.x.shape == (6, 8)\n\n\ndef test_functional_transform(dataset_adapter):\n    assert next(iter(dataset_adapter)).is_directed()\n    dataset_adapter = dataset_adapter.to_undirected()\n    assert next(iter(dataset_adapter)).is_undirected()\n\n\n@withPackage('rdkit')\ndef test_datapipe_parse_smiles():\n    smiles = 'F/C=C/F'\n\n    dp = DatasetAdapter([smiles])\n    dp = dp.parse_smiles()\n    assert to_smiles(next(iter(dp))) == smiles\n\n    dp = DatasetAdapter([{'abc': smiles, 'cba': '1.0'}])\n    dp = dp.parse_smiles(smiles_key='abc', target_key='cba')\n    assert to_smiles(next(iter(dp))) == smiles\n"
  },
  {
    "path": "test/data/test_dataset.py",
    "content": "import copy\n\nimport pytest\nimport torch\n\nfrom torch_geometric import EdgeIndex, Index\nfrom torch_geometric.data import Data, HeteroData, InMemoryDataset\nfrom torch_geometric.datasets import KarateClub\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.typing import SparseTensor\n\n\nclass MyTestDataset(InMemoryDataset):\n    def __init__(self, data_list, transform=None):\n        super().__init__(None, transform=transform)\n        self.data, self.slices = self.collate(data_list)\n\n\nclass MyStoredTestDataset(InMemoryDataset):\n    def __init__(self, root, data_list, transform=None):\n        self.data_list = data_list\n        super().__init__(root, transform=transform)\n        self.load(self.processed_paths[0], data_cls=data_list[0].__class__)\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def process(self):\n        self.save(self.data_list, self.processed_paths[0])\n\n\ndef test_in_memory_dataset():\n    x1 = torch.tensor([[1.0], [1.0], [1.0]])\n    x2 = torch.tensor([[2.0], [2.0], [2.0]])\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    face = torch.tensor([[0], [1], [2]])\n\n    data1 = Data(x1, edge_index, face=face, test_int=1, test_str='1')\n    data1.num_nodes = 10\n\n    data2 = Data(x2, edge_index, face=face, test_int=2, test_str='2')\n    data2.num_nodes = 5\n\n    dataset = MyTestDataset([data1, data2])\n    assert str(dataset) == 'MyTestDataset(2)'\n    assert len(dataset) == 2\n\n    assert len(dataset[0]) == 6\n    assert dataset[0].num_nodes == 10\n    assert dataset[0].x.tolist() == x1.tolist()\n    assert dataset[0].edge_index.tolist() == edge_index.tolist()\n    assert dataset[0].face.tolist() == face.tolist()\n    assert dataset[0].test_int == 1\n    assert dataset[0].test_str == '1'\n\n    assert len(dataset[1]) == 6\n    assert dataset[1].num_nodes == 5\n    assert dataset[1].x.tolist() == x2.tolist()\n    assert dataset[1].edge_index.tolist() == edge_index.tolist()\n    assert dataset[1].face.tolist() == face.tolist()\n    assert dataset[1].test_int == 2\n    assert dataset[1].test_str == '2'\n\n    with pytest.warns(UserWarning, match=\"internal storage format\"):\n        dataset.data  # noqa: B018\n\n    assert torch.equal(dataset.x, torch.cat([x1, x2], dim=0))\n    assert dataset.edge_index.tolist() == [\n        [0, 1, 1, 2, 10, 11, 11, 12],\n        [1, 0, 2, 1, 11, 10, 12, 11],\n    ]\n    assert torch.equal(dataset[1:].x, x2)\n\n\ndef test_stored_in_memory_dataset(tmp_path):\n    x1 = torch.tensor([[1.0], [1.0], [1.0]])\n    x2 = torch.tensor([[2.0], [2.0], [2.0], [2.0]])\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    data1 = Data(x1, edge_index, num_nodes=3, test_int=1, test_str='1')\n    data2 = Data(x2, edge_index, num_nodes=4, test_int=2, test_str='2')\n\n    dataset = MyStoredTestDataset(tmp_path, [data1, data2])\n    assert dataset._data.num_nodes == 7\n    assert dataset._data._num_nodes == [3, 4]\n\n    assert torch.equal(dataset[0].x, x1)\n    assert torch.equal(dataset[0].edge_index, edge_index)\n    assert dataset[0].num_nodes == 3\n    assert torch.equal(dataset[0].test_int, torch.tensor([1]))\n    assert dataset[0].test_str == '1'\n\n    assert torch.equal(dataset[1].x, x2)\n    assert torch.equal(dataset[1].edge_index, edge_index)\n    assert dataset[1].num_nodes == 4\n    assert torch.equal(dataset[1].test_int, torch.tensor([2]))\n    assert dataset[1].test_str == '2'\n\n\ndef test_stored_hetero_in_memory_dataset(tmp_path):\n    x1 = torch.tensor([[1.0], [1.0], [1.0]])\n    x2 = torch.tensor([[2.0], [2.0], [2.0], [2.0]])\n\n    data1 = HeteroData()\n    data1['paper'].x = x1\n    data1['paper'].num_nodes = 3\n\n    data2 = HeteroData()\n    data2['paper'].x = x2\n    data2['paper'].num_nodes = 4\n\n    dataset = MyStoredTestDataset(tmp_path, [data1, data2])\n    assert dataset._data['paper'].num_nodes == 7\n    assert dataset._data['paper']._num_nodes == [3, 4]\n\n    assert torch.equal(dataset[0]['paper'].x, x1)\n    assert dataset[0]['paper'].num_nodes == 3\n\n    assert torch.equal(dataset[1]['paper'].x, x2)\n    assert dataset[1]['paper'].num_nodes == 4\n\n\ndef test_index(tmp_path):\n    index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)\n    index2 = Index([0, 1, 1, 2, 2, 3], dim_size=4, is_sorted=True)\n\n    data1 = Data(batch=index1)\n    data2 = Data(batch=index2)\n\n    dataset = MyTestDataset([data1, data2])\n    assert len(dataset) == 2\n    for data, index in zip(dataset, [index1, index2]):\n        assert isinstance(data.batch, Index)\n        assert data.batch.equal(index)\n        assert data.batch.dim_size == index.dim_size\n        assert data.batch.is_sorted == index.is_sorted\n\n    dataset = MyStoredTestDataset(tmp_path, [data1, data2])\n    assert len(dataset) == 2\n    for data, index in zip(dataset, [index1, index2]):\n        assert isinstance(data.batch, Index)\n        assert data.batch.equal(index)\n        assert data.batch.dim_size == index.dim_size\n        assert data.batch.is_sorted == index.is_sorted\n\n\ndef test_edge_index(tmp_path):\n    edge_index1 = EdgeIndex(\n        [[0, 1, 1, 2], [1, 0, 2, 1]],\n        sparse_size=(3, 3),\n        sort_order='row',\n        is_undirected=True,\n    )\n    edge_index2 = EdgeIndex(\n        [[1, 0, 2, 1, 3, 2], [0, 1, 1, 2, 2, 3]],\n        sparse_size=(4, 4),\n        sort_order='col',\n    )\n\n    data1 = Data(edge_index=edge_index1)\n    data2 = Data(edge_index=edge_index2)\n\n    dataset = MyTestDataset([data1, data2])\n    assert len(dataset) == 2\n    for data, edge_index in zip(dataset, [edge_index1, edge_index2]):\n        assert isinstance(data.edge_index, EdgeIndex)\n        assert data.edge_index.equal(edge_index)\n        assert data.edge_index.sparse_size() == edge_index.sparse_size()\n        assert data.edge_index.sort_order == edge_index.sort_order\n        assert data.edge_index.is_undirected == edge_index.is_undirected\n\n    dataset = MyStoredTestDataset(tmp_path, [data1, data2])\n    assert len(dataset) == 2\n    for data, edge_index in zip(dataset, [edge_index1, edge_index2]):\n        assert isinstance(data.edge_index, EdgeIndex)\n        assert data.edge_index.equal(edge_index)\n        assert data.edge_index.sparse_size() == edge_index.sparse_size()\n        assert data.edge_index.sort_order == edge_index.sort_order\n        assert data.edge_index.is_undirected == edge_index.is_undirected\n\n\ndef test_in_memory_num_classes():\n    dataset = MyTestDataset([Data(), Data()])\n    assert dataset.num_classes == 0\n\n    dataset = MyTestDataset([Data(y=0), Data(y=1)])\n    assert dataset.num_classes == 2\n\n    dataset = MyTestDataset([Data(y=1.5), Data(y=2.5), Data(y=3.5)])\n    with pytest.warns(UserWarning, match=\"unique elements\"):\n        assert dataset.num_classes == 3\n\n    dataset = MyTestDataset([\n        Data(y=torch.tensor([[0, 1, 0, 1]])),\n        Data(y=torch.tensor([[1, 0, 0, 0]])),\n        Data(y=torch.tensor([[0, 0, 1, 0]])),\n    ])\n    assert dataset.num_classes == 4\n\n    # Test when `__getitem__` returns a tuple of data objects.\n    def transform(data):\n        copied_data = copy.copy(data)\n        copied_data.y += 1\n        return data, copied_data, 'foo'\n\n    dataset = MyTestDataset([Data(y=0), Data(y=1)], transform=transform)\n    assert dataset.num_classes == 3\n\n\ndef test_in_memory_dataset_copy():\n    data_list = [Data(x=torch.randn(5, 16)) for _ in range(4)]\n    dataset = MyTestDataset(data_list)\n\n    copied_dataset = dataset.copy()\n    assert id(copied_dataset) != id(dataset)\n\n    assert len(copied_dataset) == len(dataset) == 4\n    for copied_data, data in zip(copied_dataset, dataset):\n        assert torch.equal(copied_data.x, data.x)\n\n    copied_dataset = dataset.copy([1, 2])\n    assert len(copied_dataset) == 2\n    assert torch.equal(copied_dataset[0].x, data_list[1].x)\n    assert torch.equal(copied_dataset[1].x, data_list[2].x)\n\n\ndef test_to_datapipe():\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    data = Data(x=x, edge_index=edge_index)\n    dataset = MyTestDataset([data, data])\n\n    dp = dataset.to_datapipe()\n\n    assert isinstance(dp, torch.utils.data.IterDataPipe)\n    assert len(dp) == 2\n\n    assert torch.equal(dataset[0].x, list(dp)[0].x)\n    assert torch.equal(dataset[0].edge_index, list(dp)[0].edge_index)\n    assert torch.equal(dataset[1].x, list(dp)[1].x)\n    assert torch.equal(dataset[1].edge_index, list(dp)[1].edge_index)\n\n\n@withPackage('torch_sparse')\ndef test_in_memory_sparse_tensor_dataset():\n    x = torch.randn(11, 16)\n    adj = SparseTensor(\n        row=torch.tensor([4, 1, 3, 2, 2, 3]),\n        col=torch.tensor([1, 3, 2, 3, 3, 2]),\n        sparse_sizes=(11, 11),\n    )\n    data = Data(x=x, adj=adj)\n\n    dataset = MyTestDataset([data, data])\n    assert len(dataset) == 2\n    assert torch.allclose(dataset[0].x, x)\n    assert dataset[0].adj.sparse_sizes() == (11, 11)\n    assert torch.allclose(dataset[1].x, x)\n    assert dataset[1].adj.sparse_sizes() == (11, 11)\n\n\ndef test_collate_with_new_dimension():\n    class MyData(Data):\n        def __cat_dim__(self, key, value, *args, **kwargs):\n            if key == 'foo':\n                return None\n            else:\n                return super().__cat_dim__(key, value, *args, **kwargs)\n\n    x = torch.tensor([1, 2, 3], dtype=torch.float)\n    foo = torch.randn(4)\n    y = torch.tensor(1)\n\n    data = MyData(x=x, foo=foo, y=y)\n\n    dataset = MyTestDataset([data, data])\n    assert str(dataset) == 'MyTestDataset(2)'\n    assert len(dataset) == 2\n\n    data1 = dataset[0]\n    assert len(data1) == 3\n    assert data1.x.tolist() == x.tolist()\n    assert data1.foo.tolist() == foo.tolist()\n    assert data1.y.tolist() == [1]\n\n    data2 = dataset[0]\n    assert len(data2) == 3\n    assert data2.x.tolist() == x.tolist()\n    assert data2.foo.tolist() == foo.tolist()\n    assert data2.y.tolist() == [1]\n\n\ndef test_hetero_in_memory_dataset():\n    data1 = HeteroData()\n    data1.y = torch.randn(5)\n    data1['paper'].x = torch.randn(10, 16)\n    data1['paper', 'paper'].edge_index = torch.randint(0, 10, (2, 30)).long()\n\n    data2 = HeteroData()\n    data2.y = torch.randn(5)\n    data2['paper'].x = torch.randn(10, 16)\n    data2['paper', 'paper'].edge_index = torch.randint(0, 10, (2, 30)).long()\n\n    dataset = MyTestDataset([data1, data2])\n    assert str(dataset) == 'MyTestDataset(2)'\n    assert len(dataset) == 2\n\n    assert len(dataset[0]) == 3\n    assert dataset[0].y.tolist() == data1.y.tolist()\n    assert dataset[0]['paper'].x.tolist() == data1['paper'].x.tolist()\n    assert (dataset[0]['paper', 'paper'].edge_index.tolist() == data1[\n        'paper', 'paper'].edge_index.tolist())\n\n    assert len(dataset[1]) == 3\n    assert dataset[1].y.tolist() == data2.y.tolist()\n    assert dataset[1]['paper'].x.tolist() == data2['paper'].x.tolist()\n    assert (dataset[1]['paper', 'paper'].edge_index.tolist() == data2[\n        'paper', 'paper'].edge_index.tolist())\n\n\ndef test_override_behavior():\n    class DS1(InMemoryDataset):\n        def __init__(self):\n            self.enter_download = False\n            self.enter_process = False\n            super().__init__()\n\n        def _download(self):\n            self.enter_download = True\n\n        def _process(self):\n            self.enter_process = True\n\n        def download(self):\n            pass\n\n        def process(self):\n            pass\n\n    class DS2(InMemoryDataset):\n        def __init__(self):\n            self.enter_download = False\n            self.enter_process = False\n            super().__init__()\n\n        def _download(self):\n            self.enter_download = True\n\n        def _process(self):\n            self.enter_process = True\n\n        def process(self):\n            pass\n\n    class DS3(InMemoryDataset):\n        def __init__(self):\n            self.enter_download = False\n            self.enter_process = False\n            super().__init__()\n\n        def _download(self):\n            self.enter_download = True\n\n        def _process(self):\n            self.enter_process = True\n\n    class DS4(DS1):\n        pass\n\n    ds = DS1()\n    assert ds.enter_download\n    assert ds.enter_process\n\n    ds = DS2()\n    assert not ds.enter_download\n    assert ds.enter_process\n\n    ds = DS3()\n    assert not ds.enter_download\n    assert not ds.enter_process\n\n    ds = DS4()\n    assert ds.enter_download\n    assert ds.enter_process\n\n\ndef test_lists_of_tensors_in_memory_dataset():\n    def tr(n, m):\n        return torch.rand((n, m))\n\n    d1 = Data(xs=[tr(4, 3), tr(11, 4), tr(1, 2)])\n    d2 = Data(xs=[tr(5, 3), tr(14, 4), tr(3, 2)])\n    d3 = Data(xs=[tr(6, 3), tr(15, 4), tr(2, 2)])\n    d4 = Data(xs=[tr(4, 3), tr(16, 4), tr(1, 2)])\n\n    data_list = [d1, d2, d3, d4]\n\n    dataset = MyTestDataset(data_list)\n    assert len(dataset) == 4\n    assert dataset[0].xs[1].size() == (11, 4)\n    assert dataset[0].xs[2].size() == (1, 2)\n    assert dataset[1].xs[0].size() == (5, 3)\n    assert dataset[2].xs[1].size() == (15, 4)\n    assert dataset[3].xs[1].size() == (16, 4)\n\n\n@withPackage('torch_sparse')\ndef test_lists_of_sparse_tensors():\n    e1 = torch.tensor([[4, 1, 3, 2, 2, 3], [1, 3, 2, 3, 3, 2]])\n    e2 = torch.tensor([[0, 1, 4, 7, 2, 9], [7, 2, 2, 1, 4, 7]])\n    e3 = torch.tensor([[3, 5, 1, 2, 3, 3], [5, 0, 2, 1, 3, 7]])\n    e4 = torch.tensor([[0, 1, 9, 2, 0, 3], [1, 1, 2, 1, 3, 2]])\n    adj1 = SparseTensor.from_edge_index(e1, sparse_sizes=(11, 11))\n    adj2 = SparseTensor.from_edge_index(e2, sparse_sizes=(22, 22))\n    adj3 = SparseTensor.from_edge_index(e3, sparse_sizes=(12, 12))\n    adj4 = SparseTensor.from_edge_index(e4, sparse_sizes=(15, 15))\n\n    d1 = Data(adj_test=[adj1, adj2])\n    d2 = Data(adj_test=[adj3, adj4])\n\n    data_list = [d1, d2]\n    dataset = MyTestDataset(data_list)\n    assert len(dataset) == 2\n    assert dataset[0].adj_test[0].sparse_sizes() == (11, 11)\n    assert dataset[0].adj_test[1].sparse_sizes() == (22, 22)\n    assert dataset[1].adj_test[0].sparse_sizes() == (12, 12)\n    assert dataset[1].adj_test[1].sparse_sizes() == (15, 15)\n\n\ndef test_file_names_as_property_and_method():\n    class MyTestDataset(InMemoryDataset):\n        def __init__(self):\n            super().__init__('/tmp/MyTestDataset')\n\n        @property\n        def raw_file_names(self):\n            return ['test_file']\n\n        def download(self):\n            pass\n\n    MyTestDataset()\n\n    class MyTestDataset(InMemoryDataset):\n        def __init__(self):\n            super().__init__('/tmp/MyTestDataset')\n\n        def raw_file_names(self):\n            return ['test_file']\n\n        def download(self):\n            pass\n\n    MyTestDataset()\n\n\n@withPackage('sqlite3')\ndef test_to_on_disk_dataset(tmp_path):\n    class MyTransform(BaseTransform):\n        def forward(self, data: Data) -> Data:\n            data.z = 'test_str'\n            return data\n\n    in_memory_dataset = KarateClub(transform=MyTransform())\n\n    with pytest.raises(ValueError, match=\"root directory of 'KarateClub'\"):\n        in_memory_dataset.to_on_disk_dataset()\n\n    on_disk_dataset = in_memory_dataset.to_on_disk_dataset(tmp_path, log=False)\n    assert str(on_disk_dataset) == 'OnDiskKarateClub()'\n    assert on_disk_dataset.schema == {\n        'x': dict(dtype=torch.float32, size=(-1, 34)),\n        'edge_index': dict(dtype=torch.int64, size=(2, -1)),\n        'y': dict(dtype=torch.int64, size=(-1, )),\n        'train_mask': dict(dtype=torch.bool, size=(-1, )),\n    }\n    assert in_memory_dataset.transform == on_disk_dataset.transform\n\n    data1 = in_memory_dataset[0]\n    data2 = on_disk_dataset[0]\n\n    assert len(data1) == len(data2)\n    assert torch.allclose(data1.x, data2.x)\n    assert torch.equal(data1.edge_index, data2.edge_index)\n    assert torch.equal(data1.y, data2.y)\n    assert torch.equal(data1.train_mask, data2.train_mask)\n    assert data1.z == data2.z\n\n    on_disk_dataset.close()\n"
  },
  {
    "path": "test/data/test_dataset_summary.py",
    "content": "import torch\nfrom torch import Tensor\n\nfrom torch_geometric.data.summary import Stats, Summary\nfrom torch_geometric.datasets import FakeDataset, FakeHeteroDataset\nfrom torch_geometric.testing import withPackage\n\n\ndef check_stats(stats: Stats, expected: Tensor):\n    expected = expected.to(torch.float)\n    assert stats.mean == float(expected.mean())\n    assert stats.std == float(expected.std())\n    assert stats.min == float(expected.min())\n    assert stats.quantile25 == float(expected.quantile(0.25))\n    assert stats.median == float(expected.median())\n    assert stats.quantile75 == float(expected.quantile(0.75))\n    assert stats.max == float(expected.max())\n\n\ndef test_dataset_summary():\n    dataset = FakeDataset(num_graphs=10)\n    num_nodes = torch.tensor([data.num_nodes for data in dataset])\n    num_edges = torch.tensor([data.num_edges for data in dataset])\n\n    summary = dataset.get_summary()\n\n    assert summary.name == 'FakeDataset'\n    assert summary.num_graphs == 10\n\n    check_stats(summary.num_nodes, num_nodes)\n    check_stats(summary.num_edges, num_edges)\n\n\n@withPackage('tabulate')\ndef test_dataset_summary_representation():\n    dataset = FakeDataset(num_graphs=10)\n\n    summary1 = Summary.from_dataset(dataset, per_type=False)\n    summary2 = Summary.from_dataset(dataset, per_type=True)\n\n    assert str(summary1) == str(summary2)\n\n\n@withPackage('tabulate')\ndef test_dataset_summary_hetero():\n    dataset1 = FakeHeteroDataset(num_graphs=10)\n    summary1 = Summary.from_dataset(dataset1, per_type=False)\n\n    dataset2 = [data.to_homogeneous() for data in dataset1]\n    summary2 = Summary.from_dataset(dataset2)\n    summary2.name = 'FakeHeteroDataset'\n\n    assert summary1 == summary2\n    assert str(summary1) == str(summary2)\n\n\n@withPackage('tabulate')\ndef test_dataset_summary_hetero_representation_length():\n    dataset = FakeHeteroDataset(num_graphs=10)\n    summary = Summary.from_dataset(dataset)\n    num_lines = len(str(summary).splitlines())\n\n    stats_len = len(Stats.__dataclass_fields__)\n    len_header_and_border = 5\n    num_tables = 3  # general, stats per node type, stats per edge type\n\n    assert num_lines == num_tables * (stats_len + len_header_and_border)\n\n\ndef test_dataset_summary_hetero_per_type_check():\n    dataset = FakeHeteroDataset(num_graphs=10)\n    exp_num_nodes = torch.tensor([data.num_nodes for data in dataset])\n    exp_num_edges = torch.tensor([data.num_edges for data in dataset])\n\n    summary = dataset.get_summary()\n\n    assert summary.name == 'FakeHeteroDataset'\n    assert summary.num_graphs == 10\n\n    check_stats(summary.num_nodes, exp_num_nodes)\n    check_stats(summary.num_edges, exp_num_edges)\n\n    num_nodes_per_type = {}\n    for node_type in dataset.node_types:\n        num_nodes = [data[node_type].num_nodes for data in dataset]\n        num_nodes_per_type[node_type] = torch.tensor(num_nodes)\n\n    assert len(summary.num_nodes_per_type) == len(dataset.node_types)\n    for node_type, stats in summary.num_nodes_per_type.items():\n        check_stats(stats, num_nodes_per_type[node_type])\n\n    num_edges_per_type = {}\n    for edge_type in dataset.edge_types:\n        num_edges = [data[edge_type].num_edges for data in dataset]\n        num_edges_per_type[edge_type] = torch.tensor(num_edges)\n\n    assert len(summary.num_edges_per_type) == len(dataset.edge_types)\n    for edge_type, stats in summary.num_edges_per_type.items():\n        check_stats(stats, num_edges_per_type[edge_type])\n"
  },
  {
    "path": "test/data/test_feature_store.py",
    "content": "from dataclasses import dataclass\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import TensorAttr\nfrom torch_geometric.data.feature_store import AttrView, _FieldStatus\nfrom torch_geometric.testing import MyFeatureStore\n\n\n@dataclass\nclass MyTensorAttrNoGroupName(TensorAttr):\n    def __init__(self, attr_name=_FieldStatus.UNSET, index=_FieldStatus.UNSET):\n        # Treat group_name as optional, and move it to the end\n        super().__init__(None, attr_name, index)\n\n\nclass MyFeatureStoreNoGroupName(MyFeatureStore):\n    def __init__(self):\n        super().__init__()\n        self._tensor_attr_cls = MyTensorAttrNoGroupName\n\n\ndef test_feature_store():\n    store = MyFeatureStore()\n    tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]])\n\n    group_name = 'A'\n    attr_name = 'feat'\n    index = torch.tensor([0, 1, 2])\n    attr = TensorAttr(group_name, attr_name, index)\n    assert TensorAttr(group_name).update(attr) == attr\n\n    # Normal API:\n    store.put_tensor(tensor, attr)\n    assert torch.equal(store.get_tensor(attr), tensor)\n    assert torch.equal(\n        store.get_tensor(group_name, attr_name, index=torch.tensor([0, 2])),\n        tensor[torch.tensor([0, 2])],\n    )\n\n    assert store.update_tensor(tensor + 1, attr)\n    assert torch.equal(store.get_tensor(attr), tensor + 1)\n\n    store.remove_tensor(attr)\n    with pytest.raises(KeyError):\n        _ = store.get_tensor(attr)\n\n    # Views:\n    view = store.view(group_name=group_name)\n    view.attr_name = attr_name\n    view['index'] = index\n    assert view != \"not a 'AttrView' object\"\n    assert view == AttrView(store, TensorAttr(group_name, attr_name, index))\n    assert str(view) == (\"AttrView(store=MyFeatureStore(), \"\n                         \"attr=TensorAttr(group_name='A', attr_name='feat', \"\n                         \"index=tensor([0, 1, 2])))\")\n\n    # Indexing:\n    store[group_name, attr_name, index] = tensor\n\n    # Fully-specified forms, all of which produce a tensor output\n    assert torch.equal(store[group_name, attr_name, index], tensor)\n    assert torch.equal(store[group_name, attr_name, None], tensor)\n    assert torch.equal(store[group_name, attr_name, :], tensor)\n    assert torch.equal(store[group_name][attr_name][:], tensor)\n    assert torch.equal(store[group_name].feat[:], tensor)\n    assert torch.equal(store.view().A.feat[:], tensor)\n\n    with pytest.raises(AttributeError) as exc_info:\n        _ = store.view(group_name=group_name, index=None).feat.A\n        print(exc_info)\n\n    # Partially-specified forms, which produce an AttrView object\n    assert store[group_name] == store.view(TensorAttr(group_name=group_name))\n    assert store[group_name].feat == store.view(\n        TensorAttr(group_name=group_name, attr_name=attr_name))\n\n    # Partially-specified forms, when called, produce a Tensor output\n    # from the `TensorAttr` that has been partially specified.\n    store[group_name, None, None] = tensor\n    assert isinstance(store[group_name], AttrView)\n    assert torch.equal(store[group_name](), tensor)\n\n    # Deletion:\n    del store[group_name, attr_name, index]\n    with pytest.raises(KeyError):\n        _ = store[group_name, attr_name, index]\n    del store[group_name]\n    with pytest.raises(KeyError):\n        _ = store[group_name]()\n\n\ndef test_feature_store_override():\n    store = MyFeatureStoreNoGroupName()\n    tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]])\n\n    attr_name = 'feat'\n    index = torch.tensor([0, 1, 2])\n\n    # Only use attr_name and index, in that order:\n    store[attr_name, index] = tensor\n\n    # A few assertions to ensure group_name is not needed:\n    assert isinstance(store[attr_name], AttrView)\n    assert torch.equal(store[attr_name, index], tensor)\n    assert torch.equal(store[attr_name][index], tensor)\n    assert torch.equal(store[attr_name][:], tensor)\n    assert torch.equal(store[attr_name, :], tensor)\n"
  },
  {
    "path": "test/data/test_graph_store.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data.graph_store import EdgeAttr, EdgeLayout\nfrom torch_geometric.testing import MyGraphStore, get_random_edge_index\nfrom torch_geometric.utils import (\n    to_torch_coo_tensor,\n    to_torch_csc_tensor,\n    to_torch_csr_tensor,\n)\n\n\ndef test_graph_store():\n    graph_store = MyGraphStore()\n\n    assert str(graph_store) == 'MyGraphStore()'\n\n    coo = torch.tensor([0, 1]), torch.tensor([1, 2])\n    csr = torch.tensor([0, 1, 2]), torch.tensor([1, 2])\n    csc = torch.tensor([0, 1]), torch.tensor([0, 0, 1, 2])\n\n    graph_store['edge_type', 'coo'] = coo\n    graph_store['edge_type', 'csr'] = csr\n    graph_store['edge_type', 'csc'] = csc\n\n    assert torch.equal(graph_store['edge_type', 'coo'][0], coo[0])\n    assert torch.equal(graph_store['edge_type', 'coo'][1], coo[1])\n    assert torch.equal(graph_store['edge_type', 'csr'][0], csr[0])\n    assert torch.equal(graph_store['edge_type', 'csr'][1], csr[1])\n    assert torch.equal(graph_store['edge_type', 'csc'][0], csc[0])\n    assert torch.equal(graph_store['edge_type', 'csc'][1], csc[1])\n\n    assert len(graph_store.get_all_edge_attrs()) == 3\n\n    del graph_store['edge_type', 'coo']\n    with pytest.raises(KeyError):\n        graph_store['edge_type', 'coo']\n\n    with pytest.raises(KeyError):\n        graph_store['edge_type_2', 'coo']\n\n\ndef test_graph_store_conversion():\n    graph_store = MyGraphStore()\n\n    edge_index = get_random_edge_index(100, 100, 300)\n    adj = to_torch_coo_tensor(edge_index, size=(100, 100))\n    coo = (adj.indices()[0], adj.indices()[1])\n    adj = to_torch_csr_tensor(edge_index, size=(100, 100))\n    csr = (adj.crow_indices(), adj.col_indices())\n    adj = to_torch_csc_tensor(edge_index, size=(100, 100))\n    csc = (adj.row_indices(), adj.ccol_indices())\n\n    graph_store.put_edge_index(coo, ('v', '1', 'v'), 'coo', size=(100, 100))\n    graph_store.put_edge_index(csr, ('v', '2', 'v'), 'csr', size=(100, 100))\n    graph_store.put_edge_index(csc, ('v', '3', 'v'), 'csc', size=(100, 100))\n\n    # Convert to COO:\n    row_dict, col_dict, perm_dict = graph_store.coo()\n    assert len(row_dict) == len(col_dict) == len(perm_dict) == 3\n    for row, col, perm in zip(row_dict.values(), col_dict.values(),\n                              perm_dict.values()):\n        assert torch.equal(row.sort()[0], coo[0].sort()[0])\n        assert torch.equal(col.sort()[0], coo[1].sort()[0])\n        assert perm is None\n\n    # Convert to CSR:\n    row_dict, col_dict, perm_dict = graph_store.csr()\n    assert len(row_dict) == len(col_dict) == len(perm_dict) == 3\n    for row, col in zip(row_dict.values(), col_dict.values()):\n        assert torch.equal(row, csr[0])\n        assert torch.equal(col.sort()[0], csr[1].sort()[0])\n\n    # Convert to CSC:\n    row_dict, col_dict, perm_dict = graph_store.csc()\n    assert len(row_dict) == len(col_dict) == len(perm_dict) == 3\n    for row, col in zip(row_dict.values(), col_dict.values()):\n        assert torch.equal(row.sort()[0], csc[0].sort()[0])\n        assert torch.equal(col, csc[1])\n\n    # Ensure that 'edge_types' parameters work as intended:\n    out = graph_store.coo([('v', '1', 'v')])\n    assert torch.equal(list(out[0].values())[0], coo[0])\n    assert torch.equal(list(out[1].values())[0], coo[1])\n\n    # Ensure that 'store' parameter works as intended:\n    key = EdgeAttr(edge_type=('v', '1', 'v'), layout=EdgeLayout.CSR,\n                   is_sorted=False, size=(100, 100))\n    with pytest.raises(KeyError):\n        graph_store[key]\n\n    out = graph_store.csr([('v', '1', 'v')], store=True)\n    assert torch.equal(list(out[0].values())[0], csr[0])\n    assert torch.equal(list(out[1].values())[0].sort()[0], csr[1].sort()[0])\n\n    out = graph_store[key]\n    assert torch.equal(out[0], csr[0])\n    assert torch.equal(out[1].sort()[0], csr[1].sort()[0])\n"
  },
  {
    "path": "test/data/test_hetero_data.py",
    "content": "import copy\nimport warnings\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.data.storage import EdgeStorage\nfrom torch_geometric.testing import (\n    get_random_edge_index,\n    get_random_tensor_frame,\n    withPackage,\n)\nfrom torch_geometric.typing import TensorFrame\n\nx_paper = torch.randn(10, 16)\nx_author = torch.randn(5, 32)\nx_conference = torch.randn(5, 8)\n\nidx_paper = torch.randint(x_paper.size(0), (100, ), dtype=torch.long)\nidx_author = torch.randint(x_author.size(0), (100, ), dtype=torch.long)\nidx_conference = torch.randint(x_conference.size(0), (100, ), dtype=torch.long)\n\nedge_index_paper_paper = torch.stack([idx_paper[:50], idx_paper[:50]], dim=0)\nedge_index_paper_author = torch.stack([idx_paper[:30], idx_author[:30]], dim=0)\nedge_index_author_paper = torch.stack([idx_author[:30], idx_paper[:30]], dim=0)\nedge_index_paper_conference = torch.stack(\n    [idx_paper[:25], idx_conference[:25]], dim=0)\n\nedge_attr_paper_paper = torch.randn(edge_index_paper_paper.size(1), 8)\nedge_attr_author_paper = torch.randn(edge_index_author_paper.size(1), 8)\n\n\ndef test_init_hetero_data():\n    data = HeteroData()\n    data['v1'].x = 1\n    data['paper'].x = x_paper\n    data['author'].x = x_author\n    data['paper', 'paper'].edge_index = edge_index_paper_paper\n    data['paper', 'author'].edge_index = edge_index_paper_author\n    data['author', 'paper'].edge_index = edge_index_author_paper\n    with pytest.warns(UserWarning, match=\"{'v1'} are isolated\"):\n        data.validate(raise_on_error=True)\n\n    assert len(data) == 2\n    assert data.node_types == ['v1', 'paper', 'author']\n    assert len(data.node_stores) == 3\n    assert len(data.node_items()) == 3\n    assert len(data.edge_types) == 3\n    assert len(data.edge_stores) == 3\n    assert len(data.edge_items()) == 3\n\n    data = HeteroData(\n        v1={'x': 1},\n        paper={'x': x_paper},\n        author={'x': x_author},\n        paper__paper={'edge_index': edge_index_paper_paper},\n        paper__author={'edge_index': edge_index_paper_author},\n        author__paper={'edge_index': edge_index_author_paper},\n    )\n\n    assert len(data) == 2\n    assert data.node_types == ['v1', 'paper', 'author']\n    assert len(data.node_stores) == 3\n    assert len(data.node_items()) == 3\n    assert len(data.edge_types) == 3\n    assert len(data.edge_stores) == 3\n    assert len(data.edge_items()) == 3\n\n    data = HeteroData({\n        'v1': {\n            'x': 1\n        },\n        'paper': {\n            'x': x_paper\n        },\n        'author': {\n            'x': x_author\n        },\n        ('paper', 'paper'): {\n            'edge_index': edge_index_paper_paper\n        },\n        ('paper', 'author'): {\n            'edge_index': edge_index_paper_author\n        },\n        ('author', 'paper'): {\n            'edge_index': edge_index_author_paper\n        },\n    })\n\n    assert len(data) == 2\n    assert data.node_types == ['v1', 'paper', 'author']\n    assert len(data.node_stores) == 3\n    assert len(data.node_items()) == 3\n    assert len(data.edge_types) == 3\n    assert len(data.edge_stores) == 3\n    assert len(data.edge_items()) == 3\n\n\ndef test_hetero_data_to_from_dict():\n    data = HeteroData()\n    data.global_id = '1'\n    data['v1'].x = torch.randn(5, 16)\n    data['v2'].y = torch.randn(4, 16)\n    data['v1', 'v2'].edge_index = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]])\n\n    out = HeteroData.from_dict(data.to_dict())\n    assert out.global_id == data.global_id\n    assert torch.equal(out['v1'].x, data['v1'].x)\n    assert torch.equal(out['v2'].y, data['v2'].y)\n    assert torch.equal(out['v1', 'v2'].edge_index, data['v1', 'v2'].edge_index)\n\n\ndef test_hetero_data_functions():\n    data = HeteroData()\n    with pytest.raises(KeyError, match=\"did not find any occurrences of it\"):\n        data.collect('x')\n    data['paper'].x = x_paper\n    data['author'].x = x_author\n    data['paper', 'paper'].edge_index = edge_index_paper_paper\n    data['paper', 'author'].edge_index = edge_index_paper_author\n    data['author', 'paper'].edge_index = edge_index_author_paper\n    data['paper', 'paper'].edge_attr = edge_attr_paper_paper\n    assert len(data) == 3\n    assert sorted(data.keys()) == ['edge_attr', 'edge_index', 'x']\n    assert 'x' in data and 'edge_index' in data and 'edge_attr' in data\n    assert data.num_nodes == 15\n    assert data.num_edges == 110\n\n    assert data.node_attrs() == ['x']\n    assert sorted(data.edge_attrs()) == ['edge_attr', 'edge_index']\n\n    assert data.num_node_features == {'paper': 16, 'author': 32}\n    assert data.num_edge_features == {\n        ('paper', 'to', 'paper'): 8,\n        ('paper', 'to', 'author'): 0,\n        ('author', 'to', 'paper'): 0,\n    }\n\n    node_types, edge_types = data.metadata()\n    assert node_types == ['paper', 'author']\n    assert edge_types == [\n        ('paper', 'to', 'paper'),\n        ('paper', 'to', 'author'),\n        ('author', 'to', 'paper'),\n    ]\n\n    x_dict = data.collect('x')\n    assert len(x_dict) == 2\n    assert x_dict['paper'].tolist() == x_paper.tolist()\n    assert x_dict['author'].tolist() == x_author.tolist()\n    assert x_dict == data.x_dict\n\n    data.y = 0\n    assert data['y'] == 0 and data.y == 0\n    assert len(data) == 4\n    assert sorted(data.keys()) == ['edge_attr', 'edge_index', 'x', 'y']\n\n    del data['paper', 'author']\n    node_types, edge_types = data.metadata()\n    assert node_types == ['paper', 'author']\n    assert edge_types == [('paper', 'to', 'paper'), ('author', 'to', 'paper')]\n\n    assert len(data.to_dict()) == 5\n    assert len(data.to_namedtuple()) == 5\n    assert data.to_namedtuple().y == 0\n    assert len(data.to_namedtuple().paper) == 1\n\n\ndef test_hetero_data_set_value_dict():\n    data = HeteroData()\n    data.set_value_dict('x', {\n        'paper': torch.randn(4, 16),\n        'author': torch.randn(8, 32),\n    })\n    assert data.node_types == ['paper', 'author']\n    assert data.edge_types == []\n    assert data['paper'].x.size() == (4, 16)\n    assert data['author'].x.size() == (8, 32)\n\n\ndef test_hetero_data_rename():\n    data = HeteroData()\n    data['paper'].x = x_paper\n    data['author'].x = x_author\n    data['paper', 'paper'].edge_index = edge_index_paper_paper\n    data['paper', 'author'].edge_index = edge_index_paper_author\n    data['author', 'paper'].edge_index = edge_index_author_paper\n\n    data = data.rename('paper', 'article')\n    assert data.node_types == ['author', 'article']\n    assert data.edge_types == [\n        ('article', 'to', 'article'),\n        ('article', 'to', 'author'),\n        ('author', 'to', 'article'),\n    ]\n\n    assert data['article'].x.tolist() == x_paper.tolist()\n    edge_index = data['article', 'article'].edge_index\n    assert edge_index.tolist() == edge_index_paper_paper.tolist()\n\n\ndef test_dangling_types():\n    data = HeteroData()\n    data['src', 'to', 'dst'].edge_index = torch.randint(0, 10, (2, 20))\n    with pytest.raises(ValueError, match=\"do not exist as node types\"):\n        data.validate()\n\n    data = HeteroData()\n    data['node'].num_nodes = 10\n    with pytest.warns(UserWarning, match=\"{'node'} are isolated\"):\n        data.validate()\n\n\ndef test_hetero_data_subgraph():\n    data = HeteroData()\n    data.num_node_types = 3\n    data['paper'].x = x_paper\n    data['paper'].name = 'paper'\n    data['paper'].num_nodes = x_paper.size(0)\n    data['author'].x = x_author\n    data['author'].num_nodes = x_author.size(0)\n    data['conf'].x = x_conference\n    data['conf'].num_nodes = x_conference.size(0)\n    data['paper', 'paper'].edge_index = edge_index_paper_paper\n    data['paper', 'paper'].edge_attr = edge_attr_paper_paper\n    data['paper', 'paper'].name = 'cites'\n    data['author', 'paper'].edge_index = edge_index_author_paper\n    data['paper', 'author'].edge_index = edge_index_paper_author\n    data['paper', 'conf'].edge_index = edge_index_paper_conference\n\n    subset = {\n        'paper': torch.randperm(x_paper.size(0))[:4],\n        'author': torch.randperm(x_author.size(0))[:2],\n        'conf': torch.randperm(x_conference.size(0))[:2],\n    }\n\n    out = data.subgraph(subset)\n    out.validate(raise_on_error=True)\n\n    assert out.num_node_types == data.num_node_types\n    assert out.node_types == ['paper', 'author', 'conf']\n\n    for key in out.node_types:\n        assert len(out[key]) == len(data[key])\n        assert torch.allclose(out[key].x, data[key].x[subset[key]])\n        assert out[key].num_nodes == subset[key].size(0)\n        if key == 'paper':\n            assert out['paper'].name == 'paper'\n\n    # Construct correct edge index manually:\n    node_mask = {}  # for each node type a mask of nodes in the subgraph\n    node_map = {}  # for each node type a map from old node id to new node id\n    for key in out.node_types:\n        node_mask[key] = torch.zeros((data[key].num_nodes, ), dtype=torch.bool)\n        node_map[key] = torch.zeros((data[key].num_nodes, ), dtype=torch.long)\n        node_mask[key][subset[key]] = True\n        node_map[key][subset[key]] = torch.arange(subset[key].size(0))\n\n    edge_mask = {}  # for each edge type a mask of edges in the subgraph\n    subgraph_edge_index = {\n    }  # for each edge type the edge index of the subgraph\n    for key in out.edge_types:\n        edge_mask[key] = (node_mask[key[0]][data[key].edge_index[0]]\n                          & node_mask[key[-1]][data[key].edge_index[1]])\n        subgraph_edge_index[key] = data[key].edge_index[:, edge_mask[key]]\n        subgraph_edge_index[key][0] = node_map[key[0]][subgraph_edge_index[key]\n                                                       [0]]\n        subgraph_edge_index[key][1] = node_map[key[-1]][\n            subgraph_edge_index[key][1]]\n\n    assert out.edge_types == [\n        ('paper', 'to', 'paper'),\n        ('author', 'to', 'paper'),\n        ('paper', 'to', 'author'),\n        ('paper', 'to', 'conf'),\n    ]\n\n    for key in out.edge_types:\n        assert len(out[key]) == len(data[key])\n        assert torch.equal(out[key].edge_index, subgraph_edge_index[key])\n        if key == ('paper', 'to', 'paper'):\n            assert torch.allclose(out[key].edge_attr,\n                                  data[key].edge_attr[edge_mask[key]])\n            assert out[key].name == 'cites'\n\n    # Test for bool and long in `subset_dict`.\n    author_mask = torch.zeros((x_author.size(0), ), dtype=torch.bool)\n    author_mask[subset['author']] = True\n    subset_mixed = {\n        'paper': subset['paper'],\n        'author': author_mask,\n    }\n    out = data.subgraph(subset_mixed)\n    out.validate(raise_on_error=True)\n\n    assert out.num_node_types == data.num_node_types\n    assert out.node_types == ['paper', 'author', 'conf']\n    assert out['paper'].num_nodes == subset['paper'].size(0)\n    assert out['author'].num_nodes == subset['author'].size(0)\n    assert out['conf'].num_nodes == data['conf'].num_nodes\n    assert out.edge_types == [\n        ('paper', 'to', 'paper'),\n        ('author', 'to', 'paper'),\n        ('paper', 'to', 'author'),\n        ('paper', 'to', 'conf'),\n    ]\n\n    out = data.node_type_subgraph(['paper', 'author'])\n    assert out.node_types == ['paper', 'author']\n    assert out.edge_types == [('paper', 'to', 'paper'),\n                              ('author', 'to', 'paper'),\n                              ('paper', 'to', 'author')]\n\n    out = data.edge_type_subgraph([('paper', 'author')])\n    assert out.node_types == ['paper', 'author']\n    assert out.edge_types == [('paper', 'to', 'author')]\n\n    subset = {\n        ('paper', 'to', 'paper'): torch.arange(4),\n    }\n\n    out = data.edge_subgraph(subset)\n    assert out.node_types == data.node_types\n    assert out.edge_types == data.edge_types\n    assert data['paper'] == out['paper']\n    assert data['author'] == out['author']\n    assert data['paper', 'author'] == out['paper', 'author']\n    assert data['author', 'paper'] == out['author', 'paper']\n\n    assert out['paper', 'paper'].num_edges == 4\n    assert out['paper', 'paper'].edge_index.size() == (2, 4)\n    assert out['paper', 'paper'].edge_attr.size() == (4, 8)\n\n\ndef test_hetero_data_empty_subgraph():\n    data = HeteroData()\n    data.num_node_types = 3\n    data['paper'].x = torch.arange(5)\n    data['author'].x = torch.arange(5)\n    data['paper', 'author'].edge_weight = torch.arange(5)\n\n    out = data.subgraph(subset_dict={\n        'paper': torch.tensor([1, 2, 3]),\n        'author': torch.tensor([1, 2, 3]),\n    })\n\n    assert torch.equal(out['paper'].x, torch.arange(1, 4))\n    assert out['paper'].num_nodes == 3\n    assert torch.equal(out['author'].x, torch.arange(1, 4))\n    assert out['author'].num_nodes == 3\n    assert 'edge_index' not in out['paper', 'author']\n    assert torch.equal(out['paper', 'author'].edge_weight, torch.arange(5))\n\n\ndef test_copy_hetero_data():\n    data = HeteroData()\n    data['paper'].x = x_paper\n    data['paper', 'to', 'paper'].edge_index = edge_index_paper_paper\n\n    out = copy.copy(data)\n    assert id(data) != id(out)\n    assert len(data.stores) == len(out.stores)\n    for store1, store2 in zip(data.stores, out.stores):\n        assert id(store1) != id(store2)\n        assert id(data) == id(store1._parent())\n        assert id(out) == id(store2._parent())\n    assert out['paper']._key == 'paper'\n    assert data['paper'].x.data_ptr() == out['paper'].x.data_ptr()\n    assert out['to']._key == ('paper', 'to', 'paper')\n    assert data['to'].edge_index.data_ptr() == out['to'].edge_index.data_ptr()\n\n    out = copy.deepcopy(data)\n    assert id(data) != id(out)\n    assert len(data.stores) == len(out.stores)\n    for store1, store2 in zip(data.stores, out.stores):\n        assert id(store1) != id(store2)\n    assert id(out) == id(out['paper']._parent())\n    assert out['paper']._key == 'paper'\n    assert data['paper'].x.data_ptr() != out['paper'].x.data_ptr()\n    assert data['paper'].x.tolist() == out['paper'].x.tolist()\n    assert id(out) == id(out['to']._parent())\n    assert out['to']._key == ('paper', 'to', 'paper')\n    assert data['to'].edge_index.data_ptr() != out['to'].edge_index.data_ptr()\n    assert data['to'].edge_index.tolist() == out['to'].edge_index.tolist()\n\n\ndef test_to_homogeneous_and_vice_versa():\n    data = HeteroData()\n\n    data['paper'].x = torch.randn(100, 128)\n    data['paper'].y = torch.randint(0, 10, (100, ))\n    data['author'].x = torch.randn(200, 128)\n\n    data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 250)\n    data['paper', 'paper'].edge_weight = torch.randn(250, )\n    data['paper', 'paper'].edge_attr = torch.randn(250, 64)\n\n    data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 500)\n    data['paper', 'author'].edge_weight = torch.randn(500, )\n    data['paper', 'author'].edge_attr = torch.randn(500, 64)\n\n    data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000)\n    data['author', 'paper'].edge_weight = torch.randn(1000, )\n    data['author', 'paper'].edge_attr = torch.randn(1000, 64)\n\n    out = data.to_homogeneous()\n    assert len(out) == 7\n    assert out.num_nodes == 300\n    assert out.num_edges == 1750\n    assert out.num_node_features == 128\n    assert out.num_edge_features == 64\n    assert out.node_type.size() == (300, )\n    assert out.node_type.min() == 0\n    assert out.node_type.max() == 1\n    assert out.edge_type.size() == (1750, )\n    assert out.edge_type.min() == 0\n    assert out.edge_type.max() == 2\n    assert len(out._node_type_names) == 2\n    assert len(out._edge_type_names) == 3\n    assert out.y.size() == (300, )\n    assert torch.allclose(out.y[:100], data['paper'].y)\n    assert torch.all(out.y[100:] == -1)\n    assert 'y' not in data['author']\n\n    out = out.to_heterogeneous()\n    assert len(out) == 5\n    assert torch.allclose(data['paper'].x, out['paper'].x)\n    assert torch.allclose(data['author'].x, out['author'].x)\n    assert torch.allclose(data['paper'].y, out['paper'].y)\n\n    edge_index1 = data['paper', 'paper'].edge_index\n    edge_index2 = out['paper', 'paper'].edge_index\n    assert edge_index1.tolist() == edge_index2.tolist()\n    assert torch.allclose(\n        data['paper', 'paper'].edge_weight,\n        out['paper', 'paper'].edge_weight,\n    )\n    assert torch.allclose(\n        data['paper', 'paper'].edge_attr,\n        out['paper', 'paper'].edge_attr,\n    )\n\n    edge_index1 = data['paper', 'author'].edge_index\n    edge_index2 = out['paper', 'author'].edge_index\n    assert edge_index1.tolist() == edge_index2.tolist()\n    assert torch.allclose(\n        data['paper', 'author'].edge_weight,\n        out['paper', 'author'].edge_weight,\n    )\n    assert torch.allclose(\n        data['paper', 'author'].edge_attr,\n        out['paper', 'author'].edge_attr,\n    )\n\n    edge_index1 = data['author', 'paper'].edge_index\n    edge_index2 = out['author', 'paper'].edge_index\n    assert edge_index1.tolist() == edge_index2.tolist()\n    assert torch.allclose(\n        data['author', 'paper'].edge_weight,\n        out['author', 'paper'].edge_weight,\n    )\n    assert torch.allclose(\n        data['author', 'paper'].edge_attr,\n        out['author', 'paper'].edge_attr,\n    )\n\n    out = data.to_homogeneous()\n    node_type = out.node_type\n    edge_type = out.edge_type\n    del out.node_type\n    del out.edge_type\n    del out._edge_type_names\n    del out._node_type_names\n    out = out.to_heterogeneous(node_type, edge_type)\n    assert len(out) == 5\n    assert torch.allclose(data['paper'].x, out['0'].x)\n    assert torch.allclose(data['author'].x, out['1'].x)\n\n    edge_index1 = data['paper', 'paper'].edge_index\n    edge_index2 = out['0', '0'].edge_index\n    assert edge_index1.tolist() == edge_index2.tolist()\n    assert torch.allclose(\n        data['paper', 'paper'].edge_weight,\n        out['0', '0'].edge_weight,\n    )\n    assert torch.allclose(\n        data['paper', 'paper'].edge_attr,\n        out['0', '0'].edge_attr,\n    )\n\n    edge_index1 = data['paper', 'author'].edge_index\n    edge_index2 = out['0', '1'].edge_index\n    assert edge_index1.tolist() == edge_index2.tolist()\n    assert torch.allclose(\n        data['paper', 'author'].edge_weight,\n        out['0', '1'].edge_weight,\n    )\n    assert torch.allclose(\n        data['paper', 'author'].edge_attr,\n        out['0', '1'].edge_attr,\n    )\n\n    edge_index1 = data['author', 'paper'].edge_index\n    edge_index2 = out['1', '0'].edge_index\n    assert edge_index1.tolist() == edge_index2.tolist()\n    assert torch.allclose(\n        data['author', 'paper'].edge_weight,\n        out['1', '0'].edge_weight,\n    )\n    assert torch.allclose(\n        data['author', 'paper'].edge_attr,\n        out['1', '0'].edge_attr,\n    )\n\n    data = HeteroData()\n\n    data['paper'].num_nodes = 100\n    data['author'].num_nodes = 200\n\n    out = data.to_homogeneous(add_node_type=False)\n    assert len(out) == 1\n    assert out.num_nodes == 300\n\n    out = data.to_homogeneous().to_heterogeneous()\n    assert len(out) == 1\n    assert out['paper'].num_nodes == 100\n    assert out['author'].num_nodes == 200\n\n\ndef test_to_homogeneous_padding():\n    data = HeteroData()\n    data['paper'].x = torch.randn(100, 128)\n    data['author'].x = torch.randn(50, 64)\n\n    out = data.to_homogeneous()\n    assert len(out) == 2\n    assert out.node_type.size() == (150, )\n    assert out.node_type[:100].abs().sum() == 0\n    assert out.node_type[100:].sub(1).abs().sum() == 0\n    assert out.x.size() == (150, 128)\n    assert torch.equal(out.x[:100], data['paper'].x)\n    assert torch.equal(out.x[100:, :64], data['author'].x)\n    assert out.x[100:, 64:].abs().sum() == 0\n\n\ndef test_hetero_data_to_canonical():\n    data = HeteroData()\n    assert isinstance(data['user', 'product'], EdgeStorage)\n    assert len(data.edge_types) == 1\n    assert isinstance(data['user', 'to', 'product'], EdgeStorage)\n    assert len(data.edge_types) == 1\n\n    data = HeteroData()\n    assert isinstance(data['user', 'buys', 'product'], EdgeStorage)\n    assert isinstance(data['user', 'clicks', 'product'], EdgeStorage)\n    assert len(data.edge_types) == 2\n\n    with pytest.raises(TypeError, match=\"missing 1 required\"):\n        data['user', 'product']\n\n\ndef test_hetero_data_invalid_names():\n    data = HeteroData()\n    with pytest.warns(UserWarning, match=\"single underscores\"):\n        data['my test', 'a__b', 'my test'].edge_attr = torch.randn(10, 16)\n    with warnings.catch_warnings():  # No warning should be raised afterwards:\n        warnings.simplefilter('error')\n        data['my test', 'a__c', 'my test'].edge_attr = torch.randn(10, 16)\n    assert data.edge_types == [\n        ('my test', 'a__b', 'my test'),\n        ('my test', 'a__c', 'my test'),\n    ]\n\n\ndef test_hetero_data_update():\n    data = HeteroData()\n    data['paper'].x = torch.arange(0, 5)\n    data['paper'].y = torch.arange(5, 10)\n    data['author'].x = torch.arange(10, 15)\n\n    other = HeteroData()\n    other['paper'].x = torch.arange(15, 20)\n    other['author'].y = torch.arange(20, 25)\n    other['paper', 'paper'].edge_index = torch.randint(5, (2, 20))\n\n    data.update(other)\n    assert len(data) == 3\n    assert torch.equal(data['paper'].x, torch.arange(15, 20))\n    assert torch.equal(data['paper'].y, torch.arange(5, 10))\n    assert torch.equal(data['author'].x, torch.arange(10, 15))\n    assert torch.equal(data['author'].y, torch.arange(20, 25))\n    assert torch.equal(data['paper', 'paper'].edge_index,\n                       other['paper', 'paper'].edge_index)\n\n\ndef test_hetero_data_connected_components():\n    data = HeteroData()\n    data[\"red\"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0]])\n    data[\"red\"].y = torch.tensor([1, 2, 3, 4, 5])\n    data[\"red\"].z = torch.tensor([[1.1, 1.2], [2.1, 2.2], [3.1, 3.2],\n                                  [4.1, 4.2], [5.1, 5.2]])\n    data[\"blue\"].x = torch.tensor([[6.0], [7.0]])\n    data[\"red\", \"to\",\n         \"red\"].edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]],\n                                          dtype=torch.long)\n    data[\"red\", \"with\", \"red\"].edge_index = torch.tensor([[1], [1]],\n                                                         dtype=torch.long)\n    data[\"red\", \"to\", \"blue\"].edge_index = torch.tensor([[0], [0]])\n    split_data = data.connected_components()\n\n    assert isinstance(split_data, list)\n    assert len(split_data) == 4\n    assert isinstance(split_data[0], HeteroData)\n    assert isinstance(split_data[1], HeteroData)\n    assert isinstance(split_data[2], HeteroData)\n    assert isinstance(split_data[3], HeteroData)\n    assert split_data[0].node_types == ['red', 'blue']\n    assert split_data[1].node_types == ['red', 'blue']\n    assert split_data[2].node_types == ['red', 'blue']\n    assert split_data[3].node_types == ['red', 'blue']\n    assert split_data[0].edge_types == [('red', 'to', 'red'),\n                                        ('red', 'with', 'red'),\n                                        ('red', 'to', 'blue')]\n    assert split_data[1].edge_types == [('red', 'to', 'red'),\n                                        ('red', 'with', 'red'),\n                                        ('red', 'to', 'blue')]\n    assert split_data[2].edge_types == [('red', 'to', 'red'),\n                                        ('red', 'with', 'red'),\n                                        ('red', 'to', 'blue')]\n    assert split_data[3].edge_types == [('red', 'to', 'red'),\n                                        ('red', 'with', 'red'),\n                                        ('red', 'to', 'blue')]\n    assert torch.equal(split_data[0][\"red\"].x, torch.tensor([[1.0], [2.0]]))\n    assert torch.equal(split_data[0][\"red\"].y, torch.tensor([1, 2]))\n    assert torch.equal(split_data[0][\"red\"].z,\n                       torch.tensor([[1.1, 1.2], [2.1, 2.2]]))\n    assert torch.equal(split_data[0][\"blue\"].x, torch.tensor([[6.0]]))\n\n    assert torch.equal(split_data[1][\"red\"].x, torch.tensor([[3.0], [4.0]]))\n    assert torch.equal(split_data[1][\"red\"].y, torch.tensor([3, 4]))\n    assert torch.equal(split_data[1][\"red\"].z,\n                       torch.tensor([[3.1, 3.2], [4.1, 4.2]]))\n    assert torch.equal(split_data[1][\"blue\"].x, torch.empty((0, 1)))\n\n    assert torch.equal(split_data[2][\"red\"].x, torch.tensor([[5.0]]))\n    assert torch.equal(split_data[2][\"red\"].y, torch.tensor([5]))\n    assert torch.equal(split_data[2][\"red\"].z, torch.tensor([[5.1, 5.2]]))\n    assert torch.equal(split_data[2][\"blue\"].x, torch.empty((0, 1)))\n\n    assert torch.equal(split_data[3][\"red\"].x, torch.empty((0, 1)))\n    assert torch.equal(split_data[3][\"red\"].y,\n                       torch.empty((0, ), dtype=torch.int64))\n    assert torch.equal(split_data[3][\"red\"].z, torch.empty((0, 2)))\n    assert torch.equal(split_data[3][\"blue\"].x, torch.tensor([[7.0]]))\n\n    assert torch.equal(split_data[0][\"red\", \"to\", \"red\"].edge_index,\n                       torch.tensor([[0, 1], [1, 0]]))\n    assert torch.equal(split_data[0][\"red\", \"with\", \"red\"].edge_index,\n                       torch.tensor([[1], [1]]))\n    assert torch.equal(split_data[0][\"red\", \"to\", \"blue\"].edge_index,\n                       torch.tensor([[0], [0]]))\n\n    assert torch.equal(split_data[1][\"red\", \"to\", \"red\"].edge_index,\n                       torch.tensor([[0, 1], [1, 0]]))\n    assert torch.equal(split_data[1][\"red\", \"with\", \"red\"].edge_index,\n                       torch.empty((2, 0), dtype=torch.long))\n    assert torch.equal(split_data[1][\"red\", \"to\", \"blue\"].edge_index,\n                       torch.empty((2, 0), dtype=torch.long))\n\n    assert torch.equal(split_data[2][\"red\", \"to\", \"red\"].edge_index,\n                       torch.empty((2, 0), dtype=torch.long))\n    assert torch.equal(split_data[2][\"red\", \"with\", \"red\"].edge_index,\n                       torch.empty((2, 0), dtype=torch.long))\n    assert torch.equal(split_data[2][\"red\", \"to\", \"blue\"].edge_index,\n                       torch.empty((2, 0), dtype=torch.long))\n\n    assert torch.equal(split_data[3][\"red\", \"to\", \"red\"].edge_index,\n                       torch.empty((2, 0), dtype=torch.long))\n    assert torch.equal(split_data[3][\"red\", \"with\", \"red\"].edge_index,\n                       torch.empty((2, 0), dtype=torch.long))\n    assert torch.equal(split_data[3][\"red\", \"to\", \"blue\"].edge_index,\n                       torch.empty((2, 0), dtype=torch.long))\n\n\ndef test_hetero_data_connected_components_single_component():\n    data = HeteroData()\n    data[\"red\"].x = torch.tensor([[1.0], [2.0]])\n    data[\"red\"].y = torch.tensor([1, 2])\n    data[\"red\"].z = torch.tensor([[1.1, 1.2], [2.1, 2.2]])\n    data[\"blue\"].x = torch.tensor([[3.0]])\n    data[\"red\", \"to\", \"red\"].edge_index = torch.tensor([[0, 1], [1, 0]],\n                                                       dtype=torch.long)\n    data[\"red\", \"to\", \"blue\"].edge_index = torch.tensor([[0], [0]])\n    split_data = data.connected_components()\n\n    assert isinstance(split_data, list)\n    assert len(split_data) == 1\n\n\ndef test_hetero_data_find_parent():\n    # Case 1: Parent does not exist\n    data = HeteroData()\n    data._parents = {}\n    data._ranks = {}\n    node = ('paper', 1)\n    assert data._find_parent(node) == node\n    assert data._parents == {node: node}\n    assert data._ranks == {node: 0}\n\n    # Case 2: Parent exists\n    data._parents[node] = ('paper', 0)\n    assert data._find_parent(node) == ('paper', 0)\n\n\ndef test_hetero_data_union():\n    # Setup: two nodes in different sets\n    data = HeteroData()\n    data._parents = {}\n    data._ranks = {}\n    node1 = ('paper', 1)\n    node2 = ('paper', 2)\n\n    # Initially, both nodes are their own parents with rank 0\n    assert data._find_parent(node1) == node1\n    assert data._find_parent(node2) == node2\n    data._ranks[node1] = 0\n    data._ranks[node2] = 0\n\n    # Union them: node2 should now point to node1, and node1's rank increases\n    data._union(node1, node2)\n    assert data._find_parent(node1) == node1\n    assert data._find_parent(node2) == node1\n    assert data._ranks[node1] == 1\n\n    # Add a third node with higher rank and union with node1\n    node3 = ('paper', 3)\n    data._parents[node3] = node3\n    data._ranks[node3] = 2\n    data._union(node1, node3)\n    # node1's parent should now be node3, since node3 has higher rank\n    assert data._find_parent(node1) == node3\n    assert data._find_parent(node3) == node3\n\n    # Add a fourth node with lower rank and union with node3\n    node4 = ('paper', 4)\n    data._parents[node4] = node4\n    data._ranks[node4] = 0\n    data._union(node3, node4)\n    assert data._find_parent(node4) == node3\n    assert data._find_parent(node3) == node3\n\n    # Union of already connected nodes should not change anything\n    prev_ranks = data._ranks.copy()\n    prev_parents = data._parents.copy()\n    data._union(node1, node3)\n    assert data._ranks == prev_ranks\n    assert data._parents == prev_parents\n\n\n# Feature Store ###############################################################\n\n\ndef test_basic_feature_store():\n    data = HeteroData()\n    x = torch.randn(20, 20)\n\n    # Put tensor:\n    assert data.put_tensor(copy.deepcopy(x), group_name='paper', attr_name='x',\n                           index=None)\n    assert torch.equal(data['paper'].x, x)\n\n    # Put (modify) tensor slice:\n    x[15:] = 0\n    data.put_tensor(0, group_name='paper', attr_name='x',\n                    index=slice(15, None, None))\n\n    # Get tensor:\n    out = data.get_tensor(group_name='paper', attr_name='x', index=None)\n    assert torch.equal(x, out)\n\n    # Get tensor size:\n    assert data.get_tensor_size(group_name='paper', attr_name='x') == (20, 20)\n\n    # Get tensor attrs:\n    data['paper'].num_nodes = 20  # don't include, not a tensor attr\n    data['paper'].bad_attr = torch.randn(10, 20)  # don't include, bad cat_dim\n\n    tensor_attrs = data.get_all_tensor_attrs()\n    assert len(tensor_attrs) == 1\n    assert tensor_attrs[0].group_name == 'paper'\n    assert tensor_attrs[0].attr_name == 'x'\n\n    # Remove tensor:\n    assert 'x' in data['paper'].__dict__['_mapping']\n    data.remove_tensor(group_name='paper', attr_name='x', index=None)\n    assert 'x' not in data['paper'].__dict__['_mapping']\n\n\n@withPackage('torch_frame')\ndef test_hetero_data_with_tensor_frame():\n    data = HeteroData()\n    data['paper'].tf = get_random_tensor_frame(num_rows=x_paper.size(0))\n    data['author'].tf = get_random_tensor_frame(num_rows=x_author.size(0))\n    data['author', 'paper'].edge_index = edge_index_author_paper\n\n    # Basic functionality:\n    assert set(data.node_attrs()) == {'tf'}\n    assert data.num_nodes == x_paper.size(0) + x_author.size(0)\n    assert data.num_node_features['paper'] == 5\n    assert data.num_node_features['author'] == 5\n\n    # Test subgraph:\n    subset = {\n        'paper': torch.tensor([1, 2, 3, 4]),\n        'author': torch.tensor([0, 1, 2, 3]),\n    }\n    out = data.subgraph(subset)\n    assert set(out.node_attrs()) == {'tf'}\n    assert out.num_nodes == 8\n    for key, value in out['paper'].tf.feat_dict.items():\n        assert value.size(0) == 4\n        assert torch.allclose(value, data['paper'].tf.feat_dict[key][1:5])\n    for key, value in out['author'].tf.feat_dict.items():\n        assert value.size(0) == 4\n        assert torch.allclose(value, data['author'].tf.feat_dict[key][0:4])\n\n    # Test conversion to homogenous graphs and back:\n    for node_attrs in [None, ['tf']]:\n        out = data.to_homogeneous(node_attrs=node_attrs)\n        assert isinstance(out.tf, TensorFrame)\n        assert len(out.tf) == data.num_nodes\n        assert out.num_nodes == data.num_nodes\n        assert out.num_node_features == 5\n        for key, value in out.tf.feat_dict.items():\n            assert torch.allclose(\n                value,\n                torch.cat([\n                    data['paper'].tf.feat_dict[key],\n                    data['author'].tf.feat_dict[key],\n                ], dim=0),\n            )\n\n        out = out.to_heterogeneous()\n        for node_type in data.node_types:\n            for key, value in data[node_type].tf.feat_dict.items():\n                assert torch.allclose(value, out[node_type].tf.feat_dict[key])\n\n\n# Graph Store #################################################################\n\n\n@withPackage('torch_sparse')\ndef test_basic_graph_store():\n    data = HeteroData()\n\n    def assert_equal_tensor_tuple(expected, actual):\n        assert len(expected) == len(actual)\n        for i in range(len(expected)):\n            assert torch.equal(expected[i], actual[i])\n\n    # We put all three tensor types: COO, CSR, and CSC, and we get them back\n    # to confirm that `GraphStore` works as intended.\n    coo = (torch.tensor([0, 1]), torch.tensor([1, 2]))\n    csr = (torch.tensor([0, 1, 2, 2]), torch.tensor([1, 2]))\n    csc = (torch.tensor([0, 1]), torch.tensor([0, 0, 1, 2]))\n\n    # Put:\n    data.put_edge_index(coo, layout='coo', edge_type=('a', 'to', 'b'),\n                        size=(3, 3))\n    data.put_edge_index(csr, layout='csr', edge_type=('a', 'to', 'c'),\n                        size=(3, 3))\n    data.put_edge_index(csc, layout='csc', edge_type=('b', 'to', 'c'),\n                        size=(3, 3))\n\n    # Get:\n    assert_equal_tensor_tuple(\n        coo, data.get_edge_index(layout='coo', edge_type=('a', 'to', 'b')))\n    assert_equal_tensor_tuple(\n        csr, data.get_edge_index(layout='csr', edge_type=('a', 'to', 'c')))\n    assert_equal_tensor_tuple(\n        csc, data.get_edge_index(layout='csc', edge_type=('b', 'to', 'c')))\n\n    # Get attrs:\n    edge_attrs = data.get_all_edge_attrs()\n    assert len(edge_attrs) == 3\n\n    # Remove:\n    coo, csr, csc = edge_attrs\n    data.remove_edge_index(coo)\n    data.remove_edge_index(csr)\n    data.remove_edge_index(csc)\n    assert len(data.get_all_edge_attrs()) == 0\n\n\ndef test_generate_ids():\n    data = HeteroData()\n\n    data['paper'].x = torch.randn(100, 128)\n    data['author'].x = torch.randn(200, 128)\n\n    data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 300)\n    data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 400)\n    assert len(data) == 2\n\n    data.generate_ids()\n    assert len(data) == 4\n    assert data['paper'].n_id.tolist() == list(range(100))\n    assert data['author'].n_id.tolist() == list(range(200))\n    assert data['paper', 'author'].e_id.tolist() == list(range(300))\n    assert data['author', 'paper'].e_id.tolist() == list(range(400))\n\n\ndef test_invalid_keys():\n    data = HeteroData()\n\n    data['paper'].x = torch.randn(10, 128)\n    data['paper'].node_attrs = ['y']\n    data['paper', 'paper'].edge_index = get_random_edge_index(10, 10, 20)\n    data['paper', 'paper'].edge_attrs = ['edge_attr']\n\n    assert data['paper'].node_attrs() == ['x']\n    assert data['paper']['node_attrs'] == ['y']\n    assert data['paper', 'paper'].edge_attrs() == ['edge_index']\n    assert data['paper', 'paper']['edge_attrs'] == ['edge_attr']\n\n    out = data.to_homogeneous()\n    assert set(out.node_attrs()) == {'x', 'node_type'}\n    assert set(out.edge_attrs()) == {'edge_index', 'edge_type'}\n"
  },
  {
    "path": "test/data/test_hypergraph_data.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric\nfrom torch_geometric.data.hypergraph_data import HyperGraphData\nfrom torch_geometric.loader import DataLoader\n\n\ndef test_hypergraph_data():\n    torch_geometric.set_debug(True)\n\n    x = torch.tensor([[1, 3, 5, 7], [2, 4, 6, 8], [7, 8, 9, 10]],\n                     dtype=torch.float).t()\n    edge_index = torch.tensor([[0, 1, 2, 1, 2, 3, 0, 2, 3],\n                               [0, 0, 0, 1, 1, 1, 2, 2, 2]])\n    data = HyperGraphData(x=x, edge_index=edge_index).to(torch.device('cpu'))\n    data.validate(raise_on_error=True)\n\n    assert data.num_nodes == 4\n    assert data.num_edges == 3\n\n    assert data.node_attrs() == ['x']\n    assert data.edge_attrs() == ['edge_index']\n\n    assert data.x.tolist() == x.tolist()\n    assert data['x'].tolist() == x.tolist()\n    assert data.get('x').tolist() == x.tolist()\n    assert data.get('y', 2) == 2\n    assert data.get('y', None) is None\n\n    assert sorted(data.keys()) == ['edge_index', 'x']\n    assert len(data) == 2\n    assert 'x' in data and 'edge_index' in data and 'pos' not in data\n\n    D = data.to_dict()\n    assert len(D) == 2\n    assert 'x' in D and 'edge_index' in D\n\n    D = data.to_namedtuple()\n    assert len(D) == 2\n    assert D.x is not None and D.edge_index is not None\n\n    assert data.__cat_dim__('x', data.x) == 0\n    assert data.__cat_dim__('edge_index', data.edge_index) == -1\n    assert data.__inc__('x', data.x) == 0\n    assert torch.equal(data.__inc__('edge_index', data.edge_index),\n                       torch.tensor([[data.num_nodes], [data.num_edges]]))\n    data_list = [data, data]\n    loader = DataLoader(data_list, batch_size=2)\n    batch = next(iter(loader))\n    batched_edge_index = batch.edge_index\n    assert batched_edge_index.tolist() == [[\n        0, 1, 2, 1, 2, 3, 0, 2, 3, 4, 5, 6, 5, 6, 7, 4, 6, 7\n    ], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5]]\n\n    assert not data.x.is_contiguous()\n    data.contiguous()\n    assert data.x.is_contiguous()\n\n    assert not data.is_coalesced()\n    data = data.coalesce()\n    assert data.is_coalesced()\n\n    clone = data.clone()\n    assert clone != data\n    assert len(clone) == len(data)\n    assert clone.x.data_ptr() != data.x.data_ptr()\n    assert clone.x.tolist() == data.x.tolist()\n    assert clone.edge_index.data_ptr() != data.edge_index.data_ptr()\n    assert clone.edge_index.tolist() == data.edge_index.tolist()\n\n    data['x'] = x + 1\n    assert data.x.tolist() == (x + 1).tolist()\n\n    assert str(data) == 'HyperGraphData(x=[4, 3], edge_index=[2, 9])'\n\n    dictionary = {'x': data.x, 'edge_index': data.edge_index}\n    data = HyperGraphData.from_dict(dictionary)\n    assert sorted(data.keys()) == ['edge_index', 'x']\n\n    assert not data.has_isolated_nodes()\n    # assert not data.has_self_loops()\n    # assert data.is_undirected()\n    # assert not data.is_directed()\n\n    assert data.num_nodes == 4\n    assert data.num_edges == 3\n    with pytest.warns(UserWarning, match='deprecated'):\n        assert data.num_faces is None\n    assert data.num_node_features == 3\n    assert data.num_features == 3\n\n    data.edge_attr = torch.randn(data.num_edges, 2)\n    assert data.num_edge_features == 2\n    assert data.is_edge_attr('edge_attr')\n    data.edge_attr = None\n\n    data.x = None\n    with pytest.warns(UserWarning, match='Unable to accurately infer'):\n        assert data.num_nodes == 4\n\n    data.edge_index = None\n    with pytest.warns(UserWarning, match='Unable to accurately infer'):\n        assert data.num_nodes is None\n    assert data.num_edges == 0\n\n    data.num_nodes = 4\n    assert data.num_nodes == 4\n\n    data = HyperGraphData(x=x, attribute=x)\n    assert len(data) == 2\n    assert data.x.tolist() == x.tolist()\n    assert data.attribute.tolist() == x.tolist()\n\n    face = torch.tensor([[0, 1], [1, 2], [2, 3]])\n    data = HyperGraphData(num_nodes=4, face=face)\n    with pytest.warns(UserWarning, match='deprecated'):\n        assert data.num_faces == 2\n    assert data.num_nodes == 4\n\n    data = HyperGraphData(title='test')\n    assert str(data) == \"HyperGraphData(title='test')\"\n    assert data.num_node_features == 0\n    # assert data.num_edge_features == 0\n\n    key = value = 'test_value'\n    data[key] = value\n    assert data[key] == value\n    del data[value]\n    del data[value]  # Deleting unset attributes should work as well.\n\n    assert data.get(key) is None\n    assert data.get('title') == 'test'\n\n    torch_geometric.set_debug(False)\n\n\ndef test_hypergraphdata_subgraph():\n    x = torch.arange(5)\n    y = torch.tensor([0.])\n    edge_index = torch.tensor([[0, 1, 3, 2, 4, 0, 3, 4, 2, 1, 2, 3],\n                               [0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3]])\n    edge_attr = torch.rand(4, 2)\n    data = HyperGraphData(x=x, y=y, edge_index=edge_index, edge_attr=edge_attr,\n                          num_nodes=5)\n\n    out = data.subgraph(torch.tensor([1, 2, 4]))\n    assert len(out) == 5\n    assert torch.equal(out.x, torch.tensor([1, 2, 4]))\n    assert torch.equal(out.y, data.y)\n    assert out.edge_index.tolist() == [[1, 2, 2, 1, 0, 1], [0, 0, 1, 1, 2, 2]]\n    assert torch.equal(out.edge_attr, edge_attr[[1, 2, 3]])\n    assert out.num_nodes == 3\n\n    # Test unordered selection:\n    out = data.subgraph(torch.tensor([3, 1, 2]))\n    assert len(out) == 5\n    assert torch.equal(out.x, torch.tensor([3, 1, 2]))\n    assert torch.equal(out.y, data.y)\n    assert out.edge_index.tolist() == [[0, 2, 0, 2, 1, 2, 0],\n                                       [0, 0, 1, 1, 2, 2, 2]]\n    assert torch.equal(out.edge_attr, edge_attr[[1, 2, 3]])\n    assert out.num_nodes == 3\n\n    out = data.subgraph(torch.tensor([False, False, False, True, True]))\n    assert len(out) == 5\n    assert torch.equal(out.x, torch.arange(3, 5))\n    assert torch.equal(out.y, data.y)\n    assert out.edge_index.tolist() == [[0, 1, 0, 1], [0, 0, 1, 1]]\n    assert torch.equal(out.edge_attr, edge_attr[[1, 2]])\n    assert out.num_nodes == 2\n"
  },
  {
    "path": "test/data/test_inherit.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data, Dataset, InMemoryDataset\n\n\nclass MyData(Data):\n    def __init__(self, x=None, edge_index=None, arg=None):\n        super().__init__(x=x, edge_index=edge_index, arg=arg)\n\n    def random(self):\n        return torch.randn(list(self.x.size()) + list(self.arg.size()))\n\n\nclass MyInMemoryDataset(InMemoryDataset):\n    def __init__(self):\n        super().__init__('/tmp/MyInMemoryDataset')\n\n        x = torch.randn(4, 5)\n        edge_index = torch.tensor([[0, 0, 0], [1, 2, 3]])\n        arg = torch.randn(4, 3)\n\n        data_list = [MyData(x, edge_index, arg) for _ in range(10)]\n        self.data, self.slices = self.collate(data_list)\n\n    def _download(self):\n        pass\n\n    def _process(self):\n        pass\n\n\nclass MyDataset(Dataset):\n    def __init__(self):\n        super().__init__('/tmp/MyDataset')\n\n    def _download(self):\n        pass\n\n    def _process(self):\n        pass\n\n    def len(self):\n        return 10\n\n    def get(self, idx):\n        x = torch.randn(4, 5)\n        edge_index = torch.tensor([[0, 0, 0], [1, 2, 3]])\n        arg = torch.randn(4, 3)\n        return MyData(x, edge_index, arg)\n\n\ndef test_inherit():\n    dataset = MyDataset()\n    assert len(dataset) == 10\n    data = dataset[0]\n    assert data.random().size() == (4, 5, 4, 3)\n\n    dataset = MyInMemoryDataset()\n    assert len(dataset) == 10\n    data = dataset[0]\n    assert data.random().size() == (4, 5, 4, 3)\n"
  },
  {
    "path": "test/data/test_on_disk_dataset.py",
    "content": "import os.path as osp\nfrom typing import Any, Dict\n\nimport torch\n\nfrom torch_geometric.data import Data, OnDiskDataset\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('sqlite3')\ndef test_pickle(tmp_path):\n    dataset = OnDiskDataset(tmp_path)\n    assert len(dataset) == 0\n    assert str(dataset) == 'OnDiskDataset(0)'\n    assert osp.exists(osp.join(tmp_path, 'processed', 'sqlite.db'))\n\n    data_list = [\n        Data(\n            x=torch.randn(5, 8),\n            edge_index=torch.randint(0, 5, (2, 16)),\n            num_nodes=5,\n        ) for _ in range(4)\n    ]\n\n    dataset.append(data_list[0])\n    assert len(dataset) == 1\n\n    dataset.extend(data_list[1:])\n    assert len(dataset) == 4\n\n    out = dataset.get(0)\n    assert torch.equal(out.x, data_list[0].x)\n    assert torch.equal(out.edge_index, data_list[0].edge_index)\n    assert out.num_nodes == data_list[0].num_nodes\n\n    out_list = dataset.multi_get([1, 2, 3])\n    for out, data in zip(out_list, data_list[1:]):\n        assert torch.equal(out.x, data.x)\n        assert torch.equal(out.edge_index, data.edge_index)\n        assert out.num_nodes == data.num_nodes\n\n    dataset.close()\n\n    # Test persistence of datasets:\n    dataset = OnDiskDataset(tmp_path)\n    assert len(dataset) == 4\n\n    out = dataset.get(0)\n    assert torch.equal(out.x, data_list[0].x)\n    assert torch.equal(out.edge_index, data_list[0].edge_index)\n    assert out.num_nodes == data_list[0].num_nodes\n\n    dataset.close()\n\n\n@withPackage('sqlite3')\ndef test_custom_schema(tmp_path):\n    class CustomSchemaOnDiskDataset(OnDiskDataset):\n        def __init__(self, root: str):\n            schema = {\n                'x': dict(dtype=torch.float, size=(-1, 8)),\n                'edge_index': dict(dtype=torch.long, size=(2, -1)),\n                'num_nodes': int,\n            }\n            self.serialize_count = 0\n            self.deserialize_count = 0\n            super().__init__(root, schema=schema)\n\n        def serialize(self, data: Data) -> Dict[str, Any]:\n            self.serialize_count += 1\n            return data.to_dict()\n\n        def deserialize(self, mapping: Dict[str, Any]) -> Any:\n            self.deserialize_count += 1\n            return Data.from_dict(mapping)\n\n    dataset = CustomSchemaOnDiskDataset(tmp_path)\n    assert len(dataset) == 0\n    assert str(dataset) == 'CustomSchemaOnDiskDataset(0)'\n    assert osp.exists(osp.join(tmp_path, 'processed', 'sqlite.db'))\n\n    data_list = [\n        Data(\n            x=torch.randn(5, 8),\n            edge_index=torch.randint(0, 5, (2, 16)),\n            num_nodes=5,\n        ) for _ in range(4)\n    ]\n\n    dataset.append(data_list[0])\n    assert dataset.serialize_count == 1\n    assert len(dataset) == 1\n\n    dataset.extend(data_list[1:])\n    assert dataset.serialize_count == 4\n    assert len(dataset) == 4\n\n    out = dataset.get(0)\n    assert dataset.deserialize_count == 1\n    assert torch.equal(out.x, data_list[0].x)\n    assert torch.equal(out.edge_index, data_list[0].edge_index)\n    assert out.num_nodes == data_list[0].num_nodes\n\n    out_list = dataset.multi_get([1, 2, 3])\n    assert dataset.deserialize_count == 4\n    for out, data in zip(out_list, data_list[1:]):\n        assert torch.equal(out.x, data.x)\n        assert torch.equal(out.edge_index, data.edge_index)\n        assert out.num_nodes == data.num_nodes\n\n    dataset.close()\n"
  },
  {
    "path": "test/data/test_remote_backend_utils.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.data.remote_backend_utils import num_nodes, size\nfrom torch_geometric.testing import (\n    MyFeatureStore,\n    MyGraphStore,\n    get_random_edge_index,\n)\n\n\n@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroData])\n@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroData])\ndef test_num_nodes_size(FeatureStore, GraphStore):\n    feature_store = FeatureStore()\n    graph_store = GraphStore()\n\n    # Infer num nodes from features:\n    x = torch.arange(100)\n    feature_store.put_tensor(x, group_name='x', attr_name='x', index=None)\n    assert num_nodes(feature_store, graph_store, 'x') == 100\n\n    # Infer num nodes and size from edges:\n    xy = get_random_edge_index(100, 50, 20)\n    graph_store.put_edge_index(xy, edge_type=('x', 'to', 'y'), layout='coo',\n                               size=(100, 50))\n    assert num_nodes(feature_store, graph_store, 'y') == 50\n    assert size(feature_store, graph_store, ('x', 'to', 'y')) == (100, 50)\n\n    # Throw an error if we cannot infer for an unknown node type:\n    with pytest.raises(ValueError, match=\"Unable to accurately infer\"):\n        _ = num_nodes(feature_store, graph_store, 'z')\n"
  },
  {
    "path": "test/data/test_storage.py",
    "content": "import copy\nfrom typing import Any\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data.storage import BaseStorage\n\n\ndef test_base_storage():\n    storage = BaseStorage()\n    assert storage._mapping == {}\n    storage.x = torch.zeros(1)\n    storage.y = torch.ones(1)\n    assert len(storage) == 2\n    assert storage._mapping == {'x': torch.zeros(1), 'y': torch.ones(1)}\n    assert storage.x is not None\n    assert storage.y is not None\n\n    assert torch.allclose(storage.get('x', None), storage.x)\n    assert torch.allclose(storage.get('y', None), storage.y)\n    assert storage.get('z', 2) == 2\n    assert storage.get('z', None) is None\n    assert len(list(storage.keys('x', 'y', 'z'))) == 2\n    assert len(list(storage.keys('x', 'y', 'z'))) == 2\n    assert len(list(storage.values('x', 'y', 'z'))) == 2\n    assert len(list(storage.items('x', 'y', 'z'))) == 2\n\n    del storage.y\n    assert len(storage) == 1\n    assert storage.x is not None\n\n    storage = BaseStorage({'x': torch.zeros(1)})\n    assert len(storage) == 1\n    assert storage.x is not None\n\n    storage = BaseStorage(x=torch.zeros(1))\n    assert len(storage) == 1\n    assert storage.x is not None\n\n    storage = BaseStorage(x=torch.zeros(1))\n    copied_storage = copy.copy(storage)\n    assert storage == copied_storage\n    assert id(storage) != id(copied_storage)\n    assert storage.x.data_ptr() == copied_storage.x.data_ptr()\n    assert int(storage.x) == 0\n    assert int(copied_storage.x) == 0\n\n    deepcopied_storage = copy.deepcopy(storage)\n    assert storage == deepcopied_storage\n    assert id(storage) != id(deepcopied_storage)\n    assert storage.x.data_ptr() != deepcopied_storage.x.data_ptr()\n    assert int(storage.x) == 0\n    assert int(deepcopied_storage.x) == 0\n\n    with pytest.raises(AttributeError, match=\"has no attribute 'asdf'\"):\n        storage.asdf  # noqa: B018\n\n\ndef test_storage_tensor_methods():\n    x = torch.randn(5)\n    storage = BaseStorage({'x': x})\n\n    storage = storage.clone()\n    assert storage.x.data_ptr() != x.data_ptr()\n\n    storage = storage.contiguous()\n    assert storage.x.is_contiguous()\n\n    storage = storage.to('cpu')\n    assert storage.x.device == torch.device('cpu')\n\n    storage = storage.cpu()\n    assert storage.x.device == torch.device('cpu')\n\n    if torch.cuda.is_available():\n        storage = storage.pin_memory()\n        assert storage.x.is_pinned()\n\n    storage = storage.share_memory_()\n    assert storage.x.is_shared\n\n    storage = storage.detach_()\n    assert not storage.x.requires_grad\n\n    storage = storage.detach()\n    assert not storage.x.requires_grad\n\n    storage = storage.requires_grad_()\n    assert storage.x.requires_grad\n\n\ndef test_setter_and_getter():\n    class MyStorage(BaseStorage):\n        @property\n        def my_property(self) -> Any:\n            return self._my_property\n\n        @my_property.setter\n        def my_property(self, value: Any):\n            self._my_property = value\n\n    storage = MyStorage()\n    storage.my_property = 'hello'\n    assert storage.my_property == 'hello'\n    assert storage._my_property == storage._my_property\n"
  },
  {
    "path": "test/data/test_temporal.py",
    "content": "import copy\n\nimport torch\n\nfrom torch_geometric.data import TemporalData\n\n\ndef get_temporal_data(num_events, msg_channels):\n    return TemporalData(\n        src=torch.arange(num_events),\n        dst=torch.arange(num_events, num_events * 2),\n        t=torch.arange(0, num_events * 1000, step=1000),\n        msg=torch.randn(num_events, msg_channels),\n        y=torch.randint(0, 2, (num_events, )),\n    )\n\n\ndef test_temporal_data():\n    data = get_temporal_data(num_events=3, msg_channels=16)\n    assert str(data) == (\"TemporalData(src=[3], dst=[3], t=[3], \"\n                         \"msg=[3, 16], y=[3])\")\n\n    assert data.num_nodes == 6\n    assert data.num_events == data.num_edges == len(data) == 3\n\n    assert data.src.tolist() == [0, 1, 2]\n    assert data['src'].tolist() == [0, 1, 2]\n\n    assert data.edge_index.tolist() == [[0, 1, 2], [3, 4, 5]]\n    data.edge_index = 'edge_index'\n    assert data.edge_index == 'edge_index'\n    del data.edge_index\n    assert data.edge_index.tolist() == [[0, 1, 2], [3, 4, 5]]\n\n    assert sorted(data.keys()) == ['dst', 'msg', 'src', 't', 'y']\n    assert sorted(data.to_dict().keys()) == sorted(data.keys())\n\n    data_tuple = data.to_namedtuple()\n    assert len(data_tuple) == 5\n    assert data_tuple.src is not None\n    assert data_tuple.dst is not None\n    assert data_tuple.t is not None\n    assert data_tuple.msg is not None\n    assert data_tuple.y is not None\n\n    assert data.__cat_dim__('src', data.src) == 0\n    assert data.__inc__('src', data.src) == 6\n\n    clone = data.clone()\n    assert clone != data\n    assert len(clone) == len(data)\n    assert clone.src.data_ptr() != data.src.data_ptr()\n    assert clone.src.tolist() == data.src.tolist()\n    assert clone.dst.data_ptr() != data.dst.data_ptr()\n    assert clone.dst.tolist() == data.dst.tolist()\n\n    deepcopy = copy.deepcopy(data)\n    assert deepcopy != data\n    assert len(deepcopy) == len(data)\n    assert deepcopy.src.data_ptr() != data.src.data_ptr()\n    assert deepcopy.src.tolist() == data.src.tolist()\n    assert deepcopy.dst.data_ptr() != data.dst.data_ptr()\n    assert deepcopy.dst.tolist() == data.dst.tolist()\n\n    key = value = 'test_value'\n    data[key] = value\n    assert data[key] == value\n    assert data.test_value == value\n    del data[key]\n    del data[key]  # Deleting unset attributes should work as well.\n\n    assert data.get(key, 10) == 10\n\n    assert len([event for event in data]) == 3\n\n    assert len([attr for attr in data()]) == 5\n\n    assert data.size() == (2, 5)\n\n    del data.src\n    assert 'src' not in data\n\n\ndef test_train_val_test_split():\n    data = get_temporal_data(num_events=100, msg_channels=16)\n\n    train_data, val_data, test_data = data.train_val_test_split(\n        val_ratio=0.2, test_ratio=0.15)\n\n    assert len(train_data) == 65\n    assert len(val_data) == 20\n    assert len(test_data) == 15\n\n    assert train_data.t.max() < val_data.t.min()\n    assert val_data.t.max() < test_data.t.min()\n\n\ndef test_temporal_indexing():\n    data = get_temporal_data(num_events=10, msg_channels=16)\n\n    elem = data[0]\n    assert isinstance(elem, TemporalData)\n    assert len(elem) == 1\n    assert elem.src.tolist() == data.src[0:1].tolist()\n    assert elem.dst.tolist() == data.dst[0:1].tolist()\n    assert elem.t.tolist() == data.t[0:1].tolist()\n    assert elem.msg.tolist() == data.msg[0:1].tolist()\n    assert elem.y.tolist() == data.y[0:1].tolist()\n\n    subset = data[0:5]\n    assert isinstance(subset, TemporalData)\n    assert len(subset) == 5\n    assert subset.src.tolist() == data.src[0:5].tolist()\n    assert subset.dst.tolist() == data.dst[0:5].tolist()\n    assert subset.t.tolist() == data.t[0:5].tolist()\n    assert subset.msg.tolist() == data.msg[0:5].tolist()\n    assert subset.y.tolist() == data.y[0:5].tolist()\n\n    index = [0, 4, 8]\n    subset = data[torch.tensor(index)]\n    assert isinstance(subset, TemporalData)\n    assert len(subset) == 3\n    assert subset.src.tolist() == data.src[0::4].tolist()\n    assert subset.dst.tolist() == data.dst[0::4].tolist()\n    assert subset.t.tolist() == data.t[0::4].tolist()\n    assert subset.msg.tolist() == data.msg[0::4].tolist()\n    assert subset.y.tolist() == data.y[0::4].tolist()\n\n    mask = [True, False, True, False, True, False, True, False, True, False]\n    subset = data[torch.tensor(mask)]\n    assert isinstance(subset, TemporalData)\n    assert len(subset) == 5\n    assert subset.src.tolist() == data.src[0::2].tolist()\n    assert subset.dst.tolist() == data.dst[0::2].tolist()\n    assert subset.t.tolist() == data.t[0::2].tolist()\n    assert subset.msg.tolist() == data.msg[0::2].tolist()\n    assert subset.y.tolist() == data.y[0::2].tolist()\n"
  },
  {
    "path": "test/data/test_view.py",
    "content": "from torch_geometric.data.storage import BaseStorage\n\n\ndef test_views():\n    storage = BaseStorage(x=1, y=2, z=3)\n\n    assert str(storage.keys()) == \"KeysView({'x': 1, 'y': 2, 'z': 3})\"\n    assert len(storage.keys()) == 3\n    assert list(storage.keys()) == ['x', 'y', 'z']\n\n    assert str(storage.values()) == \"ValuesView({'x': 1, 'y': 2, 'z': 3})\"\n    assert len(storage.values()) == 3\n    assert list(storage.values()) == [1, 2, 3]\n\n    assert str(storage.items()) == \"ItemsView({'x': 1, 'y': 2, 'z': 3})\"\n    assert len(storage.items()) == 3\n    assert list(storage.items()) == [('x', 1), ('y', 2), ('z', 3)]\n\n    args = ['x', 'z', 'foo']\n\n    assert str(storage.keys(*args)) == \"KeysView({'x': 1, 'z': 3})\"\n    assert len(storage.keys(*args)) == 2\n    assert list(storage.keys(*args)) == ['x', 'z']\n\n    assert str(storage.values(*args)) == \"ValuesView({'x': 1, 'z': 3})\"\n    assert len(storage.values(*args)) == 2\n    assert list(storage.values(*args)) == [1, 3]\n\n    assert str(storage.items(*args)) == \"ItemsView({'x': 1, 'z': 3})\"\n    assert len(storage.items(*args)) == 2\n    assert list(storage.items(*args)) == [('x', 1), ('z', 3)]\n"
  },
  {
    "path": "test/datasets/graph_generator/test_ba_graph.py",
    "content": "from torch_geometric.datasets.graph_generator import BAGraph\n\n\ndef test_ba_graph():\n    graph_generator = BAGraph(num_nodes=300, num_edges=5)\n    assert str(graph_generator) == 'BAGraph(num_nodes=300, num_edges=5)'\n\n    data = graph_generator()\n    assert len(data) == 2\n    assert data.num_nodes == 300\n    assert data.num_edges <= 2 * 300 * 5\n"
  },
  {
    "path": "test/datasets/graph_generator/test_er_graph.py",
    "content": "from torch_geometric.datasets.graph_generator import ERGraph\n\n\ndef test_er_graph():\n    graph_generator = ERGraph(num_nodes=300, edge_prob=0.1)\n    assert str(graph_generator) == 'ERGraph(num_nodes=300, edge_prob=0.1)'\n\n    data = graph_generator()\n    assert len(data) == 2\n    assert data.num_nodes == 300\n    assert data.num_edges >= 300 * 300 * 0.05\n    assert data.num_edges <= 300 * 300 * 0.15\n"
  },
  {
    "path": "test/datasets/graph_generator/test_grid_graph.py",
    "content": "from torch_geometric.datasets.graph_generator import GridGraph\n\n\ndef test_grid_graph():\n    graph_generator = GridGraph(height=10, width=10)\n    assert str(graph_generator) == 'GridGraph(height=10, width=10)'\n\n    data = graph_generator()\n    assert len(data) == 2\n    assert data.num_nodes == 100\n    assert data.num_edges == 784\n"
  },
  {
    "path": "test/datasets/graph_generator/test_tree_graph.py",
    "content": "import pytest\n\nfrom torch_geometric.datasets.graph_generator import TreeGraph\n\n\n@pytest.mark.parametrize('undirected', [False, True])\ndef test_tree_graph(undirected):\n    graph_generator = TreeGraph(depth=2, branch=2, undirected=undirected)\n    assert str(graph_generator) == (f'TreeGraph(depth=2, branch=2, '\n                                    f'undirected={undirected})')\n\n    data = graph_generator()\n    assert len(data) == 3\n    assert data.num_nodes == 7\n    assert data.depth.tolist() == [0, 1, 1, 2, 2, 2, 2]\n    if not undirected:\n        assert data.edge_index.tolist() == [\n            [0, 0, 1, 1, 2, 2],\n            [1, 2, 3, 4, 5, 6],\n        ]\n    else:\n        assert data.edge_index.tolist() == [\n            [0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6],\n            [1, 2, 0, 3, 4, 0, 5, 6, 1, 1, 2, 2],\n        ]\n"
  },
  {
    "path": "test/datasets/motif_generator/test_custom_motif.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.motif_generator import CustomMotif\nfrom torch_geometric.testing import withPackage\n\n\ndef test_custom_motif_pyg_data():\n    structure = Data(\n        num_nodes=3,\n        edge_index=torch.tensor([[0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]]),\n    )\n\n    motif_generator = CustomMotif(structure)\n    assert str(motif_generator) == 'CustomMotif()'\n\n    assert structure == motif_generator()\n\n\n@withPackage('networkx')\ndef test_custom_motif_networkx():\n    import networkx as nx\n\n    structure = nx.gnm_random_graph(5, 10, seed=2000)\n\n    motif_generator = CustomMotif(structure)\n    assert str(motif_generator) == 'CustomMotif()'\n\n    out = motif_generator()\n    assert len(out) == 2\n    assert out.num_nodes == 5\n    assert out.num_edges == 20\n\n\ndef test_custom_motif_unknown():\n    with pytest.raises(ValueError, match=\"motif structure of type\"):\n        CustomMotif(structure='unknown')\n"
  },
  {
    "path": "test/datasets/motif_generator/test_cycle_motif.py",
    "content": "from torch_geometric.datasets.motif_generator import CycleMotif\n\n\ndef test_cycle_motif():\n    motif_generator = CycleMotif(5)\n    assert str(motif_generator) == 'CycleMotif(5)'\n\n    motif = motif_generator()\n    assert len(motif) == 2\n    assert motif.num_nodes == 5\n    assert motif.num_edges == 10\n    assert motif.edge_index.tolist() == [\n        [0, 0, 1, 1, 2, 2, 3, 3, 4, 4],\n        [1, 4, 0, 2, 1, 3, 2, 4, 0, 3],\n    ]\n"
  },
  {
    "path": "test/datasets/motif_generator/test_grid_motif.py",
    "content": "from torch_geometric.datasets.motif_generator import GridMotif\n\n\ndef test_grid_motif():\n    motif_generator = GridMotif()\n    assert str(motif_generator) == 'GridMotif()'\n\n    motif = motif_generator()\n    assert len(motif) == 3\n    assert motif.num_nodes == 9\n    assert motif.num_edges == 24\n    assert motif.edge_index.size() == (2, 24)\n    assert motif.edge_index.min() == 0\n    assert motif.edge_index.max() == 8\n    assert motif.y.size() == (9, )\n    assert motif.y.min() == 0\n    assert motif.y.max() == 2\n"
  },
  {
    "path": "test/datasets/motif_generator/test_house_motif.py",
    "content": "from torch_geometric.datasets.motif_generator import HouseMotif\n\n\ndef test_house_motif():\n    motif_generator = HouseMotif()\n    assert str(motif_generator) == 'HouseMotif()'\n\n    motif = motif_generator()\n    assert len(motif) == 3\n    assert motif.num_nodes == 5\n    assert motif.num_edges == 12\n    assert motif.y.min() == 0 and motif.y.max() == 2\n"
  },
  {
    "path": "test/datasets/test_ba_shapes.py",
    "content": "import pytest\n\n\ndef test_ba_shapes(get_dataset):\n    with pytest.warns(UserWarning, match=\"is deprecated\"):\n        dataset = get_dataset(name='BAShapes')\n\n    assert str(dataset) == 'BAShapes()'\n    assert len(dataset) == 1\n    assert dataset.num_features == 10\n    assert dataset.num_classes == 4\n\n    data = dataset[0]\n    assert len(data) == 5\n    assert data.edge_index.size(1) >= 1120\n    assert data.x.size() == (700, 10)\n    assert data.y.size() == (700, )\n    assert data.expl_mask.sum() == 60\n    assert data.edge_label.sum() == 960\n"
  },
  {
    "path": "test/datasets/test_bzr.py",
    "content": "from torch_geometric.testing import onlyFullTest, onlyOnline\n\n\n@onlyOnline\n@onlyFullTest\ndef test_bzr(get_dataset):\n    dataset = get_dataset(name='BZR')\n    assert len(dataset) == 405\n    assert dataset.num_features == 53\n    assert dataset.num_node_labels == 53\n    assert dataset.num_node_attributes == 3\n    assert dataset.num_classes == 2\n    assert str(dataset) == 'BZR(405)'\n    assert len(dataset[0]) == 3\n\n\n@onlyOnline\n@onlyFullTest\ndef test_bzr_with_node_attr(get_dataset):\n    dataset = get_dataset(name='BZR', use_node_attr=True)\n    assert dataset.num_features == 56\n    assert dataset.num_node_labels == 53\n    assert dataset.num_node_attributes == 3\n"
  },
  {
    "path": "test/datasets/test_elliptic.py",
    "content": "from torch_geometric.testing import onlyFullTest, onlyOnline\n\n\n@onlyOnline\n@onlyFullTest\ndef test_elliptic_bitcoin_dataset(get_dataset):\n    dataset = get_dataset(name='EllipticBitcoinDataset')\n    assert str(dataset) == 'EllipticBitcoinDataset()'\n    assert len(dataset) == 1\n    assert dataset.num_features == 165\n    assert dataset.num_classes == 2\n\n    data = dataset[0]\n    assert len(data) == 5\n    assert data.x.size() == (203769, 165)\n    assert data.edge_index.size() == (2, 234355)\n    assert data.y.size() == (203769, )\n\n    assert data.train_mask.size() == (203769, )\n    assert data.train_mask.sum() > 0\n    assert data.test_mask.size() == (203769, )\n    assert data.test_mask.sum() > 0\n    assert data.train_mask.sum() + data.test_mask.sum() == 4545 + 42019\n    assert data.y[data.train_mask].sum() == 3462\n    assert data.y[data.test_mask].sum() == 1083\n    assert data.y[data.train_mask].sum() + data.y[data.test_mask].sum() == 4545\n    assert data.y[data.test_mask | data.train_mask].min() == 0\n    assert data.y[data.test_mask | data.train_mask].max() == 1\n"
  },
  {
    "path": "test/datasets/test_enzymes.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.loader import DataListLoader, DataLoader, DenseDataLoader\nfrom torch_geometric.testing import onlyOnline\nfrom torch_geometric.transforms import ToDense\n\n\n@onlyOnline\ndef test_enzymes(get_dataset):\n    dataset = get_dataset(name='ENZYMES')\n    assert len(dataset) == 600\n    assert dataset.num_features == 3\n    assert dataset.num_classes == 6\n    assert str(dataset) == 'ENZYMES(600)'\n\n    assert len(dataset[0]) == 3\n    assert len(dataset.shuffle()) == 600\n    assert len(dataset.shuffle(return_perm=True)) == 2\n    assert len(dataset[:100]) == 100\n    assert len(dataset[0.1:0.2]) == 60\n    assert len(dataset[torch.arange(100, dtype=torch.long)]) == 100\n    mask = torch.zeros(600, dtype=torch.bool)\n    mask[:100] = 1\n    assert len(dataset[mask]) == 100\n\n    loader = DataLoader(dataset, batch_size=len(dataset))\n    for batch in loader:\n        assert batch.num_graphs == len(batch) == 600\n\n        avg_num_nodes = batch.num_nodes / batch.num_graphs\n        assert pytest.approx(avg_num_nodes, abs=1e-2) == 32.63\n\n        avg_num_edges = batch.num_edges / (2 * batch.num_graphs)\n        assert pytest.approx(avg_num_edges, abs=1e-2) == 62.14\n\n        assert list(batch.x.size()) == [batch.num_nodes, 3]\n        assert list(batch.y.size()) == [batch.num_graphs]\n        assert batch.y.max() + 1 == 6\n        assert list(batch.batch.size()) == [batch.num_nodes]\n        assert batch.ptr.numel() == batch.num_graphs + 1\n\n        assert batch.has_isolated_nodes()\n        assert not batch.has_self_loops()\n        assert batch.is_undirected()\n\n    loader = DataListLoader(dataset, batch_size=len(dataset))\n    for data_list in loader:\n        assert len(data_list) == 600\n\n    dataset.transform = ToDense(num_nodes=126)\n    loader = DenseDataLoader(dataset, batch_size=len(dataset))\n    for data in loader:\n        assert list(data.x.size()) == [600, 126, 3]\n        assert list(data.adj.size()) == [600, 126, 126]\n        assert list(data.mask.size()) == [600, 126]\n        assert list(data.y.size()) == [600, 1]\n\n\n@onlyOnline\ndef test_enzymes_with_node_attr(get_dataset):\n    dataset = get_dataset(name='ENZYMES', use_node_attr=True)\n    assert dataset.num_node_features == 21\n    assert dataset.num_features == 21\n    assert dataset.num_edge_features == 0\n\n\n@onlyOnline\ndef test_cleaned_enzymes(get_dataset):\n    dataset = get_dataset(name='ENZYMES', cleaned=True)\n    assert len(dataset) == 595\n"
  },
  {
    "path": "test/datasets/test_explainer_dataset.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.datasets import ExplainerDataset\nfrom torch_geometric.datasets.graph_generator import BAGraph\nfrom torch_geometric.datasets.motif_generator import HouseMotif\n\n\n@pytest.mark.parametrize('graph_generator', [\n    pytest.param(BAGraph(num_nodes=80, num_edges=5), id='BAGraph'),\n])\n@pytest.mark.parametrize('motif_generator', [\n    pytest.param(HouseMotif(), id='HouseMotif'),\n    'house',\n])\ndef test_explainer_dataset_ba_house(graph_generator, motif_generator):\n    dataset = ExplainerDataset(graph_generator, motif_generator, num_motifs=2)\n    assert str(dataset) == ('ExplainerDataset(1, graph_generator='\n                            'BAGraph(num_nodes=80, num_edges=5), '\n                            'motif_generator=HouseMotif(), num_motifs=2)')\n    assert len(dataset) == 1\n\n    data = dataset[0]\n    assert len(data) == 4\n    assert data.num_nodes == 80 + (2 * 5)\n    assert data.edge_index.min() >= 0\n    assert data.edge_index.max() < data.num_nodes\n    assert data.y.min() == 0 and data.y.max() == 3\n    assert data.node_mask.size() == (data.num_nodes, )\n    assert data.edge_mask.size() == (data.num_edges, )\n    assert data.node_mask.min() == 0 and data.node_mask.max() == 1\n    assert data.node_mask.sum() == 2 * 5\n    assert data.edge_mask.min() == 0 and data.edge_mask.max() == 1\n    assert data.edge_mask.sum() == 2 * 12\n\n\ndef test_explainer_dataset_reproducibility():\n    seed_everything(12345)\n    data1 = ExplainerDataset(BAGraph(num_nodes=80, num_edges=5), HouseMotif(),\n                             num_motifs=2)[0]\n\n    seed_everything(12345)\n    data2 = ExplainerDataset(BAGraph(num_nodes=80, num_edges=5), HouseMotif(),\n                             num_motifs=2)[0]\n\n    assert torch.equal(data1.edge_index, data2.edge_index)\n"
  },
  {
    "path": "test/datasets/test_fake.py",
    "content": "import pytest\n\nfrom torch_geometric.datasets import FakeDataset, FakeHeteroDataset\n\n\n@pytest.mark.parametrize('num_graphs', [1, 10])\n@pytest.mark.parametrize('edge_dim', [0, 1, 4])\n@pytest.mark.parametrize('task', ['node', 'graph', 'auto'])\ndef test_fake_dataset(num_graphs, edge_dim, task):\n    dataset = FakeDataset(num_graphs, edge_dim=edge_dim, task=task,\n                          global_features=3)\n\n    if num_graphs > 1:\n        assert str(dataset) == f'FakeDataset({num_graphs})'\n    else:\n        assert str(dataset) == 'FakeDataset()'\n\n    assert len(dataset) == num_graphs\n\n    data = dataset[0]\n\n    assert data.num_features == 64\n\n    if edge_dim == 0:\n        assert len(data) == 4\n    elif edge_dim == 1:\n        assert len(data) == 5\n        assert data.edge_weight.size() == (data.num_edges, )\n        assert data.edge_weight.min() >= 0 and data.edge_weight.max() < 1\n    else:\n        assert len(data) == 5\n        assert data.edge_attr.size() == (data.num_edges, edge_dim)\n        assert data.edge_attr.min() >= 0 and data.edge_attr.max() < 1\n\n    assert data.y.min() >= 0 and data.y.max() < 10\n    if task == 'node' or (task == 'auto' and num_graphs == 1):\n        assert data.y.size() == (data.num_nodes, )\n    else:\n        assert data.y.size() == (1, )\n\n    assert data.global_features.size() == (3, )\n\n\n@pytest.mark.parametrize('num_graphs', [1, 10])\n@pytest.mark.parametrize('edge_dim', [0, 1, 4])\n@pytest.mark.parametrize('task', ['node', 'graph', 'auto'])\ndef test_fake_hetero_dataset(num_graphs, edge_dim, task):\n    dataset = FakeHeteroDataset(num_graphs, edge_dim=edge_dim, task=task,\n                                global_features=3)\n\n    if num_graphs > 1:\n        assert str(dataset) == f'FakeHeteroDataset({num_graphs})'\n    else:\n        assert str(dataset) == 'FakeHeteroDataset()'\n\n    assert len(dataset) == num_graphs\n\n    data = dataset[0]\n\n    for store in data.node_stores:\n        assert store.num_features > 0\n\n        if task == 'node' or (task == 'auto' and num_graphs == 1):\n            if store._key == 'v0':\n                assert store.y.min() >= 0 and store.y.max() < 10\n                assert store.y.size() == (store.num_nodes, )\n\n    for store in data.edge_stores:\n        if edge_dim == 0:\n            assert len(data) == 4\n        elif edge_dim == 1:\n            assert len(data) == 5\n            assert store.edge_weight.size() == (store.num_edges, )\n            assert store.edge_weight.min() >= 0 and store.edge_weight.max() < 1\n        else:\n            assert len(data) == 5\n            assert store.edge_attr.size() == (store.num_edges, edge_dim)\n            assert store.edge_attr.min() >= 0 and store.edge_attr.max() < 1\n\n    if task == 'graph' or (task == 'auto' and num_graphs > 1):\n        assert data.y.min() >= 0 and data.y.max() < 10\n        assert data.y.size() == (1, )\n\n    assert data.global_features.size() == (3, )\n"
  },
  {
    "path": "test/datasets/test_git_mol_dataset.py",
    "content": "from typing import Tuple\n\nimport pytest\n\nfrom torch_geometric.datasets import GitMolDataset\nfrom torch_geometric.testing import onlyFullTest, withPackage\n\n\n@onlyFullTest\n@withPackage('torchvision', 'rdkit', 'PIL')\n@pytest.mark.parametrize('split', [\n    (0, 3610),\n    (1, 451),\n    (2, 451),\n])\ndef test_git_mol_dataset(split: Tuple[int, int]) -> None:\n    dataset = GitMolDataset(root='./data/GITMol', split=split[0])\n\n    assert len(dataset) == split[1]\n    assert dataset[0].image.size() == (1, 3, 224, 224)\n    assert dataset[0].num_node_features == 9\n    assert dataset[0].num_edge_features == 3\n"
  },
  {
    "path": "test/datasets/test_imdb_binary.py",
    "content": "from torch_geometric.testing import onlyFullTest, onlyOnline\n\n\n@onlyOnline\n@onlyFullTest\ndef test_imdb_binary(get_dataset):\n    dataset = get_dataset(name='IMDB-BINARY')\n    assert len(dataset) == 1000\n    assert dataset.num_features == 0\n    assert dataset.num_classes == 2\n    assert str(dataset) == 'IMDB-BINARY(1000)'\n\n    data = dataset[0]\n    assert len(data) == 3\n    assert data.edge_index.size() == (2, 146)\n    assert data.y.size() == (1, )\n    assert data.num_nodes == 20\n"
  },
  {
    "path": "test/datasets/test_infection_dataset.py",
    "content": "import torch\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets import InfectionDataset\nfrom torch_geometric.datasets.graph_generator import ERGraph, GraphGenerator\n\n\nclass DummyGraph(GraphGenerator):\n    def __call__(self) -> Data:\n        edge_index = torch.tensor([\n            [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9],\n            [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8],\n        ])\n        return Data(num_nodes=10, edge_index=edge_index)\n\n\ndef test_infection_dataset():\n    seed_everything(12345)\n    graph_generator = DummyGraph()\n    dataset = InfectionDataset(graph_generator, num_infected_nodes=2,\n                               max_path_length=2)\n    assert str(dataset) == ('InfectionDataset(1, '\n                            'graph_generator=DummyGraph(), '\n                            'num_infected_nodes=2, '\n                            'max_path_length=2)')\n    assert len(dataset) == 1\n\n    data = dataset[0]\n    assert len(data) == 4\n    assert data.x.size() == (10, 2)\n    assert data.x[:, 0].sum() == 8 and data.x[:, 1].sum() == 2\n    assert torch.equal(data.edge_index, graph_generator().edge_index)\n    assert data.y.size() == (10, )\n\n    # With `seed=12345`, node 0 and node 7 will be infected:\n    assert data.x[0].tolist() == [0, 1]  # First infected node.\n    assert data.x[7].tolist() == [0, 1]  # Second infected node.\n    assert data.y.tolist() == [0, 1, 2, 3, 3, 2, 1, 0, 1, 2]\n    assert data.edge_mask.tolist() == [\n        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0\n    ]\n\n\ndef test_infection_dataset_reproducibility():\n    graph_generator = ERGraph(num_nodes=500, edge_prob=0.004)\n\n    seed_everything(12345)\n    dataset1 = InfectionDataset(graph_generator, num_infected_nodes=50,\n                                max_path_length=5)\n\n    seed_everything(12345)\n    dataset2 = InfectionDataset(graph_generator, num_infected_nodes=50,\n                                max_path_length=5)\n\n    assert torch.equal(dataset1[0].edge_mask, dataset2[0].edge_mask)\n"
  },
  {
    "path": "test/datasets/test_karate.py",
    "content": "def test_karate(get_dataset):\n    dataset = get_dataset(name='KarateClub')\n    assert str(dataset) == 'KarateClub()'\n    assert len(dataset) == 1\n    assert dataset.num_features == 34\n    assert dataset.num_classes == 4\n\n    assert len(dataset[0]) == 4\n    assert dataset[0].edge_index.size() == (2, 156)\n    assert dataset[0].x.size() == (34, 34)\n    assert dataset[0].y.size() == (34, )\n    assert dataset[0].train_mask.size() == (34, )\n    assert dataset[0].train_mask.sum().item() == 4\n"
  },
  {
    "path": "test/datasets/test_medshapenet.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets import MedShapeNet\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('MedShapeNet')\ndef test_medshapenet():\n    dataset = MedShapeNet(root=\"./data/MedShapeNet\", size=1)\n\n    assert str(dataset) == f'MedShapeNet({len(dataset)})'\n\n    assert isinstance(dataset[0], Data)\n    assert dataset.num_classes == 8\n\n    assert isinstance(dataset[0].pos, torch.Tensor)\n    assert len(dataset[0].pos) > 0\n\n    assert isinstance(dataset[0].face, torch.Tensor)\n    assert len(dataset[0].face) == 3\n\n    assert isinstance(dataset[0].y, torch.Tensor)\n    assert len(dataset[0].y) == 1\n"
  },
  {
    "path": "test/datasets/test_molecule_gpt_dataset.py",
    "content": "from torch_geometric.datasets import MoleculeGPTDataset\nfrom torch_geometric.testing import onlyOnline, withPackage\n\n\n@onlyOnline\n@withPackage('transformers', 'sentencepiece', 'accelerate', 'rdkit')\ndef test_molecule_gpt_dataset():\n    dataset = MoleculeGPTDataset(\n        root='./data/MoleculeGPT',\n        num_units=10,\n    )\n    assert str(dataset) == f'MoleculeGPTDataset({len(dataset)})'\n    assert dataset.num_edge_features == 4\n    assert dataset.num_node_features == 6\n"
  },
  {
    "path": "test/datasets/test_mutag.py",
    "content": "from torch_geometric.testing import onlyOnline\n\n\n@onlyOnline\ndef test_mutag(get_dataset):\n    dataset = get_dataset(name='MUTAG')\n    assert len(dataset) == 188\n    assert dataset.num_features == 7\n    assert dataset.num_classes == 2\n    assert str(dataset) == 'MUTAG(188)'\n\n    assert len(dataset[0]) == 4\n    assert dataset[0].edge_attr.size(1) == 4\n\n\n@onlyOnline\ndef test_mutag_with_node_attr(get_dataset):\n    dataset = get_dataset(name='MUTAG', use_node_attr=True)\n    assert dataset.num_features == 7\n"
  },
  {
    "path": "test/datasets/test_planetoid.py",
    "content": "from torch_geometric.loader import DataLoader\nfrom torch_geometric.testing import onlyOnline, withPackage\n\n\n@onlyOnline\n@withPackage('scipy')\ndef test_citeseer(get_dataset):\n    dataset = get_dataset(name='CiteSeer')\n    loader = DataLoader(dataset, batch_size=len(dataset))\n\n    assert len(dataset) == 1\n    assert str(dataset) == 'CiteSeer()'\n\n    for batch in loader:\n        assert batch.num_graphs == len(batch) == 1\n        assert batch.num_nodes == 3327\n        assert batch.num_edges / 2 == 4552\n\n        assert list(batch.x.size()) == [batch.num_nodes, 3703]\n        assert list(batch.y.size()) == [batch.num_nodes]\n        assert batch.y.max() + 1 == 6\n        assert batch.train_mask.sum() == 6 * 20\n        assert batch.val_mask.sum() == 500\n        assert batch.test_mask.sum() == 1000\n        assert (batch.train_mask & batch.val_mask & batch.test_mask).sum() == 0\n        assert list(batch.batch.size()) == [batch.num_nodes]\n        assert batch.ptr.tolist() == [0, batch.num_nodes]\n\n        assert batch.has_isolated_nodes()\n        assert not batch.has_self_loops()\n        assert batch.is_undirected()\n\n\n@onlyOnline\n@withPackage('scipy')\ndef test_citeseer_with_full_split(get_dataset):\n    dataset = get_dataset(name='CiteSeer', split='full')\n    data = dataset[0]\n    assert data.val_mask.sum() == 500\n    assert data.test_mask.sum() == 1000\n    assert data.train_mask.sum() == data.num_nodes - 1500\n    assert (data.train_mask & data.val_mask & data.test_mask).sum() == 0\n\n\n@onlyOnline\n@withPackage('scipy')\ndef test_citeseer_with_random_split(get_dataset):\n    dataset = get_dataset(\n        name='CiteSeer',\n        split='random',\n        num_train_per_class=11,\n        num_val=29,\n        num_test=41,\n    )\n    data = dataset[0]\n    # from torch_geometric import EdgeIndex\n    # assert isinstance(data.edge_index, EdgeIndex)\n    # assert data.edge_index.sparse_size() == (data.num_nodes, data.num_nodes)\n    # assert data.edge_index.is_undirected\n    # assert data.edge_index.is_sorted_by_col\n\n    assert data.train_mask.sum() == dataset.num_classes * 11\n    assert data.val_mask.sum() == 29\n    assert data.test_mask.sum() == 41\n    assert (data.train_mask & data.val_mask & data.test_mask).sum() == 0\n"
  },
  {
    "path": "test/datasets/test_protein_mpnn_dataset.py",
    "content": "from torch_geometric.datasets import ProteinMPNNDataset\nfrom torch_geometric.testing import onlyLinux, onlyOnline, withPackage\n\n\n@onlyLinux\n@onlyOnline\n@withPackage('pandas')\ndef test_protein_mpnn_dataset():\n    dataset = ProteinMPNNDataset(root='./data/ProteinMPNN')\n\n    assert len(dataset) == 150\n    assert dataset[0].x.size() == (229, 4, 3)\n    assert dataset[0].chain_seq_label.size() == (229, )\n    assert dataset[0].mask.size() == (229, )\n    assert dataset[0].chain_mask_all.size() == (229, )\n    assert dataset[0].residue_idx.size() == (229, )\n    assert dataset[0].chain_encoding_all.size() == (229, )\n"
  },
  {
    "path": "test/datasets/test_snap_dataset.py",
    "content": "from torch_geometric.testing import onlyFullTest, onlyOnline\n\n\n@onlyOnline\n@onlyFullTest\ndef test_ego_facebook_snap_dataset(get_dataset):\n    import warnings\n\n    import torch\n    from packaging import version\n\n    if version.parse(torch.__version__) >= version.parse(\"2.2.0\"):\n        try:\n            from torch.serialization import add_safe_globals\n\n            from torch_geometric.datasets.snap_dataset import EgoData\n\n            add_safe_globals([EgoData])\n        except ImportError:\n            warnings.warn(\n                \"add_safe_globals is expected but not found in \"\n                \"torch.serialization.\", stacklevel=2)\n    else:\n        warnings.warn(\n            \"add_safe_globals is not available in this version \"\n            \"of PyTorch; continuing without it.\", stacklevel=2)\n\n    dataset = get_dataset(name='ego-facebook')\n    assert str(dataset) == 'SNAP-ego-facebook(10)'\n    assert len(dataset) == 10\n\n\n@onlyOnline\n@onlyFullTest\ndef test_soc_slashdot_snap_dataset(get_dataset):\n    dataset = get_dataset(name='soc-Slashdot0811')\n    assert str(dataset) == 'SNAP-soc-slashdot0811(1)'\n    assert len(dataset) == 1\n\n\n@onlyOnline\n@onlyFullTest\ndef test_wiki_vote_snap_dataset(get_dataset):\n    dataset = get_dataset(name='wiki-vote')\n    assert str(dataset) == 'SNAP-wiki-vote(1)'\n    assert len(dataset) == 1\n"
  },
  {
    "path": "test/datasets/test_suite_sparse.py",
    "content": "from torch_geometric.testing import onlyFullTest, onlyOnline\n\n\n@onlyOnline\n@onlyFullTest\ndef test_suite_sparse_dataset(get_dataset):\n    dataset = get_dataset(group='DIMACS10', name='citationCiteseer')\n    assert str(dataset) == ('SuiteSparseMatrixCollection('\n                            'group=DIMACS10, name=citationCiteseer)')\n    assert len(dataset) == 1\n\n\n@onlyOnline\n@onlyFullTest\ndef test_illc1850_suite_sparse_dataset(get_dataset):\n    dataset = get_dataset(group='HB', name='illc1850')\n    assert str(dataset) == ('SuiteSparseMatrixCollection('\n                            'group=HB, name=illc1850)')\n    assert len(dataset) == 1\n"
  },
  {
    "path": "test/datasets/test_tag_dataset.py",
    "content": "from torch_geometric.datasets import TAGDataset\nfrom torch_geometric.testing import onlyFullTest, withPackage\n\n\n@onlyFullTest\n@withPackage('ogb')\ndef test_tag_dataset() -> None:\n    from ogb.nodeproppred import PygNodePropPredDataset\n\n    root = './data/ogb'\n    hf_model = 'prajjwal1/bert-tiny'\n    token_on_disk = True\n\n    dataset = PygNodePropPredDataset('ogbn-arxiv', root=root)\n    tag_dataset = TAGDataset(root, dataset, hf_model,\n                             token_on_disk=token_on_disk)\n\n    assert 169343 == tag_dataset[0].num_nodes \\\n        == len(tag_dataset.text) \\\n        == len(tag_dataset.llm_explanation)\n    assert 1166243 == tag_dataset[0].num_edges\n"
  },
  {
    "path": "test/datasets/test_teeth3ds.py",
    "content": "from torch_geometric.data import Data\nfrom torch_geometric.datasets import Teeth3DS\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('trimesh', 'fpsample')\ndef test_teeth3ds(tmp_path) -> None:\n    dataset = Teeth3DS(root=tmp_path, split='sample', train=True)\n\n    assert len(dataset) > 0\n    data = dataset[0]\n    assert isinstance(data, Data)\n    assert data.pos.size(1) == 3\n    assert data.x.size(0) == data.pos.size(0)\n    assert data.y.size(0) == data.pos.size(0)\n    assert isinstance(data.jaw, str)\n"
  },
  {
    "path": "test/datasets/test_web_qsp_dataset.py",
    "content": "import os\nimport random\nimport string\n\nimport pytest\n\nfrom torch_geometric.datasets import WebQSPDataset\nfrom torch_geometric.datasets.web_qsp_dataset import KGQABaseDataset\nfrom torch_geometric.testing import onlyFullTest, onlyOnline, withPackage\n\n\n@pytest.mark.skip(reason=\"Times out\")\n@onlyOnline\n@onlyFullTest\n@withPackage(\"datasets\", \"pandas\")\ndef test_web_qsp_dataset(tmp_path):\n    # Split for this dataset is 2826 train | 246 val | 1628 test\n    # default split is train\n    dataset_val = WebQSPDataset(root=tmp_path, split=\"val\")\n    assert len(dataset_val) == 246\n    assert str(dataset_val) == \"WebQSPDataset(246)\"\n\n\nclass MockSentenceTransformer:\n    def __init__(self, *args, **kwargs):\n        pass\n\n    def to(self, device):\n        return self\n\n    def eval(self):\n        return self\n\n    def encode(self, sentences, batch_size=None, output_device=None):\n        import torch\n\n        def string_to_tensor(s: str) -> torch.Tensor:\n            return torch.ones(1024).float()\n\n        if isinstance(sentences, str):\n            return string_to_tensor(sentences)\n        return torch.stack([string_to_tensor(s) for s in sentences])\n\n\ndef create_mock_graphs(tmp_path: str, train_size: int, val_size: int,\n                       test_size: int, num_nodes: int, num_edge_types: int,\n                       num_trips: int, seed: int = 42):\n    random.seed(seed)\n    strkeys = string.ascii_letters + string.digits\n    qa_strkeys = string.ascii_letters + string.digits + \" \"\n\n    def create_mock_triplets(num_nodes: int, num_edges: int, num_trips: int):\n        nodes = list(\n            {\"\".join(random.sample(strkeys, 10))\n             for i in range(num_nodes)})\n        edges = list(\n            {\"\".join(random.sample(strkeys, 10))\n             for i in range(num_edges)})\n        triplets = []\n\n        for _ in range(num_trips):\n            h = random.randint(0, num_nodes - 1)\n            t = random.randint(0, num_nodes - 1)\n            r = random.randint(0, num_edge_types - 1)\n            triplets.append((nodes[h], edges[r], nodes[t]))\n        return triplets\n\n    train_triplets = [\n        create_mock_triplets(num_nodes, num_edge_types, num_trips)\n        for _ in range(train_size)\n    ]\n    val_triplets = [\n        create_mock_triplets(num_nodes, num_edge_types, num_trips)\n        for _ in range(val_size)\n    ]\n    test_triplets = [\n        create_mock_triplets(num_nodes, num_edge_types, num_trips)\n        for _ in range(test_size)\n    ]\n\n    train_questions = [\n        \"\".join(random.sample(qa_strkeys, 10)) for _ in range(train_size)\n    ]\n    val_questions = [\n        \"\".join(random.sample(qa_strkeys, 10)) for _ in range(val_size)\n    ]\n    test_questions = [\n        \"\".join(random.sample(qa_strkeys, 10)) for _ in range(test_size)\n    ]\n\n    train_answers = [\n        \"\".join(random.sample(qa_strkeys, 10)) for _ in range(train_size)\n    ]\n    val_answers = [\n        \"\".join(random.sample(qa_strkeys, 10)) for _ in range(val_size)\n    ]\n    test_answers = [\n        \"\".join(random.sample(qa_strkeys, 10)) for _ in range(test_size)\n    ]\n\n    train_graphs = {\n        \"graph\": train_triplets,\n        \"question\": train_questions,\n        \"answer\": train_answers\n    }\n    val_graphs = {\n        \"graph\": val_triplets,\n        \"question\": val_questions,\n        \"answer\": val_answers\n    }\n    test_graphs = {\n        \"graph\": test_triplets,\n        \"question\": test_questions,\n        \"answer\": test_answers\n    }\n\n    from datasets import Dataset, DatasetDict, load_from_disk\n\n    ds_train = Dataset.from_dict(train_graphs, split=\"train\")\n    ds_val = Dataset.from_dict(val_graphs, split=\"validation\")\n    ds_test = Dataset.from_dict(test_graphs, split=\"test\")\n\n    ds = DatasetDict({\n        \"train\": ds_train,\n        \"validation\": ds_val,\n        \"test\": ds_test\n    })\n\n    def mock_load_dataset(path: str):\n        # Save the dataset and then load it to emulate downloading from HF\n        DATASET_CACHE_DIR = os.path.join(tmp_path,\n                                         \".cache/huggingface/datasets\", path)\n        os.makedirs(DATASET_CACHE_DIR, exist_ok=True)\n\n        ds.save_to_disk(DATASET_CACHE_DIR)\n        dataset_remote = load_from_disk(DATASET_CACHE_DIR)\n        return dataset_remote\n\n    return mock_load_dataset, ds\n\n\n@pytest.mark.rag\n@withPackage(\"datasets\", \"pandas\")\ndef test_kgqa_base_dataset(tmp_path, monkeypatch):\n\n    num_nodes = 500\n    num_edge_types = 25\n    num_trips = 5000\n\n    # Mock the dataset graphs\n    mock_load_dataset_func, expected_result = create_mock_graphs(\n        tmp_path, train_size=10, val_size=5, test_size=5, num_nodes=num_nodes,\n        num_edge_types=num_edge_types, num_trips=num_trips)\n\n    import datasets\n\n    monkeypatch.setattr(datasets, \"load_dataset\", mock_load_dataset_func)\n\n    # Mock the SentenceTransformer\n    import torch_geometric.datasets.web_qsp_dataset\n    monkeypatch.setattr(torch_geometric.datasets.web_qsp_dataset,\n                        \"SentenceTransformer\", MockSentenceTransformer)\n\n    dataset_train = KGQABaseDataset(root=tmp_path, dataset_name=\"TestDataset\",\n                                    split=\"train\", use_pcst=False)\n    assert len(dataset_train) == 10\n    assert str(dataset_train) == \"KGQABaseDataset(10)\"\n    for graph in dataset_train:\n        assert graph.x.shape == (num_nodes, 1024)\n        assert graph.edge_index.shape == (2, num_trips)\n        assert graph.edge_attr.shape == (\n            num_trips, 1024)  # Reminder: edge_attr encodes the entire triplet\n\n    dataset_val = KGQABaseDataset(root=tmp_path, dataset_name=\"TestDataset\",\n                                  split=\"val\", use_pcst=False)\n    assert len(dataset_val) == 5\n    assert str(dataset_val) == \"KGQABaseDataset(5)\"\n\n    dataset_test = KGQABaseDataset(root=tmp_path, dataset_name=\"TestDataset\",\n                                   split=\"test\", use_pcst=False)\n    assert len(dataset_test) == 5\n    assert str(dataset_test) == \"KGQABaseDataset(5)\"\n\n    # TODO(zaristei): More rigorous tests to validate that values are correct\n    # TODO(zaristei): Proper tests for PCST and CWQ\n"
  },
  {
    "path": "test/distributed/test_dist_link_neighbor_loader.py",
    "content": "import socket\nfrom typing import Tuple\n\nimport pytest\nimport torch\nimport torch.multiprocessing as mp\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.datasets import FakeDataset, FakeHeteroDataset\nfrom torch_geometric.distributed import (\n    DistContext,\n    DistLinkNeighborLoader,\n    DistNeighborSampler,\n    LocalFeatureStore,\n    LocalGraphStore,\n    Partitioner,\n)\nfrom torch_geometric.testing import onlyDistributedTest, withMETIS\nfrom torch_geometric.testing.distributed import ProcArgs, assert_run_mproc\n\n\ndef create_dist_data(tmp_path: str, rank: int):\n    graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank)\n    feat_store = LocalFeatureStore.from_partition(tmp_path, pid=rank)\n\n    return feat_store, graph_store\n\n\ndef dist_link_neighbor_loader_homo(\n    world_size: int,\n    tmp_path: str,\n    rank: int,\n    master_addr: str,\n    master_port: int,\n    num_workers: int,\n    async_sampling: bool,\n    neg_ratio: float,\n):\n    part_data = create_dist_data(tmp_path, rank)\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-loader-test',\n    )\n\n    edge_label_index = part_data[1].get_edge_index(None, 'coo')\n    edge_label = torch.randint(high=2, size=(edge_label_index.size(1), ))\n\n    loader = DistLinkNeighborLoader(\n        data=part_data,\n        edge_label_index=(None, edge_label_index),\n        edge_label=edge_label if neg_ratio is not None else None,\n        num_neighbors=[1],\n        batch_size=10,\n        num_workers=num_workers,\n        master_addr=master_addr,\n        master_port=master_port,\n        current_ctx=current_ctx,\n        concurrency=10,\n        drop_last=True,\n        async_sampling=async_sampling,\n    )\n\n    assert str(loader).startswith('DistLinkNeighborLoader')\n    assert str(mp.current_process().pid) in str(loader)\n    assert isinstance(loader.dist_sampler, DistNeighborSampler)\n    assert not part_data[0].meta['is_hetero']\n\n    for batch in loader:\n        assert isinstance(batch, Data)\n        assert batch.n_id.size() == (batch.num_nodes, )\n        assert batch.edge_index.min() >= 0\n        assert batch.edge_index.max() < batch.num_nodes\n    assert loader.channel.empty()\n\n\ndef dist_link_neighbor_loader_hetero(\n    world_size: int,\n    tmp_path: str,\n    rank: int,\n    master_addr: str,\n    master_port: int,\n    num_workers: int,\n    async_sampling: bool,\n    neg_ratio: float,\n    edge_type: Tuple[str, str, str],\n):\n    part_data = create_dist_data(tmp_path, rank)\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name=\"dist-loader-test\",\n    )\n\n    edge_label_index = part_data[1].get_edge_index(edge_type, 'coo')\n    edge_label = torch.randint(high=2, size=(edge_label_index.size(1), ))\n\n    loader = DistLinkNeighborLoader(\n        data=part_data,\n        edge_label_index=(edge_type, edge_label_index),\n        edge_label=edge_label if neg_ratio is not None else None,\n        num_neighbors=[1],\n        batch_size=10,\n        num_workers=num_workers,\n        master_addr=master_addr,\n        master_port=master_port,\n        current_ctx=current_ctx,\n        concurrency=10,\n        drop_last=True,\n        async_sampling=async_sampling,\n    )\n\n    assert str(loader).startswith('DistLinkNeighborLoader')\n    assert str(mp.current_process().pid) in str(loader)\n    assert isinstance(loader.dist_sampler, DistNeighborSampler)\n    assert part_data[0].meta['is_hetero']\n\n    for batch in loader:\n        assert isinstance(batch, HeteroData)\n        assert len(batch.node_types) == 2\n        for node_type in batch.node_types:\n            assert torch.equal(batch[node_type].x, batch.x_dict[node_type])\n            assert batch.x_dict[node_type].size(0) >= 0\n            assert batch[node_type].n_id.size(0) == batch[node_type].num_nodes\n\n        assert len(batch.edge_types) == 4\n        for key in batch.edge_types:\n            if key[-1] == 'v0':\n                assert batch[key].num_sampled_edges[0] > 0\n                assert batch[key].edge_attr.size(0) == batch[key].num_edges\n            else:\n                assert batch[key].num_sampled_edges[0] == 0\n    assert loader.channel.empty()\n\n\n@withMETIS\n@onlyDistributedTest\n@pytest.mark.parametrize('num_parts', [2])\n@pytest.mark.parametrize('num_workers', [0])\n@pytest.mark.parametrize('async_sampling', [True])\n@pytest.mark.parametrize('neg_ratio', [None])\ndef test_dist_link_neighbor_loader_homo(\n    tmp_path,\n    num_parts,\n    num_workers,\n    async_sampling,\n    neg_ratio,\n):\n    addr = '127.0.0.1'\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    data = FakeDataset(\n        num_graphs=1,\n        avg_num_nodes=100,\n        avg_degree=3,\n        edge_dim=2,\n    )[0]\n    partitioner = Partitioner(data, num_parts, tmp_path)\n    partitioner.generate_partition()\n\n    procs = [\n        ProcArgs(\n            target=dist_link_neighbor_loader_homo,\n            args=(tmp_path, part, addr, port, num_workers, async_sampling,\n                  neg_ratio),\n        ) for part in range(num_parts)\n    ]\n    assert_run_mproc(mp_context, procs)\n\n\n@withMETIS\n@onlyDistributedTest\n@pytest.mark.parametrize('num_parts', [2])\n@pytest.mark.parametrize('num_workers', [0])\n@pytest.mark.parametrize('async_sampling', [True])\n@pytest.mark.parametrize('neg_ratio', [None])\n@pytest.mark.parametrize('edge_type', [('v0', 'e0', 'v0')])\ndef test_dist_link_neighbor_loader_hetero(\n    tmp_path,\n    num_parts,\n    num_workers,\n    async_sampling,\n    neg_ratio,\n    edge_type,\n):\n    mp_context = torch.multiprocessing.get_context('spawn')\n    addr = '127.0.0.1'\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    data = FakeHeteroDataset(\n        num_graphs=1,\n        avg_num_nodes=100,\n        avg_degree=3,\n        num_node_types=2,\n        num_edge_types=4,\n        edge_dim=2,\n    )[0]\n    partitioner = Partitioner(data, num_parts, tmp_path)\n    partitioner.generate_partition()\n\n    procs = [\n        ProcArgs(\n            target=dist_link_neighbor_loader_hetero,\n            args=(tmp_path, part, addr, port, num_workers, async_sampling,\n                  neg_ratio, edge_type),\n        ) for part in range(num_parts)\n    ]\n    assert_run_mproc(mp_context, procs)\n"
  },
  {
    "path": "test/distributed/test_dist_link_neighbor_sampler.py",
    "content": "import atexit\nimport socket\nfrom typing import Optional\n\nimport pytest\nimport torch\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets import FakeHeteroDataset\nfrom torch_geometric.distributed import (\n    DistNeighborSampler,\n    LocalFeatureStore,\n    LocalGraphStore,\n    Partitioner,\n)\nfrom torch_geometric.distributed.dist_context import DistContext\nfrom torch_geometric.distributed.event_loop import ConcurrentEventLoop\nfrom torch_geometric.distributed.rpc import init_rpc, shutdown_rpc\nfrom torch_geometric.sampler import EdgeSamplerInput, NeighborSampler\nfrom torch_geometric.sampler.neighbor_sampler import edge_sample\nfrom torch_geometric.testing import onlyDistributedTest, withMETIS\nfrom torch_geometric.testing.distributed import ProcArgs, assert_run_mproc\nfrom torch_geometric.typing import EdgeType\n\n\ndef create_data(rank, world_size, time_attr: Optional[str] = None):\n    if rank == 0:  # Partition 0:\n        node_id = torch.tensor([0, 1, 2, 3, 4, 5, 9])\n        edge_index = torch.tensor([  # Sorted by destination.\n            [1, 2, 3, 4, 5, 0, 0],\n            [0, 1, 2, 3, 4, 4, 9],\n        ])\n    else:  # Partition 1:\n        node_id = torch.tensor([0, 4, 5, 6, 7, 8, 9])\n        edge_index = torch.tensor([  # Sorted by destination.\n            [5, 6, 7, 8, 9, 5, 0],\n            [4, 5, 6, 7, 8, 9, 9],\n        ])\n\n    feature_store = LocalFeatureStore.from_data(node_id)\n    graph_store = LocalGraphStore.from_data(\n        edge_id=None,\n        edge_index=edge_index,\n        num_nodes=10,\n        is_sorted=True,\n    )\n\n    graph_store.node_pb = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])\n    graph_store.meta.update({'num_parts': 2})\n    graph_store.partition_idx = rank\n    graph_store.num_partitions = world_size\n\n    edge_index = torch.tensor([  # Create reference data:\n        [1, 2, 3, 4, 5, 0, 5, 6, 7, 8, 9, 0],\n        [0, 1, 2, 3, 4, 4, 9, 5, 6, 7, 8, 9],\n    ])\n    data = Data(x=None, y=None, edge_index=edge_index, num_nodes=10)\n\n    if time_attr == 'time':  # Create node-level time data:\n        data.time = torch.tensor([5, 0, 1, 3, 3, 4, 4, 4, 4, 4])\n        feature_store.put_tensor(data.time, group_name=None, attr_name='time')\n\n    elif time_attr == 'edge_time':  # Create edge-level time data:\n        data.edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 11])\n\n        if rank == 0:\n            edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 11])\n        if rank == 1:\n            edge_time = torch.tensor([4, 7, 7, 7, 7, 7, 11])\n\n        feature_store.put_tensor(edge_time, group_name=None,\n                                 attr_name=time_attr)\n\n    return (feature_store, graph_store), data\n\n\ndef create_hetero_data(tmp_path: str, rank: int):\n    graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank)\n    other_graph_store = LocalGraphStore.from_partition(tmp_path, int(not rank))\n    feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank)\n\n    return (feature_store, graph_store), other_graph_store\n\n\ndef dist_link_neighbor_sampler(\n    world_size: int,\n    rank: int,\n    master_port: int,\n    disjoint: bool = False,\n):\n    dist_data, data = create_data(rank, world_size)\n\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-sampler-test',\n    )\n\n    dist_sampler = DistNeighborSampler(\n        data=dist_data,\n        current_ctx=current_ctx,\n        num_neighbors=[-1, -1],\n        shuffle=False,\n        disjoint=disjoint,\n    )\n\n    # Close RPC & worker group at exit:\n    atexit.register(shutdown_rpc)\n\n    init_rpc(\n        current_ctx=current_ctx,\n        master_addr='localhost',\n        master_port=master_port,\n    )\n    dist_sampler.init_sampler_instance()\n    dist_sampler.register_sampler_rpc()\n    dist_sampler.event_loop = ConcurrentEventLoop(2)\n    dist_sampler.event_loop.start_loop()\n\n    if rank == 0:  # Seed nodes:\n        input_row = torch.tensor([1, 6], dtype=torch.int64)\n        input_col = torch.tensor([2, 7], dtype=torch.int64)\n    else:\n        input_row = torch.tensor([4, 9], dtype=torch.int64)\n        input_col = torch.tensor([5, 0], dtype=torch.int64)\n\n    inputs = EdgeSamplerInput(\n        input_id=None,\n        row=input_row,\n        col=input_col,\n        input_type=None,\n    )\n\n    # evaluate distributed edge sample function\n    out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample(\n        inputs, dist_sampler.node_sample, data.num_nodes, disjoint))\n\n    sampler = NeighborSampler(data=data, num_neighbors=[-1, -1],\n                              disjoint=disjoint)\n\n    # Evaluate edge sample function:\n    out = edge_sample(\n        inputs,\n        sampler._sample,\n        data.num_nodes,\n        disjoint,\n        node_time=None,\n        neg_sampling=None,\n    )\n\n    # Compare distributed output with single machine output:\n    assert torch.equal(out_dist.node, out.node)\n    assert torch.equal(out_dist.row, out.row)\n    assert torch.equal(out_dist.col, out.col)\n    if disjoint:\n        assert torch.equal(out_dist.batch, out.batch)\n    assert out_dist.num_sampled_nodes == out.num_sampled_nodes\n    assert out_dist.num_sampled_edges == out.num_sampled_edges\n\n\ndef dist_link_neighbor_sampler_temporal(\n    world_size: int,\n    rank: int,\n    master_port: int,\n    seed_time: torch.tensor = None,\n    temporal_strategy: str = 'uniform',\n    time_attr: str = 'time',\n):\n    dist_data, data = create_data(rank, world_size, time_attr)\n\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-sampler-test',\n    )\n\n    num_neighbors = [-1, -1] if temporal_strategy == 'uniform' else [1, 1]\n    dist_sampler = DistNeighborSampler(\n        data=dist_data,\n        current_ctx=current_ctx,\n        num_neighbors=num_neighbors,\n        shuffle=False,\n        disjoint=True,\n        temporal_strategy=temporal_strategy,\n        time_attr=time_attr,\n    )\n\n    # Close RPC & worker group at exit:\n    atexit.register(shutdown_rpc)\n\n    init_rpc(\n        current_ctx=current_ctx,\n        master_addr='localhost',\n        master_port=master_port,\n    )\n    dist_sampler.init_sampler_instance()\n    dist_sampler.register_sampler_rpc()\n    dist_sampler.event_loop = ConcurrentEventLoop(2)\n    dist_sampler.event_loop.start_loop()\n\n    if rank == 0:  # Seed nodes:\n        input_row = torch.tensor([1, 6], dtype=torch.int64)\n        input_col = torch.tensor([2, 7], dtype=torch.int64)\n    else:\n        input_row = torch.tensor([4, 9], dtype=torch.int64)\n        input_col = torch.tensor([5, 0], dtype=torch.int64)\n\n    inputs = EdgeSamplerInput(\n        input_id=None,\n        row=input_row,\n        col=input_col,\n        time=seed_time,\n    )\n\n    # Evaluate distributed edge sample function\n    out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample(\n        inputs, dist_sampler.node_sample, data.num_nodes, disjoint=True,\n        node_time=seed_time, neg_sampling=None))\n\n    sampler = NeighborSampler(\n        data=data,\n        num_neighbors=num_neighbors,\n        disjoint=True,\n        temporal_strategy=temporal_strategy,\n        time_attr=time_attr,\n    )\n\n    # Evaluate edge sample function\n    out = edge_sample(\n        inputs,\n        sampler._sample,\n        data.num_nodes,\n        disjoint=True,\n        node_time=seed_time,\n        neg_sampling=None,\n    )\n\n    # Compare distributed output with single machine output\n    assert torch.equal(out_dist.node, out.node)\n    assert torch.equal(out_dist.row, out.row)\n    assert torch.equal(out_dist.col, out.col)\n    assert torch.equal(out_dist.batch, out.batch)\n    assert out_dist.num_sampled_nodes == out.num_sampled_nodes\n    assert out_dist.num_sampled_edges == out.num_sampled_edges\n\n\ndef dist_link_neighbor_sampler_hetero(\n    world_size: int,\n    data: FakeHeteroDataset,\n    tmp_path: str,\n    rank: int,\n    master_port: int,\n    input_type: EdgeType,\n    disjoint: bool = False,\n):\n    dist_data, other_graph_store = create_hetero_data(tmp_path, rank)\n\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-sampler-test',\n    )\n\n    dist_sampler = DistNeighborSampler(\n        data=dist_data,\n        current_ctx=current_ctx,\n        rpc_worker_names={},\n        num_neighbors=[-1],\n        shuffle=False,\n        disjoint=disjoint,\n    )\n\n    # close RPC & worker group at exit:\n    atexit.register(shutdown_rpc)\n\n    init_rpc(\n        current_ctx=current_ctx,\n        master_addr='localhost',\n        master_port=master_port,\n    )\n    dist_sampler.init_sampler_instance()\n    dist_sampler.register_sampler_rpc()\n    dist_sampler.event_loop = ConcurrentEventLoop(2)\n    dist_sampler.event_loop.start_loop()\n\n    # Create input rows/cols such that pairs belong to different partitions.\n    # Edge from the current partition:\n    edge_label_index1 = dist_data[1]._edge_index[(input_type, 'coo')]\n    row_0 = edge_label_index1[0][0]\n    col_0 = edge_label_index1[1][0]\n    # Edge from the other partition:\n    edge_label_index2 = other_graph_store._edge_index[(input_type, 'coo')]\n    row_1 = edge_label_index2[0][0]\n    col_1 = edge_label_index2[1][0]\n\n    # Seed edges:\n    input_row = torch.tensor([row_0, row_1])\n    input_col = torch.tensor([col_0, col_1])\n\n    inputs = EdgeSamplerInput(\n        input_id=None,\n        row=input_row,\n        col=input_col,\n        input_type=input_type,\n    )\n\n    # Evaluate distributed `node_sample` function:\n    out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample(\n        inputs, dist_sampler.node_sample, data.num_nodes, disjoint))\n\n    sampler = NeighborSampler(\n        data=data,\n        num_neighbors=[-1],\n        disjoint=disjoint,\n    )\n\n    # Evaluate edge sample function:\n    out = edge_sample(\n        inputs,\n        sampler._sample,\n        data.num_nodes,\n        disjoint,\n    )\n\n    # Compare distributed output with single machine output:\n    for k in data.node_types:\n        assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0])\n        assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k]\n        if disjoint:\n            assert torch.equal(\n                out_dist.batch[k].sort()[0],\n                out.batch[k].sort()[0],\n            )\n\n\ndef dist_link_neighbor_sampler_temporal_hetero(\n    world_size: int,\n    data: FakeHeteroDataset,\n    tmp_path: str,\n    rank: int,\n    master_port: int,\n    input_type: EdgeType,\n    seed_time: torch.tensor = None,\n    temporal_strategy: str = 'uniform',\n    time_attr: str = 'time',\n):\n    dist_data, other_graph_store = create_hetero_data(tmp_path, rank)\n\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-sampler-test',\n    )\n\n    dist_sampler = DistNeighborSampler(\n        data=dist_data,\n        current_ctx=current_ctx,\n        rpc_worker_names={},\n        num_neighbors=[-1],\n        shuffle=False,\n        disjoint=True,\n        temporal_strategy=temporal_strategy,\n        time_attr=time_attr,\n    )\n\n    # close RPC & worker group at exit:\n    atexit.register(shutdown_rpc)\n\n    init_rpc(\n        current_ctx=current_ctx,\n        master_addr='localhost',\n        master_port=master_port,\n    )\n    dist_sampler.init_sampler_instance()\n    dist_sampler.register_sampler_rpc()\n    dist_sampler.event_loop = ConcurrentEventLoop(2)\n    dist_sampler.event_loop.start_loop()\n\n    # Create input rows/cols such that pairs belong to different partitions.\n    # Edge from the current partition:\n    edge_label_index1 = dist_data[1]._edge_index[(input_type, 'coo')]\n    row_0 = edge_label_index1[0][0]\n    col_0 = edge_label_index1[1][0]\n    # Edge from the other partition:\n    edge_label_index2 = other_graph_store._edge_index[(input_type, 'coo')]\n    row_1 = edge_label_index2[0][0]\n    col_1 = edge_label_index2[1][0]\n\n    # Seed nodes:\n    input_row = torch.tensor([row_0, row_1], dtype=torch.int64)\n    input_col = torch.tensor([col_0, col_1], dtype=torch.int64)\n\n    inputs = EdgeSamplerInput(\n        input_id=None,\n        row=input_row,\n        col=input_col,\n        time=seed_time,\n        input_type=input_type,\n    )\n\n    # Evaluate distributed node sample function:\n    out_dist = dist_sampler.event_loop.run_task(coro=dist_sampler.edge_sample(\n        inputs, dist_sampler.node_sample, data.num_nodes, disjoint=True))\n\n    sampler = NeighborSampler(\n        data=data,\n        num_neighbors=[-1],\n        disjoint=True,\n        temporal_strategy=temporal_strategy,\n        time_attr=time_attr,\n    )\n\n    # Evaluate edge sample function:\n    out = edge_sample(\n        inputs,\n        sampler._sample,\n        data.num_nodes,\n        disjoint=True,\n        node_time=seed_time,\n        neg_sampling=None,\n    )\n\n    # Compare distributed output with single machine output:\n    for k in data.node_types:\n        assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0])\n        assert torch.equal(out_dist.batch[k].sort()[0], out.batch[k].sort()[0])\n        assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k]\n\n\n@onlyDistributedTest\n@pytest.mark.parametrize('disjoint', [False, True])\ndef test_dist_link_neighbor_sampler(disjoint):\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    procs = [\n        ProcArgs(target=dist_link_neighbor_sampler, args=(0, port, disjoint)),\n        ProcArgs(target=dist_link_neighbor_sampler, args=(1, port, disjoint)),\n    ]\n    assert_run_mproc(mp_context, procs)\n\n\n@onlyDistributedTest\n@pytest.mark.parametrize('seed_time', [None, torch.tensor([3, 6])])\n@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])\ndef test_dist_link_neighbor_sampler_temporal(seed_time, temporal_strategy):\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    procs = [\n        ProcArgs(\n            target=dist_link_neighbor_sampler_temporal,\n            args=(0, port, seed_time, temporal_strategy, 'time'),\n        ),\n        ProcArgs(\n            target=dist_link_neighbor_sampler_temporal,\n            args=(1, port, seed_time, temporal_strategy, 'time'),\n        ),\n    ]\n    assert_run_mproc(mp_context, procs)\n\n\n@onlyDistributedTest\n@pytest.mark.parametrize('seed_time', [[1, 1], [3, 7]])\n@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])\ndef test_dist_link_neighbor_sampler_edge_level_temporal(\n    seed_time,\n    temporal_strategy,\n):\n    seed_time = torch.tensor(seed_time)\n\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    procs = [\n        ProcArgs(\n            target=dist_link_neighbor_sampler_temporal,\n            args=(0, port, seed_time, temporal_strategy, 'edge_time'),\n        ),\n        ProcArgs(\n            target=dist_link_neighbor_sampler_temporal,\n            args=(1, port, seed_time, temporal_strategy, 'edge_time'),\n        ),\n    ]\n    assert_run_mproc(mp_context, procs)\n\n\n@withMETIS\n@onlyDistributedTest\n@pytest.mark.parametrize('disjoint', [False, True])\ndef test_dist_link_neighbor_sampler_hetero(tmp_path, disjoint):\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    data = FakeHeteroDataset(\n        num_graphs=1,\n        avg_num_nodes=100,\n        avg_degree=3,\n        num_node_types=2,\n        num_edge_types=4,\n        edge_dim=2,\n    )[0]\n    data = T.ToUndirected()(data)\n\n    procs = [\n        ProcArgs(\n            target=dist_link_neighbor_sampler_hetero,\n            args=(data, tmp_path, 0, port, ('v0', 'e0', 'v0'), disjoint),\n        ),\n        ProcArgs(\n            target=dist_link_neighbor_sampler_hetero,\n            args=(data, tmp_path, 1, port, ('v1', 'e0', 'v0'), disjoint),\n        ),\n    ]\n\n    partitioner = Partitioner(data, len(procs), tmp_path)\n    partitioner.generate_partition()\n\n    assert_run_mproc(mp_context, procs)\n\n\n@withMETIS\n@onlyDistributedTest\n@pytest.mark.parametrize('seed_time', [None, [0, 0], [3, 3]])\n@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])\ndef test_dist_link_neighbor_sampler_temporal_hetero(\n    tmp_path,\n    seed_time,\n    temporal_strategy,\n):\n    if seed_time is not None:\n        seed_time = torch.tensor(seed_time)\n\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    data = FakeHeteroDataset(\n        num_graphs=1,\n        avg_num_nodes=100,\n        avg_degree=3,\n        num_node_types=2,\n        num_edge_types=4,\n        edge_dim=2,\n    )[0]\n    data = T.ToUndirected()(data)\n\n    # Add time information to the data:\n    data['v0'].time = torch.ones(data['v0'].num_nodes, dtype=torch.int64)\n    data['v1'].time = torch.full((data['v1'].num_nodes, ), 2).long()\n\n    procs = [\n        ProcArgs(\n            target=dist_link_neighbor_sampler_temporal_hetero,\n            args=(data, tmp_path, 0, port, ('v0', 'e0', 'v0'), seed_time,\n                  temporal_strategy, 'time'),\n        ),\n        ProcArgs(\n            target=dist_link_neighbor_sampler_temporal_hetero,\n            args=(data, tmp_path, 1, port, ('v1', 'e0', 'v0'), seed_time,\n                  temporal_strategy, 'time'),\n        ),\n    ]\n\n    partitioner = Partitioner(data, len(procs), tmp_path)\n    partitioner.generate_partition()\n\n    assert_run_mproc(mp_context, procs)\n\n\n@withMETIS\n@onlyDistributedTest\n@pytest.mark.parametrize('seed_time', [[0, 0], [3, 3]])\n@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])\ndef test_dist_link_neighbor_sampler_edge_level_temporal_hetero(\n    tmp_path,\n    seed_time,\n    temporal_strategy,\n):\n    seed_time = torch.tensor(seed_time)\n\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    data = FakeHeteroDataset(\n        num_graphs=1,\n        avg_num_nodes=100,\n        avg_degree=3,\n        num_node_types=2,\n        num_edge_types=4,\n        edge_dim=2,\n    )[0]\n    data = T.ToUndirected()(data)\n\n    # Add time information to the data:\n    for i, edge_type in enumerate(data.edge_types):\n        data[edge_type].edge_time = torch.full(  #\n            (data[edge_type].num_edges, ), i, dtype=torch.int64)\n\n    procs = [\n        ProcArgs(\n            target=dist_link_neighbor_sampler_temporal_hetero,\n            args=(data, tmp_path, 0, port, ('v0', 'e0', 'v0'), seed_time,\n                  temporal_strategy, 'edge_time'),\n        ),\n        ProcArgs(\n            target=dist_link_neighbor_sampler_temporal_hetero,\n            args=(data, tmp_path, 1, port, ('v0', 'e0', 'v1'), seed_time,\n                  temporal_strategy, 'edge_time'),\n        ),\n    ]\n\n    partitioner = Partitioner(data, len(procs), tmp_path)\n    partitioner.generate_partition()\n\n    assert_run_mproc(mp_context, procs)\n"
  },
  {
    "path": "test/distributed/test_dist_neighbor_loader.py",
    "content": "import socket\nimport warnings\n\nimport pytest\nimport torch\nimport torch.multiprocessing as mp\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.datasets import FakeDataset, FakeHeteroDataset\nfrom torch_geometric.distributed import (\n    DistContext,\n    DistNeighborLoader,\n    DistNeighborSampler,\n    LocalFeatureStore,\n    LocalGraphStore,\n    Partitioner,\n)\nfrom torch_geometric.testing import onlyDistributedTest, withMETIS\nfrom torch_geometric.testing.distributed import ProcArgs, assert_run_mproc\n\n\ndef create_dist_data(tmp_path: str, rank: int):\n    graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank)\n    feat_store = LocalFeatureStore.from_partition(tmp_path, pid=rank)\n\n    return feat_store, graph_store\n\n\ndef dist_neighbor_loader_homo(\n    world_size: int,\n    tmp_path: str,\n    rank: int,\n    master_addr: str,\n    master_port: int,\n    num_workers: int,\n    async_sampling: bool,\n):\n    part_data = create_dist_data(tmp_path, rank)\n    input_nodes = part_data[0].get_global_id(None)\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-loader-test',\n    )\n\n    loader = DistNeighborLoader(\n        part_data,\n        num_neighbors=[1],\n        batch_size=10,\n        num_workers=num_workers,\n        input_nodes=input_nodes,\n        master_addr=master_addr,\n        master_port=master_port,\n        current_ctx=current_ctx,\n        concurrency=10,\n        drop_last=True,\n        async_sampling=async_sampling,\n    )\n\n    edge_index = part_data[1]._edge_index[(None, 'coo')]\n\n    assert str(loader).startswith('DistNeighborLoader')\n    assert str(mp.current_process().pid) in str(loader)\n    assert isinstance(loader.dist_sampler, DistNeighborSampler)\n    assert not part_data[0].meta['is_hetero']\n\n    for batch in loader:\n        assert isinstance(batch, Data)\n        assert batch.n_id.size() == (batch.num_nodes, )\n        assert batch.input_id.numel() == batch.batch_size == 10\n        assert batch.edge_index.min() >= 0\n        assert batch.edge_index.max() < batch.num_nodes\n        assert torch.equal(\n            batch.n_id[batch.edge_index],\n            edge_index[:, batch.e_id],\n        )\n    assert loader.channel.empty()\n\n\ndef dist_neighbor_loader_hetero(\n    world_size: int,\n    tmp_path: str,\n    rank: int,\n    master_addr: str,\n    master_port: int,\n    num_workers: int,\n    async_sampling: bool,\n):\n    part_data = create_dist_data(tmp_path, rank)\n    input_nodes = ('v0', part_data[0].get_global_id('v0'))\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-loader-test',\n    )\n\n    loader = DistNeighborLoader(\n        part_data,\n        num_neighbors=[1],\n        batch_size=10,\n        num_workers=num_workers,\n        input_nodes=input_nodes,\n        master_addr=master_addr,\n        master_port=master_port,\n        current_ctx=current_ctx,\n        concurrency=10,\n        drop_last=True,\n        async_sampling=async_sampling,\n    )\n\n    assert str(loader).startswith('DistNeighborLoader')\n    assert str(mp.current_process().pid) in str(loader)\n    assert isinstance(loader.dist_sampler, DistNeighborSampler)\n    assert part_data[0].meta['is_hetero']\n\n    for batch in loader:\n        assert isinstance(batch, HeteroData)\n        assert batch['v0'].input_id.numel() == batch['v0'].batch_size == 10\n\n        assert len(batch.node_types) == 2\n        for node_type in batch.node_types:\n            assert torch.equal(batch[node_type].x, batch.x_dict[node_type])\n            assert batch.x_dict[node_type].size(0) >= 0\n            assert batch[node_type].n_id.size(0) == batch[node_type].num_nodes\n\n        assert len(batch.edge_types) == 4\n        for edge_type in batch.edge_types:\n            num_edges = batch[edge_type].edge_index.size(1)\n\n            if num_edges > 0:  # Test edge mapping:\n                assert batch[edge_type].edge_attr.size(0) == num_edges\n                src, _, dst = edge_type\n                edge_index = part_data[1]._edge_index[(edge_type, \"coo\")]\n                global_edge_index1 = torch.stack([\n                    batch[src].n_id[batch[edge_type].edge_index[0]],\n                    batch[dst].n_id[batch[edge_type].edge_index[1]],\n                ], dim=0)\n\n                # TODO There is a current known flake, which we need to fix:\n                e_id = batch[edge_type].e_id\n                if e_id.numel() > 0 and e_id.max() >= edge_index.size(1):\n                    warnings.warn(\"Known test flake\", stacklevel=2)\n                else:\n                    global_edge_index2 = edge_index[:, e_id]\n                    if not torch.equal(global_edge_index1, global_edge_index2):\n                        warnings.warn(\"Known test flake\", stacklevel=2)\n\n    assert loader.channel.empty()\n\n\n@withMETIS\n@onlyDistributedTest\n@pytest.mark.parametrize('num_parts', [2])\n@pytest.mark.parametrize('num_workers', [0])\n@pytest.mark.parametrize('async_sampling', [True])\ndef test_dist_neighbor_loader_homo(\n    tmp_path,\n    num_parts,\n    num_workers,\n    async_sampling,\n):\n    mp_context = torch.multiprocessing.get_context('spawn')\n    addr = '127.0.0.1'\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    data = FakeDataset(\n        num_graphs=1,\n        avg_num_nodes=100,\n        avg_degree=3,\n        edge_dim=2,\n    )[0]\n    partitioner = Partitioner(data, num_parts, tmp_path)\n    partitioner.generate_partition()\n\n    procs = [\n        ProcArgs(\n            target=dist_neighbor_loader_homo,\n            args=(tmp_path, part, addr, port, num_workers, async_sampling),\n        ) for part in range(num_parts)\n    ]\n    assert_run_mproc(mp_context, procs)\n\n\n@withMETIS\n@onlyDistributedTest\n@pytest.mark.parametrize('num_parts', [2])\n@pytest.mark.parametrize('num_workers', [0])\n@pytest.mark.parametrize('async_sampling', [True])\ndef test_dist_neighbor_loader_hetero(\n    tmp_path,\n    num_parts,\n    num_workers,\n    async_sampling,\n):\n    mp_context = torch.multiprocessing.get_context('spawn')\n    addr = '127.0.0.1'\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    data = FakeHeteroDataset(\n        num_graphs=1,\n        avg_num_nodes=100,\n        avg_degree=3,\n        num_node_types=2,\n        num_edge_types=4,\n        edge_dim=2,\n    )[0]\n    partitioner = Partitioner(data, num_parts, tmp_path)\n    partitioner.generate_partition()\n\n    procs = [\n        ProcArgs(\n            target=dist_neighbor_loader_hetero,\n            args=(tmp_path, part, addr, port, num_workers, async_sampling),\n        ) for part in range(num_parts)\n    ]\n    assert_run_mproc(mp_context, procs)\n"
  },
  {
    "path": "test/distributed/test_dist_neighbor_sampler.py",
    "content": "import atexit\nimport socket\nfrom typing import Optional\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets import FakeHeteroDataset\nfrom torch_geometric.distributed import (\n    DistNeighborSampler,\n    LocalFeatureStore,\n    LocalGraphStore,\n    Partitioner,\n)\nfrom torch_geometric.distributed.dist_context import DistContext\nfrom torch_geometric.distributed.event_loop import ConcurrentEventLoop\nfrom torch_geometric.distributed.rpc import init_rpc, shutdown_rpc\nfrom torch_geometric.sampler import NeighborSampler, NodeSamplerInput\nfrom torch_geometric.sampler.neighbor_sampler import node_sample\nfrom torch_geometric.testing import onlyDistributedTest, withMETIS\nfrom torch_geometric.testing.distributed import ProcArgs, assert_run_mproc\n\n\ndef create_data(rank: int, world_size: int, time_attr: Optional[str] = None):\n    if rank == 0:  # Partition 0:\n        node_id = torch.tensor([0, 1, 2, 3, 4, 5, 9])\n        edge_index = torch.tensor([  # Sorted by destination.\n            [1, 2, 3, 4, 5, 0, 0],\n            [0, 1, 2, 3, 4, 4, 9],\n        ])\n    else:  # Partition 1:\n        node_id = torch.tensor([0, 4, 5, 6, 7, 8, 9])\n        edge_index = torch.tensor([  # Sorted by destination.\n            [5, 6, 7, 8, 9, 5, 0],\n            [4, 5, 6, 7, 8, 9, 9],\n        ])\n    feature_store = LocalFeatureStore.from_data(node_id)\n    graph_store = LocalGraphStore.from_data(\n        edge_id=None,\n        edge_index=edge_index,\n        num_nodes=10,\n        is_sorted=True,\n    )\n\n    graph_store.node_pb = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])\n    graph_store.meta.update({'num_parts': 2})\n    graph_store.partition_idx = rank\n    graph_store.num_partitions = world_size\n\n    edge_index = torch.tensor([  # Create reference data:\n        [1, 2, 3, 4, 5, 0, 5, 6, 7, 8, 9, 0],\n        [0, 1, 2, 3, 4, 4, 9, 5, 6, 7, 8, 9],\n    ])\n    data = Data(x=None, y=None, edge_index=edge_index, num_nodes=10)\n\n    if time_attr == 'time':  # Create node-level time data:\n        data.time = torch.tensor([5, 0, 1, 3, 3, 4, 4, 4, 4, 4])\n        feature_store.put_tensor(data.time, group_name=None,\n                                 attr_name=time_attr)\n\n    elif time_attr == 'edge_time':  # Create edge-level time data:\n        data.edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 7, 7, 7, 7, 7, 11])\n\n        if rank == 0:\n            edge_time = torch.tensor([0, 1, 2, 3, 4, 5, 11])\n        if rank == 1:\n            edge_time = torch.tensor([4, 7, 7, 7, 7, 7, 11])\n\n        feature_store.put_tensor(edge_time, group_name=None,\n                                 attr_name=time_attr)\n\n    return (feature_store, graph_store), data\n\n\ndef create_hetero_data(\n    tmp_path: str,\n    rank: int,\n):\n    graph_store = LocalGraphStore.from_partition(tmp_path, pid=rank)\n    feature_store = LocalFeatureStore.from_partition(tmp_path, pid=rank)\n\n    return feature_store, graph_store\n\n\ndef dist_neighbor_sampler(\n    world_size: int,\n    rank: int,\n    master_port: int,\n    disjoint: bool = False,\n):\n    dist_data, data = create_data(rank, world_size)\n\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-sampler-test',\n    )\n\n    dist_sampler = DistNeighborSampler(\n        data=dist_data,\n        current_ctx=current_ctx,\n        num_neighbors=[-1, -1],\n        shuffle=False,\n        disjoint=disjoint,\n    )\n    # Close RPC & worker group at exit:\n    atexit.register(shutdown_rpc)\n\n    init_rpc(\n        current_ctx=current_ctx,\n        master_addr='localhost',\n        master_port=master_port,\n    )\n    dist_sampler.init_sampler_instance()\n    dist_sampler.register_sampler_rpc()\n    dist_sampler.event_loop = ConcurrentEventLoop(2)\n    dist_sampler.event_loop.start_loop()\n\n    if rank == 0:  # Seed nodes:\n        input_node = torch.tensor([1, 6])\n    else:\n        input_node = torch.tensor([4, 9])\n\n    inputs = NodeSamplerInput(input_id=None, node=input_node)\n\n    # Evaluate distributed node sample function:\n    out_dist = dist_sampler.event_loop.run_task(\n        coro=dist_sampler.node_sample(inputs))\n\n    sampler = NeighborSampler(\n        data=data,\n        num_neighbors=[-1, -1],\n        disjoint=disjoint,\n    )\n\n    # Evaluate node sample function:\n    out = node_sample(inputs, sampler._sample)\n\n    # Compare distributed output with single machine output:\n    assert torch.equal(out_dist.node, out.node)\n    assert torch.equal(out_dist.row, out.row)\n    assert torch.equal(out_dist.col, out.col)\n    if disjoint:\n        assert torch.equal(out_dist.batch, out.batch)\n    assert out_dist.num_sampled_nodes == out.num_sampled_nodes\n    assert out_dist.num_sampled_edges == out.num_sampled_edges\n\n\ndef dist_neighbor_sampler_temporal(\n    world_size: int,\n    rank: int,\n    master_port: int,\n    seed_time: torch.tensor = None,\n    temporal_strategy: str = 'uniform',\n    time_attr: str = 'time',\n):\n    dist_data, data = create_data(rank, world_size, time_attr)\n\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-sampler-test',\n    )\n\n    num_neighbors = [-1, -1] if temporal_strategy == 'uniform' else [1, 1]\n    dist_sampler = DistNeighborSampler(\n        data=dist_data,\n        current_ctx=current_ctx,\n        num_neighbors=num_neighbors,\n        shuffle=False,\n        disjoint=True,\n        temporal_strategy=temporal_strategy,\n        time_attr=time_attr,\n    )\n    # Close RPC & worker group at exit:\n    atexit.register(shutdown_rpc)\n\n    init_rpc(\n        current_ctx=current_ctx,\n        master_addr='localhost',\n        master_port=master_port,\n    )\n    dist_sampler.init_sampler_instance()\n    dist_sampler.register_sampler_rpc()\n    dist_sampler.event_loop = ConcurrentEventLoop(2)\n    dist_sampler.event_loop.start_loop()\n\n    if rank == 0:  # Seed nodes:\n        input_node = torch.tensor([1, 6], dtype=torch.int64)\n    else:\n        input_node = torch.tensor([4, 9], dtype=torch.int64)\n\n    inputs = NodeSamplerInput(\n        input_id=None,\n        node=input_node,\n        time=seed_time,\n    )\n\n    # Evaluate distributed node sample function:\n    out_dist = dist_sampler.event_loop.run_task(\n        coro=dist_sampler.node_sample(inputs))\n\n    sampler = NeighborSampler(\n        data=data,\n        num_neighbors=num_neighbors,\n        disjoint=True,\n        temporal_strategy=temporal_strategy,\n        time_attr=time_attr,\n    )\n\n    # Evaluate node sample function:\n    out = node_sample(inputs, sampler._sample)\n\n    # Compare distributed output with single machine output:\n    assert torch.equal(out_dist.node, out.node)\n    assert torch.equal(out_dist.row, out.row)\n    assert torch.equal(out_dist.col, out.col)\n    assert torch.equal(out_dist.batch, out.batch)\n    assert out_dist.num_sampled_nodes == out.num_sampled_nodes\n    assert out_dist.num_sampled_edges == out.num_sampled_edges\n\n\ndef dist_neighbor_sampler_hetero(\n    world_size: int,\n    data: FakeHeteroDataset,\n    tmp_path: str,\n    rank: int,\n    master_port: int,\n    input_type: str,\n    disjoint: bool = False,\n):\n    dist_data = create_hetero_data(tmp_path, rank)\n\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-sampler-test',\n    )\n\n    num_neighbors = [-1, -1]\n    dist_sampler = DistNeighborSampler(\n        data=dist_data,\n        current_ctx=current_ctx,\n        rpc_worker_names={},\n        num_neighbors=num_neighbors,\n        shuffle=False,\n        disjoint=disjoint,\n    )\n\n    # Close RPC & worker group at exit:\n    atexit.register(shutdown_rpc)\n\n    init_rpc(\n        current_ctx=current_ctx,\n        master_addr='localhost',\n        master_port=master_port,\n    )\n    dist_sampler.init_sampler_instance()\n    dist_sampler.register_sampler_rpc()\n    dist_sampler.event_loop = ConcurrentEventLoop(2)\n    dist_sampler.event_loop.start_loop()\n\n    # Create inputs nodes such that each belongs to a different partition\n    node_pb_list = dist_data[1].node_pb[input_type].tolist()\n    node_0 = node_pb_list.index(0)\n    node_1 = node_pb_list.index(1)\n\n    input_node = torch.tensor([node_0, node_1], dtype=torch.int64)\n\n    inputs = NodeSamplerInput(\n        input_id=None,\n        node=input_node,\n        input_type=input_type,\n    )\n\n    # Evaluate distributed node sample function:\n    out_dist = dist_sampler.event_loop.run_task(\n        coro=dist_sampler.node_sample(inputs))\n\n    sampler = NeighborSampler(\n        data=data,\n        num_neighbors=num_neighbors,\n        disjoint=disjoint,\n    )\n\n    # Evaluate node sample function:\n    out = node_sample(inputs, sampler._sample)\n\n    # Compare distributed output with single machine output:\n    for k in data.node_types:\n        assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0])\n        assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k]\n        if disjoint:\n            assert torch.equal(\n                out_dist.batch[k].sort()[0],\n                out.batch[k].sort()[0],\n            )\n\n\ndef dist_neighbor_sampler_temporal_hetero(\n    world_size: int,\n    data: FakeHeteroDataset,\n    tmp_path: str,\n    rank: int,\n    master_port: int,\n    input_type: str,\n    seed_time: torch.tensor = None,\n    temporal_strategy: str = 'uniform',\n    time_attr: str = 'time',\n):\n    dist_data = create_hetero_data(tmp_path, rank)\n\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-sampler-test',\n    )\n\n    dist_sampler = DistNeighborSampler(\n        data=dist_data,\n        current_ctx=current_ctx,\n        rpc_worker_names={},\n        num_neighbors=[-1, -1],\n        shuffle=False,\n        disjoint=True,\n        temporal_strategy=temporal_strategy,\n        time_attr=time_attr,\n    )\n\n    # Close RPC & worker group at exit:\n    atexit.register(shutdown_rpc)\n\n    init_rpc(\n        current_ctx=current_ctx,\n        rpc_worker_names={},\n        master_addr='localhost',\n        master_port=master_port,\n    )\n\n    dist_sampler.init_sampler_instance()\n    dist_sampler.register_sampler_rpc()\n    dist_sampler.event_loop = ConcurrentEventLoop(2)\n    dist_sampler.event_loop.start_loop()\n\n    # Create inputs nodes such that each belongs to a different partition:\n    node_pb_list = dist_data[1].node_pb[input_type].tolist()\n    node_0 = node_pb_list.index(0)\n    node_1 = node_pb_list.index(1)\n\n    input_node = torch.tensor([node_0, node_1], dtype=torch.int64)\n\n    inputs = NodeSamplerInput(\n        input_id=None,\n        node=input_node,\n        time=seed_time,\n        input_type=input_type,\n    )\n\n    # Evaluate distributed node sample function:\n    out_dist = dist_sampler.event_loop.run_task(\n        coro=dist_sampler.node_sample(inputs))\n\n    sampler = NeighborSampler(\n        data=data,\n        num_neighbors=[-1, -1],\n        disjoint=True,\n        temporal_strategy=temporal_strategy,\n        time_attr=time_attr,\n    )\n\n    # Evaluate node sample function:\n    out = node_sample(inputs, sampler._sample)\n\n    # Compare distributed output with single machine output:\n    for k in data.node_types:\n        assert torch.equal(out_dist.node[k].sort()[0], out.node[k].sort()[0])\n        assert torch.equal(out_dist.batch[k].sort()[0], out.batch[k].sort()[0])\n        assert out_dist.num_sampled_nodes[k] == out.num_sampled_nodes[k]\n\n\n@onlyDistributedTest\n@pytest.mark.parametrize('disjoint', [False, True])\ndef test_dist_neighbor_sampler(disjoint):\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    procs = [\n        ProcArgs(target=dist_neighbor_sampler, args=(0, port, disjoint)),\n        ProcArgs(target=dist_neighbor_sampler, args=(1, port, disjoint)),\n    ]\n    assert_run_mproc(mp_context, procs)\n\n\n@onlyDistributedTest\n@pytest.mark.parametrize('seed_time', [None, torch.tensor([3, 6])])\n@pytest.mark.parametrize('temporal_strategy', ['uniform'])\ndef test_dist_neighbor_sampler_temporal(seed_time, temporal_strategy):\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    procs = [\n        ProcArgs(\n            target=dist_neighbor_sampler_temporal,\n            args=(0, port, seed_time, temporal_strategy, 'time'),\n        ),\n        ProcArgs(\n            target=dist_neighbor_sampler_temporal,\n            args=(1, port, seed_time, temporal_strategy, 'time'),\n        ),\n    ]\n    assert_run_mproc(mp_context, procs)\n\n\n@onlyDistributedTest\n@pytest.mark.parametrize('seed_time', [[3, 7]])\n@pytest.mark.parametrize('temporal_strategy', ['last'])\ndef test_dist_neighbor_sampler_edge_level_temporal(\n    seed_time,\n    temporal_strategy,\n):\n    seed_time = torch.tensor(seed_time)\n\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    procs = [\n        ProcArgs(\n            target=dist_neighbor_sampler_temporal,\n            args=(0, port, seed_time, temporal_strategy, 'edge_time'),\n        ),\n        ProcArgs(\n            target=dist_neighbor_sampler_temporal,\n            args=(1, port, seed_time, temporal_strategy, 'edge_time'),\n        ),\n    ]\n    assert_run_mproc(mp_context, procs)\n\n\n@withMETIS\n@onlyDistributedTest\n@pytest.mark.parametrize('disjoint', [False, True])\ndef test_dist_neighbor_sampler_hetero(tmp_path, disjoint):\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    data = FakeHeteroDataset(\n        num_graphs=1,\n        avg_num_nodes=100,\n        avg_degree=3,\n        num_node_types=2,\n        num_edge_types=4,\n        edge_dim=2,\n    )[0]\n\n    procs = [\n        ProcArgs(\n            target=dist_neighbor_sampler_hetero,\n            args=(data, tmp_path, 0, port, 'v0', disjoint),\n        ),\n        ProcArgs(\n            target=dist_neighbor_sampler_hetero,\n            args=(data, tmp_path, 1, port, 'v1', disjoint),\n        ),\n    ]\n\n    partitioner = Partitioner(data, len(procs), tmp_path)\n    partitioner.generate_partition()\n\n    assert_run_mproc(mp_context, procs)\n\n\n@withMETIS\n@onlyDistributedTest\n@pytest.mark.parametrize('seed_time', [None, [0, 0], [2, 2]])\n@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])\ndef test_dist_neighbor_sampler_temporal_hetero(\n    tmp_path,\n    seed_time,\n    temporal_strategy,\n):\n    if seed_time is not None:\n        seed_time = torch.tensor(seed_time)\n\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    data = FakeHeteroDataset(\n        num_graphs=1,\n        avg_num_nodes=100,\n        avg_degree=3,\n        num_node_types=2,\n        num_edge_types=4,\n        edge_dim=2,\n    )[0]\n\n    data['v0'].time = torch.full((data.num_nodes_dict['v0'], ), 1,\n                                 dtype=torch.int64)\n    data['v1'].time = torch.full((data.num_nodes_dict['v1'], ), 2,\n                                 dtype=torch.int64)\n\n    procs = [\n        ProcArgs(\n            target=dist_neighbor_sampler_temporal_hetero,\n            args=(data, tmp_path, 0, port, 'v0', seed_time, temporal_strategy,\n                  'time'),\n        ),\n        ProcArgs(\n            target=dist_neighbor_sampler_temporal_hetero,\n            args=(data, tmp_path, 1, port, 'v1', seed_time, temporal_strategy,\n                  'time'),\n        ),\n    ]\n\n    partitioner = Partitioner(data, len(procs), tmp_path)\n    partitioner.generate_partition()\n\n    assert_run_mproc(mp_context, procs)\n\n\n@withMETIS\n@onlyDistributedTest\n@pytest.mark.parametrize('seed_time', [[0, 0], [1, 2]])\n@pytest.mark.parametrize('temporal_strategy', ['uniform', 'last'])\ndef test_dist_neighbor_sampler_edge_level_temporal_hetero(\n    tmp_path,\n    seed_time,\n    temporal_strategy,\n):\n    seed_time = torch.tensor(seed_time)\n\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('', 0))\n        port = s.getsockname()[1]\n\n    data = FakeHeteroDataset(\n        num_graphs=1,\n        avg_num_nodes=100,\n        avg_degree=3,\n        num_node_types=2,\n        num_edge_types=4,\n        edge_dim=2,\n    )[0]\n\n    for i, edge_type in enumerate(data.edge_types):\n        data[edge_type].edge_time = torch.full(\n            (data[edge_type].edge_index.size(1), ), i, dtype=torch.int64)\n\n    procs = [\n        ProcArgs(\n            target=dist_neighbor_sampler_temporal_hetero,\n            args=(data, tmp_path, 0, port, 'v0', seed_time, temporal_strategy,\n                  'edge_time'),\n        ),\n        ProcArgs(\n            target=dist_neighbor_sampler_temporal_hetero,\n            args=(data, tmp_path, 1, port, 'v1', seed_time, temporal_strategy,\n                  'edge_time'),\n        ),\n    ]\n\n    partitioner = Partitioner(data, len(procs), tmp_path)\n    partitioner.generate_partition()\n\n    assert_run_mproc(mp_context, procs)\n"
  },
  {
    "path": "test/distributed/test_dist_utils.py",
    "content": "import torch\n\nfrom torch_geometric.distributed.utils import remove_duplicates\nfrom torch_geometric.sampler import SamplerOutput\nfrom torch_geometric.testing import onlyDistributedTest\n\n\n@onlyDistributedTest\ndef test_remove_duplicates():\n    node = torch.tensor([0, 1, 2, 3])\n    out_node = torch.tensor([0, 4, 1, 5, 1, 6, 2, 7, 3, 8])\n\n    out = SamplerOutput(out_node, None, None, None)\n\n    src, node, _, _ = remove_duplicates(out, node)\n\n    assert src.tolist() == [4, 5, 6, 7, 8]\n    assert node.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8]\n\n\n@onlyDistributedTest\ndef test_remove_duplicates_disjoint():\n    node = torch.tensor([0, 1, 2, 3])\n    batch = torch.tensor([0, 1, 2, 3])\n\n    out_node = torch.tensor([0, 4, 1, 5, 1, 6, 2, 6, 7, 3, 8])\n    out_batch = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])\n\n    out = SamplerOutput(out_node, None, None, None, out_batch)\n\n    src, node, src_batch, batch = remove_duplicates(out, node, batch,\n                                                    disjoint=True)\n\n    assert src.tolist() == [4, 5, 6, 7, 8]\n    assert node.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8]\n\n    assert src_batch.tolist() == [0, 1, 2, 3, 3]\n    assert batch.tolist() == [0, 1, 2, 3, 0, 1, 2, 3, 3]\n"
  },
  {
    "path": "test/distributed/test_local_feature_store.py",
    "content": "import torch\n\nfrom torch_geometric.distributed import LocalFeatureStore\nfrom torch_geometric.testing import onlyDistributedTest\n\n\n@onlyDistributedTest\ndef test_local_feature_store_global_id():\n    store = LocalFeatureStore()\n\n    feat = torch.tensor([\n        [0.0, 0.0, 0.0],\n        [1.0, 1.0, 1.0],\n        [2.0, 2.0, 2.0],\n        [3.0, 3.0, 3.0],\n        [4.0, 4.0, 4.0],\n        [5.0, 5.0, 5.0],\n        [6.0, 6.0, 6.0],\n        [7.0, 7.0, 7.0],\n        [8.0, 8.0, 8.0],\n    ])\n\n    paper_global_id = torch.tensor([1, 2, 3, 5, 8, 4])\n    paper_feat = feat[paper_global_id]\n\n    store.put_global_id(paper_global_id, group_name='paper')\n    store.put_tensor(paper_feat, group_name='paper', attr_name='feat')\n\n    out = store.get_tensor_from_global_id(group_name='paper', attr_name='feat',\n                                          index=torch.tensor([3, 8, 4]))\n    assert torch.equal(out, feat[torch.tensor([3, 8, 4])])\n\n\n@onlyDistributedTest\ndef test_local_feature_store_utils():\n    store = LocalFeatureStore()\n\n    feat = torch.tensor([\n        [0.0, 0.0, 0.0],\n        [1.0, 1.0, 1.0],\n        [2.0, 2.0, 2.0],\n        [3.0, 3.0, 3.0],\n        [4.0, 4.0, 4.0],\n        [5.0, 5.0, 5.0],\n        [6.0, 6.0, 6.0],\n        [7.0, 7.0, 7.0],\n        [8.0, 8.0, 8.0],\n    ])\n\n    paper_global_id = torch.tensor([1, 2, 3, 5, 8, 4])\n    paper_feat = feat[paper_global_id]\n\n    store.put_tensor(paper_feat, group_name='paper', attr_name='feat')\n\n    assert len(store.get_all_tensor_attrs()) == 1\n    attr = store.get_all_tensor_attrs()[0]\n    assert attr.group_name == 'paper'\n    assert attr.attr_name == 'feat'\n    assert attr.index is None\n    assert store.get_tensor_size(attr) == (6, 3)\n\n\n@onlyDistributedTest\ndef test_homogeneous_feature_store():\n    node_id = torch.randperm(6)\n    x = torch.randn(6, 32)\n    y = torch.randint(0, 2, (6, ))\n    edge_id = torch.randperm(12)\n    edge_attr = torch.randn(12, 16)\n\n    store = LocalFeatureStore.from_data(node_id, x, y, edge_id, edge_attr)\n\n    assert len(store.get_all_tensor_attrs()) == 3\n    attrs = store.get_all_tensor_attrs()\n\n    assert attrs[0].group_name is None\n    assert attrs[0].attr_name == 'x'\n    assert attrs[1].group_name is None\n    assert attrs[1].attr_name == 'y'\n    assert attrs[2].group_name == (None, None)\n    assert attrs[2].attr_name == 'edge_attr'\n\n    assert torch.equal(store.get_global_id(group_name=None), node_id)\n    assert torch.equal(store.get_tensor(group_name=None, attr_name='x'), x)\n    assert torch.equal(store.get_tensor(group_name=None, attr_name='y'), y)\n    assert torch.equal(store.get_global_id(group_name=(None, None)), edge_id)\n    assert torch.equal(\n        store.get_tensor(group_name=(None, None), attr_name='edge_attr'),\n        edge_attr,\n    )\n\n\n@onlyDistributedTest\ndef test_heterogeneous_feature_store():\n    node_type = 'paper'\n    edge_type = ('paper', 'to', 'paper')\n    node_id_dict = {node_type: torch.randperm(6)}\n    x_dict = {node_type: torch.randn(6, 32)}\n    y_dict = {node_type: torch.randint(0, 2, (6, ))}\n    edge_id_dict = {edge_type: torch.randperm(12)}\n    edge_attr_dict = {edge_type: torch.randn(12, 16)}\n\n    store = LocalFeatureStore.from_hetero_data(\n        node_id_dict,\n        x_dict,\n        y_dict,\n        edge_id_dict,\n        edge_attr_dict,\n    )\n\n    assert len(store.get_all_tensor_attrs()) == 3\n    attrs = store.get_all_tensor_attrs()\n\n    assert attrs[0].group_name == node_type\n    assert attrs[0].attr_name == 'x'\n    assert attrs[1].group_name == node_type\n    assert attrs[1].attr_name == 'y'\n    assert attrs[2].group_name == edge_type\n    assert attrs[2].attr_name == 'edge_attr'\n\n    assert torch.equal(\n        store.get_global_id(group_name=node_type),\n        node_id_dict[node_type],\n    )\n    assert torch.equal(\n        store.get_tensor(group_name=node_type, attr_name='x'),\n        x_dict[node_type],\n    )\n    assert torch.equal(\n        store.get_tensor(group_name=node_type, attr_name='y'),\n        y_dict[node_type],\n    )\n    assert torch.equal(\n        store.get_global_id(group_name=edge_type),\n        edge_id_dict[edge_type],\n    )\n    assert torch.equal(\n        store.get_tensor(group_name=edge_type, attr_name='edge_attr'),\n        edge_attr_dict[edge_type],\n    )\n"
  },
  {
    "path": "test/distributed/test_local_graph_store.py",
    "content": "import torch\n\nfrom torch_geometric.distributed import LocalGraphStore\nfrom torch_geometric.testing import get_random_edge_index, onlyDistributedTest\n\n\n@onlyDistributedTest\ndef test_local_graph_store():\n    graph_store = LocalGraphStore()\n\n    edge_index = get_random_edge_index(100, 100, 300)\n    edge_id = torch.tensor([1, 2, 3, 5, 8, 4])\n\n    graph_store.put_edge_index(\n        edge_index,\n        edge_type=None,\n        layout='coo',\n        size=(100, 100),\n    )\n\n    graph_store.put_edge_id(\n        edge_id,\n        edge_type=None,\n        layout='coo',\n        size=(100, 100),\n    )\n\n    assert len(graph_store.get_all_edge_attrs()) == 1\n    edge_attr = graph_store.get_all_edge_attrs()[0]\n    assert torch.equal(graph_store.get_edge_index(edge_attr), edge_index)\n    assert torch.equal(graph_store.get_edge_id(edge_attr), edge_id)\n    assert not graph_store.is_sorted\n    graph_store.remove_edge_index(edge_attr)\n    graph_store.remove_edge_id(edge_attr)\n    assert len(graph_store.get_all_edge_attrs()) == 0\n\n\n@onlyDistributedTest\ndef test_homogeneous_graph_store():\n    edge_id = torch.randperm(300)\n    edge_index = get_random_edge_index(100, 100, 300)\n    edge_index[1] = torch.sort(edge_index[1])[0]\n\n    graph_store = LocalGraphStore.from_data(\n        edge_id,\n        edge_index,\n        num_nodes=100,\n        is_sorted=True,\n    )\n\n    assert len(graph_store.get_all_edge_attrs()) == 1\n    edge_attr = graph_store.get_all_edge_attrs()[0]\n    assert edge_attr.edge_type is None\n    assert edge_attr.layout.value == 'coo'\n    assert edge_attr.is_sorted\n    assert edge_attr.size == (100, 100)\n\n    assert torch.equal(\n        graph_store.get_edge_id(edge_type=None, layout='coo'),\n        edge_id,\n    )\n    assert torch.equal(\n        graph_store.get_edge_index(edge_type=None, layout='coo'),\n        edge_index,\n    )\n\n\n@onlyDistributedTest\ndef test_heterogeneous_graph_store():\n    edge_type = ('paper', 'to', 'paper')\n    edge_id_dict = {edge_type: torch.randperm(300)}\n    edge_index = get_random_edge_index(100, 100, 300)\n    edge_index[1] = torch.sort(edge_index[1])[0]\n    edge_index_dict = {edge_type: edge_index}\n\n    graph_store = LocalGraphStore.from_hetero_data(\n        edge_id_dict,\n        edge_index_dict,\n        num_nodes_dict={'paper': 100},\n        is_sorted=True,\n    )\n\n    assert len(graph_store.get_all_edge_attrs()) == 1\n    edge_attr = graph_store.get_all_edge_attrs()[0]\n    assert edge_attr.edge_type == edge_type\n    assert edge_attr.layout.value == 'coo'\n    assert edge_attr.is_sorted\n    assert edge_attr.size == (100, 100)\n\n    assert torch.equal(\n        graph_store.get_edge_id(edge_type, layout='coo'),\n        edge_id_dict[edge_type],\n    )\n    assert torch.equal(\n        graph_store.get_edge_index(edge_type, layout='coo'),\n        edge_index_dict[edge_type],\n    )\n\n\n@onlyDistributedTest\ndef test_sorted_graph_store():\n    edge_index_sorted = torch.tensor([[1, 7, 5, 6, 1], [0, 0, 1, 1, 2]])\n    edge_id_sorted = torch.tensor([0, 1, 2, 3, 4])\n\n    edge_index = torch.tensor([[1, 5, 7, 1, 6], [0, 1, 0, 2, 1]])\n    edge_id = torch.tensor([0, 2, 1, 4, 3])\n\n    graph_store = LocalGraphStore.from_data(\n        edge_id,\n        edge_index,\n        num_nodes=8,\n        is_sorted=False,\n    )\n    assert torch.equal(\n        graph_store.get_edge_index(edge_type=None, layout='coo'),\n        edge_index_sorted,\n    )\n    assert torch.equal(\n        graph_store.get_edge_id(edge_type=None, layout='coo'),\n        edge_id_sorted,\n    )\n\n    edge_type = ('paper', 'to', 'paper')\n    edge_index_dict = {edge_type: edge_index}\n    edge_id_dict = {edge_type: edge_id}\n\n    graph_store = LocalGraphStore.from_hetero_data(\n        edge_id_dict,\n        edge_index_dict,\n        num_nodes_dict={'paper': 8},\n        is_sorted=False,\n    )\n    assert torch.equal(\n        graph_store.get_edge_index(edge_type, layout='coo'),\n        edge_index_sorted,\n    )\n    assert torch.equal(\n        graph_store.get_edge_id(edge_type, layout='coo'),\n        edge_id_sorted,\n    )\n"
  },
  {
    "path": "test/distributed/test_partition.py",
    "content": "import os.path as osp\n\nimport torch\n\nfrom torch_geometric.datasets import FakeDataset, FakeHeteroDataset\nfrom torch_geometric.distributed import (\n    LocalFeatureStore,\n    LocalGraphStore,\n    Partitioner,\n)\nfrom torch_geometric.io import fs\nfrom torch_geometric.testing import onlyDistributedTest, withMETIS\nfrom torch_geometric.typing import EdgeTypeStr\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_partition_data(tmp_path):\n    data = FakeDataset()[0]\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    node_map_path = osp.join(tmp_path, 'node_map.pt')\n    assert osp.exists(node_map_path)\n    node_map = fs.torch_load(node_map_path)\n    assert node_map.numel() == data.num_nodes\n\n    edge_map_path = osp.join(tmp_path, 'edge_map.pt')\n    assert osp.exists(edge_map_path)\n    edge_map = fs.torch_load(edge_map_path)\n    assert edge_map.numel() == data.num_edges\n\n    meta_path = osp.join(tmp_path, 'META.json')\n    assert osp.exists(meta_path)\n\n    graph0_path = osp.join(tmp_path, 'part_0', 'graph.pt')\n    assert osp.exists(graph0_path)\n    graph0 = fs.torch_load(graph0_path)\n    assert len({'edge_id', 'row', 'col', 'size'} & set(graph0.keys())) == 4\n\n    graph1_path = osp.join(tmp_path, 'part_1', 'graph.pt')\n    assert osp.exists(graph1_path)\n    graph1 = fs.torch_load(graph1_path)\n    assert len({'edge_id', 'row', 'col', 'size'} & set(graph1.keys())) == 4\n\n    node_feats0_path = osp.join(tmp_path, 'part_0', 'node_feats.pt')\n    assert osp.exists(node_feats0_path)\n    node_feats0 = fs.torch_load(node_feats0_path)\n\n    node_feats1_path = osp.join(tmp_path, 'part_1', 'node_feats.pt')\n    assert osp.exists(node_feats1_path)\n    node_feats1 = fs.torch_load(node_feats1_path)\n\n    assert (node_feats0['feats']['x'].size(0) +\n            node_feats1['feats']['x'].size(0) == data.num_nodes)\n    assert torch.equal(data.x[node_feats0['global_id']],\n                       node_feats0['feats']['x'])\n    assert torch.equal(data.x[node_feats1['global_id']],\n                       node_feats1['feats']['x'])\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_partition_hetero_data(tmp_path):\n    data = FakeHeteroDataset()[0]\n\n    num_parts = 2\n    partitioner = Partitioner(data, num_parts=num_parts, root=tmp_path)\n    partitioner.generate_partition()\n\n    meta_path = osp.join(tmp_path, 'META.json')\n    assert osp.exists(meta_path)\n\n    for edge_type, num_edges in data.num_edges_dict.items():\n        assert len(edge_type) == 3\n        edge_name = EdgeTypeStr(edge_type)\n        edge_map_path = osp.join(tmp_path, 'edge_map', f'{edge_name}.pt')\n        assert osp.exists(edge_map_path)\n        edge_map = fs.torch_load(edge_map_path)\n        assert edge_map.numel() == num_edges\n\n    for node_type, num_nodes in data.num_nodes_dict.items():\n        node_map_path = osp.join(tmp_path, 'node_map', f'{node_type}.pt')\n        assert osp.exists(node_map_path)\n        node_map = fs.torch_load(node_map_path)\n        assert node_map.numel() == num_nodes\n\n    for pid in range(num_parts):\n        graph_path = osp.join(tmp_path, f'part_{pid}', 'graph.pt')\n        assert osp.exists(graph_path)\n        node_feats_path = osp.join(tmp_path, f'part_{pid}', 'node_feats.pt')\n        assert osp.exists(node_feats_path)\n        edge_feats_path = osp.join(tmp_path, f'part_{pid}', 'edge_feats.pt')\n        assert osp.exists(edge_feats_path)\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_partition_data_temporal(tmp_path):\n    data = FakeDataset()[0]\n    data.time = torch.arange(data.num_nodes)\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    node_feats0_path = osp.join(tmp_path, 'part_0', 'node_feats.pt')\n    assert osp.exists(node_feats0_path)\n    node_feats0 = fs.torch_load(node_feats0_path)\n\n    node_feats1_path = osp.join(tmp_path, 'part_1', 'node_feats.pt')\n    assert osp.exists(node_feats1_path)\n    node_feats1 = fs.torch_load(node_feats1_path)\n\n    assert torch.equal(data.time, node_feats0['time'])\n    assert torch.equal(data.time, node_feats1['time'])\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_partition_data_edge_level_temporal(tmp_path):\n    data = FakeDataset(edge_dim=2)[0]\n    data.edge_time = torch.arange(data.num_edges)\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    edge_feats0_path = osp.join(tmp_path, 'part_0', 'edge_feats.pt')\n    assert osp.exists(edge_feats0_path)\n    edge_feats0 = fs.torch_load(edge_feats0_path)\n\n    edge_feats1_path = osp.join(tmp_path, 'part_1', 'edge_feats.pt')\n    assert osp.exists(edge_feats1_path)\n    edge_feats1 = fs.torch_load(edge_feats1_path)\n\n    assert torch.equal(data.edge_time[edge_feats0['global_id']],\n                       edge_feats0['edge_time'])\n    assert torch.equal(data.edge_time[edge_feats1['global_id']],\n                       edge_feats1['edge_time'])\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_partition_hetero_data_temporal(tmp_path):\n    data = FakeHeteroDataset()[0]\n\n    for key in data.node_types:\n        data[key].time = torch.arange(data[key].num_nodes)\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    node_feats0_path = osp.join(tmp_path, 'part_0', 'node_feats.pt')\n    assert osp.exists(node_feats0_path)\n    node_feats0 = fs.torch_load(node_feats0_path)\n\n    node_feats1_path = osp.join(tmp_path, 'part_1', 'node_feats.pt')\n    assert osp.exists(node_feats1_path)\n    node_feats1 = fs.torch_load(node_feats1_path)\n\n    for key in data.node_types:\n        assert torch.equal(data[key].time, node_feats0[key]['time'])\n        assert torch.equal(data[key].time, node_feats1[key]['time'])\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_partition_hetero_data_edge_level_temporal(tmp_path):\n    data = FakeHeteroDataset(edge_dim=2)[0]\n\n    for key in data.edge_types:\n        data[key].edge_time = torch.arange(data[key].num_edges)\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    edge_feats0_path = osp.join(tmp_path, 'part_0', 'edge_feats.pt')\n    assert osp.exists(edge_feats0_path)\n    edge_feats0 = fs.torch_load(edge_feats0_path)\n\n    edge_feats1_path = osp.join(tmp_path, 'part_1', 'edge_feats.pt')\n    assert osp.exists(edge_feats1_path)\n    edge_feats1 = fs.torch_load(edge_feats1_path)\n\n    for key in data.edge_types:\n        assert torch.equal(\n            data[key].edge_time[edge_feats0[key]['global_id']],\n            edge_feats0[key]['edge_time'],\n        )\n        assert torch.equal(\n            data[key].edge_time[edge_feats1[key]['global_id']],\n            edge_feats1[key]['edge_time'],\n        )\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_from_partition_data(tmp_path):\n    data = FakeDataset()[0]\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    graph_store1 = LocalGraphStore.from_partition(tmp_path, pid=0)\n    graph_store2 = LocalGraphStore.from_partition(tmp_path, pid=1)\n\n    attr1 = graph_store1.get_all_edge_attrs()[0]\n    (row1, col1) = graph_store1.get_edge_index(attr1)\n    attr2 = graph_store2.get_all_edge_attrs()[0]\n    (row2, col2) = graph_store2.get_edge_index(attr2)\n    assert row1.size(0) + row2.size(0) == data.num_edges\n\n    feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0)\n    feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1)\n\n    node_attr1 = feat_store1.get_all_tensor_attrs()[0]\n    assert node_attr1.attr_name == 'x'\n    x1 = feat_store1.get_tensor(node_attr1)\n    id1 = feat_store1.get_global_id(node_attr1.group_name)\n\n    node_attr2 = feat_store2.get_all_tensor_attrs()[0]\n    assert node_attr2.attr_name == 'x'\n    x2 = feat_store2.get_tensor(node_attr2)\n    id2 = feat_store2.get_global_id(node_attr2.group_name)\n\n    assert x1.size(0) + x2.size(0) == data.num_nodes\n    assert torch.allclose(data.x[id1], x1)\n    assert torch.allclose(data.x[id2], x2)\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_from_partition_hetero_data(tmp_path):\n    data = FakeHeteroDataset()[0]\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    graph_store1 = LocalGraphStore.from_partition(tmp_path, pid=0)\n    graph_store2 = LocalGraphStore.from_partition(tmp_path, pid=1)\n\n    attrs1 = graph_store1.get_all_edge_attrs()\n    attrs2 = graph_store2.get_all_edge_attrs()\n    assert len(data.edge_types) == len(attrs1) == len(attrs2)\n\n    node_types = set()\n    for attr in attrs1:\n        node_types.add(attr.edge_type[0])\n        node_types.add(attr.edge_type[2])\n    assert node_types == set(data.node_types)\n\n    node_types = set()\n    for attr in attrs2:\n        node_types.add(attr.edge_type[0])\n        node_types.add(attr.edge_type[2])\n    assert node_types == set(data.node_types)\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_from_partition_temporal_data(tmp_path):\n    data = FakeDataset()[0]\n    data.time = torch.arange(data.num_nodes)\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0)\n    feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1)\n\n    time_attr1 = feat_store1.get_all_tensor_attrs()[1]\n    assert time_attr1.attr_name == 'time'\n    time1 = feat_store1.get_tensor(time_attr1)\n\n    time_attr2 = feat_store2.get_all_tensor_attrs()[1]\n    assert time_attr2.attr_name == 'time'\n    time2 = feat_store2.get_tensor(time_attr2)\n\n    assert time1.size(0) == data.num_nodes\n    assert time2.size(0) == data.num_nodes\n    assert torch.equal(time1, data.time)\n    assert torch.equal(time2, data.time)\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_from_partition_edge_level_temporal_data(tmp_path):\n    data = FakeDataset(edge_dim=2)[0]\n    data.edge_time = torch.arange(data.num_edges)\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0)\n    feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1)\n\n    time_attr1 = feat_store1.get_all_tensor_attrs()[2]\n    assert time_attr1.attr_name == 'edge_time'\n    time1 = feat_store1.get_tensor(time_attr1)\n\n    time_attr2 = feat_store2.get_all_tensor_attrs()[2]\n    assert time_attr2.attr_name == 'edge_time'\n    time2 = feat_store2.get_tensor(time_attr2)\n\n    edge_id1 = feat_store1.get_global_id(group_name=(None, None))\n    edge_id2 = feat_store2.get_global_id(group_name=(None, None))\n\n    assert time1.size(0) + time2.size(0) == data.edge_index.size(1)\n    assert torch.equal(data.edge_time[edge_id1], time1)\n    assert torch.equal(data.edge_time[edge_id2], time2)\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_from_partition_hetero_temporal_data(tmp_path):\n    data = FakeHeteroDataset()[0]\n\n    for key in data.node_types:\n        data[key].time = torch.arange(data[key].num_nodes)\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0)\n    feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1)\n\n    attrs1 = feat_store1.get_all_tensor_attrs()\n    attrs2 = feat_store2.get_all_tensor_attrs()\n\n    times1 = {\n        attr.group_name: feat_store1.get_tensor(attr)\n        for attr in attrs1 if attr.attr_name == 'time'\n    }\n    times2 = {\n        attr.group_name: feat_store2.get_tensor(attr)\n        for attr in attrs2 if attr.attr_name == 'time'\n    }\n\n    for key in data.node_types:\n        assert times1[key].size(0) == data[key].num_nodes\n        assert times2[key].size(0) == data[key].num_nodes\n        assert torch.equal(times1[key], data[key].time)\n        assert torch.equal(times2[key], data[key].time)\n\n\n@withMETIS\n@onlyDistributedTest\ndef test_from_partition_hetero_edge_level_temporal_data(tmp_path):\n    data = FakeHeteroDataset(edge_dim=2)[0]\n\n    for key in data.edge_types:\n        data[key].edge_time = torch.arange(data[key].num_edges)\n\n    partitioner = Partitioner(data, num_parts=2, root=tmp_path)\n    partitioner.generate_partition()\n\n    feat_store1 = LocalFeatureStore.from_partition(tmp_path, pid=0)\n    feat_store2 = LocalFeatureStore.from_partition(tmp_path, pid=1)\n\n    attrs1 = feat_store1.get_all_tensor_attrs()\n    attrs2 = feat_store2.get_all_tensor_attrs()\n\n    times1 = {\n        attr.group_name: feat_store1.get_tensor(attr)\n        for attr in attrs1 if attr.attr_name == 'edge_time'\n    }\n    times2 = {\n        attr.group_name: feat_store2.get_tensor(attr)\n        for attr in attrs2 if attr.attr_name == 'edge_time'\n    }\n\n    for key in data.edge_types:\n        edge_id1 = feat_store1.get_global_id(group_name=key)\n        edge_id2 = feat_store2.get_global_id(group_name=key)\n        assert times1[key].size(0) + times2[key].size(0) == data[key].num_edges\n        assert torch.equal(data[key].edge_time[edge_id1], times1[key])\n        assert torch.equal(data[key].edge_time[edge_id2], times2[key])\n"
  },
  {
    "path": "test/distributed/test_rpc.py",
    "content": "import socket\n\nimport torch\n\nimport torch_geometric.distributed.rpc as rpc\nfrom torch_geometric.distributed import LocalFeatureStore\nfrom torch_geometric.distributed.dist_context import DistContext\nfrom torch_geometric.distributed.rpc import RPCRouter\nfrom torch_geometric.testing import onlyDistributedTest\n\n\ndef run_rpc_feature_test(\n    world_size: int,\n    rank: int,\n    feature: LocalFeatureStore,\n    partition_book: torch.Tensor,\n    master_port: int,\n):\n    # 1) Initialize the context info:\n    current_ctx = DistContext(\n        rank=rank,\n        global_rank=rank,\n        world_size=world_size,\n        global_world_size=world_size,\n        group_name='dist-feature-test',\n    )\n\n    rpc.init_rpc(\n        current_ctx=current_ctx,\n        master_addr='localhost',\n        master_port=master_port,\n    )\n\n    # 2) Collect all workers:\n    partition_to_workers = rpc.rpc_partition_to_workers(\n        current_ctx, world_size, rank)\n\n    assert partition_to_workers == [\n        ['dist-feature-test-0'],\n        ['dist-feature-test-1'],\n    ]\n\n    # 3) Find the mapping between worker and partition ID:\n    rpc_router = RPCRouter(partition_to_workers)\n\n    assert rpc_router.get_to_worker(partition_idx=0) == 'dist-feature-test-0'\n    assert rpc_router.get_to_worker(partition_idx=1) == 'dist-feature-test-1'\n\n    meta = {\n        'edge_types': None,\n        'is_hetero': False,\n        'node_types': None,\n        'num_parts': 2,\n    }\n\n    feature.num_partitions = world_size\n    feature.partition_idx = rank\n    feature.node_feat_pb = partition_book\n    feature.meta = meta\n    feature.local_only = False\n    feature.set_rpc_router(rpc_router)\n\n    # Global node IDs:\n    global_id0 = torch.arange(128 * 2)\n    global_id1 = torch.arange(128 * 2) + 128 * 2\n\n    # Lookup the features from stores including locally and remotely:\n    tensor0 = feature.lookup_features(global_id0)\n    tensor1 = feature.lookup_features(global_id1)\n\n    # Expected searched results:\n    cpu_tensor0 = torch.cat([torch.ones(128, 1024), torch.ones(128, 1024) * 2])\n    cpu_tensor1 = torch.cat([torch.zeros(128, 1024), torch.zeros(128, 1024)])\n\n    # Verify..\n    assert torch.allclose(cpu_tensor0, tensor0.wait())\n    assert torch.allclose(cpu_tensor1, tensor1.wait())\n\n    rpc.shutdown_rpc()\n    assert rpc.rpc_is_initialized() is False\n\n\n@onlyDistributedTest\ndef test_dist_feature_lookup():\n    cpu_tensor0 = torch.cat([torch.ones(128, 1024), torch.ones(128, 1024) * 2])\n    cpu_tensor1 = torch.cat([torch.zeros(128, 1024), torch.zeros(128, 1024)])\n\n    # Global node IDs:\n    global_id0 = torch.arange(128 * 2)\n    global_id1 = torch.arange(128 * 2) + 128 * 2\n\n    # Set the partition book for two features (partition 0 and 1):\n    partition_book = torch.cat([\n        torch.zeros(128 * 2, dtype=torch.long),\n        torch.ones(128 * 2, dtype=torch.long),\n    ])\n\n    # Put the test tensor into the different feature stores with IDs:\n    feature0 = LocalFeatureStore()\n    feature0.put_global_id(global_id0, group_name=None)\n    feature0.put_tensor(cpu_tensor0, group_name=None, attr_name='x')\n\n    feature1 = LocalFeatureStore()\n    feature1.put_global_id(global_id1, group_name=None)\n    feature1.put_tensor(cpu_tensor1, group_name=None, attr_name='x')\n\n    mp_context = torch.multiprocessing.get_context('spawn')\n    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:\n        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)\n        s.settimeout(1)\n        s.bind(('127.0.0.1', 0))\n        port = s.getsockname()[1]\n\n    w0 = mp_context.Process(target=run_rpc_feature_test,\n                            args=(2, 0, feature0, partition_book, port))\n    w1 = mp_context.Process(target=run_rpc_feature_test,\n                            args=(2, 1, feature1, partition_book, port))\n\n    w0.start()\n    w1.start()\n    w0.join()\n    w1.join()\n"
  },
  {
    "path": "test/explain/algorithm/test_attention_explainer.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.explain import (\n    AttentionExplainer,\n    Explainer,\n    HeteroExplanation,\n)\nfrom torch_geometric.explain.config import (\n    ExplanationType,\n    MaskType,\n    ModelConfig,\n    ModelMode,\n)\nfrom torch_geometric.nn import (\n    AttentiveFP,\n    GATConv,\n    GATv2Conv,\n    TransformerConv,\n    to_hetero,\n)\nfrom torch_geometric.nn.conv import HeteroConv\n\n\nclass AttentionGNN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GATConv(3, 16, heads=4)\n        self.conv2 = GATv2Conv(4 * 16, 16, heads=2)\n        self.conv3 = TransformerConv(2 * 16, 7, heads=1)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n        x = self.conv3(x, edge_index)\n        return x\n\n\nclass HeteroAttentionGNN(torch.nn.Module):\n    def __init__(self, metadata, model_config=None):\n        super().__init__()\n        self.model_config = model_config\n\n        # Create a single BaseGNN that uses all three attention mechanisms\n        class BaseGNN(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                # Use different attention mechanisms in sequence\n                self.conv1 = GATConv((-1, -1), 16, heads=2,\n                                     add_self_loops=False)\n                self.conv2 = GATv2Conv((-1, -1), 16, heads=2,\n                                       add_self_loops=False)\n                self.conv3 = TransformerConv((-1, -1), 32, heads=1)\n\n            def forward(self, x, edge_index):\n                x = self.conv1(x, edge_index).relu()\n                x = self.conv2(x, edge_index).relu()\n                x = self.conv3(x, edge_index)\n                return x\n\n        # Convert to heterogeneous model with a single to_hetero call\n        self.gnn = to_hetero(BaseGNN(), metadata, debug=False)\n\n        # Output dimension based on model config\n        out_channels = 7 if (model_config and model_config.mode\n                             == ModelMode.multiclass_classification) else 1\n        self.lin = torch.nn.Linear(32, out_channels)\n\n    def forward(self, x_dict, edge_index_dict, **kwargs):\n        # Process through the heterogeneous GNN\n        out_dict = self.gnn(x_dict, edge_index_dict)\n\n        # Project paper node embeddings for classification/regression\n        x = self.lin(out_dict['paper'])\n\n        # Apply appropriate output transformation based on model config\n        if self.model_config:\n            if self.model_config.mode == ModelMode.binary_classification:\n                if self.model_config.return_type == 'probs':\n                    x = x.sigmoid()\n            elif self.model_config.mode == ModelMode.multiclass_classification:\n                if self.model_config.return_type == 'probs':\n                    x = x.softmax(dim=-1)\n                elif self.model_config.return_type == 'log_probs':\n                    x = x.log_softmax(dim=-1)\n\n        return x\n\n\nclass HeteroConvAttentionGNN(torch.nn.Module):\n    def __init__(self, metadata, model_config=None):\n        super().__init__()\n        self.model_config = model_config\n\n        # Determine output channels based on model_config\n        self.out_channels = 1\n        if (model_config\n                and model_config.mode == ModelMode.multiclass_classification):\n            self.out_channels = 7\n\n        # Initialize node type-specific layers\n        self.lin_dict = torch.nn.ModuleDict()\n        self.initialized = False\n\n        # Create a dictionary of attention-based convolutions for each edge\n        # type\n        conv_dict = {}\n        for edge_type in metadata[1]:  # metadata[1] contains edge types\n            src_type, _, dst_type = edge_type\n            if src_type == dst_type:\n                # For same node type, use GATConv with add_self_loops=False\n                # Use concat=False to avoid dimension issues\n                conv_dict[edge_type] = GATConv(\n                    (-1, -1), 32, heads=2, add_self_loops=False, concat=False)\n            else:\n                # For different node types, use GATv2Conv with\n                # add_self_loops=False Use concat=False to avoid dimension\n                # issues\n                conv_dict[edge_type] = GATv2Conv(\n                    (-1, -1), 32, heads=2, add_self_loops=False, concat=False)\n\n        # Create the HeteroConv layer\n        self.conv = HeteroConv(conv_dict, aggr='sum')\n\n        # Output layer will be initialized in forward pass\n        self.out_lin = None\n\n    def _initialize_layers(self, x_dict):\n        \"\"\"Initialize layers with correct dimensions when we first see the\n        data.\n        \"\"\"\n        if not self.initialized:\n            # Initialize input projections\n            for node_type, x in x_dict.items():\n                in_channels = x.size(-1)\n                self.lin_dict[node_type] = torch.nn.Linear(in_channels,\n                                                           32).to(x.device)\n\n            # Initialize output projection\n            self.out_lin = torch.nn.Linear(32, self.out_channels).to(\n                x_dict['paper'].device)\n\n            self.initialized = True\n\n    def forward(self, x_dict, edge_index_dict):\n        # Initialize layers if not done yet\n        self._initialize_layers(x_dict)\n\n        # Apply node type-specific transformations\n        h_dict = {}\n        for node_type, x in x_dict.items():\n            h_dict[node_type] = self.lin_dict[node_type](x).relu_()\n\n        # Apply heterogeneous convolution\n        out_dict = self.conv(h_dict, edge_index_dict)\n\n        # Final transformation for paper nodes\n        out = self.out_lin(out_dict['paper'])\n\n        # Apply transformations based on model_config if available\n        if self.model_config:\n            if self.model_config.mode == ModelMode.binary_classification:\n                if self.model_config.return_type == 'probs':\n                    out = out.sigmoid()\n            elif self.model_config.mode == ModelMode.multiclass_classification:\n                if self.model_config.return_type == 'probs':\n                    out = out.softmax(dim=-1)\n                elif self.model_config.return_type == 'log_probs':\n                    out = out.log_softmax(dim=-1)\n\n        return out\n\n\nx = torch.randn(8, 3)\nedge_index = torch.tensor([\n    [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],\n    [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],\n])\nedge_attr = torch.randn(edge_index.size(1), 5)\nbatch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2])\n\n\n@pytest.mark.parametrize('index', [None, 2, torch.arange(3)])\ndef test_attention_explainer(index, check_explanation):\n    explainer = Explainer(\n        model=AttentionGNN(),\n        algorithm=AttentionExplainer(),\n        explanation_type='model',\n        edge_mask_type='object',\n        model_config=dict(\n            mode='multiclass_classification',\n            task_level='node',\n            return_type='raw',\n        ),\n    )\n\n    explanation = explainer(x, edge_index, index=index)\n    check_explanation(explanation, None, explainer.edge_mask_type)\n\n\n@pytest.mark.parametrize('explanation_type', [e for e in ExplanationType])\n@pytest.mark.parametrize('node_mask_type', [m for m in MaskType])\ndef test_attention_explainer_supports(explanation_type, node_mask_type):\n    with pytest.raises(ValueError, match=\"not support the given explanation\"):\n        Explainer(\n            model=AttentionGNN(),\n            algorithm=AttentionExplainer(),\n            explanation_type=explanation_type,\n            node_mask_type=node_mask_type,\n            edge_mask_type='object',\n            model_config=dict(\n                mode='multiclass_classification',\n                task_level='node',\n                return_type='raw',\n            ),\n        )\n\n\ndef test_attention_explainer_attentive_fp(check_explanation):\n    model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=AttentionExplainer(),\n        explanation_type='model',\n        edge_mask_type='object',\n        model_config=dict(\n            mode='binary_classification',\n            task_level='node',\n            return_type='raw',\n        ),\n    )\n\n    explanation = explainer(x, edge_index, edge_attr=edge_attr, batch=batch)\n    check_explanation(explanation, None, explainer.edge_mask_type)\n\n\n@pytest.mark.parametrize('index', [None, 2, torch.arange(3)])\ndef test_attention_explainer_hetero(index, hetero_data,\n                                    check_explanation_hetero):\n    # Create model configuration\n    model_config = ModelConfig(\n        mode='multiclass_classification',\n        task_level='node',\n        return_type='raw',\n    )\n\n    # Get metadata from hetero_data\n    metadata = hetero_data.metadata()\n\n    # Create the hetero attention model\n    model = HeteroAttentionGNN(metadata, model_config)\n\n    # Create the explainer\n    explainer = Explainer(\n        model=model,\n        algorithm=AttentionExplainer(),\n        explanation_type='model',\n        edge_mask_type='object',\n        model_config=model_config,\n    )\n\n    # Generate the explanation\n    explanation = explainer(\n        hetero_data.x_dict,\n        hetero_data.edge_index_dict,\n        index=index,\n    )\n\n    # Check that the explanation is correct\n    assert isinstance(explanation, HeteroExplanation)\n    check_explanation_hetero(explanation, None, explainer.edge_mask_type,\n                             hetero_data)\n\n\n@pytest.mark.parametrize('index', [None, 2, torch.arange(3)])\ndef test_attention_explainer_hetero_conv(index, hetero_data,\n                                         check_explanation_hetero):\n    \"\"\"Test AttentionExplainer with HeteroConv using attention-based layers.\"\"\"\n    # Create model configuration\n    model_config = ModelConfig(\n        mode='multiclass_classification',\n        task_level='node',\n        return_type='raw',\n    )\n\n    # Get metadata from hetero_data\n    metadata = hetero_data.metadata()\n\n    # Create the hetero conv attention model\n    model = HeteroConvAttentionGNN(metadata, model_config)\n\n    # Create the explainer\n    explainer = Explainer(\n        model=model,\n        algorithm=AttentionExplainer(),\n        explanation_type='model',\n        edge_mask_type='object',\n        model_config=model_config,\n    )\n\n    # Generate the explanation\n    explanation = explainer(\n        hetero_data.x_dict,\n        hetero_data.edge_index_dict,\n        index=index,\n    )\n\n    # Check that the explanation is correct\n    assert isinstance(explanation, HeteroExplanation)\n    check_explanation_hetero(explanation, None, explainer.edge_mask_type,\n                             hetero_data)\n"
  },
  {
    "path": "test/explain/algorithm/test_captum.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.explain.algorithm.captum import to_captum_input\nfrom torch_geometric.nn import GAT, GCN, SAGEConv\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.models import to_captum_model\nfrom torch_geometric.testing import withPackage\n\nx = torch.randn(8, 3, requires_grad=True)\nedge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],\n                           [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])\n\nGCN = GCN(3, 16, 2, 7, dropout=0.5)\nGAT = GAT(3, 16, 2, 7, heads=2, concat=False)\nmask_types = ['edge', 'node_and_edge', 'node']\nmethods = [\n    'Saliency',\n    'InputXGradient',\n    'Deconvolution',\n    'FeatureAblation',\n    'ShapleyValueSampling',\n    'IntegratedGradients',\n    'GradientShap',\n    'Occlusion',\n    'GuidedBackprop',\n    'KernelShap',\n    'Lime',\n]\n\n\n@pytest.mark.parametrize('mask_type', mask_types)\n@pytest.mark.parametrize('model', [GCN, GAT])\n@pytest.mark.parametrize('output_idx', [None, 1])\ndef test_to_captum(model, mask_type, output_idx):\n    captum_model = to_captum_model(model, mask_type=mask_type,\n                                   output_idx=output_idx)\n    pre_out = model(x, edge_index)\n    if mask_type == 'node':\n        mask = x * 0.0\n        out = captum_model(mask.unsqueeze(0), edge_index)\n    elif mask_type == 'edge':\n        mask = torch.ones(edge_index.shape[1], dtype=torch.float,\n                          requires_grad=True) * 0.5\n        out = captum_model(mask.unsqueeze(0), x, edge_index)\n    elif mask_type == 'node_and_edge':\n        node_mask = x * 0.0\n        edge_mask = torch.ones(edge_index.shape[1], dtype=torch.float,\n                               requires_grad=True) * 0.5\n        out = captum_model(node_mask.unsqueeze(0), edge_mask.unsqueeze(0),\n                           edge_index)\n\n    if output_idx is not None:\n        assert out.shape == (1, 7)\n        assert torch.any(out != pre_out[[output_idx]])\n    else:\n        assert out.shape == (8, 7)\n        assert torch.any(out != pre_out)\n\n\n@withPackage('captum', 'sklearn')\n@pytest.mark.parametrize('mask_type', mask_types)\n@pytest.mark.parametrize('method', methods)\ndef test_captum_attribution_methods(mask_type, method):\n    from captum import attr  # noqa\n\n    captum_model = to_captum_model(GCN, mask_type, 0)\n    explainer = getattr(attr, method)(captum_model)\n    data = Data(x, edge_index)\n    input, additional_forward_args = to_captum_input(data.x, data.edge_index,\n                                                     mask_type)\n    if mask_type == 'node':\n        sliding_window_shapes = (3, 3)\n    elif mask_type == 'edge':\n        sliding_window_shapes = (5, )\n    elif mask_type == 'node_and_edge':\n        sliding_window_shapes = ((3, 3), (5, ))\n\n    if method == 'IntegratedGradients':\n        attributions, delta = explainer.attribute(\n            input, target=0, internal_batch_size=1,\n            additional_forward_args=additional_forward_args,\n            return_convergence_delta=True)\n    elif method == 'GradientShap':\n        attributions, delta = explainer.attribute(\n            input, target=0, return_convergence_delta=True, baselines=input,\n            n_samples=1, additional_forward_args=additional_forward_args)\n    elif method == 'DeepLiftShap' or method == 'DeepLift':\n        attributions, delta = explainer.attribute(\n            input, target=0, return_convergence_delta=True, baselines=input,\n            additional_forward_args=additional_forward_args)\n    elif method == 'Occlusion':\n        attributions = explainer.attribute(\n            input, target=0, sliding_window_shapes=sliding_window_shapes,\n            additional_forward_args=additional_forward_args)\n    else:\n        attributions = explainer.attribute(\n            input, target=0, additional_forward_args=additional_forward_args)\n    if mask_type == 'node':\n        assert attributions[0].shape == (1, 8, 3)\n    elif mask_type == 'edge':\n        assert attributions[0].shape == (1, 14)\n    else:\n        assert attributions[0].shape == (1, 8, 3)\n        assert attributions[1].shape == (1, 14)\n\n\ndef test_custom_explain_message():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])\n\n    conv = SAGEConv(8, 32)\n\n    def explain_message(self, inputs, x_i, x_j):\n        assert isinstance(self, SAGEConv)\n        assert inputs.size() == (6, 8)\n        assert inputs.size() == x_i.size() == x_j.size()\n        assert torch.allclose(inputs, x_j)\n        self.x_i = x_i\n        self.x_j = x_j\n        return inputs\n\n    conv.explain_message = explain_message.__get__(conv, MessagePassing)\n    conv.explain = True\n\n    conv(x, edge_index)\n\n    assert torch.allclose(conv.x_i, x[edge_index[1]])\n    assert torch.allclose(conv.x_j, x[edge_index[0]])\n\n\n@withPackage('captum')\n@pytest.mark.parametrize('mask_type', ['node', 'edge', 'node_and_edge'])\ndef test_to_captum_input(mask_type):\n    num_nodes = x.shape[0]\n    num_node_feats = x.shape[1]\n    num_edges = edge_index.shape[1]\n\n    # Check for Data:\n    data = Data(x, edge_index)\n    args = 'test_args'\n    inputs, additional_forward_args = to_captum_input(data.x, data.edge_index,\n                                                      mask_type, args)\n    if mask_type == 'node':\n        assert len(inputs) == 1\n        assert inputs[0].shape == (1, num_nodes, num_node_feats)\n        assert len(additional_forward_args) == 2\n        assert torch.allclose(additional_forward_args[0], edge_index)\n    elif mask_type == 'edge':\n        assert len(inputs) == 1\n        assert inputs[0].shape == (1, num_edges)\n        assert inputs[0].sum() == num_edges\n        assert len(additional_forward_args) == 3\n        assert torch.allclose(additional_forward_args[0], x)\n        assert torch.allclose(additional_forward_args[1], edge_index)\n    else:\n        assert len(inputs) == 2\n        assert inputs[0].shape == (1, num_nodes, num_node_feats)\n        assert inputs[1].shape == (1, num_edges)\n        assert inputs[1].sum() == num_edges\n        assert len(additional_forward_args) == 2\n        assert torch.allclose(additional_forward_args[0], edge_index)\n\n    # Check for HeteroData:\n    data = HeteroData()\n    x2 = torch.rand(8, 3)\n    data['paper'].x = x\n    data['author'].x = x2\n    data['paper', 'to', 'author'].edge_index = edge_index\n    data['author', 'to', 'paper'].edge_index = edge_index.flip([0])\n    inputs, additional_forward_args = to_captum_input(data.x_dict,\n                                                      data.edge_index_dict,\n                                                      mask_type, args)\n    if mask_type == 'node':\n        assert len(inputs) == 2\n        assert inputs[0].shape == (1, num_nodes, num_node_feats)\n        assert inputs[1].shape == (1, num_nodes, num_node_feats)\n        assert len(additional_forward_args) == 2\n        for key in data.edge_types:\n            torch.allclose(additional_forward_args[0][key],\n                           data[key].edge_index)\n    elif mask_type == 'edge':\n        assert len(inputs) == 2\n        assert inputs[0].shape == (1, num_edges)\n        assert inputs[1].shape == (1, num_edges)\n        assert inputs[1].sum() == inputs[0].sum() == num_edges\n        assert len(additional_forward_args) == 3\n        for key in data.node_types:\n            torch.allclose(additional_forward_args[0][key], data[key].x)\n        for key in data.edge_types:\n            torch.allclose(additional_forward_args[1][key],\n                           data[key].edge_index)\n    else:\n        assert len(inputs) == 4\n        assert inputs[0].shape == (1, num_nodes, num_node_feats)\n        assert inputs[1].shape == (1, num_nodes, num_node_feats)\n        assert inputs[2].shape == (1, num_edges)\n        assert inputs[3].shape == (1, num_edges)\n        assert inputs[3].sum() == inputs[2].sum() == num_edges\n        assert len(additional_forward_args) == 2\n        for key in data.edge_types:\n            torch.allclose(additional_forward_args[0][key],\n                           data[key].edge_index)\n"
  },
  {
    "path": "test/explain/algorithm/test_captum_explainer.py",
    "content": "from typing import Optional\n\nimport pytest\nimport torch\n\nfrom torch_geometric.explain import Explainer, Explanation\nfrom torch_geometric.explain.algorithm import CaptumExplainer\nfrom torch_geometric.explain.config import (\n    MaskType,\n    ModelConfig,\n    ModelMode,\n    ModelReturnType,\n    ModelTaskLevel,\n)\nfrom torch_geometric.nn import GCNConv, global_add_pool\nfrom torch_geometric.testing import withPackage\n\nmethods = [\n    'Saliency',\n    'InputXGradient',\n    'Deconvolution',\n    'ShapleyValueSampling',\n    'IntegratedGradients',\n    'GuidedBackprop',\n]\n\nunsupported_methods = [\n    'FeatureAblation',\n    'Occlusion',\n    'DeepLift',\n    'DeepLiftShap',\n    'GradientShap',\n    'KernelShap',\n    'Lime',\n]\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, model_config: ModelConfig):\n        super().__init__()\n        self.model_config = model_config\n\n        if model_config.mode == ModelMode.multiclass_classification:\n            out_channels = 7\n        else:\n            out_channels = 1\n\n        self.conv1 = GCNConv(3, 16)\n        self.conv2 = GCNConv(16, out_channels)\n\n        # Add unused parameter:\n        self.param = torch.nn.Parameter(torch.empty(1))\n\n    def forward(self, x, edge_index, batch=None, edge_label_index=None):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n\n        if self.model_config.task_level == ModelTaskLevel.graph:\n            x = global_add_pool(x, batch)\n        elif self.model_config.task_level == ModelTaskLevel.edge:\n            assert edge_label_index is not None\n            x = x[edge_label_index[0]] * x[edge_label_index[1]]\n\n        if self.model_config.mode == ModelMode.binary_classification:\n            if self.model_config.return_type == ModelReturnType.probs:\n                x = x.sigmoid()\n        elif self.model_config.mode == ModelMode.multiclass_classification:\n            if self.model_config.return_type == ModelReturnType.probs:\n                x = x.softmax(dim=-1)\n            elif self.model_config.return_type == ModelReturnType.log_probs:\n                x = x.log_softmax(dim=-1)\n\n        return x\n\n\nnode_mask_types = [MaskType.attributes, None]\nedge_mask_types = [MaskType.object, None]\ntask_levels = [ModelTaskLevel.node, ModelTaskLevel.edge, ModelTaskLevel.graph]\nindices = [1, torch.arange(2)]\n\n\ndef check_explanation(\n    explanation: Explanation,\n    node_mask_type: Optional[MaskType],\n    edge_mask_type: Optional[MaskType],\n):\n    if node_mask_type == MaskType.attributes:\n        assert explanation.node_mask.size() == explanation.x.size()\n    elif node_mask_type is None:\n        assert 'node_mask' not in explanation\n\n    if edge_mask_type == MaskType.object:\n        assert explanation.edge_mask.size() == (explanation.num_edges, )\n    elif edge_mask_type is None:\n        assert 'edge_mask' not in explanation\n\n\n@withPackage('captum')\n@pytest.mark.parametrize('method', unsupported_methods)\ndef test_unsupported_methods(method):\n    model_config = ModelConfig(mode='regression', task_level='node')\n\n    with pytest.raises(ValueError, match=\"does not support attribution\"):\n        Explainer(\n            GCN(model_config),\n            algorithm=CaptumExplainer(method),\n            explanation_type='model',\n            edge_mask_type='object',\n            node_mask_type='attributes',\n            model_config=model_config,\n        )\n\n\n@withPackage('captum')\n@pytest.mark.parametrize('method', ['IntegratedGradients'])\n@pytest.mark.parametrize('node_mask_type', node_mask_types)\n@pytest.mark.parametrize('edge_mask_type', edge_mask_types)\n@pytest.mark.parametrize('task_level', task_levels)\n@pytest.mark.parametrize('index', indices)\ndef test_captum_explainer_binary_classification(\n    method,\n    data,\n    node_mask_type,\n    edge_mask_type,\n    task_level,\n    index,\n):\n    if node_mask_type is None and edge_mask_type is None:\n        return\n\n    batch = torch.tensor([0, 0, 1, 1])\n    edge_label_index = torch.tensor([[0, 1, 2], [2, 3, 1]])\n\n    model_config = ModelConfig(\n        mode='binary_classification',\n        task_level=task_level,\n        return_type='probs',\n    )\n\n    explainer = Explainer(\n        GCN(model_config),\n        algorithm=CaptumExplainer(method),\n        explanation_type='model',\n        edge_mask_type=edge_mask_type,\n        node_mask_type=node_mask_type,\n        model_config=model_config,\n    )\n\n    explanation = explainer(\n        data.x,\n        data.edge_index,\n        index=index,\n        batch=batch,\n        edge_label_index=edge_label_index,\n    )\n    check_explanation(explanation, node_mask_type, edge_mask_type)\n\n\n@withPackage('captum')\n@pytest.mark.parametrize('method', methods)\n@pytest.mark.parametrize('node_mask_type', node_mask_types)\n@pytest.mark.parametrize('edge_mask_type', edge_mask_types)\n@pytest.mark.parametrize('task_level', task_levels)\n@pytest.mark.parametrize('index', indices)\ndef test_captum_explainer_multiclass_classification(\n    method,\n    data,\n    node_mask_type,\n    edge_mask_type,\n    task_level,\n    index,\n):\n    if node_mask_type is None and edge_mask_type is None:\n        return\n\n    batch = torch.tensor([0, 0, 1, 1])\n    edge_label_index = torch.tensor([[0, 1, 2], [2, 3, 1]])\n\n    model_config = ModelConfig(\n        mode='multiclass_classification',\n        task_level=task_level,\n        return_type='probs',\n    )\n\n    explainer = Explainer(\n        GCN(model_config),\n        algorithm=CaptumExplainer(method),\n        explanation_type='model',\n        edge_mask_type=edge_mask_type,\n        node_mask_type=node_mask_type,\n        model_config=model_config,\n    )\n\n    explanation = explainer(\n        data.x,\n        data.edge_index,\n        index=index,\n        batch=batch,\n        edge_label_index=edge_label_index,\n    )\n    check_explanation(explanation, node_mask_type, edge_mask_type)\n\n\n@withPackage('captum')\n@pytest.mark.parametrize(\n    'method',\n    [m for m in methods if m != 'ShapleyValueSampling'],\n)\n@pytest.mark.parametrize(\n    'node_mask_type',\n    [nm for nm in node_mask_types if nm is not None],\n)\n@pytest.mark.parametrize(\n    'edge_mask_type',\n    [em for em in edge_mask_types if em is not None],\n)\n@pytest.mark.parametrize('index', [1, torch.arange(2)])\ndef test_captum_hetero_data(method, node_mask_type, edge_mask_type, index,\n                            hetero_data, hetero_model):\n\n    model_config = ModelConfig(mode='regression', task_level='node')\n\n    explainer = Explainer(\n        hetero_model(hetero_data.metadata()),\n        algorithm=CaptumExplainer(method),\n        edge_mask_type=edge_mask_type,\n        node_mask_type=node_mask_type,\n        model_config=model_config,\n        explanation_type='model',\n    )\n\n    explanation = explainer(hetero_data.x_dict, hetero_data.edge_index_dict,\n                            index=index)\n\n    explanation.validate(raise_on_error=True)\n\n\n@withPackage('captum')\n@pytest.mark.parametrize('node_mask_type', [\n    MaskType.object,\n    MaskType.common_attributes,\n])\ndef test_captum_explainer_supports(node_mask_type):\n    model_config = ModelConfig(\n        mode='multiclass_classification',\n        task_level='node',\n        return_type='probs',\n    )\n\n    with pytest.raises(ValueError, match=\"not support the given explanation\"):\n        Explainer(\n            GCN(model_config),\n            algorithm=CaptumExplainer('IntegratedGradients'),\n            edge_mask_type=MaskType.object,\n            node_mask_type=node_mask_type,\n            model_config=model_config,\n            explanation_type='model',\n        )\n"
  },
  {
    "path": "test/explain/algorithm/test_captum_hetero.py",
    "content": "import pytest\n\nfrom torch_geometric.explain.algorithm.captum import (\n    CaptumHeteroModel,\n    captum_output_to_dicts,\n    to_captum_input,\n)\nfrom torch_geometric.nn import to_captum_model\nfrom torch_geometric.testing import withPackage\n\nmask_types = [\n    'node',\n    'edge',\n    'node_and_edge',\n]\n\nmethods = [\n    'Saliency',\n    'InputXGradient',\n    'Deconvolution',\n    'FeatureAblation',\n    'ShapleyValueSampling',\n    'IntegratedGradients',\n    'GradientShap',\n    'Occlusion',\n    'GuidedBackprop',\n    'KernelShap',\n    'Lime',\n]\n\n\n@withPackage('captum', 'sklearn')\n@pytest.mark.parametrize('mask_type', mask_types)\n@pytest.mark.parametrize('method', methods)\ndef test_captum_attribution_methods_hetero(mask_type, method, hetero_data,\n                                           hetero_model):\n    from captum import attr  # noqa\n    data = hetero_data\n    metadata = data.metadata()\n    model = hetero_model(metadata)\n    captum_model = to_captum_model(model, mask_type, 0, metadata)\n    explainer = getattr(attr, method)(captum_model)\n    assert isinstance(captum_model, CaptumHeteroModel)\n\n    inputs, additional_forward_args = to_captum_input(\n        data.x_dict,\n        data.edge_index_dict,\n        mask_type,\n        'additional_arg',\n    )\n\n    if mask_type == 'node':\n        sliding_window_shapes = ((3, 3), (3, 3))\n    elif mask_type == 'edge':\n        sliding_window_shapes = ((5, ), (5, ), (5, ))\n    else:\n        sliding_window_shapes = ((3, 3), (3, 3), (5, ), (5, ), (5, ))\n\n    if method == 'IntegratedGradients':\n        attributions, delta = explainer.attribute(\n            inputs, target=0, internal_batch_size=1,\n            additional_forward_args=additional_forward_args,\n            return_convergence_delta=True)\n    elif method == 'GradientShap':\n        attributions, delta = explainer.attribute(\n            inputs, target=0, return_convergence_delta=True, baselines=inputs,\n            n_samples=1, additional_forward_args=additional_forward_args)\n    elif method == 'DeepLiftShap' or method == 'DeepLift':\n        attributions, delta = explainer.attribute(\n            inputs, target=0, return_convergence_delta=True, baselines=inputs,\n            additional_forward_args=additional_forward_args)\n    elif method == 'Occlusion':\n        attributions = explainer.attribute(\n            inputs, target=0, sliding_window_shapes=sliding_window_shapes,\n            additional_forward_args=additional_forward_args)\n    else:\n        attributions = explainer.attribute(\n            inputs, target=0, additional_forward_args=additional_forward_args)\n\n    if mask_type == 'node':\n        assert len(attributions) == len(metadata[0])\n        x_attr_dict, _ = captum_output_to_dicts(attributions, mask_type,\n                                                metadata)\n        for node_type in metadata[0]:\n            num_nodes = data[node_type].num_nodes\n            num_node_feats = data[node_type].x.shape[1]\n            assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats)\n    elif mask_type == 'edge':\n        assert len(attributions) == len(metadata[1])\n        _, edge_attr_dict = captum_output_to_dicts(attributions, mask_type,\n                                                   metadata)\n        for edge_type in metadata[1]:\n            num_edges = data[edge_type].edge_index.shape[1]\n            assert edge_attr_dict[edge_type].shape == (num_edges, )\n    else:\n        assert len(attributions) == len(metadata[0]) + len(metadata[1])\n        x_attr_dict, edge_attr_dict = captum_output_to_dicts(\n            attributions, mask_type, metadata)\n        for edge_type in metadata[1]:\n            num_edges = data[edge_type].edge_index.shape[1]\n            assert edge_attr_dict[edge_type].shape == (num_edges, )\n\n        for node_type in metadata[0]:\n            num_nodes = data[node_type].num_nodes\n            num_node_feats = data[node_type].x.shape[1]\n            assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats)\n"
  },
  {
    "path": "test/explain/algorithm/test_explain_algorithm_utils.py",
    "content": "import torch\n\nfrom torch_geometric.explain.algorithm.utils import (\n    clear_masks,\n    set_hetero_masks,\n)\nfrom torch_geometric.nn import GCNConv, HeteroConv, SAGEConv, to_hetero\n\n\nclass HeteroModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.conv1 = HeteroConv({\n            ('paper', 'to', 'paper'):\n            GCNConv(-1, 32),\n            ('author', 'to', 'paper'):\n            SAGEConv((-1, -1), 32),\n            ('paper', 'to', 'author'):\n            SAGEConv((-1, -1), 32),\n        })\n\n        self.conv2 = HeteroConv({\n            ('paper', 'to', 'paper'):\n            GCNConv(-1, 32),\n            ('author', 'to', 'paper'):\n            SAGEConv((-1, -1), 32),\n            ('paper', 'to', 'author'):\n            SAGEConv((-1, -1), 32),\n        })\n\n\nclass GraphSAGE(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = SAGEConv((-1, -1), 32)\n        self.conv2 = SAGEConv((-1, -1), 32)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        return self.conv2(x, edge_index)\n\n\ndef test_set_clear_mask(hetero_data):\n    edge_mask_dict = {\n        ('paper', 'to', 'paper'): torch.ones(200),\n        ('author', 'to', 'paper'): torch.ones(100),\n        ('paper', 'to', 'author'): torch.ones(100),\n    }\n\n    model = HeteroModel()\n\n    set_hetero_masks(model, edge_mask_dict, hetero_data.edge_index_dict)\n    for edge_type in hetero_data.edge_types:\n        # Check that masks are correctly set:\n        assert torch.allclose(model.conv1.convs[edge_type]._edge_mask,\n                              edge_mask_dict[edge_type])\n        assert model.conv1.convs[edge_type].explain\n\n    clear_masks(model)\n    for edge_type in hetero_data.edge_types:\n        assert model.conv1.convs[edge_type]._edge_mask is None\n        assert not model.conv1.convs[edge_type].explain\n\n    model = to_hetero(GraphSAGE(), hetero_data.metadata(), debug=False)\n\n    set_hetero_masks(model, edge_mask_dict, hetero_data.edge_index_dict)\n    for edge_type in hetero_data.edge_types:\n        # Check that masks are correctly set:\n        str_edge_type = '__'.join(edge_type)\n        assert torch.allclose(model.conv1[str_edge_type]._edge_mask,\n                              edge_mask_dict[edge_type])\n        assert model.conv1[str_edge_type].explain\n\n    clear_masks(model)\n    for edge_type in hetero_data.edge_types:\n        str_edge_type = '__'.join(edge_type)\n        assert model.conv1[str_edge_type]._edge_mask is None\n        assert not model.conv1[str_edge_type].explain\n"
  },
  {
    "path": "test/explain/algorithm/test_gnn_explainer.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.explain import Explainer, GNNExplainer, HeteroExplanation\nfrom torch_geometric.explain.config import (\n    ExplanationType,\n    MaskType,\n    ModelConfig,\n    ModelMode,\n    ModelReturnType,\n    ModelTaskLevel,\n)\nfrom torch_geometric.nn import (\n    AttentiveFP,\n    ChebConv,\n    GCNConv,\n    TransformerConv,\n    global_add_pool,\n)\n\n\nclass GNN(torch.nn.Module):\n    def __init__(self, Conv, model_config: ModelConfig):\n        super().__init__()\n        self.model_config = model_config\n\n        if model_config.mode == ModelMode.multiclass_classification:\n            out_channels = 7\n        else:\n            out_channels = 1\n\n        self.conv1 = Conv(3, 16)\n        self.conv2 = Conv(16, out_channels)\n\n        # Add unused parameter:\n        self.param = torch.nn.Parameter(torch.empty(1))\n\n    def forward(self, x, edge_index, batch=None, edge_label_index=None):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n\n        if self.model_config.task_level == ModelTaskLevel.graph:\n            x = global_add_pool(x, batch)\n        elif self.model_config.task_level == ModelTaskLevel.edge:\n            assert edge_label_index is not None\n            x = x[edge_label_index[0]] * x[edge_label_index[1]]\n\n        if self.model_config.mode == ModelMode.binary_classification:\n            if self.model_config.return_type == ModelReturnType.probs:\n                x = x.sigmoid()\n        elif self.model_config.mode == ModelMode.multiclass_classification:\n            if self.model_config.return_type == ModelReturnType.probs:\n                x = x.softmax(dim=-1)\n            elif self.model_config.return_type == ModelReturnType.log_probs:\n                x = x.log_softmax(dim=-1)\n\n        return x\n\n\nnode_mask_types = [\n    MaskType.object,\n    MaskType.common_attributes,\n    MaskType.attributes,\n]\nedge_mask_types = [MaskType.object, None]\nexplanation_types = [ExplanationType.model, ExplanationType.phenomenon]\ntask_levels = [ModelTaskLevel.node, ModelTaskLevel.edge, ModelTaskLevel.graph]\nindices = [None, 2, torch.arange(3)]\n\nx = torch.randn(8, 3)\nedge_index = torch.tensor([\n    [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],\n    [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],\n])\nedge_attr = torch.randn(edge_index.size(1), 5)\nbatch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2])\nedge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]])\n\n\n@pytest.mark.parametrize('Conv', [GCNConv, TransformerConv])\n@pytest.mark.parametrize('edge_mask_type', edge_mask_types)\n@pytest.mark.parametrize('node_mask_type', node_mask_types)\n@pytest.mark.parametrize('explanation_type', explanation_types)\n@pytest.mark.parametrize('task_level', task_levels)\n@pytest.mark.parametrize('return_type', [\n    ModelReturnType.probs,\n    ModelReturnType.raw,\n])\n@pytest.mark.parametrize('index', indices)\ndef test_gnn_explainer_binary_classification(\n    Conv,\n    edge_mask_type,\n    node_mask_type,\n    explanation_type,\n    task_level,\n    return_type,\n    index,\n    check_explanation,\n):\n    model_config = ModelConfig(\n        mode='binary_classification',\n        task_level=task_level,\n        return_type=return_type,\n    )\n\n    model = GNN(Conv, model_config)\n\n    target = None\n    if explanation_type == ExplanationType.phenomenon:\n        with torch.no_grad():\n            out = model(x, edge_index, batch, edge_label_index)\n            if model_config.return_type == ModelReturnType.raw:\n                target = (out > 0).long().view(-1)\n            if model_config.return_type == ModelReturnType.probs:\n                target = (out > 0.5).long().view(-1)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=GNNExplainer(epochs=2),\n        explanation_type=explanation_type,\n        node_mask_type=node_mask_type,\n        edge_mask_type=edge_mask_type,\n        model_config=model_config,\n    )\n\n    explanation = explainer(\n        x,\n        edge_index,\n        target=target,\n        index=index,\n        batch=batch,\n        edge_label_index=edge_label_index,\n    )\n\n    assert explainer.algorithm.node_mask is None\n    assert explainer.algorithm.edge_mask is None\n\n    check_explanation(explanation, node_mask_type, edge_mask_type)\n\n\n@pytest.mark.parametrize('Conv', [GCNConv])\n@pytest.mark.parametrize('edge_mask_type', edge_mask_types)\n@pytest.mark.parametrize('node_mask_type', node_mask_types)\n@pytest.mark.parametrize('explanation_type', explanation_types)\n@pytest.mark.parametrize('task_level', task_levels)\n@pytest.mark.parametrize('return_type', [\n    ModelReturnType.log_probs,\n    ModelReturnType.probs,\n    ModelReturnType.raw,\n])\n@pytest.mark.parametrize('index', indices)\ndef test_gnn_explainer_multiclass_classification(\n    Conv,\n    edge_mask_type,\n    node_mask_type,\n    explanation_type,\n    task_level,\n    return_type,\n    index,\n    check_explanation,\n):\n    model_config = ModelConfig(\n        mode='multiclass_classification',\n        task_level=task_level,\n        return_type=return_type,\n    )\n\n    model = GNN(Conv, model_config)\n\n    target = None\n    if explanation_type == ExplanationType.phenomenon:\n        with torch.no_grad():\n            target = model(x, edge_index, batch, edge_label_index).argmax(-1)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=GNNExplainer(epochs=2),\n        explanation_type=explanation_type,\n        node_mask_type=node_mask_type,\n        edge_mask_type=edge_mask_type,\n        model_config=model_config,\n    )\n\n    explanation = explainer(\n        x,\n        edge_index,\n        target=target,\n        index=index,\n        batch=batch,\n        edge_label_index=edge_label_index,\n    )\n\n    assert explainer.algorithm.node_mask is None\n    assert explainer.algorithm.edge_mask is None\n\n    check_explanation(explanation, node_mask_type, edge_mask_type)\n\n\n@pytest.mark.parametrize('Conv', [GCNConv])\n@pytest.mark.parametrize('edge_mask_type', edge_mask_types)\n@pytest.mark.parametrize('node_mask_type', node_mask_types)\n@pytest.mark.parametrize('explanation_type', explanation_types)\n@pytest.mark.parametrize('task_level', task_levels)\n@pytest.mark.parametrize('index', indices)\ndef test_gnn_explainer_regression(\n    Conv,\n    edge_mask_type,\n    node_mask_type,\n    explanation_type,\n    task_level,\n    index,\n    check_explanation,\n):\n    model_config = ModelConfig(\n        mode='regression',\n        task_level=task_level,\n    )\n\n    model = GNN(Conv, model_config)\n\n    target = None\n    if explanation_type == ExplanationType.phenomenon:\n        with torch.no_grad():\n            target = model(x, edge_index, batch, edge_label_index)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=GNNExplainer(epochs=2),\n        explanation_type=explanation_type,\n        node_mask_type=node_mask_type,\n        edge_mask_type=edge_mask_type,\n        model_config=model_config,\n    )\n\n    explanation = explainer(\n        x,\n        edge_index,\n        target=target,\n        index=index,\n        batch=batch,\n        edge_label_index=edge_label_index,\n    )\n\n    assert explainer.algorithm.node_mask is None\n    assert explainer.algorithm.edge_mask is None\n\n    check_explanation(explanation, node_mask_type, edge_mask_type)\n\n\ndef test_gnn_explainer_cheb_conv(check_explanation):\n    explainer = Explainer(\n        model=ChebConv(3, 1, K=2),\n        algorithm=GNNExplainer(epochs=2),\n        explanation_type='model',\n        node_mask_type='object',\n        edge_mask_type='object',\n        model_config=dict(\n            mode='binary_classification',\n            task_level='node',\n            return_type='raw',\n        ),\n    )\n\n    explanation = explainer(x, edge_index)\n\n    assert explainer.algorithm.node_mask is None\n    assert explainer.algorithm.edge_mask is None\n\n    check_explanation(explanation, MaskType.object, MaskType.object)\n\n\ndef test_gnn_explainer_attentive_fp(check_explanation):\n    model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=GNNExplainer(epochs=2),\n        explanation_type='model',\n        node_mask_type='object',\n        edge_mask_type='object',\n        model_config=dict(\n            mode='binary_classification',\n            task_level='node',\n            return_type='raw',\n        ),\n    )\n\n    explanation = explainer(x, edge_index, edge_attr=edge_attr, batch=batch)\n\n    assert explainer.algorithm.node_mask is None\n    assert explainer.algorithm.edge_mask is None\n\n    check_explanation(explanation, MaskType.object, MaskType.object)\n\n\n@pytest.mark.parametrize('node_mask_type', node_mask_types)\n@pytest.mark.parametrize('edge_mask_type', edge_mask_types)\n@pytest.mark.parametrize('explanation_type', explanation_types)\n@pytest.mark.parametrize('task_level', task_levels)\n@pytest.mark.parametrize('return_type', [\n    ModelReturnType.log_probs,\n    ModelReturnType.probs,\n    ModelReturnType.raw,\n])\n@pytest.mark.parametrize('index', indices)\ndef test_gnn_explainer_hetero(\n    node_mask_type,\n    edge_mask_type,\n    explanation_type,\n    task_level,\n    return_type,\n    index,\n    hetero_data,\n    hetero_model,\n    check_explanation_hetero,\n):\n    if node_mask_type is None and edge_mask_type is None:\n        return\n\n    model_config = ModelConfig(\n        mode='multiclass_classification',\n        task_level=task_level,\n        return_type=return_type,\n    )\n\n    metadata = hetero_data.metadata()\n    model = hetero_model(metadata, model_config)\n\n    target = None\n    if explanation_type == ExplanationType.phenomenon:\n        with torch.no_grad():\n            target = model(hetero_data.x_dict,\n                           hetero_data.edge_index_dict).argmax(-1)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=GNNExplainer(epochs=2),\n        explanation_type=explanation_type,\n        node_mask_type=node_mask_type,\n        edge_mask_type=edge_mask_type,\n        model_config=model_config,\n    )\n\n    explanation = explainer(\n        hetero_data.x_dict,\n        hetero_data.edge_index_dict,\n        target=target,\n        index=index,\n    )\n\n    assert isinstance(explanation, HeteroExplanation)\n    check_explanation_hetero(explanation, node_mask_type, edge_mask_type,\n                             hetero_data)\n"
  },
  {
    "path": "test/explain/algorithm/test_graphmask_explainer.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.explain import Explainer, Explanation, GraphMaskExplainer\nfrom torch_geometric.explain.config import (\n    MaskType,\n    ModelConfig,\n    ModelMode,\n    ModelReturnType,\n    ModelTaskLevel,\n)\nfrom torch_geometric.nn import GCNConv, global_add_pool\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, model_config: ModelConfig):\n        super().__init__()\n        self.model_config = model_config\n\n        if model_config.mode == ModelMode.multiclass_classification:\n            out_channels = 7\n        else:\n            out_channels = 1\n\n        self.conv1 = GCNConv(3, 16)\n        self.conv2 = GCNConv(16, out_channels)\n\n    def forward(self, x, edge_index, batch=None, edge_label_index=None):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n\n        if self.model_config.task_level == ModelTaskLevel.graph:\n            x = global_add_pool(x, batch)\n        elif self.model_config.task_level == ModelTaskLevel.edge:\n            assert edge_label_index is not None\n            x = x[edge_label_index[0]] * x[edge_label_index[1]]\n\n        if self.model_config.mode == ModelMode.binary_classification:\n            if self.model_config.return_type == ModelReturnType.probs:\n                x = x.sigmoid()\n        elif self.model_config.mode == ModelMode.multiclass_classification:\n            if self.model_config.return_type == ModelReturnType.probs:\n                x = x.softmax(dim=-1)\n            elif self.model_config.return_type == ModelReturnType.log_probs:\n                x = x.log_softmax(dim=-1)\n\n        return x\n\n\ndef check_explanation(\n    edge_mask_type: MaskType,\n    node_mask_type: MaskType,\n    explanation: Explanation,\n):\n    if node_mask_type == MaskType.attributes:\n        assert explanation.node_mask.size() == explanation.x.size()\n        assert explanation.node_mask.min() >= 0\n        assert explanation.node_mask.max() <= 1\n    elif node_mask_type == MaskType.object:\n        assert explanation.node_mask.size() == (explanation.num_nodes, 1)\n        assert explanation.node_mask.min() >= 0\n        assert explanation.node_mask.max() <= 1\n    elif node_mask_type == MaskType.common_attributes:\n        assert explanation.node_mask.size() == (1, explanation.num_features)\n        assert explanation.node_mask.min() >= 0\n        assert explanation.node_mask.max() <= 1\n\n    if edge_mask_type == MaskType.object:\n        assert explanation.edge_mask.size() == (explanation.num_edges, )\n        assert explanation.edge_mask.min() >= 0\n        assert explanation.edge_mask.max() <= 1\n\n\nnode_mask_types = [\n    MaskType.object,\n    MaskType.common_attributes,\n    MaskType.attributes,\n]\nedge_mask_types = [\n    MaskType.object,\n    None,\n]\n\nx = torch.randn(8, 3)\nedge_index = torch.tensor([\n    [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],\n    [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],\n])\nbatch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2])\nedge_label_index = torch.tensor([[0, 1, 2], [3, 4, 5]])\n\n\n@pytest.mark.parametrize('edge_mask_type', edge_mask_types)\n@pytest.mark.parametrize('node_mask_type', node_mask_types)\n@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon'])\n@pytest.mark.parametrize('task_level', ['node', 'edge', 'graph'])\n@pytest.mark.parametrize('return_type', ['probs', 'raw'])\n@pytest.mark.parametrize('index', [None, 2, torch.arange(3)])\ndef test_graph_mask_explainer_binary_classification(\n    edge_mask_type,\n    node_mask_type,\n    explanation_type,\n    task_level,\n    return_type,\n    index,\n):\n    model_config = ModelConfig(\n        mode='binary_classification',\n        task_level=task_level,\n        return_type=return_type,\n    )\n\n    model = GCN(model_config)\n\n    target = None\n    if explanation_type == 'phenomenon':\n        with torch.no_grad():\n            out = model(x, edge_index, batch, edge_label_index)\n            if model_config.return_type == ModelReturnType.raw:\n                target = (out > 0).long().view(-1)\n            if model_config.return_type == ModelReturnType.probs:\n                target = (out > 0.5).long().view(-1)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=GraphMaskExplainer(2, epochs=5, log=False),\n        explanation_type=explanation_type,\n        node_mask_type=node_mask_type,\n        edge_mask_type=edge_mask_type,\n        model_config=model_config,\n    )\n\n    explanation = explainer(\n        x,\n        edge_index,\n        target=target,\n        index=index,\n        batch=batch,\n        edge_label_index=edge_label_index,\n    )\n\n    check_explanation(edge_mask_type, node_mask_type, explanation)\n\n\n@pytest.mark.parametrize('edge_mask_type', edge_mask_types)\n@pytest.mark.parametrize('node_mask_type', node_mask_types)\n@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon'])\n@pytest.mark.parametrize('task_level', ['node', 'edge', 'graph'])\n@pytest.mark.parametrize('return_type', ['log_probs', 'probs', 'raw'])\n@pytest.mark.parametrize('index', [None, 2, torch.arange(3)])\ndef test_graph_mask_explainer_multiclass_classification(\n    edge_mask_type,\n    node_mask_type,\n    explanation_type,\n    task_level,\n    return_type,\n    index,\n):\n    model_config = ModelConfig(\n        mode='multiclass_classification',\n        task_level=task_level,\n        return_type=return_type,\n    )\n\n    model = GCN(model_config)\n\n    target = None\n    if explanation_type == 'phenomenon':\n        with torch.no_grad():\n            target = model(x, edge_index, batch, edge_label_index).argmax(-1)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=GraphMaskExplainer(2, epochs=5, log=False),\n        explanation_type=explanation_type,\n        node_mask_type=node_mask_type,\n        edge_mask_type=edge_mask_type,\n        model_config=model_config,\n    )\n\n    explanation = explainer(\n        x,\n        edge_index,\n        target=target,\n        index=index,\n        batch=batch,\n        edge_label_index=edge_label_index,\n    )\n\n    check_explanation(edge_mask_type, node_mask_type, explanation)\n\n\n@pytest.mark.parametrize('edge_mask_type', edge_mask_types)\n@pytest.mark.parametrize('node_mask_type', node_mask_types)\n@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon'])\n@pytest.mark.parametrize('task_level', ['node', 'edge', 'graph'])\n@pytest.mark.parametrize('index', [None, 2, torch.arange(3)])\ndef test_graph_mask_explainer_regression(\n    edge_mask_type,\n    node_mask_type,\n    explanation_type,\n    task_level,\n    index,\n):\n    model_config = ModelConfig(\n        mode='regression',\n        task_level=task_level,\n    )\n\n    model = GCN(model_config)\n\n    target = None\n    if explanation_type == 'phenomenon':\n        with torch.no_grad():\n            target = model(x, edge_index, batch, edge_label_index)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=GraphMaskExplainer(2, epochs=5, log=False),\n        explanation_type=explanation_type,\n        node_mask_type=node_mask_type,\n        edge_mask_type=edge_mask_type,\n        model_config=model_config,\n    )\n\n    explanation = explainer(\n        x,\n        edge_index,\n        target=target,\n        index=index,\n        batch=batch,\n        edge_label_index=edge_label_index,\n    )\n\n    check_explanation(edge_mask_type, node_mask_type, explanation)\n"
  },
  {
    "path": "test/explain/algorithm/test_pg_explainer.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.explain import Explainer, HeteroExplanation, PGExplainer\nfrom torch_geometric.explain.config import (\n    ExplanationType,\n    ModelConfig,\n    ModelMode,\n    ModelReturnType,\n    ModelTaskLevel,\n)\nfrom torch_geometric.nn import GCNConv, global_add_pool\nfrom torch_geometric.testing import withCUDA\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self, model_config: ModelConfig):\n        super().__init__()\n        self.model_config = model_config\n\n        if model_config.mode == ModelMode.multiclass_classification:\n            out_channels = 7\n        else:\n            out_channels = 1\n\n        self.conv1 = GCNConv(3, 16)\n        self.conv2 = GCNConv(16, out_channels)\n\n    def forward(self, x, edge_index, batch=None, edge_label_index=None):\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index)\n        if self.model_config.task_level == ModelTaskLevel.graph:\n            x = global_add_pool(x, batch)\n        return x\n\n\n@withCUDA\n@pytest.mark.parametrize('mode', [\n    ModelMode.binary_classification,\n    ModelMode.multiclass_classification,\n    ModelMode.regression,\n])\ndef test_pg_explainer_node(device, check_explanation, mode):\n    x = torch.randn(8, 3, device=device)\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],\n        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],\n    ], device=device)\n\n    if mode == ModelMode.binary_classification:\n        target = torch.randint(2, (x.size(0), ), device=device)\n    elif mode == ModelMode.multiclass_classification:\n        target = torch.randint(7, (x.size(0), ), device=device)\n    elif mode == ModelMode.regression:\n        target = torch.randn((x.size(0), 1), device=device)\n\n    model_config = ModelConfig(mode=mode, task_level='node', return_type='raw')\n\n    model = GCN(model_config).to(device)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=PGExplainer(epochs=2).to(device),\n        explanation_type='phenomenon',\n        edge_mask_type='object',\n        model_config=model_config,\n    )\n\n    with pytest.raises(ValueError, match=\"not yet fully trained\"):\n        explainer(x, edge_index, target=target)\n\n    explainer.algorithm.reset_parameters()\n    for epoch in range(2):\n        for index in range(x.size(0)):\n            loss = explainer.algorithm.train(epoch, model, x, edge_index,\n                                             target=target, index=index)\n            assert loss >= 0.0\n\n    explanation = explainer(x, edge_index, target=target, index=0)\n\n    check_explanation(explanation, None, explainer.edge_mask_type)\n\n\n@withCUDA\n@pytest.mark.parametrize('mode', [\n    ModelMode.binary_classification,\n    ModelMode.multiclass_classification,\n    ModelMode.regression,\n])\ndef test_pg_explainer_graph(device, check_explanation, mode):\n    x = torch.randn(8, 3, device=device)\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],\n        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],\n    ], device=device)\n\n    if mode == ModelMode.binary_classification:\n        target = torch.randint(2, (1, ), device=device)\n    elif mode == ModelMode.multiclass_classification:\n        target = torch.randint(7, (1, ), device=device)\n    elif mode == ModelMode.regression:\n        target = torch.randn((1, 1), device=device)\n\n    model_config = ModelConfig(mode=mode, task_level='graph',\n                               return_type='raw')\n\n    model = GCN(model_config).to(device)\n\n    explainer = Explainer(\n        model=model,\n        algorithm=PGExplainer(epochs=2).to(device),\n        explanation_type='phenomenon',\n        edge_mask_type='object',\n        model_config=model_config,\n    )\n\n    with pytest.raises(ValueError, match=\"not yet fully trained\"):\n        explainer(x, edge_index, target=target)\n\n    explainer.algorithm.reset_parameters()\n    for epoch in range(2):\n        loss = explainer.algorithm.train(epoch, model, x, edge_index,\n                                         target=target)\n        assert loss >= 0.0\n\n    explanation = explainer(x, edge_index, target=target)\n\n    check_explanation(explanation, None, explainer.edge_mask_type)\n\n\n@withCUDA\n@pytest.mark.parametrize('mode', [\n    ModelMode.binary_classification,\n    ModelMode.multiclass_classification,\n    ModelMode.regression,\n])\n@pytest.mark.parametrize('task_level', [\n    ModelTaskLevel.node,\n    ModelTaskLevel.graph,\n])\ndef test_pg_explainer_hetero(device, hetero_data, hetero_model,\n                             check_explanation_hetero, mode, task_level):\n    # Move data to device\n    hetero_data = hetero_data.to(device)\n\n    # Prepare target based on mode and task level\n    index = 0 if task_level == ModelTaskLevel.node else None\n\n    # Create model config\n    model_config = ModelConfig(\n        mode=mode,\n        task_level=task_level,\n        return_type=ModelReturnType.raw,\n    )\n\n    # Create and initialize model\n    metadata = hetero_data.metadata()\n    model = hetero_model(metadata, model_config).to(device)\n\n    with torch.no_grad():\n        raw_output = model(hetero_data.x_dict, hetero_data.edge_index_dict)\n        if mode == ModelMode.multiclass_classification:\n            # For multiclass, use class indices (long tensor)\n            target = raw_output.argmax(dim=-1)\n        elif mode == ModelMode.binary_classification:\n            # For binary, convert to binary targets (long tensor)\n            target = (raw_output > 0).long()\n        else:  # regression\n            # For regression, use raw outputs (float tensor)\n            target = raw_output.float()\n\n    # Create explainer\n    explainer = Explainer(\n        model=model,\n        algorithm=PGExplainer(epochs=2).to(device),\n        explanation_type=ExplanationType.phenomenon,\n        edge_mask_type='object',\n        model_config=model_config,\n    )\n\n    # Should raise error when not fully trained\n    with pytest.raises(ValueError, match=\"not yet fully trained\"):\n        explainer(\n            hetero_data.x_dict,\n            hetero_data.edge_index_dict,\n            target=target,\n            index=index if task_level == ModelTaskLevel.node else None,\n        )\n\n    # Train the explainer\n    explainer.algorithm.reset_parameters()\n    for epoch in range(2):\n        if task_level == ModelTaskLevel.node:\n            # For node-level, train on a single node\n            loss = explainer.algorithm.train(\n                epoch,\n                model,\n                hetero_data.x_dict,\n                hetero_data.edge_index_dict,\n                target=target,\n                index=index,\n            )\n        else:\n            # For graph-level, train on the whole graph\n            loss = explainer.algorithm.train(\n                epoch,\n                model,\n                hetero_data.x_dict,\n                hetero_data.edge_index_dict,\n                target=target,\n            )\n        assert isinstance(loss, float)\n\n    # Get explanation\n    explanation = explainer(\n        hetero_data.x_dict,\n        hetero_data.edge_index_dict,\n        target=target,\n        index=index if task_level == ModelTaskLevel.node else None,\n    )\n\n    # Check if the explanation is valid\n    assert isinstance(explanation, HeteroExplanation)\n    # Run through the standard explanation checker\n    check_explanation_hetero(explanation, None, explainer.edge_mask_type,\n                             hetero_data)\n\n\ndef test_pg_explainer_supports():\n    # Test unsupported model task level:\n    with pytest.raises(ValueError, match=\"not support the given explanation\"):\n        model_config = ModelConfig(\n            mode='binary_classification',\n            task_level='edge',\n            return_type='raw',\n        )\n        Explainer(\n            model=GCN(model_config),\n            algorithm=PGExplainer(epochs=2),\n            explanation_type='phenomenon',\n            edge_mask_type='object',\n            model_config=model_config,\n        )\n\n    # Test unsupported explanation type:\n    with pytest.raises(ValueError, match=\"not support the given explanation\"):\n        model_config = ModelConfig(\n            mode='binary_classification',\n            task_level='node',\n            return_type='raw',\n        )\n        Explainer(\n            model=GCN(model_config),\n            algorithm=PGExplainer(epochs=2),\n            explanation_type='model',\n            edge_mask_type='object',\n            model_config=model_config,\n        )\n\n    # Test unsupported node mask:\n    with pytest.raises(ValueError, match=\"not support the given explanation\"):\n        model_config = ModelConfig(\n            mode='binary_classification',\n            task_level='node',\n            return_type='raw',\n        )\n        Explainer(\n            model=GCN(model_config),\n            algorithm=PGExplainer(epochs=2),\n            explanation_type='model',\n            node_mask_type='object',\n            edge_mask_type='object',\n            model_config=model_config,\n        )\n\n\n@withCUDA\n@pytest.mark.parametrize('conv_type', ['HGTConv', 'HANConv'])\n@pytest.mark.parametrize('mode', [\n    ModelMode.binary_classification,\n    ModelMode.multiclass_classification,\n])\n@pytest.mark.parametrize('task_level', [\n    ModelTaskLevel.node,\n    ModelTaskLevel.graph,\n])\ndef test_pg_explainer_native_hetero(device, hetero_data, hetero_model_native,\n                                    check_explanation_hetero, conv_type, mode,\n                                    task_level):\n    \"\"\"Test PGExplainer with native heterogeneous GNNs\n    (not created by to_hetero).\n    \"\"\"\n    # Move data to device\n    hetero_data = hetero_data.to(device)\n\n    # Create model config\n    model_config = ModelConfig(\n        mode=mode,\n        task_level=task_level,\n        return_type=ModelReturnType.raw,\n    )\n\n    # Create and initialize model\n    metadata = hetero_data.metadata()\n    model = hetero_model_native(metadata, model_config,\n                                conv_type=conv_type).to(device)\n\n    # Generate target\n    with torch.no_grad():\n        raw_output = model(hetero_data.x_dict, hetero_data.edge_index_dict)\n        if mode == ModelMode.multiclass_classification:\n            # For multiclass, use class indices (long tensor)\n            target = raw_output.argmax(dim=-1)\n        else:  # binary classification\n            # For binary, convert to binary targets (long tensor)\n            target = (raw_output > 0).long()\n\n    # Setup index for node-level tasks\n    index = 0 if task_level == ModelTaskLevel.node else None\n\n    # Create explainer\n    explainer = Explainer(\n        model=model,\n        algorithm=PGExplainer(epochs=2).to(device),\n        explanation_type=ExplanationType.phenomenon,\n        edge_mask_type='object',\n        model_config=model_config,\n    )\n\n    # Should raise error when not fully trained\n    with pytest.raises(ValueError, match=\"not yet fully trained\"):\n        explainer(\n            hetero_data.x_dict,\n            hetero_data.edge_index_dict,\n            target=target,\n            index=index if task_level == ModelTaskLevel.node else None,\n        )\n\n    # Train the explainer\n    explainer.algorithm.reset_parameters()\n    for epoch in range(2):\n        if task_level == ModelTaskLevel.node:\n            # For node-level, train on a single node\n            loss = explainer.algorithm.train(\n                epoch,\n                model,\n                hetero_data.x_dict,\n                hetero_data.edge_index_dict,\n                target=target,\n                index=index,\n            )\n        else:\n            # For graph-level, train on the whole graph\n            loss = explainer.algorithm.train(\n                epoch,\n                model,\n                hetero_data.x_dict,\n                hetero_data.edge_index_dict,\n                target=target,\n            )\n        assert isinstance(loss, float)\n\n    # Get explanation\n    explanation = explainer(\n        hetero_data.x_dict,\n        hetero_data.edge_index_dict,\n        target=target,\n        index=index if task_level == ModelTaskLevel.node else None,\n    )\n\n    # Check if the explanation is valid\n    assert isinstance(explanation, HeteroExplanation)\n    # Run through the standard explanation checker\n    check_explanation_hetero(explanation, None, explainer.edge_mask_type,\n                             hetero_data)\n\n\n@withCUDA\n@pytest.mark.parametrize('mode', [\n    ModelMode.binary_classification,\n    ModelMode.multiclass_classification,\n])\n@pytest.mark.parametrize('task_level', [\n    ModelTaskLevel.node,\n    ModelTaskLevel.graph,\n])\ndef test_pg_explainer_hetero_conv(device, hetero_data, hetero_model_custom,\n                                  check_explanation_hetero, mode, task_level):\n    \"\"\"Test PGExplainer with the built-in HeteroConv model.\"\"\"\n    # Move data to device\n    hetero_data = hetero_data.to(device)\n\n    # Create model config\n    model_config = ModelConfig(\n        mode=mode,\n        task_level=task_level,\n        return_type=ModelReturnType.raw,\n    )\n\n    # Create and initialize model\n    metadata = hetero_data.metadata()\n    model = hetero_model_custom(metadata, model_config).to(device)\n\n    # Generate target\n    with torch.no_grad():\n        raw_output = model(hetero_data.x_dict, hetero_data.edge_index_dict)\n        if mode == ModelMode.multiclass_classification:\n            # For multiclass, use class indices (long tensor)\n            target = raw_output.argmax(dim=-1)\n        else:  # binary classification\n            # For binary, convert to binary targets (long tensor)\n            target = (raw_output > 0).long()\n\n    # Setup index for node-level tasks\n    index = 0 if task_level == ModelTaskLevel.node else None\n\n    # Create explainer\n    explainer = Explainer(\n        model=model,\n        algorithm=PGExplainer(epochs=2).to(device),\n        explanation_type=ExplanationType.phenomenon,\n        edge_mask_type='object',\n        model_config=model_config,\n    )\n\n    # Should raise error when not fully trained\n    with pytest.raises(ValueError, match=\"not yet fully trained\"):\n        explainer(\n            hetero_data.x_dict,\n            hetero_data.edge_index_dict,\n            target=target,\n            index=index if task_level == ModelTaskLevel.node else None,\n        )\n\n    # Train the explainer\n    explainer.algorithm.reset_parameters()\n    for epoch in range(2):\n        if task_level == ModelTaskLevel.node:\n            # For node-level, train on a single node\n            loss = explainer.algorithm.train(\n                epoch,\n                model,\n                hetero_data.x_dict,\n                hetero_data.edge_index_dict,\n                target=target,\n                index=index,\n            )\n        else:\n            # For graph-level, train on the whole graph\n            loss = explainer.algorithm.train(\n                epoch,\n                model,\n                hetero_data.x_dict,\n                hetero_data.edge_index_dict,\n                target=target,\n            )\n        assert isinstance(loss, float)\n\n    # Get explanation\n    explanation = explainer(\n        hetero_data.x_dict,\n        hetero_data.edge_index_dict,\n        target=target,\n        index=index if task_level == ModelTaskLevel.node else None,\n    )\n\n    # Check if the explanation is valid\n    assert isinstance(explanation, HeteroExplanation)\n    # Run through the standard explanation checker\n    check_explanation_hetero(explanation, None, explainer.edge_mask_type,\n                             hetero_data)\n"
  },
  {
    "path": "test/explain/conftest.py",
    "content": "from typing import Optional\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.explain import Explanation, HeteroExplanation\nfrom torch_geometric.explain.config import (\n    MaskType,\n    ModelConfig,\n    ModelMode,\n    ModelReturnType,\n    ModelTaskLevel,\n)\nfrom torch_geometric.nn import (\n    HANConv,\n    HGTConv,\n    SAGEConv,\n    global_add_pool,\n    to_hetero,\n)\nfrom torch_geometric.nn.conv import GCNConv, HeteroConv\nfrom torch_geometric.testing import get_random_edge_index\n\n\n@pytest.fixture()\ndef data():\n    return Data(\n        x=torch.randn(4, 3),\n        edge_index=get_random_edge_index(4, 4, num_edges=6),\n        edge_attr=torch.randn(6, 3),\n    )\n\n\n@pytest.fixture()\ndef hetero_data():\n    data = HeteroData()\n    data['paper'].x = torch.randn(8, 16)\n    data['author'].x = torch.randn(10, 8)\n\n    data['paper', 'paper'].edge_index = get_random_edge_index(8, 8, 10)\n    data['paper', 'paper'].edge_attr = torch.randn(10, 16)\n    data['paper', 'author'].edge_index = get_random_edge_index(8, 10, 10)\n    data['paper', 'author'].edge_attr = torch.randn(10, 8)\n    data['author', 'paper'].edge_index = get_random_edge_index(10, 8, 10)\n    data['author', 'paper'].edge_attr = torch.randn(10, 8)\n\n    return data\n\n\n@pytest.fixture()\ndef hetero_model():\n    return HeteroSAGE\n\n\n@pytest.fixture()\ndef hetero_model_custom():\n    return HeteroConvModel\n\n\nclass GraphSAGE(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = SAGEConv((-1, -1), 32)\n        self.conv2 = SAGEConv((-1, -1), 32)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        return self.conv2(x, edge_index)\n\n\nclass HeteroSAGE(torch.nn.Module):\n    def __init__(self, metadata, model_config: Optional[ModelConfig] = None):\n        super().__init__()\n        self.model_config = model_config\n        self.graph_sage = to_hetero(GraphSAGE(), metadata, debug=False)\n\n        # Determine output channels based on model_config\n        out_channels = 1\n        if (model_config\n                and model_config.mode == ModelMode.multiclass_classification):\n            out_channels = 7\n\n        self.lin = torch.nn.Linear(32, out_channels)\n\n    def forward(self, x_dict, edge_index_dict,\n                additonal_arg=None) -> torch.Tensor:\n        x = self.lin(self.graph_sage(x_dict, edge_index_dict)['paper'])\n\n        # Apply transformations based on model_config if available\n        if self.model_config:\n            if self.model_config.mode == ModelMode.binary_classification:\n                if self.model_config.return_type == ModelReturnType.probs:\n                    x = x.sigmoid()\n            elif self.model_config.mode == ModelMode.multiclass_classification:\n                if self.model_config.return_type == ModelReturnType.probs:\n                    x = x.softmax(dim=-1)\n                elif (self.model_config.return_type ==\n                      ModelReturnType.log_probs):\n                    x = x.log_softmax(dim=-1)\n\n        return x\n\n\n@pytest.fixture()\ndef check_explanation():\n    def _check_explanation(\n        explanation: Explanation,\n        node_mask_type: Optional[MaskType],\n        edge_mask_type: Optional[MaskType],\n    ):\n        if node_mask_type == MaskType.attributes:\n            assert explanation.node_mask.size() == explanation.x.size()\n            assert explanation.node_mask.min() >= 0\n            assert explanation.node_mask.max() <= 1\n        elif node_mask_type == MaskType.object:\n            assert explanation.node_mask.size() == (explanation.num_nodes, 1)\n            assert explanation.node_mask.min() >= 0\n            assert explanation.node_mask.max() <= 1\n        elif node_mask_type == MaskType.common_attributes:\n            assert explanation.node_mask.size() == (1, explanation.x.size(-1))\n            assert explanation.node_mask.min() >= 0\n            assert explanation.node_mask.max() <= 1\n        elif node_mask_type is None:\n            assert 'node_mask' not in explanation\n\n        if edge_mask_type == MaskType.object:\n            assert explanation.edge_mask.size() == (explanation.num_edges, )\n            assert explanation.edge_mask.min() >= 0\n            assert explanation.edge_mask.max() <= 1\n        elif edge_mask_type is None:\n            assert 'edge_mask' not in explanation\n\n    return _check_explanation\n\n\n@pytest.fixture()\ndef check_explanation_hetero():\n    def _check_explanation_hetero(\n        explanation: HeteroExplanation,\n        node_mask_type: Optional[MaskType],\n        edge_mask_type: Optional[MaskType],\n        hetero_data: HeteroData,\n    ):\n        # Validate the explanation\n        explanation.validate(raise_on_error=True)\n\n        # Check node masks for different node types\n        if node_mask_type is not None:\n            for node_type in hetero_data.node_types:\n                assert explanation[node_type].get('node_mask') is not None\n                assert explanation[node_type].get('node_mask').min() >= 0\n                assert explanation[node_type].get('node_mask').max() <= 1\n\n                # Check dimensions based on mask type\n                if node_mask_type == MaskType.attributes:\n                    mask = explanation[node_type].get('node_mask')\n                    assert mask.size() == hetero_data.x_dict[node_type].size()\n                elif node_mask_type == MaskType.object:\n                    mask = explanation[node_type].get('node_mask')\n                    assert mask.size() == (\n                        hetero_data.x_dict[node_type].size(0), 1)\n                elif node_mask_type == MaskType.common_attributes:\n                    mask = explanation[node_type].get('node_mask')\n                    assert mask.size() == (\n                        1, hetero_data.x_dict[node_type].size(1))\n\n        # Check edge masks for different edge types\n        if edge_mask_type is not None:\n            for edge_type in hetero_data.edge_types:\n                assert explanation[edge_type].get('edge_mask') is not None\n                assert explanation[edge_type].get('edge_mask').min() >= 0\n                assert explanation[edge_type].get('edge_mask').max() <= 1\n\n    return _check_explanation_hetero\n\n\nclass NativeHeteroGNN(torch.nn.Module):\n    def __init__(self, metadata, model_config: Optional[ModelConfig] = None,\n                 conv_type: str = 'HGTConv', hidden_channels: int = 32):\n        super().__init__()\n        self.model_config = model_config\n        self.conv_type = conv_type\n        self.hidden_channels = hidden_channels\n        self.metadata = metadata\n\n        # Determine output size based on model_config\n        self.out_channels = 1\n        if (model_config\n                and model_config.mode == ModelMode.multiclass_classification):\n            self.out_channels = 7\n\n        # Initialize dictionaries to store the layers\n        self.lin_dict = torch.nn.ModuleDict()\n        self.initialized = False\n\n        # Heterogeneous convolution layer\n        if conv_type == 'HGTConv':\n            self.conv = HGTConv(hidden_channels, hidden_channels, metadata,\n                                heads=2)\n        elif conv_type == 'HANConv':\n            self.conv = HANConv(hidden_channels, hidden_channels, metadata,\n                                heads=2)\n        else:\n            raise ValueError(f\"Unsupported conv_type: {conv_type}\")\n\n        # Output projection will be initialized in forward pass\n        self.out_lin = None\n\n    def _initialize_layers(self, x_dict):\n        \"\"\"Initialize layers with correct dimensions when we first see\n        the data.\n        \"\"\"\n        if not self.initialized:\n            # Initialize input projections\n            for node_type, x in x_dict.items():\n                in_channels = x.size(-1)\n                self.lin_dict[node_type] = torch.nn.Linear(\n                    in_channels, self.hidden_channels).to(x.device)\n\n            # Initialize output projection\n            self.out_lin = torch.nn.Linear(self.hidden_channels,\n                                           self.out_channels).to(x.device)\n\n            self.initialized = True\n\n    def forward(self, x_dict, edge_index_dict):\n        # Initialize layers if not done yet\n        self._initialize_layers(x_dict)\n\n        # Apply input projections\n        x_dict = {\n            node_type: self.lin_dict[node_type](x).relu_()\n            for node_type, x in x_dict.items()\n        }\n\n        # Apply heterogeneous convolution\n        x_dict = self.conv(x_dict, edge_index_dict)\n\n        # Get paper node features for prediction\n        x = x_dict['paper']\n\n        # Apply output projection\n        out = self.out_lin(x)\n\n        # For graph-level tasks, perform global pooling\n        if (self.model_config\n                and self.model_config.task_level == ModelTaskLevel.graph):\n            # Since we don't have batch information in the fixture,\n            # we'll treat the whole graph as a single graph\n            batch_size = x.size(0)\n            batch = torch.zeros(batch_size, dtype=torch.long, device=x.device)\n            out = global_add_pool(out, batch)\n\n        return out\n\n\n@pytest.fixture()\ndef hetero_model_native():\n    return NativeHeteroGNN\n\n\nclass HeteroConvModel(torch.nn.Module):\n    def __init__(self, metadata, model_config: Optional[ModelConfig] = None):\n        super().__init__()\n        self.model_config = model_config\n\n        # Create a HeteroConv model\n        conv_dict = {}\n        for edge_type in metadata[1]:  # metadata[1] contains edge types\n            src_type, _, dst_type = edge_type\n            if src_type == dst_type:\n                conv_dict[edge_type] = GCNConv(-1, 32)\n            else:\n                # For different node types, use SAGEConv\n                conv_dict[edge_type] = SAGEConv((-1, -1), 32)\n\n        self.conv = HeteroConv(conv_dict, aggr='sum')\n\n        # Determine output channels based on model_config\n        out_channels = 1\n        if (model_config\n                and model_config.mode == ModelMode.multiclass_classification):\n            out_channels = 7\n\n        # Output layer\n        self.out_lin = torch.nn.Linear(32, out_channels)\n\n    def forward(self, x_dict, edge_index_dict):\n        # Apply heterogeneous convolution\n        out_dict = self.conv(x_dict, edge_index_dict)\n\n        # Final transformation for paper nodes\n        out = self.out_lin(out_dict['paper'])\n\n        # Apply transformations based on model_config if available\n        if self.model_config:\n            if self.model_config.mode == ModelMode.binary_classification:\n                if self.model_config.return_type == ModelReturnType.probs:\n                    out = out.sigmoid()\n            elif self.model_config.mode == ModelMode.multiclass_classification:\n                if self.model_config.return_type == ModelReturnType.probs:\n                    out = out.softmax(dim=-1)\n                elif (self.model_config.return_type ==\n                      ModelReturnType.log_probs):\n                    out = out.log_softmax(dim=-1)\n\n        return out\n"
  },
  {
    "path": "test/explain/metric/test_basic_metric.py",
    "content": "import warnings\n\nimport torch\n\nfrom torch_geometric.explain import groundtruth_metrics\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('torchmetrics>=0.10.0')\ndef test_groundtruth_metrics():\n    pred_mask = torch.rand(10)\n    target_mask = torch.rand(10)\n\n    accuracy, recall, precision, f1_score, auroc = groundtruth_metrics(\n        pred_mask, target_mask)\n\n    assert accuracy >= 0.0 and accuracy <= 1.0\n    assert recall >= 0.0 and recall <= 1.0\n    assert precision >= 0.0 and precision <= 1.0\n    assert f1_score >= 0.0 and f1_score <= 1.0\n    assert auroc >= 0.0 and auroc <= 1.0\n\n\n@withPackage('torchmetrics>=0.10.0')\ndef test_perfect_groundtruth_metrics():\n    pred_mask = target_mask = torch.rand(10)\n\n    accuracy, recall, precision, f1_score, auroc = groundtruth_metrics(\n        pred_mask, target_mask)\n\n    assert round(accuracy, 6) == 1.0\n    assert round(recall, 6) == 1.0\n    assert round(precision, 6) == 1.0\n    assert round(f1_score, 6) == 1.0\n    assert round(auroc, 6) == 1.0\n\n\n@withPackage('torchmetrics>=0.10.0')\ndef test_groundtruth_true_negative():\n    warnings.filterwarnings('ignore', '.*No positive samples in targets.*')\n    pred_mask = target_mask = torch.zeros(10)\n\n    accuracy, recall, precision, f1_score, auroc = groundtruth_metrics(\n        pred_mask, target_mask)\n\n    assert round(accuracy, 6) == 1.0\n    assert round(recall, 6) == 0.0\n    assert round(precision, 6) == 0.0\n    assert round(f1_score, 6) == 0.0\n    assert round(auroc, 6) == 0.0\n"
  },
  {
    "path": "test/explain/metric/test_faithfulness.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.explain import (\n    DummyExplainer,\n    Explainer,\n    ModelConfig,\n    unfaithfulness,\n)\n\n\nclass DummyModel(torch.nn.Module):\n    def __init__(self, model_config: ModelConfig):\n        super().__init__()\n        self.model_config = model_config\n\n    def forward(self, x, edge_index):\n        if self.model_config.return_type.value == 'probs':\n            x = x.softmax(dim=-1)\n        elif self.model_config.return_type.value == 'log_probs':\n            x = x.log_softmax(dim=-1)\n        return x\n\n\n@pytest.mark.parametrize('top_k', [None, 2])\n@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon'])\n@pytest.mark.parametrize('node_mask_type', ['common_attributes', 'attributes'])\n@pytest.mark.parametrize('return_type', ['raw', 'probs', 'log_probs'])\ndef test_unfaithfulness(top_k, explanation_type, node_mask_type, return_type):\n    x = torch.randn(8, 4)\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],\n        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],\n    ])\n\n    model_config = ModelConfig(\n        mode='multiclass_classification',\n        task_level='node',\n        return_type=return_type,\n    )\n\n    explainer = Explainer(\n        DummyModel(model_config),\n        algorithm=DummyExplainer(),\n        explanation_type=explanation_type,\n        node_mask_type=node_mask_type,\n        edge_mask_type='object',\n        model_config=model_config,\n    )\n\n    target = None\n    if explanation_type == 'phenomenon':\n        target = torch.randint(0, x.size(1), (x.size(0), ))\n\n    explanation = explainer(x, edge_index, target=target,\n                            index=torch.arange(4))\n\n    metric = unfaithfulness(explainer, explanation, top_k)\n    assert metric >= 0. and metric <= 1.\n"
  },
  {
    "path": "test/explain/metric/test_fidelity.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.explain import (\n    DummyExplainer,\n    Explainer,\n    characterization_score,\n    fidelity,\n    fidelity_curve_auc,\n)\n\n\nclass DummyModel(torch.nn.Module):\n    def forward(self, x, edge_index):\n        return x\n\n\n@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon'])\ndef test_fidelity(explanation_type):\n    x = torch.randn(8, 4)\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],\n        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],\n    ])\n\n    explainer = Explainer(\n        DummyModel(),\n        algorithm=DummyExplainer(),\n        explanation_type=explanation_type,\n        node_mask_type='object',\n        edge_mask_type='object',\n        model_config=dict(\n            mode='multiclass_classification',\n            return_type='raw',\n            task_level='node',\n        ),\n    )\n\n    target = None\n    if explanation_type == 'phenomenon':\n        target = torch.randint(0, x.size(1), (x.size(0), ))\n\n    explanation = explainer(x, edge_index, target=target,\n                            index=torch.arange(4))\n\n    pos_fidelity, neg_fidelity = fidelity(explainer, explanation)\n    assert pos_fidelity == 0.0 and neg_fidelity == 0.0\n\n\ndef test_characterization_score():\n    out = characterization_score(\n        pos_fidelity=torch.tensor([1.0, 0.6, 0.5, 1.0]),\n        neg_fidelity=torch.tensor([0.0, 0.2, 0.5, 1.0]),\n        pos_weight=0.2,\n        neg_weight=0.8,\n    )\n    assert out.tolist() == [1.0, 0.75, 0.5, 0.0]\n\n\ndef test_fidelity_curve_auc():\n    pos_fidelity = torch.tensor([1.0, 1.0, 0.5, 1.0])\n    neg_fidelity = torch.tensor([0.0, 0.5, 0.5, 0.9])\n\n    x = torch.tensor([0, 1, 2, 3])\n    out = round(float(fidelity_curve_auc(pos_fidelity, neg_fidelity, x)), 4)\n    assert out == 8.5\n\n    x = torch.tensor([10, 11, 12, 13])\n    out = round(float(fidelity_curve_auc(pos_fidelity, neg_fidelity, x)), 4)\n    assert out == 8.5\n\n    x = torch.tensor([0, 1, 2, 5])\n    out = round(float(fidelity_curve_auc(pos_fidelity, neg_fidelity, x)), 4)\n    assert out == 19.5\n"
  },
  {
    "path": "test/explain/test_explain_config.py",
    "content": "import pytest\n\nfrom torch_geometric.explain.config import ExplainerConfig, ThresholdConfig\n\n\n@pytest.mark.parametrize('threshold_pairs', [\n    ('hard', 0.5, True),\n    ('hard', 1.1, False),\n    ('hard', -1, False),\n    ('topk', 1, True),\n    ('topk', 0, False),\n    ('topk', -1, False),\n    ('topk', 0.5, False),\n    ('invalid', None, False),\n    ('hard', None, False),\n])\ndef test_threshold_config(threshold_pairs):\n    threshold_type, threshold_value, valid = threshold_pairs\n    if valid:\n        threshold = ThresholdConfig(threshold_type, threshold_value)\n        assert threshold.type.value == threshold_type\n        assert threshold.value == threshold_value\n    else:\n        with pytest.raises(ValueError):\n            ThresholdConfig(threshold_type, threshold_value)\n\n\n@pytest.mark.parametrize('explanation_type', [\n    'model',\n    'phenomenon',\n    'invalid',\n])\n@pytest.mark.parametrize('mask_type', [\n    None,\n    'object',\n    'common_attributes',\n    'attributes',\n    'invalid',\n])\ndef test_configuration_config(explanation_type, mask_type):\n    if (explanation_type != 'invalid' and mask_type is not None\n            and mask_type != 'invalid'):\n        config = ExplainerConfig(explanation_type, mask_type, None)\n        assert config.explanation_type.value == explanation_type\n        assert config.node_mask_type.value == mask_type\n    else:\n        with pytest.raises(ValueError):\n            ExplainerConfig(explanation_type, mask_type, mask_type)\n"
  },
  {
    "path": "test/explain/test_explainer.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.explain import DummyExplainer, Explainer, Explanation\nfrom torch_geometric.explain.config import ExplanationType\n\n\nclass DummyModel(torch.nn.Module):\n    def forward(self, x, edge_index):\n        return x.mean().view(-1)\n\n\ndef test_get_prediction(data):\n    model = DummyModel()\n    assert model.training\n\n    explainer = Explainer(\n        model,\n        algorithm=DummyExplainer(),\n        explanation_type='phenomenon',\n        node_mask_type='object',\n        model_config=dict(\n            mode='regression',\n            task_level='graph',\n        ),\n    )\n    pred = explainer.get_prediction(data.x, data.edge_index)\n    assert model.training\n    assert pred.size() == (1, )\n\n\n@pytest.mark.parametrize('target', [None, torch.randn(2)])\n@pytest.mark.parametrize('explanation_type', [x for x in ExplanationType])\ndef test_forward(data, target, explanation_type):\n    model = DummyModel()\n    assert model.training\n\n    explainer = Explainer(\n        model,\n        algorithm=DummyExplainer(),\n        explanation_type=explanation_type,\n        node_mask_type='attributes',\n        model_config=dict(\n            mode='regression',\n            task_level='graph',\n        ),\n    )\n\n    if target is None and explanation_type == ExplanationType.phenomenon:\n        with pytest.raises(ValueError):\n            explainer(data.x, data.edge_index, target=target)\n    else:\n        explanation = explainer(\n            data.x,\n            data.edge_index,\n            target=target\n            if explanation_type == ExplanationType.phenomenon else None,\n        )\n        assert model.training\n        assert isinstance(explanation, Explanation)\n        assert 'x' in explanation\n        assert 'edge_index' in explanation\n        assert 'target' in explanation\n        assert 'node_mask' in explanation.available_explanations\n        assert explanation.node_mask.size() == data.x.size()\n\n\n@pytest.mark.parametrize('threshold_value', [0.2, 0.5, 0.8])\n@pytest.mark.parametrize('node_mask_type', ['object', 'attributes'])\ndef test_hard_threshold(data, threshold_value, node_mask_type):\n    explainer = Explainer(\n        DummyModel(),\n        algorithm=DummyExplainer(),\n        explanation_type='model',\n        node_mask_type=node_mask_type,\n        edge_mask_type='object',\n        model_config=dict(\n            mode='regression',\n            task_level='graph',\n        ),\n        threshold_config=('hard', threshold_value),\n    )\n    explanation = explainer(data.x, data.edge_index)\n\n    assert 'node_mask' in explanation.available_explanations\n    assert 'edge_mask' in explanation.available_explanations\n\n    for key in explanation.available_explanations:\n        mask = explanation[key]\n        assert set(mask.unique().tolist()).issubset({0, 1})\n\n\n@pytest.mark.parametrize('threshold_value', [1, 5, 10])\n@pytest.mark.parametrize('threshold_type', ['topk', 'topk_hard'])\n@pytest.mark.parametrize('node_mask_type', ['object', 'attributes'])\ndef test_topk_threshold(data, threshold_value, threshold_type, node_mask_type):\n    explainer = Explainer(\n        DummyModel(),\n        algorithm=DummyExplainer(),\n        explanation_type='model',\n        node_mask_type=node_mask_type,\n        edge_mask_type='object',\n        model_config=dict(\n            mode='regression',\n            task_level='graph',\n        ),\n        threshold_config=(threshold_type, threshold_value),\n    )\n    explanation = explainer(data.x, data.edge_index)\n\n    assert 'node_mask' in explanation.available_explanations\n    assert 'edge_mask' in explanation.available_explanations\n\n    for key in explanation.available_explanations:\n        mask = explanation[key]\n        if threshold_type == 'topk':\n            assert (mask > 0).sum() == min(mask.numel(), threshold_value)\n            assert ((mask == 0).sum() == mask.numel() -\n                    min(mask.numel(), threshold_value))\n        else:\n            assert (mask == 1).sum() == min(mask.numel(), threshold_value)\n            assert ((mask == 0).sum() == mask.numel() -\n                    min(mask.numel(), threshold_value))\n"
  },
  {
    "path": "test/explain/test_explanation.py",
    "content": "import os.path as osp\nfrom typing import Optional, Union\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.explain import Explanation\nfrom torch_geometric.explain.config import MaskType\nfrom torch_geometric.testing import withPackage\n\n\ndef create_random_explanation(\n    data: Data,\n    node_mask_type: Optional[Union[MaskType, str]] = None,\n    edge_mask_type: Optional[Union[MaskType, str]] = None,\n):\n    if node_mask_type is not None:\n        node_mask_type = MaskType(node_mask_type)\n    if edge_mask_type is not None:\n        edge_mask_type = MaskType(edge_mask_type)\n\n    if node_mask_type == MaskType.object:\n        node_mask = torch.rand(data.x.size(0), 1)\n    elif node_mask_type == MaskType.common_attributes:\n        node_mask = torch.rand(1, data.x.size(1))\n    elif node_mask_type == MaskType.attributes:\n        node_mask = torch.rand_like(data.x)\n    else:\n        node_mask = None\n\n    if edge_mask_type == MaskType.object:\n        edge_mask = torch.rand(data.edge_index.size(1))\n    else:\n        edge_mask = None\n\n    return Explanation(  # Create explanation.\n        node_mask=node_mask,\n        edge_mask=edge_mask,\n        x=data.x,\n        edge_index=data.edge_index,\n        edge_attr=data.edge_attr,\n    )\n\n\n@pytest.mark.parametrize('node_mask_type',\n                         [None, 'object', 'common_attributes', 'attributes'])\n@pytest.mark.parametrize('edge_mask_type', [None, 'object'])\ndef test_available_explanations(data, node_mask_type, edge_mask_type):\n    expected = []\n    if node_mask_type is not None:\n        expected.append('node_mask')\n    if edge_mask_type is not None:\n        expected.append('edge_mask')\n\n    explanation = create_random_explanation(\n        data,\n        node_mask_type=node_mask_type,\n        edge_mask_type=edge_mask_type,\n    )\n\n    assert set(explanation.available_explanations) == set(expected)\n\n\ndef test_validate_explanation(data):\n    explanation = create_random_explanation(data)\n    explanation.validate(raise_on_error=True)\n\n    with pytest.raises(ValueError, match=\"with 5 nodes\"):\n        explanation = create_random_explanation(data, node_mask_type='object')\n        explanation.x = torch.randn(5, 5)\n        explanation.validate(raise_on_error=True)\n\n    with pytest.raises(ValueError, match=\"with 4 features\"):\n        explanation = create_random_explanation(data, 'attributes')\n        explanation.x = torch.randn(4, 4)\n        explanation.validate(raise_on_error=True)\n\n    with pytest.raises(ValueError, match=\"with 7 edges\"):\n        explanation = create_random_explanation(data, edge_mask_type='object')\n        explanation.edge_index = torch.randint(0, 4, (2, 7))\n        explanation.validate(raise_on_error=True)\n\n\ndef test_node_mask(data):\n    node_mask = torch.tensor([[1.], [0.], [1.], [0.]])\n\n    explanation = Explanation(\n        node_mask=node_mask,\n        x=data.x,\n        edge_index=data.edge_index,\n        edge_attr=data.edge_attr,\n    )\n    explanation.validate(raise_on_error=True)\n\n    out = explanation.get_explanation_subgraph()\n    assert out.node_mask.size() == (2, 1)\n    assert (out.node_mask > 0.0).sum() == 2\n    assert out.x.size() == (2, 3)\n    assert out.edge_index.size(1) <= 6\n    assert out.edge_index.size(1) == out.edge_attr.size(0)\n\n    out = explanation.get_complement_subgraph()\n    assert out.node_mask.size() == (2, 1)\n    assert (out.node_mask == 0.0).sum() == 2\n    assert out.x.size() == (2, 3)\n    assert out.edge_index.size(1) <= 6\n    assert out.edge_index.size(1) == out.edge_attr.size(0)\n\n\ndef test_edge_mask(data):\n    edge_mask = torch.tensor([1., 0., 1., 0., 0., 1.])\n\n    explanation = Explanation(\n        edge_mask=edge_mask,\n        x=data.x,\n        edge_index=data.edge_index,\n        edge_attr=data.edge_attr,\n    )\n    explanation.validate(raise_on_error=True)\n\n    out = explanation.get_explanation_subgraph()\n    assert out.x.size() == (4, 3)\n    assert out.edge_mask.size() == (3, )\n    assert (out.edge_mask > 0.0).sum() == 3\n    assert out.edge_index.size() == (2, 3)\n    assert out.edge_attr.size() == (3, 3)\n\n    out = explanation.get_complement_subgraph()\n    assert out.x.size() == (4, 3)\n    assert out.edge_mask.size() == (3, )\n    assert (out.edge_mask == 0.0).sum() == 3\n    assert out.edge_index.size() == (2, 3)\n    assert out.edge_attr.size() == (3, 3)\n\n\n@withPackage('matplotlib', 'pandas')\n@pytest.mark.parametrize('top_k', [2, None])\n@pytest.mark.parametrize('node_mask_type', [None, 'attributes'])\ndef test_visualize_feature_importance(tmp_path, data, top_k, node_mask_type):\n    explanation = create_random_explanation(data, node_mask_type)\n\n    path = osp.join(tmp_path, 'feature_importance.png')\n\n    if node_mask_type is None:\n        with pytest.raises(ValueError, match=\"node_mask' is not\"):\n            explanation.visualize_feature_importance(path, top_k=top_k)\n    else:\n        explanation.visualize_feature_importance(path, top_k=top_k)\n        assert osp.exists(path)\n"
  },
  {
    "path": "test/explain/test_hetero_explainer.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.explain import (\n    DummyExplainer,\n    Explainer,\n    HeteroExplanation,\n)\nfrom torch_geometric.explain.config import ExplanationType\n\n\nclass DummyModel(torch.nn.Module):\n    def forward(self, x_dict, edge_index_dict, *args) -> torch.Tensor:\n        return x_dict['paper'].mean().view(-1)\n\n\ndef test_get_prediction(hetero_data):\n    model = DummyModel()\n    assert model.training\n\n    explainer = Explainer(\n        model,\n        algorithm=DummyExplainer(),\n        explanation_type='phenomenon',\n        node_mask_type='object',\n        model_config=dict(\n            mode='regression',\n            task_level='graph',\n        ),\n    )\n    pred = explainer.get_prediction(hetero_data.x_dict,\n                                    hetero_data.edge_index_dict)\n    assert model.training\n    assert pred.size() == (1, )\n\n\n@pytest.mark.parametrize('target', [None, torch.randn(2)])\n@pytest.mark.parametrize('explanation_type', [x for x in ExplanationType])\ndef test_forward(hetero_data, target, explanation_type):\n    model = DummyModel()\n\n    explainer = Explainer(\n        model,\n        algorithm=DummyExplainer(),\n        explanation_type=explanation_type,\n        node_mask_type='attributes',\n        model_config=dict(\n            mode='regression',\n            task_level='graph',\n        ),\n    )\n\n    if target is None and explanation_type == ExplanationType.phenomenon:\n        with pytest.raises(ValueError):\n            explainer(hetero_data.x_dict, hetero_data.edge_index_dict,\n                      target=target)\n    else:\n        explanation = explainer(\n            hetero_data.x_dict,\n            hetero_data.edge_index_dict,\n            target=target\n            if explanation_type == ExplanationType.phenomenon else None,\n        )\n        assert model.training\n        assert isinstance(explanation, HeteroExplanation)\n        assert 'node_mask' in explanation.available_explanations\n        for key in explanation.node_types:\n            expected_size = hetero_data[key].x.size()\n            assert explanation[key].node_mask.size() == expected_size\n\n\n@pytest.mark.parametrize('threshold_value', [0.2, 0.5, 0.8])\n@pytest.mark.parametrize('node_mask_type', ['object', 'attributes'])\ndef test_hard_threshold(hetero_data, threshold_value, node_mask_type):\n\n    explainer = Explainer(\n        DummyModel(),\n        algorithm=DummyExplainer(),\n        explanation_type='model',\n        node_mask_type=node_mask_type,\n        edge_mask_type='object',\n        model_config=dict(\n            mode='regression',\n            task_level='graph',\n        ),\n        threshold_config=('hard', threshold_value),\n    )\n    explanation = explainer(hetero_data.x_dict, hetero_data.edge_index_dict)\n    assert 'node_mask' in explanation.available_explanations\n    assert 'edge_mask' in explanation.available_explanations\n\n    for key in explanation.available_explanations:\n        for mask in explanation.collect(key).values():\n            assert set(mask.unique().tolist()).issubset({0, 1})\n\n\n@pytest.mark.parametrize('threshold_value', [1, 5, 10])\n@pytest.mark.parametrize('threshold_type', ['topk', 'topk_hard'])\n@pytest.mark.parametrize('node_mask_type', ['object', 'attributes'])\ndef test_topk_threshold(hetero_data, threshold_value, threshold_type,\n                        node_mask_type):\n    explainer = Explainer(\n        DummyModel(),\n        algorithm=DummyExplainer(),\n        explanation_type='model',\n        node_mask_type=node_mask_type,\n        edge_mask_type='object',\n        model_config=dict(\n            mode='regression',\n            task_level='graph',\n        ),\n        threshold_config=(threshold_type, threshold_value),\n    )\n    explanation = explainer(hetero_data.x_dict, hetero_data.edge_index_dict)\n\n    assert 'node_mask' in explanation.available_explanations\n    assert 'edge_mask' in explanation.available_explanations\n\n    for key in explanation.available_explanations:\n        for mask in explanation.collect(key).values():\n            if threshold_type == 'topk':\n                assert (mask > 0).sum() == min(mask.numel(), threshold_value)\n                assert ((mask == 0).sum() == mask.numel() -\n                        min(mask.numel(), threshold_value))\n            else:\n                assert (mask == 1).sum() == min(mask.numel(), threshold_value)\n                assert ((mask == 0).sum() == mask.numel() -\n                        min(mask.numel(), threshold_value))\n"
  },
  {
    "path": "test/explain/test_hetero_explanation.py",
    "content": "import os.path as osp\nfrom typing import Optional, Union\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.explain import HeteroExplanation\nfrom torch_geometric.explain.config import MaskType\nfrom torch_geometric.testing import withPackage\n\n\ndef create_random_explanation(\n    hetero_data: HeteroData,\n    node_mask_type: Optional[Union[MaskType, str]] = None,\n    edge_mask_type: Optional[Union[MaskType, str]] = None,\n):\n    if node_mask_type is not None:\n        node_mask_type = MaskType(node_mask_type)\n    if edge_mask_type is not None:\n        edge_mask_type = MaskType(edge_mask_type)\n\n    out = HeteroExplanation()\n\n    for key in ['paper', 'author']:\n        out[key].x = hetero_data[key].x\n        if node_mask_type == MaskType.object:\n            out[key].node_mask = torch.rand(hetero_data[key].num_nodes, 1)\n        elif node_mask_type == MaskType.common_attributes:\n            out[key].node_mask = torch.rand(1, hetero_data[key].num_features)\n        elif node_mask_type == MaskType.attributes:\n            out[key].node_mask = torch.rand_like(hetero_data[key].x)\n\n    for key in [('paper', 'paper'), ('paper', 'author')]:\n        out[key].edge_index = hetero_data[key].edge_index\n        out[key].edge_attr = hetero_data[key].edge_attr\n        if edge_mask_type == MaskType.object:\n            out[key].edge_mask = torch.rand(hetero_data[key].num_edges)\n\n    return out\n\n\n@pytest.mark.parametrize('node_mask_type',\n                         [None, 'object', 'common_attributes', 'attributes'])\n@pytest.mark.parametrize('edge_mask_type', [None, 'object'])\ndef test_available_explanations(hetero_data, node_mask_type, edge_mask_type):\n    expected = []\n    if node_mask_type:\n        expected.append('node_mask')\n    if edge_mask_type:\n        expected.append('edge_mask')\n\n    explanation = create_random_explanation(\n        hetero_data,\n        node_mask_type=node_mask_type,\n        edge_mask_type=edge_mask_type,\n    )\n\n    assert set(explanation.available_explanations) == set(expected)\n\n\ndef test_validate_explanation(hetero_data):\n    explanation = create_random_explanation(hetero_data)\n    explanation.validate(raise_on_error=True)\n\n    with pytest.raises(ValueError, match=\"with 8 nodes\"):\n        explanation = create_random_explanation(hetero_data)\n        explanation['paper'].node_mask = torch.rand(5, 5)\n        explanation.validate(raise_on_error=True)\n\n    with pytest.raises(ValueError, match=\"with 5 features\"):\n        explanation = create_random_explanation(hetero_data, 'attributes')\n        explanation['paper'].x = torch.randn(8, 5)\n        explanation.validate(raise_on_error=True)\n\n    with pytest.raises(ValueError, match=\"with 10 edges\"):\n        explanation = create_random_explanation(hetero_data)\n        explanation['paper', 'paper'].edge_mask = torch.randn(5)\n        explanation.validate(raise_on_error=True)\n\n\ndef test_node_mask():\n    explanation = HeteroExplanation()\n    explanation['paper'].node_mask = torch.tensor([[1.], [0.], [1.], [1.]])\n    explanation['author'].node_mask = torch.tensor([[1.], [0.], [1.], [1.]])\n    with pytest.warns(UserWarning, match=\"are isolated\"):\n        explanation.validate(raise_on_error=True)\n\n    out = explanation.get_explanation_subgraph()\n    assert out['paper'].node_mask.size() == (3, 1)\n    assert out['author'].node_mask.size() == (3, 1)\n\n    out = explanation.get_complement_subgraph()\n    assert out['paper'].node_mask.size() == (1, 1)\n    assert out['author'].node_mask.size() == (1, 1)\n\n\ndef test_edge_mask():\n    explanation = HeteroExplanation()\n    explanation['paper'].num_nodes = 4\n    explanation['author'].num_nodes = 4\n    explanation['paper', 'author'].edge_index = torch.tensor([\n        [0, 1, 2, 3],\n        [0, 1, 2, 3],\n    ])\n    explanation['paper', 'author'].edge_mask = torch.tensor([1., 0., 1., 1.])\n\n    out = explanation.get_explanation_subgraph()\n    assert out['paper'].num_nodes == 4\n    assert out['author'].num_nodes == 4\n    assert out['paper', 'author'].edge_mask.size() == (3, )\n    assert torch.equal(out['paper', 'author'].edge_index,\n                       torch.tensor([[0, 2, 3], [0, 2, 3]]))\n\n    out = explanation.get_complement_subgraph()\n    assert out['paper'].num_nodes == 4\n    assert out['author'].num_nodes == 4\n    assert out['paper', 'author'].edge_mask.size() == (1, )\n    assert torch.equal(out['paper', 'author'].edge_index,\n                       torch.tensor([[1], [1]]))\n\n\n@withPackage('matplotlib', 'pandas')\n@pytest.mark.parametrize('top_k', [2, None])\n@pytest.mark.parametrize('node_mask_type', [None, 'attributes'])\ndef test_visualize_feature_importance(\n    top_k,\n    node_mask_type,\n    tmp_path,\n    hetero_data,\n):\n    explanation = create_random_explanation(\n        hetero_data,\n        node_mask_type=node_mask_type,\n    )\n\n    path = osp.join(tmp_path, 'feature_importance.png')\n\n    if node_mask_type is None:\n        with pytest.raises(KeyError, match=\"Tried to collect 'node_mask'\"):\n            explanation.visualize_feature_importance(path, top_k=top_k)\n    else:\n        explanation.visualize_feature_importance(path, top_k=top_k)\n        assert osp.exists(path)\n\n\n@withPackage('matplotlib', 'networkx')\ndef test_hetero_visualize_graph(tmp_path, hetero_data):\n    # Create explanation with both node and edge masks\n    explanation = create_random_explanation(hetero_data,\n                                            node_mask_type='object',\n                                            edge_mask_type='object')\n\n    path = osp.join(tmp_path, 'explanation_graph.png')\n\n    # Test with default parameters\n    explanation.visualize_graph(path=path)\n    assert osp.exists(path)\n\n    # Test with custom visualization parameters\n    explanation.visualize_graph(path=path, node_size_range=(20, 400),\n                                node_opacity_range=(0.3, 0.9),\n                                edge_width_range=(0.2, 3.0),\n                                edge_opacity_range=(0.3, 0.9))\n    assert osp.exists(path)\n\n    # Test with node labels\n    node_labels = {\n        'paper': [f'Paper {i}' for i in range(hetero_data['paper'].num_nodes)],\n        'author':\n        [f'Author {i}' for i in range(hetero_data['author'].num_nodes)],\n    }\n    explanation.visualize_graph(path=path, node_labels=node_labels)\n    assert osp.exists(path)\n\n    # Test with invalid number of labels\n    invalid_labels = {\n        'paper': ['Paper 0'],  # Too few labels\n        'author': ['Author 0', 'Author 1'],  # Too few labels\n    }\n    with pytest.raises(ValueError, match=\"Number of labels\"):\n        explanation.visualize_graph(node_labels=invalid_labels)\n\n    # Test with invalid node type in labels\n    invalid_labels = {\n        'paper': [f'Paper {i}' for i in range(hetero_data['paper'].num_nodes)],\n        'author':\n        [f'Author {i}' for i in range(hetero_data['author'].num_nodes)],\n        'invalid_type': ['Invalid 0', 'Invalid 1'],  # Invalid node type\n    }\n    with pytest.raises(ValueError, match=\"Node type\"):\n        explanation.visualize_graph(node_labels=invalid_labels)\n"
  },
  {
    "path": "test/graphgym/example_node.yml",
    "content": "tensorboard_each_run: false\ntensorboard_agg: false\ndataset:\n  format: PyG\n  name: Cora\n  task: node\n  task_type: classification\n  node_encoder: false\n  node_encoder_name: Atom\n  edge_encoder: false\n  edge_encoder_name: Bond\ntrain:\n  batch_size: 128\n  eval_period: 2\n  ckpt_period: 100\n  enable_ckpt: false\n  skip_train_eval: true\n  sampler: full_batch\nmodel:\n  type: gnn\n  loss_fun: cross_entropy\n  edge_decoding: dot\n  graph_pooling: add\ngnn:\n  layers_pre_mp: 2\n  layers_mp: 2\n  layers_post_mp: 1\n  dim_inner: 16\n  layer_type: gcnconv\n  stage_type: stack\n  batchnorm: false\n  act: prelu\n  dropout: 0.1\n  agg: mean\n  normalize_adj: false\noptim:\n  optimizer: adam\n  base_lr: 0.01\n  max_epoch: 6\n"
  },
  {
    "path": "test/graphgym/test_config.py",
    "content": "from dataclasses import dataclass\n\nfrom torch_geometric.graphgym.config import from_config\n\n\n@dataclass\nclass MyConfig:\n    a: int\n    b: int = 4\n\n\ndef my_func(a: int, b: int = 2) -> str:\n    return f'a={a},b={b}'\n\n\ndef test_from_config():\n    assert my_func(a=1) == 'a=1,b=2'\n\n    assert my_func.__name__ == from_config(my_func).__name__\n    assert from_config(my_func)(cfg=MyConfig(a=1)) == 'a=1,b=4'\n    assert from_config(my_func)(cfg=MyConfig(a=1, b=1)) == 'a=1,b=1'\n    assert from_config(my_func)(2, cfg=MyConfig(a=1, b=3)) == 'a=2,b=3'\n    assert from_config(my_func)(cfg=MyConfig(a=1), b=3) == 'a=1,b=3'\n"
  },
  {
    "path": "test/graphgym/test_graphgym.py",
    "content": "import os.path as osp\nimport warnings\nfrom collections import namedtuple\n\nimport pytest\nimport torch\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.graphgym import register\nfrom torch_geometric.graphgym.checkpoint import get_ckpt_dir\nfrom torch_geometric.graphgym.config import (\n    cfg,\n    dump_cfg,\n    load_cfg,\n    set_out_dir,\n    set_run_dir,\n)\nfrom torch_geometric.graphgym.loader import create_loader\nfrom torch_geometric.graphgym.logger import set_printing\nfrom torch_geometric.graphgym.model_builder import create_model\nfrom torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNStackStage\nfrom torch_geometric.graphgym.models.head import GNNNodeHead\nfrom torch_geometric.graphgym.train import GraphGymDataModule, train\nfrom torch_geometric.graphgym.utils import (\n    agg_runs,\n    auto_select_device,\n    params_count,\n)\nfrom torch_geometric.testing import onlyLinux, onlyOnline, withPackage\n\nnum_trivial_metric_calls = 0\n\nArgs = namedtuple('Args', ['cfg_file', 'opts'])\nroot = osp.join(osp.dirname(osp.realpath(__file__)))\nargs = Args(osp.join(root, 'example_node.yml'), [])\n\n\ndef trivial_metric(true, pred, task_type):\n    global num_trivial_metric_calls\n    num_trivial_metric_calls += 1\n    return 1\n\n\n@onlyOnline\n@withPackage('yacs', 'pytorch_lightning')\n@pytest.mark.parametrize('auto_resume', [True, False])\n@pytest.mark.parametrize('skip_train_eval', [True, False])\n@pytest.mark.parametrize('use_trivial_metric', [True, False])\ndef test_run_single_graphgym(tmp_path, capfd, auto_resume, skip_train_eval,\n                             use_trivial_metric):\n    warnings.filterwarnings('ignore', \".*does not have many workers.*\")\n    warnings.filterwarnings('ignore', \".*lower value for log_every_n_steps.*\")\n\n    load_cfg(cfg, args)\n    cfg.out_dir = osp.join(tmp_path, 'out_dir')\n    cfg.run_dir = osp.join(tmp_path, 'run_dir')\n    cfg.dataset.dir = osp.join('/', 'tmp', 'pyg_test_datasets', 'Planetoid')\n    cfg.train.auto_resume = auto_resume\n\n    set_out_dir(cfg.out_dir, args.cfg_file)\n    dump_cfg(cfg)\n    set_printing()\n\n    seed_everything(cfg.seed)\n    auto_select_device()\n    set_run_dir(cfg.out_dir)\n\n    cfg.train.skip_train_eval = skip_train_eval\n    cfg.train.enable_ckpt = use_trivial_metric and skip_train_eval\n    if use_trivial_metric:\n        if 'trivial' not in register.metric_dict:\n            register.register_metric('trivial', trivial_metric)\n        global num_trivial_metric_calls\n        num_trivial_metric_calls = 0\n        cfg.metric_best = 'trivial'\n        cfg.custom_metrics = ['trivial']\n    else:\n        cfg.metric_best = 'auto'\n        cfg.custom_metrics = []\n\n    datamodule = GraphGymDataModule()\n    assert len(datamodule.loaders) == 3\n\n    model = create_model()\n    assert isinstance(model, torch.nn.Module)\n    assert isinstance(model.encoder, FeatureEncoder)\n    assert isinstance(model.mp, GNNStackStage)\n    assert isinstance(model.post_mp, GNNNodeHead)\n    assert len(list(model.pre_mp.children())) == cfg.gnn.layers_pre_mp\n\n    optimizer, scheduler = model.configure_optimizers()\n    assert isinstance(optimizer[0], torch.optim.Adam)\n    assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR)\n\n    cfg.params = params_count(model)\n    assert cfg.params == 23883\n\n    train(model, datamodule, logger=True,\n          trainer_config={\"enable_progress_bar\": False})\n\n    assert osp.isdir(get_ckpt_dir()) is cfg.train.enable_ckpt\n\n    agg_runs(cfg.out_dir, cfg.metric_best)\n\n    out, _ = capfd.readouterr()\n    assert \"train: {'epoch': 0,\" in out\n    assert \"val: {'epoch': 0,\" in out\n    assert \"train: {'epoch': 5,\" in out\n    assert \"val: {'epoch': 5,\" in out\n\n\n@onlyOnline\n@withPackage('yacs', 'pytorch_lightning')\ndef test_graphgym_module(tmp_path):\n    import pytorch_lightning as pl\n\n    load_cfg(cfg, args)\n    cfg.out_dir = osp.join(tmp_path, 'out_dir')\n    cfg.run_dir = osp.join(tmp_path, 'run_dir')\n    cfg.dataset.dir = osp.join('/', 'tmp', 'pyg_test_datasets', 'Planetoid')\n\n    set_out_dir(cfg.out_dir, args.cfg_file)\n    dump_cfg(cfg)\n    set_printing()\n\n    seed_everything(cfg.seed)\n    auto_select_device()\n    set_run_dir(cfg.out_dir)\n\n    loaders = create_loader()\n    assert len(loaders) == 3\n\n    model = create_model()\n    assert isinstance(model, pl.LightningModule)\n\n    optimizer, scheduler = model.configure_optimizers()\n    assert isinstance(optimizer[0], torch.optim.Adam)\n    assert isinstance(scheduler[0], torch.optim.lr_scheduler.CosineAnnealingLR)\n\n    cfg.params = params_count(model)\n    assert cfg.params == 23883\n\n    keys = {\"loss\", \"true\", \"pred_score\", \"step_end_time\"}\n    # test training step\n    batch = next(iter(loaders[0]))\n    batch.to(model.device)\n    outputs = model.training_step(batch)\n    assert keys == set(outputs.keys())\n    assert isinstance(outputs[\"loss\"], torch.Tensor)\n\n    # test validation step\n    batch = next(iter(loaders[1]))\n    batch.to(model.device)\n    outputs = model.validation_step(batch)\n    assert keys == set(outputs.keys())\n    assert isinstance(outputs[\"loss\"], torch.Tensor)\n\n    # test test step\n    batch = next(iter(loaders[2]))\n    batch.to(model.device)\n    outputs = model.test_step(batch)\n    assert keys == set(outputs.keys())\n    assert isinstance(outputs[\"loss\"], torch.Tensor)\n\n\n@pytest.fixture\ndef destroy_process_group():\n    yield\n    if torch.distributed.is_initialized():\n        torch.distributed.destroy_process_group()\n\n\n@onlyOnline\n@onlyLinux\n@withPackage('yacs', 'pytorch_lightning')\ndef test_train(destroy_process_group, tmp_path, capfd):\n    warnings.filterwarnings('ignore', \".*does not have many workers.*\")\n\n    import pytorch_lightning as pl\n\n    load_cfg(cfg, args)\n    cfg.out_dir = osp.join(tmp_path, 'out_dir')\n    cfg.run_dir = osp.join(tmp_path, 'run_dir')\n    cfg.dataset.dir = osp.join('/', 'tmp', 'pyg_test_datasets', 'Planetoid')\n\n    set_out_dir(cfg.out_dir, args.cfg_file)\n    dump_cfg(cfg)\n    set_printing()\n\n    seed_everything(cfg.seed)\n    auto_select_device()\n    set_run_dir(cfg.out_dir)\n\n    loaders = create_loader()\n    model = create_model()\n    cfg.params = params_count(model)\n\n    # --- minimal logger callback that collects logs ---\n    class LoggerCallback(pl.Callback):\n        def __init__(self):\n            super().__init__()\n            self.logged = []\n\n        def on_train_batch_end(self, trainer, pl_module, outputs, batch,\n                               batch_idx):\n            self.logged.append({\"type\": \"train\", \"step\": trainer.global_step})\n\n        def on_validation_batch_end(self, trainer, pl_module, outputs, batch,\n                                    batch_idx, dataloader_idx=0):\n            self.logged.append({\"type\": \"val\", \"step\": trainer.global_step})\n\n    logger = LoggerCallback()\n    trainer = pl.Trainer(max_epochs=2, max_steps=4, callbacks=[logger],\n                         log_every_n_steps=1, enable_progress_bar=False)\n    train_loader, val_loader = loaders[0], loaders[1]\n    trainer.fit(model, train_loader, val_loader)\n\n    assert trainer.current_epoch > 0\n    # ensure both train and val batches were seen\n    types = {entry[\"type\"] for entry in logger.logged}\n    assert \"val\" in types, \"Validation did not run\"\n    assert \"train\" in types, \"Training did not run\"\n"
  },
  {
    "path": "test/graphgym/test_logger.py",
    "content": "from torch_geometric.graphgym.config import set_run_dir\nfrom torch_geometric.graphgym.loader import create_loader\nfrom torch_geometric.graphgym.logger import Logger, LoggerCallback\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('yacs', 'pytorch_lightning')\ndef test_logger_callback():\n    loaders = create_loader()\n    assert len(loaders) == 3\n\n    set_run_dir('.')\n    logger = LoggerCallback()\n    assert isinstance(logger.train_logger, Logger)\n    assert isinstance(logger.val_logger, Logger)\n    assert isinstance(logger.test_logger, Logger)\n"
  },
  {
    "path": "test/graphgym/test_register.py",
    "content": "import torch\n\nimport torch_geometric.graphgym.register as register\nfrom torch_geometric.testing import withPackage\n\n\n@register.register_act('identity')\ndef identity_act(x: torch.Tensor) -> torch.Tensor:\n    return x\n\n\n@withPackage('yacs')\ndef test_register():\n    assert len(register.act_dict) == 8\n    assert list(register.act_dict.keys()) == [\n        'relu', 'selu', 'prelu', 'elu', 'lrelu_01', 'lrelu_025', 'lrelu_05',\n        'identity'\n    ]\n    assert str(register.act_dict['relu']()) == 'ReLU()'\n\n    register.register_act('lrelu_03', torch.nn.LeakyReLU(0.3))\n    assert len(register.act_dict) == 9\n    assert 'lrelu_03' in register.act_dict\n"
  },
  {
    "path": "test/io/example1.off",
    "content": "OFF\n4 2 0\n0.0 0.0 0.0\n0.0 1.0 0.0\n1.0 0.0 0.0\n1.0 1.0 0.0\n3 0 1 2\n3 1 2 3\n"
  },
  {
    "path": "test/io/example2.off",
    "content": "OFF\n4 1 0\n0.0 0.0 0.0\n0.0 1.0 0.0\n1.0 0.0 0.0\n1.0 1.0 0.0\n4 0 1 2 3\n"
  },
  {
    "path": "test/io/test_fs.py",
    "content": "import zipfile\nfrom os import path as osp\n\nimport fsspec\nimport pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.data import extract_zip\nfrom torch_geometric.io import fs\nfrom torch_geometric.testing import noWindows\n\nif torch_geometric.typing.WITH_WINDOWS:  # FIXME\n    params = ['file']\nelse:\n    params = ['file', 'memory']\n\n\n@pytest.fixture(params=params)\ndef tmp_fs_path(request, tmp_path) -> str:\n    if request.param == 'file':\n        return tmp_path.resolve().as_posix()\n    elif request.param == 'memory':\n        return f'memory://{tmp_path}'\n    raise NotImplementedError\n\n\ndef test_get_fs():\n    assert 'file' in fs.get_fs('/tmp/test').protocol\n    assert 'memory' in fs.get_fs('memory:///tmp/test').protocol\n\n\n@noWindows\ndef test_normpath():\n    assert fs.normpath('////home') == '/home'\n    assert fs.normpath('memory:////home') == 'memory:////home'\n\n\ndef test_exists(tmp_fs_path):\n    path = osp.join(tmp_fs_path, 'file.txt')\n    assert not fs.exists(path)\n    with fsspec.open(path, 'w') as f:\n        f.write('here')\n    assert fs.exists(path)\n\n\ndef test_makedirs(tmp_fs_path):\n    path = osp.join(tmp_fs_path, '1', '2')\n    assert not fs.isdir(path)\n    fs.makedirs(path)\n    assert fs.isdir(path)\n\n\n@pytest.mark.parametrize('detail', [False, True])\ndef test_ls(tmp_fs_path, detail):\n    for i in range(2):\n        with fsspec.open(osp.join(tmp_fs_path, str(i)), 'w') as f:\n            f.write('here')\n    res = fs.ls(tmp_fs_path, detail)\n    assert len(res) == 2\n    expected_protocol = fs.get_fs(tmp_fs_path).protocol\n    for output in res:\n        if detail:\n            output = output['name']\n        assert fs.get_fs(output).protocol == expected_protocol\n\n\ndef test_cp(tmp_fs_path):\n    src = osp.join(tmp_fs_path, 'src')\n    for i in range(2):\n        with fsspec.open(osp.join(src, str(i)), 'w') as f:\n            f.write('here')\n    assert fs.exists(src)\n\n    dst = osp.join(tmp_fs_path, 'dst')\n    assert not fs.exists(dst)\n\n    # Can copy a file to new name:\n    fs.cp(osp.join(src, '1'), dst)\n    assert fs.isfile(dst)\n    fs.rm(dst)\n\n    # Can copy a single file to directory:\n    fs.makedirs(dst)\n    fs.cp(osp.join(src, '1'), dst)\n    assert len(fs.ls(dst)) == 1\n\n    # Can copy multiple files to directory:\n    fs.cp(src, dst)\n    assert len(fs.ls(dst)) == 2\n    for i in range(2):\n        fs.exists(osp.join(dst, str(i)))\n\n\ndef test_extract(tmp_fs_path):\n    def make_zip(path: str):\n        with fsspec.open(path, mode='wb') as f:\n            with zipfile.ZipFile(f, mode='w') as z:\n                z.writestr('1', b'data')\n                z.writestr('2', b'data')\n\n    src = osp.join(tmp_fs_path, 'src', 'test.zip')\n    make_zip(src)\n    assert len(fsspec.open_files(f'zip://*::{src}')) == 2\n\n    dst = osp.join(tmp_fs_path, 'dst')\n    assert not fs.exists(dst)\n\n    # Can copy and extract afterwards:\n    if fs.isdisk(tmp_fs_path):\n        fs.cp(src, osp.join(dst, 'test.zip'))\n        assert fs.exists(osp.join(dst, 'test.zip'))\n        extract_zip(osp.join(dst, 'test.zip'), dst)\n        assert len(fs.ls(dst)) == 3\n        for i in range(2):\n            fs.exists(osp.join(dst, str(i)))\n        fs.rm(dst)\n\n    # Can copy and extract:\n    fs.cp(src, dst, extract=True)\n    assert len(fs.ls(dst)) == 2\n    for i in range(2):\n        fs.exists(osp.join(dst, str(i)))\n\n\ndef test_torch_save_load(tmp_fs_path):\n    x = torch.randn(5, 5)\n    path = osp.join(tmp_fs_path, 'x.pt')\n\n    fs.torch_save(x, path)\n    out = fs.torch_load(path)\n    assert torch.equal(x, out)\n"
  },
  {
    "path": "test/io/test_off.py",
    "content": "import os\nimport os.path as osp\nimport random\nimport sys\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.io import read_off, write_off\n\n\ndef test_read_off():\n    root_dir = osp.join(osp.dirname(osp.realpath(__file__)))\n\n    data = read_off(osp.join(root_dir, 'example1.off'))\n    assert len(data) == 2\n    assert data.pos.tolist() == [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]\n    assert data.face.tolist() == [[0, 1], [1, 2], [2, 3]]\n\n    data = read_off(osp.join(root_dir, 'example2.off'))\n    assert len(data) == 2\n    assert data.pos.tolist() == [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]\n    assert data.face.tolist() == [[0, 0], [1, 2], [2, 3]]\n\n\ndef test_write_off():\n    pos = torch.tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]])\n    face = torch.tensor([[0, 1], [1, 2], [2, 3]])\n\n    name = str(random.randrange(sys.maxsize))\n    path = osp.join('/', 'tmp', f'{name}.off')\n    write_off(Data(pos=pos, face=face), path)\n    data = read_off(path)\n    os.unlink(path)\n\n    assert data.pos.tolist() == pos.tolist()\n    assert data.face.tolist() == face.tolist()\n"
  },
  {
    "path": "test/llm/conftest.py",
    "content": "import pathlib\n\nimport pytest\n\nLLM_DIR = pathlib.Path(__file__).parent\n\n\ndef pytest_collection_modifyitems(items):\n    for item in items:\n        if pathlib.Path(item.fspath).is_relative_to(LLM_DIR):\n            item.add_marker(pytest.mark.rag)\n"
  },
  {
    "path": "test/llm/models/test_g_retriever.py",
    "content": "import gc\nfrom contextlib import nullcontext\nfrom types import SimpleNamespace\n\nimport pytest\nimport torch\nfrom torch import nn\n\nfrom torch_geometric.llm.models import LLM, GRetriever\nfrom torch_geometric.nn import GAT\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('transformers', 'sentencepiece', 'accelerate', 'peft')\n@pytest.mark.parametrize('use_lora', [True, False])\ndef test_g_retriever(use_lora: bool) -> None:\n    llm = LLM(model_name='Qwen/Qwen3-0.6B', dtype=torch.float32,\n              sys_prompt=\"You're an agent, answer my questions.\")\n\n    gnn = GAT(\n        in_channels=1024,\n        out_channels=1024,\n        hidden_channels=1024,\n        num_layers=2,\n        heads=4,\n        norm='batch_norm',\n    )\n\n    model = GRetriever(\n        llm=llm,\n        gnn=gnn,\n        use_lora=use_lora,\n    )\n    assert str(model) == ('GRetriever(\\n'\n                          '  llm=LLM(Qwen/Qwen3-0.6B),\\n'\n                          '  gnn=GAT(1024, 1024, num_layers=2),\\n'\n                          ')')\n\n    x = torch.randn(10, 1024)\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n        [1, 2, 3, 4, 5, 6, 7, 8, 9, 0],\n    ])\n    edge_attr = torch.randn(edge_index.size(1), 1024)\n    batch = torch.zeros(x.size(0), dtype=torch.long)\n\n    question = [\"Is PyG the best open-source GNN library?\"]\n    label = [\"yes!\"]\n\n    # Test train:\n    loss = model(question, x, edge_index, batch, label, edge_attr)\n    assert loss >= 0\n\n    # Test inference:\n    pred = model.inference(question, x, edge_index, batch, edge_attr)\n    assert len(pred) == 1\n    del model, llm, gnn\n    gc.collect()\n    torch.cuda.empty_cache()\n\n\n@withPackage('transformers', 'sentencepiece', 'accelerate', 'peft')\ndef test_g_retriever_many_tokens() -> None:\n    llm = LLM(model_name='Qwen/Qwen3-0.6B', dtype=torch.float32,\n              sys_prompt=\"You're an agent, answer my questions.\")\n\n    gnn = GAT(\n        in_channels=1024,\n        out_channels=1024,\n        hidden_channels=1024,\n        num_layers=2,\n        heads=4,\n        norm='batch_norm',\n    )\n\n    model = GRetriever(\n        llm=llm,\n        gnn=gnn,\n        mlp_out_tokens=2,\n        use_lora=True,\n    )\n    assert str(model) == ('GRetriever(\\n'\n                          '  llm=LLM(Qwen/Qwen3-0.6B),\\n'\n                          '  gnn=GAT(1024, 1024, num_layers=2),\\n'\n                          ')')\n\n    x = torch.randn(10, 1024)\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n        [1, 2, 3, 4, 5, 6, 7, 8, 9, 0],\n    ])\n    edge_attr = torch.randn(edge_index.size(1), 1024)\n    batch = torch.zeros(x.size(0), dtype=torch.long)\n\n    question = [\"Is PyG the best open-source GNN library?\"]\n    label = [\"yes!\"]\n\n    # Test train:\n    loss = model(question, x, edge_index, batch, label, edge_attr)\n    assert loss >= 0\n\n    # Test inference:\n    pred = model.inference(question, x, edge_index, batch, edge_attr)\n    assert len(pred) == 1\n    del model, llm, gnn\n    gc.collect()\n    torch.cuda.empty_cache()\n\n\nclass DummyHFModel(nn.Module):\n    def __init__(self, vocab_size=10):\n        super().__init__()\n        self.vocab_size = vocab_size\n        self.dummy = nn.Parameter(torch.zeros(1))\n\n    def forward(self, inputs_embeds=None, **kwargs):\n        B, T, _ = inputs_embeds.shape\n        logits = torch.randn(B, T, self.vocab_size,\n                             device=inputs_embeds.device)\n        loss = torch.tensor(0.0, device=inputs_embeds.device)\n        loss.logits = logits\n        return SimpleNamespace(\n            logits=logits,\n            loss=loss,\n        )\n\n\nclass DummyLLM:\n    def __init__(self, hidden_dim):\n        self.word_embedding = nn.Embedding(100, hidden_dim)\n        self.llm = DummyHFModel()\n        self.device = torch.device(\"cpu\")\n        self.autocast_context = nullcontext()\n\n    def _get_embeds(self, question, *args):\n        batch_size = len(question)\n        seq_len = 4\n        hidden = self.word_embedding.embedding_dim\n\n        inputs_embeds = torch.randn(batch_size, seq_len, hidden)\n        attention_mask = torch.ones(batch_size, seq_len, dtype=torch.long)\n\n        return inputs_embeds, attention_mask, None\n\n\nclass DummyGNN(nn.Module):\n    \"\"\"Simple GNN stub returning node embeddings.\"\"\"\n    def __init__(self, in_channels=4, out_channels=8):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.lin = nn.Linear(in_channels, out_channels)\n\n    def forward(self, *args, **kwargs):\n        x = args[0]\n        return self.lin(x)\n\n\n@pytest.mark.parametrize(\"batch_size\", [1, 3])\ndef test_gretriever_prefix_embedding_injection(batch_size):\n    hidden_dim = 8\n    num_nodes = 5\n\n    llm = DummyLLM(hidden_dim)\n    gnn = DummyGNN(in_channels=4, out_channels=8)\n\n    model = GRetriever(\n        llm=llm,\n        gnn=gnn,\n        mlp_out_tokens=2,\n    )\n\n    # graph inputs\n    x = torch.randn(num_nodes, 4)\n    edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])\n    batch = torch.zeros(num_nodes, dtype=torch.long)\n\n    # token ids\n    questions = [\"What is this graph?\"] * batch_size\n    labels = [\"dummy answer\"] * batch_size\n\n    out = model(\n        x=x,\n        edge_index=edge_index,\n        batch=batch,\n        question=questions,\n        label=labels,\n    )\n\n    # basic correctness assertions\n    assert hasattr(out, \"logits\")\n    assert out.logits.shape[0] == batch_size\n"
  },
  {
    "path": "test/llm/models/test_git_mol.py",
    "content": "import torch\n\nfrom torch_geometric.llm.models import GITMol\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('transformers', 'sentencepiece', 'accelerate')\ndef test_git_mol():\n    model = GITMol()\n\n    x = torch.ones(10, 16, dtype=torch.long)\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n        [1, 2, 3, 4, 0, 6, 7, 8, 9, 5],\n    ])\n    edge_attr = torch.zeros(edge_index.size(1), 16, dtype=torch.long)\n    # batch size = 1\n    batch = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0])\n    smiles = ['CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O']\n    captions = ['The molecule is the (R)-(-)-enantiomer of columbianetin.']\n    images = torch.randn(1, 3, 224, 224)\n    loss = model(x, edge_index, batch, edge_attr, smiles, images, captions)\n    assert loss >= 0\n\n    # batch size > 1\n    batch = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])\n    smiles = [\n        'CC(C)([C@H]1CC2=C(O1)C=CC3=C2OC(=O)C=C3)O',\n        'CCOc1ccccc1',\n    ]\n    captions = [\n        'The molecule is the (R)-(-)-enantiomer of columbianetin.',\n        'Ethoxybenzene is an aromatic ether.',\n    ]\n    images = torch.randn(2, 3, 224, 224)\n    loss = model(x, edge_index, batch, edge_attr, smiles, images, captions)\n    assert loss >= 0\n"
  },
  {
    "path": "test/llm/models/test_glem.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.llm.models.glem import deal_nan\n\n\ndef test_deal_nan_tensor_replaces_nans():\n    x = torch.tensor([1.0, float('nan'), 3.0])\n    result = deal_nan(x)\n\n    expected = torch.tensor([1.0, 0.0, 3.0])\n    assert torch.allclose(result, expected, equal_nan=True)\n    assert isinstance(result, torch.Tensor)\n    assert not torch.isnan(result).any()\n\n\ndef test_deal_nan_non_tensor_passthrough():\n    assert deal_nan(42.0) == 42.0\n    assert deal_nan(\"foo\") == \"foo\"\n\n\n@pytest.mark.parametrize(\"dtype\", [torch.float32, torch.bfloat16])\ndef test_deal_nan_tensor_dtypes(dtype):\n    # Create a tensor with one NaN value\n    x = torch.tensor([1.0, float('nan'), 3.0], dtype=dtype)\n    result = deal_nan(x)\n\n    expected = torch.tensor([1.0, 0.0, 3.0], dtype=dtype)\n\n    # `bfloat16` doesn't support `allclose` directly on CPU,\n    # so we cast to float32 for comparison\n    if dtype == torch.bfloat16:\n        assert torch.allclose(result.to(torch.float32),\n                              expected.to(torch.float32), atol=1e-2)\n    else:\n        assert torch.allclose(result, expected, equal_nan=True)\n\n    assert isinstance(result, torch.Tensor)\n    assert not torch.isnan(result).any()\n    assert result.dtype == dtype\n\n\ndef test_deal_nan_is_non_mutating():\n    x = torch.tensor([1.0, float('nan'), 3.0])\n    x_copy = x.clone()\n    _ = deal_nan(x)\n    assert torch.isnan(x).any()  # Original still contains NaN\n    assert torch.allclose(x, x_copy, equal_nan=True)\n"
  },
  {
    "path": "test/llm/models/test_llm.py",
    "content": "import gc\n\nimport pytest\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.llm.models import LLM\nfrom torch_geometric.llm.models.llm import get_llm_kwargs\nfrom torch_geometric.testing import withPackage\n\n\ndef test_get_llm_kwargs():\n    kwargs = get_llm_kwargs(required_memory=640)\n    assert kwargs == {'revision': 'main'}\n\n\n@withPackage('transformers', 'accelerate')\n@pytest.mark.parametrize('sys_prompt',\n                         ['You are an agent, answer my questions.', None])\n@pytest.mark.parametrize('context', [['This is context.'], None])\n@pytest.mark.parametrize('use_embedding', [True, False])\ndef test_llm(sys_prompt, context, use_embedding) -> None:\n    question = [\"Is PyG the best open-source GNN library?\"]\n    answer = [\"yes!\"]\n\n    model = LLM(\n        model_name='Qwen/Qwen3-0.6B',\n        num_params=1,\n        sys_prompt=sys_prompt,\n    )\n    assert str(model) == 'LLM(Qwen/Qwen3-0.6B)'\n\n    embedding = [torch.randn(1, 1024, dtype=torch.bfloat16).to(model.device)\n                 ] if use_embedding else None\n    loss = model(question, answer, context=context, embedding=embedding)\n    assert isinstance(loss, Tensor)\n    assert loss.dim() == 0\n    assert loss >= 0.0\n\n    pred = model.inference(question)\n    assert len(pred) == 1\n    del model\n    gc.collect()\n    torch.cuda.empty_cache()\n\n\nclass DummyBatch(dict):\n    \"\"\"Mimics HuggingFace BatchEncoding.\"\"\"\n    def to(self, device):\n        return self\n\n\nclass DummyTokenizer:\n    pad_token_id = 0\n    padding_side = \"left\"\n\n    def __call__(self, texts, return_tensors=None, padding=True):\n        lengths = [len(t) for t in texts]\n        max_len = max(lengths)\n\n        ids = []\n        mask = []\n\n        for seq_len in lengths:\n            padding = max_len - seq_len\n            ids.append([0] * padding + list(range(1, seq_len + 1)))\n            mask.append([0] * padding + [1] * seq_len)\n\n        return DummyBatch({\n            \"input_ids\": torch.tensor(ids),\n            \"attention_mask\": torch.tensor(mask)\n        })\n\n\nclass DummyModel(torch.nn.Module):\n    def get_input_embeddings(self):\n        return torch.nn.Embedding(100, 8)\n\n    def forward(self, inputs_embeds=None, attention_mask=None, **kwargs):\n        batch, seq, dim = inputs_embeds.shape\n\n        class Out:\n            pass\n\n        out = Out()\n        out.logits = torch.zeros(batch, seq, 10)\n        return out\n\n\n@pytest.fixture\ndef dummy_llm():\n    llm = LLM.__new__(LLM)\n    torch.nn.Module.__init__(llm)\n    llm.device = torch.device(\"cpu\")\n    llm.tokenizer = DummyTokenizer()\n    llm.model = DummyModel()\n    return llm\n\n\ndef test_llm_prepare_inputs(dummy_llm):\n    prompts = [\"hello\", \"hi\"]\n\n    encoded = dummy_llm.tokenizer(prompts)\n\n    input_ids = encoded[\"input_ids\"]\n    attention_mask = encoded[\"attention_mask\"]\n\n    emb = dummy_llm.model.get_input_embeddings()\n    inputs_embeds = emb(input_ids)\n\n    out = dummy_llm.model(inputs_embeds=inputs_embeds,\n                          attention_mask=attention_mask)\n\n    assert inputs_embeds.shape[0] == 2\n    assert attention_mask.shape == input_ids.shape\n    assert hasattr(out, \"logits\")\n    assert out.logits.shape[:2] == inputs_embeds.shape[:2]\n\n\ndef test_llm_single_prompt(dummy_llm):\n    encoded = dummy_llm.tokenizer([\"test\"])\n\n    assert encoded[\"input_ids\"].shape[0] == 1\n\n\ndef test_llm_variable_lengths(dummy_llm):\n    prompts = [\"a\", \"abcdef\", \"abc\"]\n\n    encoded = dummy_llm.tokenizer(prompts)\n\n    input_ids = encoded[\"input_ids\"]\n\n    assert input_ids.shape[0] == 3\n    assert input_ids.shape[1] == max(len(p) for p in prompts)\n"
  },
  {
    "path": "test/llm/models/test_llm_judge.py",
    "content": "import numpy as np\n\nfrom torch_geometric.llm.models import LLMJudge\n\n\ndef test_llm_judge():\n    judge = LLMJudge()\n    assert judge._process_score('1234') == 1.0\n    assert judge._average_scores(1, 3) == 2\n    assert judge._average_scores(-1, 3) == 3\n\n    assert np.isnan(judge.score('question', 'model_pred', 'correct_answer'))\n"
  },
  {
    "path": "test/llm/models/test_molecule_gpt.py",
    "content": "import torch\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nfrom torch_geometric.llm.models import LLM, MoleculeGPT, SentenceTransformer\nfrom torch_geometric.nn import GINEConv\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('transformers', 'sentencepiece', 'accelerate')\ndef test_molecule_gpt() -> None:\n    llm = LLM(\n        # model_name='lmsys/vicuna-7b-v1.5',\n        model_name='Qwen/Qwen3-0.6B',\n        num_params=1,\n        dtype=torch.float32)\n    graph_encoder = GINEConv(nn=Seq(Lin(16, 16), ReLU(), Lin(16, 16)),\n                             train_eps=True, edge_dim=16)\n\n    smiles_encoder = SentenceTransformer(\n        model_name='DeepChem/ChemBERTa-77M-MTR',\n        pooling_strategy='last_hidden_state',\n    )\n\n    model = MoleculeGPT(\n        llm=llm,\n        graph_encoder=graph_encoder,\n        smiles_encoder=smiles_encoder,\n    )\n\n    assert str(model) == (\n        'MoleculeGPT(\\n'\n        '  llm=LLM(Qwen/Qwen3-0.6B),\\n'\n        '  graph=GINEConv,\\n'\n        '  smiles=SentenceTransformer(model_name=DeepChem/ChemBERTa-77M-MTR),\\n'  # noqa: E501\n        ')')\n\n    x = torch.randn(10, 16)\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n        [1, 2, 3, 4, 5, 6, 7, 8, 9, 0],\n    ])\n    edge_attr = torch.randn(edge_index.size(1), 16)\n    batch = torch.zeros(x.size(0), dtype=torch.long)\n    smiles = ['CCCCCCCCCC']\n    instructions = ['What is ∼ functional related to?']\n    label = ['I do not know!']\n\n    # Test train:\n    loss = model(x, edge_index, batch, edge_attr, smiles, instructions, label)\n    assert loss >= 0\n\n    # Test inference:\n    pred = model.inference(x, edge_index, batch, edge_attr, smiles,\n                           instructions)\n    assert len(pred) == 1\n"
  },
  {
    "path": "test/llm/models/test_protein_mpnn.py",
    "content": "import torch\n\nfrom torch_geometric.llm.models import ProteinMPNN\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('torch_cluster')\ndef test_protein_mpnn():\n    num_nodes = 10\n    vocab_size = 21\n\n    model = ProteinMPNN(vocab_size=vocab_size)\n    x = torch.randn(num_nodes, 4, 3)\n    chain_seq_label = torch.randint(0, vocab_size, (num_nodes, ))\n    mask = torch.ones(num_nodes)\n    chain_mask_all = torch.ones(num_nodes)\n    residue_idx = torch.randint(0, 10, (num_nodes, ))\n    chain_encoding_all = torch.ones(num_nodes)\n    batch = torch.zeros(num_nodes, dtype=torch.long)\n\n    logits = model(x, chain_seq_label, mask, chain_mask_all, residue_idx,\n                   chain_encoding_all, batch)\n    assert logits.size() == (num_nodes, vocab_size)\n"
  },
  {
    "path": "test/llm/models/test_sentence_transformer.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.llm.models import SentenceTransformer\nfrom torch_geometric.llm.models.sentence_transformer import (\n    last_pooling,\n    mean_pooling,\n)\nfrom torch_geometric.testing import withCUDA, withPackage\n\n\n@withCUDA\n@withPackage('transformers')\n@pytest.mark.parametrize('batch_size', [None, 1])\n@pytest.mark.parametrize('pooling_strategy', ['mean', 'last', 'cls'])\n@pytest.mark.parametrize('verbose', [True, False])\ndef test_sentence_transformer(batch_size, pooling_strategy, device, verbose):\n\n    model_name = 'bert-base-uncased'\n    model = SentenceTransformer(\n        model_name=model_name,\n        pooling_strategy=pooling_strategy,\n    ).to(device)\n    assert model.device == device\n    assert str(model) == f'SentenceTransformer(model_name={model_name})'\n\n    text = [\n        \"this is a basic english text\",\n        \"PyG is the best open-source GNN library :)\",\n    ]\n\n    model_embedding_dim = model.model.config.hidden_size\n\n    out = model.encode(text, batch_size=batch_size, verbose=verbose)\n    assert out.device == device\n    assert out.shape == (2, model_embedding_dim)\n\n    out = model.encode(text, batch_size=batch_size, output_device='cpu',\n                       verbose=verbose)\n    assert out.is_cpu\n    assert out.shape == (2, model_embedding_dim)\n\n    out = model.encode([], batch_size=batch_size, verbose=verbose)\n    assert out.device == device\n    assert out.shape == (0, model_embedding_dim)\n\n\ndef test_mean_pooling():\n    x = torch.randn(2, 1, 2)\n    attention_mask = torch.zeros(2, 1)\n\n    result = mean_pooling(x, attention_mask)\n    expected = torch.zeros_like(x)\n    assert torch.allclose(result, expected, atol=1e-6)\n\n\n@pytest.mark.parametrize('mask', [torch.ones, torch.zeros])\ndef test_last_pooling(mask):\n    x = torch.randn(2, 1, 2)\n    attention_mask = mask(2, 1, dtype=torch.long)\n    out = last_pooling(x, attention_mask)\n    assert torch.allclose(out, x[:, 0, :], atol=1e-6)\n"
  },
  {
    "path": "test/llm/models/test_txt2kg.py",
    "content": "import sys\nimport types\n\nimport pytest\n\nimport torch_geometric.llm.models.txt2kg as txt2kg\nfrom torch_geometric.llm.models.txt2kg import (\n    TXT2KG,\n    _chunk_text,\n    _merge_triples_deterministically,\n    _multiproc_helper,\n    _parse_n_check_triples,\n)\n\n\ndef test_init_local_lm_flag():\n    model = TXT2KG(local_LM=True, chunk_size=20)\n    assert model.local_LM is True\n    assert model.initd_LM is False\n\n\ndef test_parse_n_check_triples_formats():\n    s = \"(A, rel, B)\\n(C, rel2, D)\"\n    parsed = _parse_n_check_triples(s)\n    assert (\"A\", \"rel\", \"B\") in parsed\n    assert (\"C\", \"rel2\", \"D\") in parsed\n\n\ndef test_chunk_text_simple_sentence():\n    text = \"Hello world. Another sentence!\"\n    chunks = _chunk_text(text, chunk_size=10)\n    # Only makes chunks at sentence boundaries\n    assert any(\"Hello\" in c for c in chunks)\n\n\nclass DummyLLM:\n    def __init__(self):\n        pass\n\n    def inference(self, *args, **kwargs):\n        return [\"(X,edge,Y)\"]\n\n\ndef test_local_lm_integration(monkeypatch):\n    model = TXT2KG(local_LM=True)\n\n    model.model = DummyLLM()\n    model.initd_LM = True\n\n    # Simulate time progression\n    times = iter([100.0, 100.05])  # 0.05 sec elapsed\n    monkeypatch.setattr(\"time.time\", lambda: next(times))\n\n    out = model._chunk_to_triples_str_local(\"text\")\n\n    assert out == \"(X,edge,Y)\"\n    assert model.time_to_parse > 0\n\n\ndef test_add_doc_empty(monkeypatch):\n    model = TXT2KG(local_LM=True)\n    model.add_doc_2_KG(\"\", QA_pair=None)\n    assert model.relevant_triples[0] == []\n\n\n# Mock LLM + parsing on real text:\ndef test_add_doc_to_KG(monkeypatch):\n    model = TXT2KG(local_LM=True, chunk_size=10)\n\n    # Mock only the LLM output stage\n    monkeypatch.setattr(model, \"_chunk_to_triples_str_local\",\n                        lambda *_: \"(A,rel,B)\\n(C,rel,D)\")\n\n    model.add_doc_2_KG(\"Some text\")\n\n    triples = model.relevant_triples[0]\n\n    assert len(triples) == 2\n    assert (\"A\", \"rel\", \"B\") in triples\n    assert model.doc_id_counter == 1\n\n\ndef test_merge_triples_deterministically_basic():\n    # Simple case: multiple sublists, strings only\n    results = [\n        [[\"b\", \"rel\", \"c\"], [\"A\", \"rel\", \"d\"]],\n        [[\"a\", \"rel\", \"c\"]],\n    ]\n\n    merged = _merge_triples_deterministically(results)\n\n    # Expect deterministic, casefolded lexicographic order\n    expected = [\n        (\"a\", \"rel\", \"c\"),\n        (\"A\", \"rel\", \"d\"),\n        (\"b\", \"rel\", \"c\"),\n    ]\n    assert merged == expected\n\n\ndef test_merge_triples_deterministically_unicode_and_nonstring():\n    # Include unicode and a numeric element to cover else branch in lambda\n    results = [\n        [[\"ä\", 2, \"x\"], [\"A\", 1, \"y\"]],\n        [[\"a\", 3, \"z\"]],\n    ]\n\n    merged = _merge_triples_deterministically(results)\n\n    # Ensure tuples, unicode sorted, numeric untouched\n    expected = [\n        (\"A\", 1, \"y\"),\n        (\"a\", 3, \"z\"),\n        (\"ä\", 2, \"x\"),\n    ]\n    assert merged == expected\n\n\ndef test_merge_triples_deterministically_empty():\n    # Edge case: empty input\n    results = []\n\n    merged = _merge_triples_deterministically(results)\n    assert merged == []\n\n\ndef test_merge_triples_deterministically_singleton():\n    # Edge case: single sublist, single triple\n    results = [[[\"only\", \"one\", \"triple\"]]]\n\n    merged = _merge_triples_deterministically(results)\n    assert merged == [(\"only\", \"one\", \"triple\")]\n\n\ndef test_chunk_to_triples_str_cloud(monkeypatch):\n    # Fake streaming chunk object\n    class DummyChunk:\n        class Choice:\n            class Delta:\n                content = \"A\"\n\n            delta = Delta()\n\n        choices = [Choice()]\n\n    class DummyCompletion:\n        def __iter__(self):\n            return iter([DummyChunk()])\n\n    class DummyClient:\n        class Chat:\n            class Completions:\n                def create(self, **kwargs):\n                    return DummyCompletion()\n\n            completions = Completions()\n\n        chat = Chat()\n\n    class DummyOpenAI:\n        def __init__(self, *args, **kwargs):\n            pass\n\n        chat = DummyClient.chat\n\n    fake_openai = types.ModuleType(\"openai\")\n    fake_openai.OpenAI = DummyOpenAI\n\n    monkeypatch.setitem(sys.modules, \"openai\", fake_openai)\n\n    txt2kg.CLIENT_INITD = False\n\n    out = txt2kg._chunk_to_triples_str_cloud(\"text\")\n    assert isinstance(out, str)\n\n\ndef dummy_multiproc_helper(\n    rank,\n    chunks,\n    py_fn,\n    llm_fn,\n    NIM_KEY,\n    NIM_MODEL,\n    ENDPOINT_URL,\n    max_retries=3,\n    base_delay=0,\n):\n    return [(\"A\", \"rel\", \"B\")]\n\n\ndef test_extract_relevant_triples_cloud(monkeypatch):\n\n    model = TXT2KG(local_LM=False, chunk_size=10)\n\n    # Mock the multiproc helper (module-level)\n    monkeypatch.setattr(txt2kg, \"_multiproc_helper\", dummy_multiproc_helper)\n\n    triples = model._extract_relevant_triples(\"Some text\")\n    assert (\"A\", \"rel\", \"B\") in triples\n\n\ndef test_multiproc_helper_success(monkeypatch):\n    # Dummy LLM/Python parser\n    def dummy_llm_fn(x, **kwargs):\n        return [\"llm:\" + str(x)]\n\n    def dummy_py_fn(x):\n        return [\"py:\" + str(i) for i in x]\n\n    # Patch _llm_then_python_parse\n    monkeypatch.setattr(\n        \"torch_geometric.llm.models.txt2kg._llm_then_python_parse\",\n        lambda chunks, py_fn, llm_fn, **kwargs: [\"PARSED:\" + str(chunks)])\n\n    # Input chunks for rank 0\n    chunks_for_rank = [\"chunk0\", \"chunk1\"]\n\n    result = _multiproc_helper(\n        rank=0,\n        chunks_for_rank=chunks_for_rank,\n        py_fn=dummy_py_fn,\n        llm_fn=dummy_llm_fn,\n        NIM_KEY=\"dummy\",\n        NIM_MODEL=\"dummy\",\n        ENDPOINT_URL=\"dummy\",\n        max_retries=3,\n        base_delay=0.01  # keep backoff small in tests\n    )\n\n    assert result == [\"PARSED:['chunk0', 'chunk1']\"]\n\n\ndef test_multiproc_helper_retry(monkeypatch):\n    attempts = []\n\n    def failing_parse(chunks, py_fn, llm_fn, **kwargs):\n        attempts.append(1)\n        if len(attempts) < 3:\n            raise RuntimeError(\"fail\")\n        return [\"SUCCESS\"]\n\n    monkeypatch.setattr(\n        \"torch_geometric.llm.models.txt2kg._llm_then_python_parse\",\n        failing_parse)\n\n    result = _multiproc_helper(\n        rank=0,\n        chunks_for_rank=[\"chunk\"],\n        py_fn=lambda x: x,\n        llm_fn=lambda x: x,\n        NIM_KEY=\"dummy\",\n        NIM_MODEL=\"dummy\",\n        ENDPOINT_URL=\"dummy\",\n        max_retries=5,\n        base_delay=0  # instant retries for test\n    )\n\n    assert result == [\"SUCCESS\"]\n    assert len(attempts) == 3  # retried twice, succeeded on 3rd\n\n\ndef test_add_doc_empty_text():\n    kg = TXT2KG(local_LM=True)\n\n    kg.add_doc_2_KG(txt=\"\")\n\n    # first doc uses doc_id_counter=0 as key\n    assert 0 in kg.relevant_triples\n    assert kg.relevant_triples[0] == []\n\n    # doc counter should increment\n    assert kg.doc_id_counter == 1\n\n\ndef test_add_doc_empty_text_with_QA_pair():\n    kg = TXT2KG(local_LM=True)\n\n    qa = (\"What is PyG?\", \"Graph ML library\")\n\n    kg.add_doc_2_KG(txt=\"\", QA_pair=qa)\n\n    assert qa in kg.relevant_triples\n    assert kg.relevant_triples[qa] == []\n\n\n@pytest.fixture\ndef kg_cpu():\n    # TXT2KG instance using CPU local LLM mode\n    return TXT2KG(local_LM=True)\n\n\ndef test_add_doc_empty_text_cpu(kg_cpu):\n    \"\"\"Cover the empty text branch (lines 194-201).\"\"\"\n    kg_cpu.add_doc_2_KG(txt=\"\")\n    # doc_id_counter starts at 0\n    assert kg_cpu.relevant_triples[0] == []\n    assert kg_cpu.doc_id_counter == 1\n\n\ndef test_add_doc_empty_text_with_QA_pair_cpu(kg_cpu):\n    \"\"\"Cover QA_pair key path with empty text.\"\"\"\n    qa = (\"What is PyG?\", \"Graph ML library\")\n    kg_cpu.add_doc_2_KG(txt=\"\", QA_pair=qa)\n    assert qa in kg_cpu.relevant_triples\n    assert kg_cpu.relevant_triples[qa] == []\n\n\ndef test_add_doc_nonempty_text_placeholder(kg_cpu, monkeypatch):\n    \"\"\"Minimal coverage for non-empty text branch.\n    Avoids importing the real LLM.\n    \"\"\"\n    # Patch the module-level function _llm_then_python_parse\n    monkeypatch.setattr(txt2kg, \"_llm_then_python_parse\",\n                        lambda chunks, *args, **kwargs: [])\n\n    # Call add_doc_2_KG with non-empty text\n    kg_cpu.add_doc_2_KG(txt=\"some text\")\n\n    # Ensure doc_id_counter incremented and key exists\n    key = kg_cpu.doc_id_counter - 1\n    assert key in kg_cpu.relevant_triples\n"
  },
  {
    "path": "test/llm/models/test_vision_transformer.py",
    "content": "import torch\n\nfrom torch_geometric.llm.models import VisionTransformer\nfrom torch_geometric.testing import onlyFullTest, withCUDA, withPackage\n\n\n@withCUDA\n@onlyFullTest\n@withPackage('transformers')\ndef test_vision_transformer(device):\n    model = VisionTransformer(\n        model_name='microsoft/swin-base-patch4-window7-224', ).to(device)\n    assert model.device == device\n    assert str(\n        model\n    ) == 'VisionTransformer(model_name=microsoft/swin-base-patch4-window7-224)'\n\n    images = torch.randn(2, 3, 224, 224).to(device)\n\n    out = model(images)\n    assert out.device == device\n    assert out.size() == (2, 49, 1024)\n\n    out = model(images, output_device='cpu')\n    assert out.is_cpu\n    assert out.size() == (2, 49, 1024)\n"
  },
  {
    "path": "test/llm/test_large_graph_indexer.py",
    "content": "import random\nimport string\nfrom typing import List\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.llm.large_graph_indexer import (\n    EDGE_PID,\n    EDGE_RELATION,\n    NODE_PID,\n    LargeGraphIndexer,\n    TripletLike,\n    get_features_for_triplets,\n)\nfrom torch_geometric.llm.utils.backend_utils import preprocess_triplet\nfrom torch_geometric.typing import WITH_PT20\n\n# create possible nodes and edges for graph\nstrkeys = string.ascii_letters + string.digits\nNODE_POOL = list(\n    {\"\".join(random.sample(strkeys, 10)).lower()\n     for i in range(1000)})\nEDGE_POOL = list(\n    {\"\".join(random.sample(strkeys, 10)).lower()\n     for i in range(50)})\n\n\ndef featurize(s: str) -> int:\n    return int.from_bytes(s.encode(), 'little')\n\n\ndef sample_triplets(amount: int = 1) -> List[TripletLike]:\n    trips = []\n    for _ in range(amount):\n        h, t = random.sample(NODE_POOL, k=2)\n        r = random.sample(EDGE_POOL, k=1)[0]\n        trips.append(tuple([h, r, t]))\n    return trips\n\n\ndef test_basic_collate():\n    graphs = [sample_triplets(1000) for i in range(2)]\n\n    indexer_0 = LargeGraphIndexer.from_triplets(\n        graphs[0], pre_transform=preprocess_triplet)\n    indexer_1 = LargeGraphIndexer.from_triplets(\n        graphs[1], pre_transform=preprocess_triplet)\n\n    big_indexer = LargeGraphIndexer.collate([indexer_0, indexer_1])\n\n    assert len(indexer_0._nodes) + len(\n        indexer_1._nodes) - len(indexer_0._nodes.keys()\n                                & indexer_1._nodes.keys()) == len(\n                                    big_indexer._nodes)\n    assert len(indexer_0._edges) + len(\n        indexer_1._edges) - len(indexer_0._edges.keys()\n                                & indexer_1._edges.keys()) == len(\n                                    big_indexer._edges)\n\n    assert len(set(big_indexer._nodes.values())) == len(big_indexer._nodes)\n    assert len(set(big_indexer._edges.values())) == len(big_indexer._edges)\n\n    for node in (indexer_0._nodes.keys() | indexer_1._nodes.keys()):\n        assert big_indexer.node_attr[NODE_PID][\n            big_indexer._nodes[node]] == node\n\n\ndef test_large_graph_index():\n    graphs = [sample_triplets(1000) for i in range(100)]\n\n    # Preprocessing of trips lowercases nodes but not edges\n    node_feature_vecs = {s.lower(): featurize(s.lower()) for s in NODE_POOL}\n    edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL}\n\n    def encode_graph_from_trips(triplets: List[TripletLike]) -> Data:\n        seen_nodes = dict()\n        edge_attrs = list()\n        edge_idx = []\n        for trip in triplets:\n            trip = preprocess_triplet(trip)\n            h, r, t = trip\n            seen_nodes[h] = len(\n                seen_nodes) if h not in seen_nodes else seen_nodes[h]\n            seen_nodes[t] = len(\n                seen_nodes) if t not in seen_nodes else seen_nodes[t]\n            edge_attrs.append(edge_feature_vecs[r])\n            edge_idx.append((seen_nodes[h], seen_nodes[t]))\n\n        x = torch.Tensor([node_feature_vecs[n] for n in seen_nodes.keys()])\n        edge_idx = torch.LongTensor(edge_idx).T\n        edge_attrs = torch.Tensor(edge_attrs)\n        return Data(x=x, edge_index=edge_idx, edge_attr=edge_attrs)\n\n    naive_graph_ds = [\n        encode_graph_from_trips(triplets=trips) for trips in graphs\n    ]\n\n    indexer = LargeGraphIndexer.collate([\n        LargeGraphIndexer.from_triplets(g, pre_transform=preprocess_triplet)\n        for g in graphs\n    ])\n    indexer_nodes = indexer.get_unique_node_features()\n    indexer_node_vals = torch.Tensor(\n        [node_feature_vecs[n] for n in indexer_nodes])\n    indexer_edges = indexer.get_unique_edge_features(\n        feature_name=EDGE_RELATION)\n    indexer_edge_vals = torch.Tensor(\n        [edge_feature_vecs[e] for e in indexer_edges])\n    indexer.add_node_feature('x', indexer_node_vals)\n    indexer.add_edge_feature('edge_attr', indexer_edge_vals,\n                             map_from_feature=EDGE_RELATION)\n    large_graph_ds = [\n        get_features_for_triplets(indexer=indexer, triplets=g,\n                                  node_feature_name='x',\n                                  edge_feature_name='edge_attr',\n                                  pre_transform=preprocess_triplet)\n        for g in graphs\n    ]\n\n    for ds in large_graph_ds:\n        assert NODE_PID in ds\n        assert EDGE_PID in ds\n        assert \"node_idx\" in ds\n        assert \"edge_idx\" in ds\n\n    def results_are_close_enough(ground_truth: Data, new_method: Data,\n                                 thresh=.99):\n        def _sorted_tensors_are_close(tensor1, tensor2):\n            return torch.all(\n                torch.isclose(tensor1.sort()[0],\n                              tensor2.sort()[0]) > thresh)\n\n        def _graphs_are_same(tensor1, tensor2):\n            if not WITH_PT20:\n                pytest.skip(\n                    \"This test requires a PyG version with NetworkX as a \" +\n                    \"dependency.\")\n            import networkx as nx\n            return nx.weisfeiler_lehman_graph_hash(nx.Graph(\n                tensor1.T)) == nx.weisfeiler_lehman_graph_hash(\n                    nx.Graph(tensor2.T))\n        return _sorted_tensors_are_close(\n            ground_truth.x, new_method.x) \\\n            and _sorted_tensors_are_close(\n                ground_truth.edge_attr, new_method.edge_attr) \\\n            and _graphs_are_same(\n                ground_truth.edge_index, new_method.edge_index)\n\n    for dsets in zip(naive_graph_ds, large_graph_ds):\n        assert results_are_close_enough(*dsets)\n\n\ndef test_save_load(tmp_path):\n    graph = sample_triplets(1000)\n\n    node_feature_vecs = {s: featurize(s) for s in NODE_POOL}\n    edge_feature_vecs = {s: featurize(s) for s in EDGE_POOL}\n\n    indexer = LargeGraphIndexer.from_triplets(graph)\n    indexer_nodes = indexer.get_unique_node_features()\n    indexer_node_vals = torch.Tensor(\n        [node_feature_vecs[n] for n in indexer_nodes])\n    indexer_edges = indexer.get_unique_edge_features(\n        feature_name=EDGE_RELATION)\n    indexer_edge_vals = torch.Tensor(\n        [edge_feature_vecs[e] for e in indexer_edges])\n    indexer.add_node_feature('x', indexer_node_vals)\n    indexer.add_edge_feature('edge_attr', indexer_edge_vals,\n                             map_from_feature=EDGE_RELATION)\n\n    indexer.save(str(tmp_path))\n    assert indexer == LargeGraphIndexer.from_disk(str(tmp_path))\n"
  },
  {
    "path": "test/llm/test_rag_loader.py",
    "content": "import os\nfrom typing import Any, Dict\nfrom unittest.mock import Mock\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.llm.models import SentenceTransformer\nfrom torch_geometric.llm.rag_loader import RAGQueryLoader\nfrom torch_geometric.llm.utils.backend_utils import (\n    create_graph_from_triples,\n    create_remote_backend_from_graph_data,\n)\nfrom torch_geometric.llm.utils.feature_store import KNNRAGFeatureStore\nfrom torch_geometric.llm.utils.graph_store import NeighborSamplingRAGGraphStore\nfrom torch_geometric.llm.utils.vectorrag import VectorRetriever\nfrom torch_geometric.sampler import SamplerOutput\nfrom torch_geometric.testing import withPackage\n\n\nclass MockRAGFeatureStore:\n    \"\"\"Mock implementation of RAGFeatureStore protocol for testing.\"\"\"\n    def __init__(self):\n        self._config = {}\n        self.x = torch.randn(10, 64)  # Sample node features\n\n    def retrieve_seed_nodes(self, query: Any, **kwargs):\n        \"\"\"Mock retrieve_seed_nodes method.\"\"\"\n        seed_nodes = torch.tensor([0, 1, 2, 3, 4])\n        query_enc = torch.randn(1, 64)\n        return seed_nodes, query_enc\n\n    @property\n    def config(self) -> Dict[str, Any]:\n        return self._config\n\n    @config.setter\n    def config(self, config: Dict[str, Any]):\n        if config is None:\n            raise ValueError(\"Config cannot be None\")\n        if 'a' not in config:\n            raise ValueError(\"Required config parameter 'a' not found\")\n        self._config = config\n\n    def retrieve_seed_edges(self, query: Any, **kwargs):\n        \"\"\"Mock retrieve_seed_edges method.\"\"\"\n        return torch.tensor([[0, 1], [1, 2], [2, 3]])\n\n    def load_subgraph(self, sample):\n        \"\"\"Mock load_subgraph method.\"\"\"\n        data = Data()\n        data.edge_idx = torch.tensor([0, 1, 2])\n        data.node_idx = torch.tensor([0, 1, 2, 3, 4])\n        return data\n\n\nclass MockRAGGraphStore:\n    \"\"\"Mock implementation of RAGGraphStore protocol for testing.\"\"\"\n    def __init__(self):\n        self._config = {}\n        self.edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 0]])\n\n    def sample_subgraph(self, seed_nodes, seed_edges=None, **kwargs):\n        \"\"\"Mock sample_subgraph method.\"\"\"\n        return SamplerOutput(node=seed_nodes, row=torch.tensor([0, 1, 2]),\n                             col=torch.tensor([1, 2, 3]),\n                             edge=torch.tensor([0, 1, 2]), batch=None)\n\n    @property\n    def config(self) -> Dict[str, Any]:\n        return self._config\n\n    @config.setter\n    def config(self, config: Dict[str, Any]):\n        if config is None:\n            raise ValueError(\"Config cannot be None\")\n        if 'b' not in config:\n            raise ValueError(\"Required config parameter 'b' not found\")\n        self._config = config\n\n    def register_feature_store(self, feature_store):\n        \"\"\"Mock register_feature_store method.\"\"\"\n\n\nclass TestRAGQueryLoader:\n    \"\"\"Test suite for RAGQueryLoader.\"\"\"\n    def setup_method(self):\n        \"\"\"Set up test fixtures before each test method.\"\"\"\n        self.mock_feature_store = MockRAGFeatureStore()\n        self.mock_graph_store = MockRAGGraphStore()\n        self.graph_data = (self.mock_feature_store, self.mock_graph_store)\n\n        # Sample config\n        self.sample_config = {\"a\": 5, \"b\": [10, 5], \"c\": \"test_value\"}\n\n    def test_initialization_basic(self):\n        \"\"\"Test basic initialization of RAGQueryLoader.\"\"\"\n        loader = RAGQueryLoader(self.graph_data, config=self.sample_config)\n\n        assert loader.feature_store == self.mock_feature_store\n        assert loader.graph_store == self.mock_graph_store\n        assert loader.vector_retriever is None\n        assert loader.augment_query is False\n        assert loader.subgraph_filter is None\n        assert loader.config == self.sample_config\n\n    def test_initialization_with_all_params(self):\n        \"\"\"Test initialization with all parameters.\"\"\"\n        mock_vector_retriever = Mock(spec=VectorRetriever)\n        mock_subgraph_filter = Mock()\n\n        loader = RAGQueryLoader(graph_data=self.graph_data,\n                                subgraph_filter=mock_subgraph_filter,\n                                augment_query=True,\n                                vector_retriever=mock_vector_retriever,\n                                config=self.sample_config)\n\n        assert loader.feature_store == self.mock_feature_store\n        assert loader.graph_store == self.mock_graph_store\n        assert loader.vector_retriever == mock_vector_retriever\n        assert loader.augment_query is True\n        assert loader.subgraph_filter == mock_subgraph_filter\n        assert loader.config == self.sample_config\n\n    def test_bad_config(self):\n        \"\"\"Test bad config initialization.\"\"\"\n        with pytest.raises(ValueError):\n            RAGQueryLoader(self.graph_data)\n        with pytest.raises(ValueError):\n            RAGQueryLoader(self.graph_data, config={'d': 'foobar'})\n\n    def test_config_propagation(self):\n        \"\"\"Test that config is propagated during initialization.\"\"\"\n        loader = RAGQueryLoader(self.graph_data, config=self.sample_config)\n\n        assert loader.feature_store.config == self.sample_config\n        assert loader.graph_store.config == self.sample_config\n\n    def test_basic_query_without_vector_retriever(self):\n        \"\"\"Test basic query functionality without vector retriever.\"\"\"\n        loader = RAGQueryLoader(self.graph_data, config=self.sample_config)\n\n        query = \"test query\"\n        result = loader.query(query)\n\n        # Verify result is a Data object\n        assert isinstance(result, Data)\n\n        # Verify the data has expected attributes\n        assert hasattr(result, 'node_idx')\n        assert hasattr(result, 'num_nodes')\n        assert hasattr(result, 'x')\n        assert hasattr(result, 'edge_index')\n\n    def test_query_with_vector_retriever(self):\n        \"\"\"Test query functionality with vector retriever.\"\"\"\n        mock_vector_retriever = Mock(spec=VectorRetriever)\n        mock_vector_retriever.query.return_value = [\n            \"retrieved doc 1\", \"retrieved doc 2\"\n        ]\n\n        loader = RAGQueryLoader(self.graph_data,\n                                vector_retriever=mock_vector_retriever,\n                                config=self.sample_config)\n\n        query = \"test query\"\n        result = loader.query(query)\n\n        # Verify vector retriever was called\n        mock_vector_retriever.query.assert_called_once_with(query)\n\n        # Verify result has text_context\n        assert hasattr(result, 'text_context')\n        assert result.text_context == [\"retrieved doc 1\", \"retrieved doc 2\"]\n\n    def test_query_with_subgraph_filter(self):\n        \"\"\"Test query functionality with subgraph filter.\"\"\"\n        mock_filter_result = Data()\n        mock_filter_result.filtered = True\n\n        mock_subgraph_filter = Mock(return_value=mock_filter_result)\n\n        loader = RAGQueryLoader(self.graph_data,\n                                subgraph_filter=mock_subgraph_filter,\n                                config=self.sample_config)\n\n        query = \"test query\"\n        result = loader.query(query)\n\n        # Verify subgraph filter was called\n        mock_subgraph_filter.assert_called_once()\n        call_args = mock_subgraph_filter.call_args[0]\n        assert len(call_args) == 2\n        assert call_args[1] == query\n\n        # Verify result is the filtered result\n        assert result == mock_filter_result\n        assert hasattr(result, 'filtered')\n        assert result.filtered is True\n\n\n@withPackage('pyg_lib', 'torch_sparse')\ndef test_rag_loader_integration(tmp_path):\n    \"\"\"Test RAGQueryLoader with real feature and graph stores from triples.\"\"\"\n    # Define test triplets - simple knowledge graph about cities/countries\n    triplets = [\n        [\"Paris\", \"capital_of\", \"France\"],\n        [\"London\", \"capital_of\", \"UK\"],\n        [\"Berlin\", \"capital_of\", \"Germany\"],\n        [\"France\", \"in_continent\", \"Europe\"],\n        [\"UK\", \"in_continent\", \"Europe\"],\n        [\"Germany\", \"in_continent\", \"Europe\"],\n        [\"Rome\", \"capital_of\", \"Italy\"],\n        [\"Italy\", \"in_continent\", \"Europe\"],\n        [\"Madrid\", \"capital_of\", \"Spain\"],\n        [\"Spain\", \"in_continent\", \"Europe\"],\n    ]\n\n    encoder_model = SentenceTransformer('bert-base-uncased')\n    # Create graph from triplets\n    graph_data = create_graph_from_triples(triplets, encoder_model.encode)\n\n    save_path = os.path.join(tmp_path, \"test_graph.pt\")\n    loader = create_remote_backend_from_graph_data(\n        graph_data=graph_data, path=save_path, n_parts=1,\n        graph_db=NeighborSamplingRAGGraphStore, feature_db=KNNRAGFeatureStore)\n    feature_store, graph_store = loader.load()\n\n    # Configuration\n    config = {\n        \"k_nodes\": 1,\n        \"encoder_model\": encoder_model,\n        \"num_neighbors\": [10]  # 10 neighbors only one hop\n    }\n\n    # Create RAG loader\n    rag_data = (feature_store, graph_store)\n    loader = RAGQueryLoader(rag_data, config=config)\n\n    # Test query about European capitals\n    query = \"countries in Europe\"\n    result = loader.query(query)\n\n    # Verify result structure\n    assert isinstance(result, Data)\n    assert torch.equal(result.edge_index,\n                       torch.tensor([[1, 2, 3, 4, 5], [0, 0, 0, 0, 0]]))\n    expected_x = encoder_model.encode(\n        [\"Europe\", \"France\", \"UK\", \"Germany\", \"Italy\", \"Spain\"]).cpu()\n    expected_edge_attr = encoder_model.encode([\"in_continent\"] * 5).cpu()\n    assert torch.allclose(result.x, expected_x, atol=1e-6)\n    assert torch.allclose(result.edge_attr, expected_edge_attr, atol=1e-6)\n"
  },
  {
    "path": "test/llm/utils/test_rag_backend_utils.py",
    "content": "import os\nimport tempfile\nfrom typing import List, Tuple\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.llm.utils.backend_utils import (\n    batch_knn,\n    create_graph_from_triples,\n    create_remote_backend_from_graph_data,\n    make_pcst_filter,\n    preprocess_triplet,\n    retrieval_via_pcst,\n)\nfrom torch_geometric.testing import onlyLinux\n\n\ndef test_preprocess_triplet():\n    triplet = ('Alice', 'works with', 'Bob')\n    processed = preprocess_triplet(triplet)\n    assert processed == ('alice', 'works with', 'bob')\n\n\ndef test_batch_knn():\n    query_embeddings = torch.randn(2, 64)\n    candidate_embeddings = torch.randn(10, 64)\n    k = 3\n    top_k_indices, top_k_scores = batch_knn(\n        query_embeddings,\n        candidate_embeddings,\n        k,\n    )\n    assert top_k_indices[0].size() == (k, )\n    assert top_k_indices[1].size() == (1, 64)\n    assert top_k_scores[0].size() == (k, )\n    assert top_k_scores[1].size() == (1, 64)\n\n\n\"\"\"Test retrieval_via_pcst\"\"\"\n\n\ndef create_mock_data(num_nodes=3, num_edges=2):\n    import pandas as pd\n    x = torch.randn(num_nodes, 16)\n    edge_index = torch.tensor([[0, 1], [1, 2]], dtype=torch.long)\n    edge_attr = torch.randn(num_edges, 16)\n    node_idx = list(range(num_nodes))\n    edge_idx = list(range(num_edges))\n\n    textual_nodes = pd.DataFrame({\n        'node_id': node_idx,\n        'text': [f\"Node {i}\" for i in node_idx]\n    })\n    textual_edges = pd.DataFrame({\n        'src': edge_index[0].tolist(),\n        'dst': edge_index[1].tolist(),\n        'edge_attr': [f\"Edge {i}\" for i in edge_idx]\n    })\n\n    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,\n                node_idx=node_idx,\n                edge_idx=edge_idx), textual_nodes, textual_edges\n\n\n@onlyLinux\ndef test_empty_graph():\n    import pandas as pd\n\n    # w/o node and edge\n    bad_data = Data(x=None, edge_index=None, edge_attr=None)\n    textual_nodes = pd.DataFrame({'node_id': [], 'text': []})\n    textual_edges = pd.DataFrame({'src': [], 'dst': [], 'edge_attr': []})\n    q_emb = torch.randn(1, 16)\n\n    result_data, desc = retrieval_via_pcst(bad_data, q_emb, textual_nodes,\n                                           textual_edges)\n\n    assert result_data == bad_data\n    assert desc.strip() == 'node_id,text\\n\\nsrc,edge_attr,dst'\n\n\ndef test_topk_zero():\n    data, textual_nodes, textual_edges = create_mock_data()\n    q_emb = torch.randn(1, 16)\n\n    result_data, desc = retrieval_via_pcst(data, q_emb, textual_nodes,\n                                           textual_edges, topk=0, topk_e=0)\n\n    assert isinstance(result_data, Data)\n\n\n\"\"\"Test make_pcst_filter\"\"\"\n\n\nclass MockSentenceTransformer:\n    def encode(self, sentences, **kwargs):\n        return torch.randn(len(sentences), 32)\n\n\ndef create_mock_graph_and_triples():\n    triples: List[Tuple[str, str, str]] = [(\"Alice\", \"works_at\", \"Google\"),\n                                           (\"Bob\", \"works_at\", \"Meta\"),\n                                           (\"Alice\", \"knows\", \"Bob\")]\n\n    # Alice=0, Bob=1, Google=2, Meta=3\n    node_idx = [0, 1, 2, 3]\n    edge_idx = [0, 1, 2]\n\n    x = torch.randn(4, 32)\n    edge_index = torch.tensor([[0, 1, 0], [2, 3, 1]], dtype=torch.long)\n    edge_attr = torch.randn(3, 32)\n\n    graph = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,\n                 node_idx=node_idx, edge_idx=edge_idx)\n\n    return triples, graph\n\n\n@onlyLinux\ndef test_apply_retrieval_via_pcst_isolated_node():\n    triples, graph = create_mock_graph_and_triples()\n    model = MockSentenceTransformer()\n\n    mock_out_graph = Data(x=graph.x[:1],\n                          edge_index=torch.empty(2, 0, dtype=torch.long),\n                          edge_attr=torch.empty(0, 32))\n    mock_out_graph.node_idx = [0]\n    mock_out_graph.edge_idx = []\n\n    filter_fn = make_pcst_filter(triples, model)\n    result = filter_fn(graph, \"Who is Alice?\")\n\n    assert result.triples == [\n        ('Alice', 'works_at', 'Google'),\n        ('Bob', 'works_at', 'Meta'),\n        ('Alice', 'knows', 'Bob'),\n    ]\n\n\nclass MockEmbeddingModel:\n    \"\"\"Mock embedding model for testing.\"\"\"\n    def __init__(self, embed_dim: int = 64):\n        self.embed_dim = embed_dim\n\n    def __call__(self, texts: List[str], **kwargs) -> torch.Tensor:\n        \"\"\"Mock embedding generation - creates deterministic embeddings.\"\"\"\n        # Create simple hash-based embeddings for reproducible testing\n        if len(texts) == 0:\n            return torch.empty(0, self.embed_dim)\n        embeddings = []\n        for text in texts:\n            # Simple deterministic embedding based on text hash\n            hash_val = hash(text)\n            # Use the hash to create a reproducible embedding\n            torch.manual_seed(abs(hash_val) % 2**31)\n            embedding = torch.randn(self.embed_dim)\n            embeddings.append(embedding)\n        return torch.stack(embeddings)\n\n\nclass TestCreateGraphFromTriples:\n    \"\"\"Test suite for create_graph_from_triples function.\"\"\"\n    def setup_method(self):\n        \"\"\"Set up test fixtures.\"\"\"\n        self.sample_triples = [('Alice', 'works with', 'Bob'),\n                               ('Alice', 'leads', 'Carol'),\n                               ('Carol', 'works with', 'Dave')]\n        self.mock_embedding_model = MockEmbeddingModel(embed_dim=32)\n\n    def test_create_graph_basic_functionality(self):\n        \"\"\"Test basic functionality of create_graph_from_triples.\"\"\"\n        result = create_graph_from_triples(\n            triples=self.sample_triples,\n            embedding_model=self.mock_embedding_model)\n\n        # Verify result is a Data object\n        assert isinstance(result, Data)\n\n        x = result.x\n        edge_attr = result.edge_attr\n        assert x.shape == (4, 32)\n        assert edge_attr.shape == (3, 32)\n        for t in self.sample_triples:\n            assert self.mock_embedding_model([t[0]]) in x\n            assert self.mock_embedding_model([t[2]]) in x\n            assert self.mock_embedding_model([t[1]]) in edge_attr\n\n        expected_edge_index = torch.tensor([[0, 0, 2], [1, 2, 3]])\n        assert torch.allclose(result.edge_index, expected_edge_index)\n\n    def test_create_graph_empty_triples(self):\n        \"\"\"Test create_graph_from_triples with empty triples list.\"\"\"\n        empty_triples = []\n\n        result = create_graph_from_triples(\n            triples=empty_triples, embedding_model=self.mock_embedding_model)\n\n        # Should create an empty graph\n        assert isinstance(result, Data)\n        assert result.num_nodes == 0\n        assert result.num_edges == 0\n\n\nclass TestCreateRemoteBackendFromGraphData:\n    \"\"\"Test suite for create_remote_backend_from_graph_data function.\"\"\"\n    def setup_method(self):\n        \"\"\"Set up test fixtures.\"\"\"\n        self.sample_triples = [('Alice', 'works with', 'Bob'),\n                               ('Alice', 'leads', 'Carol'),\n                               ('Carol', 'works with', 'Dave')]\n        self.mock_embedding_model = MockEmbeddingModel(embed_dim=32)\n\n        # Create sample graph data using create_graph_from_triples\n        self.sample_graph_data = create_graph_from_triples(\n            triples=self.sample_triples,\n            embedding_model=self.mock_embedding_model)\n\n    def test_create_backend_data_load(self):\n        \"\"\"Test that data integrity is preserved in backend creation.\"\"\"\n        with tempfile.TemporaryDirectory() as temp_dir:\n            save_path = os.path.join(temp_dir, \"test_graph.pt\")\n\n            loader = create_remote_backend_from_graph_data(\n                graph_data=self.sample_graph_data, path=save_path, n_parts=1)\n\n            # Load and verify data\n            feature_store, graph_store = loader.load()\n\n            # Check that the original graph structure is preserved\n            loaded_data = torch.load(save_path, weights_only=False)\n\n            # Verify basic properties match\n            assert loaded_data.num_nodes == self.sample_graph_data.num_nodes\n            assert loaded_data.num_edges == self.sample_graph_data.num_edges\n\n            # Verify tensors match\n            assert torch.allclose(loaded_data.x, self.sample_graph_data.x)\n            assert torch.allclose(loaded_data.edge_index,\n                                  self.sample_graph_data.edge_index)\n            assert torch.allclose(loaded_data.edge_attr,\n                                  self.sample_graph_data.edge_attr)\n"
  },
  {
    "path": "test/llm/utils/test_rag_feature_store.py",
    "content": "from unittest.mock import Mock, patch\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.llm.utils.feature_store import KNNRAGFeatureStore\nfrom torch_geometric.sampler import SamplerOutput\n\n\nclass TestKNNRAGFeatureStore:\n    \"\"\"Test suite for KNNRAGFeatureStore methods.\"\"\"\n    def setup_method(self):\n        \"\"\"Set up test fixtures.\"\"\"\n        self.mock_encoder = Mock()\n        self.mock_encoder.encode = Mock()\n        self.mock_encoder.to = Mock(return_value=self.mock_encoder)\n        self.mock_encoder.eval = Mock()\n\n        self.config = {\"k_nodes\": 5, \"encoder_model\": self.mock_encoder}\n        self.sample_x = torch.randn(40, 128)  # 40 nodes, 128 features\n        self.sample_edge_attr = torch.randn(40, 64)  # 40 edges, 64 features\n\n    def test_bad_config(self):\n        \"\"\"Test bad config initialization.\"\"\"\n        with pytest.raises(ValueError, match=\"Required config parameter\"):\n            store = KNNRAGFeatureStore()\n            store.config = {}\n\n    def create_feature_store(self):\n        \"\"\"Create a FeatureStore with mocked dependencies.\"\"\"\n        store = KNNRAGFeatureStore()\n\n        store.config = self.config\n\n        # Mock the tensor storage\n        store.put_tensor(self.sample_x, group_name=None, attr_name='x')\n        store.put_tensor(self.sample_edge_attr, group_name=(None, None),\n                         attr_name='edge_attr')\n\n        return store\n\n    def test_retrieve_seed_nodes_single_query(self):\n        \"\"\"Test retrieve_seed_nodes with a single query.\"\"\"\n        store = self.create_feature_store()\n\n        # Mock the encoder output and batch_knn\n        query_text = \"test query\"\n        mock_query_enc = torch.randn(1, 128)\n        self.mock_encoder.encode.return_value = mock_query_enc\n\n        expected_indices = torch.tensor([0, 3, 7, 2, 9])\n\n        with patch('torch_geometric.llm.utils.feature_store.batch_knn'\n                   ) as mock_batch_knn:\n            # Mock batch_knn to return an iterator\n            def mock_generator():\n                yield (expected_indices, mock_query_enc)\n\n            mock_batch_knn.return_value = mock_generator()\n\n            result, query_enc = store.retrieve_seed_nodes(query_text)\n\n            # Verify encoder was called correctly\n            self.mock_encoder.encode.assert_called_once_with([query_text])\n\n            # Verify batch_knn was called correctly\n            mock_batch_knn.assert_called_once()\n            args = mock_batch_knn.call_args[0]\n            assert torch.equal(args[0], mock_query_enc)\n            assert torch.equal(args[1], self.sample_x)\n            assert args[2] == 5  # k_nodes\n\n            # Verify results\n            assert torch.equal(result, expected_indices)\n            assert torch.equal(query_enc, mock_query_enc)\n\n    def test_retrieve_seed_nodes_multiple_queries(self):\n        \"\"\"Test retrieve_seed_nodes with multiple queries.\"\"\"\n        store = self.create_feature_store()\n\n        queries = [\"query 1\", \"query 2\"]\n        mock_query_enc = torch.randn(2, 128)\n        self.mock_encoder.encode.return_value = mock_query_enc\n\n        expected_indices = [\n            torch.tensor([1, 4, 6, 8, 0]),\n            torch.tensor([0, 3, 7, 2, 9])\n        ]\n\n        with patch('torch_geometric.llm.utils.feature_store.batch_knn'\n                   ) as mock_batch_knn:\n\n            def mock_generator():\n                for i in range(len(expected_indices)):\n                    yield (expected_indices[i], mock_query_enc[i])\n\n            mock_batch_knn.return_value = mock_generator()\n\n            out_dict = store.retrieve_seed_nodes(queries)\n\n            # Verify encoder was called with the list directly\n            self.mock_encoder.encode.assert_called_once_with(queries)\n\n            # Verify results\n            for i, query in enumerate(queries):\n                result, query_enc = out_dict[query]\n                assert torch.equal(result, expected_indices[i])\n                assert torch.equal(query_enc, mock_query_enc[i])\n\n    @pytest.mark.parametrize(\"induced\", [True, False])\n    def test_load_subgraph_valid_sample(self, induced):\n        \"\"\"Test load_subgraph with valid SamplerOutput.\"\"\"\n        store = self.create_feature_store()\n\n        # Create a mock SamplerOutput\n        sample = SamplerOutput(node=torch.tensor([6, 7, 8, 9]),\n                               row=torch.tensor([0, 1, 2]),\n                               col=torch.tensor([1, 2, 3]),\n                               edge=torch.tensor([0, 1, 2]), batch=None)\n\n        expected_edge_indices = torch.tensor([[0, 1, 2], [1, 2, 3]]) \\\n            if induced else torch.tensor([[6, 7, 8], [7, 8, 9]])\n\n        result = store.load_subgraph(sample, induced=induced)\n\n        # Verify result is a Data object\n        assert isinstance(result, Data)\n\n        # Verify edge attributes are correctly extracted\n        expected_edge_attr = self.sample_edge_attr[torch.tensor([0, 1, 2])]\n        assert torch.equal(result.edge_attr, expected_edge_attr)\n        assert torch.equal(result.edge_index, expected_edge_indices)\n        if induced:\n            assert torch.equal(result.node_idx, sample.node)\n            assert torch.equal(result.edge_idx, sample.edge)\n"
  },
  {
    "path": "test/llm/utils/test_rag_graph_store.py",
    "content": "from unittest.mock import Mock, patch\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import FeatureStore\nfrom torch_geometric.llm.utils.graph_store import NeighborSamplingRAGGraphStore\nfrom torch_geometric.sampler import BidirectionalNeighborSampler, SamplerOutput\n\n\ndef setup_test_fixtures():\n    \"\"\"Set up test fixtures.\"\"\"\n    feature_store = Mock(spec=FeatureStore)\n    config = {\"num_neighbors\": [10, 5]}\n    return feature_store, config\n\n\ndef test_sample_subgraph_with_valid_tensor_input():\n    \"\"\"Test sample_subgraph with valid tensor input.\"\"\"\n    # Create graph store and set config\n    feature_store, config = setup_test_fixtures()\n    graph_store = NeighborSamplingRAGGraphStore(replace=True, disjoint=False)\n    graph_store.register_feature_store(feature_store=feature_store)\n    graph_store.config = config\n    assert graph_store.put_edge_id(torch.tensor([10]), edge_type=None,\n                                   layout='coo')\n\n    # Create mock sampler and its output\n    mock_sampler = Mock(spec=BidirectionalNeighborSampler)\n    expected_output = SamplerOutput(node=torch.tensor([0, 1, 2, 3]),\n                                    row=torch.tensor([0, 1, 1]),\n                                    col=torch.tensor([1, 2, 3]),\n                                    edge=torch.tensor([0, 1, 2]), batch=None,\n                                    num_sampled_nodes=[2, 2],\n                                    num_sampled_edges=[3])\n    mock_sampler.sample_from_nodes.return_value = expected_output\n\n    # Intentionally not sorted\n    graph_store.edge_index = torch.tensor([[3, 1, 1, 0], [4, 2, 3, 1]])\n\n    # Initially sampler should not be initialized\n    assert not graph_store._sampler_is_initialized\n\n    # Mock the _init_sampler method to set our mock sampler\n    with patch.object(graph_store, '_init_sampler') as mock_init:\n\n        def set_sampler():\n            graph_store.sampler = mock_sampler\n            graph_store._sampler_is_initialized = True\n\n        mock_init.side_effect = set_sampler\n\n        # Test input\n        seed_nodes = torch.tensor([0])\n        result = graph_store.sample_subgraph(seed_nodes)\n\n        # Verify sampler was initialized\n        mock_init.assert_called_once()\n\n        # Verify sample_from_nodes was called with correct input\n        mock_sampler.sample_from_nodes.assert_called_once()\n        assert result == expected_output\n\n\ndef test_bad_config():\n    \"\"\"Test bad config initialization.\"\"\"\n    with pytest.raises(ValueError, match=\"Required config parameter\"):\n        store = NeighborSamplingRAGGraphStore()\n        store.config = {}\n"
  },
  {
    "path": "test/llm/utils/test_vectorrag.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.llm.utils.vectorrag import DocumentRetriever\n\n\n@pytest.fixture\ndef sample_documents():\n    \"\"\"Fixture providing sample documents for testing.\"\"\"\n    return [\n        \"This is the first test document.\",\n        \"This is the second test document.\",\n        \"This is the third test document.\",\n    ]\n\n\n@pytest.fixture\ndef sample_model():\n    \"\"\"Fixture providing a mock model for testing.\"\"\"\n    from unittest.mock import Mock\n\n    mock_model = Mock()\n    # Mock the model to return a simple tensor when called\n    mock_model.side_effect = [\n        torch.zeros(1, 384),\n        torch.ones(1, 384),\n        torch.ones(1, 384) * 2,\n        torch.ones(1, 384) * 1,\n    ]\n\n    return mock_model\n\n\ndef test_save_load(sample_documents, sample_model, tmp_path):\n    \"\"\"Test whether saving/loading a DocumentRetriever maintains state.\"\"\"\n    retriever = DocumentRetriever(sample_documents, model=sample_model)\n    retriever.save(tmp_path / \"retriever.pth\")\n    loaded_retriever = DocumentRetriever.load(tmp_path / \"retriever.pth\",\n                                              sample_model)\n    assert retriever.raw_docs == loaded_retriever.raw_docs\n    assert torch.allclose(retriever.embedded_docs,\n                          loaded_retriever.embedded_docs)\n    assert retriever.k_for_docs == loaded_retriever.k_for_docs\n    assert retriever.model == loaded_retriever.model\n\n\ndef test_query(sample_documents, sample_model):\n    \"\"\"Test query functionality of DocumentRetriever.\"\"\"\n    retriever = DocumentRetriever(sample_documents, model=sample_model)\n    query = \"What is the first test document?\"\n    retrieved_docs = retriever.query(query)\n    assert retrieved_docs == [sample_documents[0]]\n"
  },
  {
    "path": "test/loader/test_cache.py",
    "content": "import torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import CachedLoader, NeighborLoader\nfrom torch_geometric.testing import withDevice, withPackage\n\n\n@withDevice\n@withPackage('pyg_lib')\ndef test_cached_loader(device):\n    x = torch.randn(14, 16)\n    edge_index = torch.tensor([\n        [2, 3, 4, 5, 7, 7, 10, 11, 12, 13],\n        [0, 1, 2, 3, 2, 3, 7, 7, 7, 7],\n    ])\n\n    loader = NeighborLoader(\n        Data(x=x, edge_index=edge_index),\n        num_neighbors=[2],\n        batch_size=10,\n        shuffle=False,\n    )\n    cached_loader = CachedLoader(loader, device=device)\n\n    assert len(cached_loader) == len(loader)\n    assert len(cached_loader._cache) == 0\n\n    cache = []\n    for i, batch in enumerate(cached_loader):\n        assert len(cached_loader._cache) == i + 1\n        assert batch.x.device == device\n        assert batch.edge_index.device == device\n\n        cache.append(batch)\n\n    for i, batch in enumerate(cached_loader):\n        assert batch == cache[i]\n\n    cached_loader.clear()\n    assert len(cached_loader._cache) == 0\n\n\n@withDevice\n@withPackage('pyg_lib')\ndef test_cached_loader_transform(device):\n    x = torch.randn(14, 16)\n    edge_index = torch.tensor([\n        [2, 3, 4, 5, 7, 7, 10, 11, 12, 13],\n        [0, 1, 2, 3, 2, 3, 7, 7, 7, 7],\n    ])\n\n    loader = NeighborLoader(\n        Data(x=x, edge_index=edge_index),\n        num_neighbors=[2],\n        batch_size=10,\n        shuffle=False,\n    )\n    cached_loader = CachedLoader(\n        loader,\n        device=device,\n        transform=lambda batch: batch.edge_index,\n    )\n\n    assert len(cached_loader) == len(loader)\n    assert len(cached_loader._cache) == 0\n\n    cache = []\n    for i, batch in enumerate(cached_loader):\n        assert len(cached_loader._cache) == i + 1\n        assert isinstance(batch, Tensor)\n        assert batch.dim() == 2 and batch.size(0) == 2\n        assert batch.device == device\n\n        cache.append(batch)\n\n    for i, batch in enumerate(cached_loader):\n        assert torch.equal(batch, cache[i])\n"
  },
  {
    "path": "test/loader/test_cluster.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import ClusterData, ClusterLoader\nfrom torch_geometric.testing import onlyFullTest, onlyOnline, withMETIS\nfrom torch_geometric.utils import sort_edge_index\n\n\n@withMETIS\ndef test_cluster_gcn():\n    adj = torch.tensor([\n        [1, 1, 1, 0, 1, 0],\n        [1, 1, 0, 1, 0, 1],\n        [1, 0, 1, 0, 1, 0],\n        [0, 1, 0, 1, 0, 1],\n        [1, 0, 1, 0, 1, 0],\n        [0, 1, 0, 1, 0, 1],\n    ])\n\n    x = torch.tensor([\n        [0.0, 0.0],\n        [1.0, 1.0],\n        [2.0, 2.0],\n        [3.0, 3.0],\n        [4.0, 4.0],\n        [5.0, 5.0],\n    ])\n    edge_index = adj.nonzero(as_tuple=False).t()\n    edge_attr = torch.arange(edge_index.size(1))\n    n_id = torch.arange(6)\n    data = Data(x=x, n_id=n_id, edge_index=edge_index, edge_attr=edge_attr)\n    data.num_nodes = 6\n\n    cluster_data = ClusterData(data, num_parts=2, log=False)\n\n    partition = cluster_data._partition(\n        edge_index, cluster=torch.tensor([0, 1, 0, 1, 0, 1]))\n    assert partition.partptr.tolist() == [0, 3, 6]\n    assert partition.node_perm.tolist() == [0, 2, 4, 1, 3, 5]\n    assert partition.edge_perm.tolist() == [\n        0, 2, 3, 1, 8, 9, 10, 14, 15, 16, 4, 5, 6, 7, 11, 12, 13, 17, 18, 19\n    ]\n\n    assert cluster_data.partition.partptr.tolist() == [0, 3, 6]\n    assert torch.equal(\n        cluster_data.partition.node_perm.sort()[0],\n        torch.arange(data.num_nodes),\n    )\n    assert torch.equal(\n        cluster_data.partition.edge_perm.sort()[0],\n        torch.arange(data.num_edges),\n    )\n\n    out = cluster_data[0]\n    expected = data.subgraph(out.n_id)\n    out.validate()\n    assert out.num_nodes == 3\n    assert out.n_id.size() == (3, )\n    assert torch.equal(out.x, expected.x)\n    tmp = sort_edge_index(expected.edge_index, expected.edge_attr)\n    assert torch.equal(out.edge_index, tmp[0])\n    assert torch.equal(out.edge_attr, tmp[1])\n\n    out = cluster_data[1]\n    out.validate()\n    assert out.num_nodes == 3\n    assert out.n_id.size() == (3, )\n    expected = data.subgraph(out.n_id)\n    assert torch.equal(out.x, expected.x)\n    tmp = sort_edge_index(expected.edge_index, expected.edge_attr)\n    assert torch.equal(out.edge_index, tmp[0])\n    assert torch.equal(out.edge_attr, tmp[1])\n\n    loader = ClusterLoader(cluster_data, batch_size=1)\n    iterator = iter(loader)\n\n    out = next(iterator)\n    out.validate()\n    assert out.num_nodes == 3\n    assert out.n_id.size() == (3, )\n    expected = data.subgraph(out.n_id)\n    assert torch.equal(out.x, expected.x)\n    tmp = sort_edge_index(expected.edge_index, expected.edge_attr)\n    assert torch.equal(out.edge_index, tmp[0])\n    assert torch.equal(out.edge_attr, tmp[1])\n\n    out = next(iterator)\n    out.validate()\n    assert out.num_nodes == 3\n    assert out.n_id.size() == (3, )\n    expected = data.subgraph(out.n_id)\n    assert torch.equal(out.x, expected.x)\n    tmp = sort_edge_index(expected.edge_index, expected.edge_attr)\n    assert torch.equal(out.edge_index, tmp[0])\n    assert torch.equal(out.edge_attr, tmp[1])\n\n    loader = ClusterLoader(cluster_data, batch_size=2, shuffle=False)\n    out = next(iter(loader))\n    out.validate()\n    assert out.num_nodes == 6\n    assert out.n_id.size() == (6, )\n    expected = data.subgraph(out.n_id)\n    assert torch.equal(out.x, expected.x)\n    tmp = sort_edge_index(expected.edge_index, expected.edge_attr)\n    assert torch.equal(out.edge_index, tmp[0])\n    assert torch.equal(out.edge_attr, tmp[1])\n\n\n@withMETIS\ndef test_keep_inter_cluster_edges():\n    adj = torch.tensor([\n        [1, 1, 1, 0, 1, 0],\n        [1, 1, 0, 1, 0, 1],\n        [1, 0, 1, 0, 1, 0],\n        [0, 1, 0, 1, 0, 1],\n        [1, 0, 1, 0, 1, 0],\n        [0, 1, 0, 1, 0, 1],\n    ])\n\n    x = torch.tensor([\n        [0.0, 0.0],\n        [1.0, 1.0],\n        [2.0, 2.0],\n        [3.0, 3.0],\n        [4.0, 4.0],\n        [5.0, 5.0],\n    ])\n    edge_index = adj.nonzero(as_tuple=False).t()\n    edge_attr = torch.arange(edge_index.size(1))\n    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)\n    data.num_nodes = 6\n\n    cluster_data = ClusterData(data, num_parts=2, log=False,\n                               keep_inter_cluster_edges=True)\n\n    data = cluster_data[0]\n    assert data.edge_index[0].min() == 0\n    assert data.edge_index[0].max() == 2\n    assert data.edge_index[1].min() == 0\n    assert data.edge_index[1].max() > 2\n    assert data.edge_index.size(1) == data.edge_attr.size(0)\n\n    data = cluster_data[1]\n    assert data.edge_index[0].min() == 0\n    assert data.edge_index[0].max() == 2\n    assert data.edge_index[1].min() == 0\n    assert data.edge_index[1].max() > 2\n    assert data.edge_index.size(1) == data.edge_attr.size(0)\n\n\n@withMETIS\n@onlyOnline\n@onlyFullTest\n@pytest.mark.parametrize('sparse_format', ['csr', 'csc'])\ndef test_cluster_gcn_correctness(get_dataset, sparse_format):\n    dataset = get_dataset('Cora')\n    data = dataset[0].clone()\n    data.n_id = torch.arange(data.num_nodes)\n    cluster_data = ClusterData(\n        data,\n        num_parts=10,\n        log=False,\n        sparse_format=sparse_format,\n    )\n    loader = ClusterLoader(cluster_data, batch_size=3, shuffle=False)\n\n    for batch1 in loader:\n        batch1.validate()\n        batch2 = data.subgraph(batch1.n_id)\n        assert batch1.num_nodes == batch2.num_nodes\n        assert batch1.num_edges == batch2.num_edges\n        assert torch.equal(batch1.x, batch2.x)\n        assert torch.equal(\n            batch1.edge_index,\n            sort_edge_index(\n                batch2.edge_index,\n                sort_by_row=sparse_format == 'csr',\n            ),\n        )\n\n\nif __name__ == '__main__':\n    import argparse\n\n    from ogb.nodeproppred import PygNodePropPredDataset\n    from tqdm import tqdm\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--num_workers', type=int, default=0)\n    args = parser.parse_args()\n\n    data = PygNodePropPredDataset('ogbn-products', root='/tmp/ogb')[0]\n\n    loader = ClusterLoader(\n        ClusterData(data, num_parts=15_000, save_dir='/tmp/ogb/ogbn_products'),\n        batch_size=32,\n        shuffle=True,\n        num_workers=args.num_workers,\n    )\n\n    for _ in tqdm(loader):\n        pass\n"
  },
  {
    "path": "test/loader/test_dataloader.py",
    "content": "import multiprocessing\nimport sys\nfrom collections import namedtuple\n\nimport pytest\nimport torch\n\nfrom torch_geometric import EdgeIndex, Index\nfrom torch_geometric.data import Data, HeteroData, OnDiskDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.testing import (\n    get_random_edge_index,\n    get_random_tensor_frame,\n    onlyLinux,\n    withDevice,\n    withPackage,\n)\n\nwith_mp = sys.platform not in ['win32']\nnum_workers_list = [0, 2] if with_mp else [0]\n\nif sys.platform == 'darwin':\n    multiprocessing.set_start_method('spawn')\n\n\n@withDevice\n@pytest.mark.parametrize('num_workers', num_workers_list)\ndef test_dataloader(num_workers, device):\n    if num_workers > 0 and device != torch.device('cpu'):\n        return\n\n    x = torch.tensor([[1.0], [1.0], [1.0]])\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    face = torch.tensor([[0], [1], [2]])\n    y = 2.\n    z = torch.tensor(0.)\n    name = 'data'\n\n    data = Data(x=x, edge_index=edge_index, y=y, z=z, name=name).to(device)\n    assert str(data) == (\"Data(x=[3, 1], edge_index=[2, 4], y=2.0, z=0.0, \"\n                         \"name='data')\")\n    data.face = face\n\n    loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False,\n                        num_workers=num_workers)\n    assert len(loader) == 2\n\n    for batch in loader:\n        assert batch.x.device == device\n        assert batch.edge_index.device == device\n        assert batch.z.device == device\n        assert batch.num_graphs == len(batch) == 2\n        assert batch.batch.tolist() == [0, 0, 0, 1, 1, 1]\n        assert batch.ptr.tolist() == [0, 3, 6]\n        assert batch.x.tolist() == [[1], [1], [1], [1], [1], [1]]\n        assert batch.edge_index.tolist() == [[0, 1, 1, 2, 3, 4, 4, 5],\n                                             [1, 0, 2, 1, 4, 3, 5, 4]]\n        assert batch.y.tolist() == [2.0, 2.0]\n        assert batch.z.tolist() == [0.0, 0.0]\n        assert batch.name == ['data', 'data']\n        assert batch.face.tolist() == [[0, 3], [1, 4], [2, 5]]\n\n        for store in batch.stores:\n            assert id(batch) == id(store._parent())\n\n    loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False,\n                        follow_batch=['edge_index'], num_workers=num_workers,\n                        collate_fn=None)\n    assert len(loader) == 2\n\n    for batch in loader:\n        assert batch.num_graphs == len(batch) == 2\n        assert batch.edge_index_batch.tolist() == [0, 0, 0, 0, 1, 1, 1, 1]\n\n\n@onlyLinux\n@pytest.mark.parametrize('num_workers', num_workers_list)\ndef test_dataloader_on_disk_dataset(tmp_path, num_workers):\n    dataset = OnDiskDataset(tmp_path)\n    data1 = Data(x=torch.randn(3, 8))\n    data2 = Data(x=torch.randn(4, 8))\n    dataset.extend([data1, data2])\n\n    loader = DataLoader(dataset, batch_size=2, num_workers=num_workers)\n    assert len(loader) == 1\n    batch = next(iter(loader))\n    assert batch.num_nodes == 7\n    assert torch.equal(batch.x, torch.cat([data1.x, data2.x], dim=0))\n    assert batch.batch.tolist() == [0, 0, 0, 1, 1, 1, 1]\n\n    dataset.close()\n\n\ndef test_dataloader_fallbacks():\n    # Test inputs of type List[torch.Tensor]:\n    data_list = [torch.ones(3) for _ in range(4)]\n    batch = next(iter(DataLoader(data_list, batch_size=4)))\n    assert torch.equal(batch, torch.ones(4, 3))\n\n    # Test inputs of type List[float]:\n    data_list = [1.0, 1.0, 1.0, 1.0]\n    batch = next(iter(DataLoader(data_list, batch_size=4)))\n    assert torch.equal(batch, torch.ones(4))\n\n    # Test inputs of type List[int]:\n    data_list = [1, 1, 1, 1]\n    batch = next(iter(DataLoader(data_list, batch_size=4)))\n    assert torch.equal(batch, torch.ones(4, dtype=torch.long))\n\n    # Test inputs of type List[str]:\n    data_list = ['test'] * 4\n    batch = next(iter(DataLoader(data_list, batch_size=4)))\n    assert batch == data_list\n\n    # Test inputs of type List[Mapping]:\n    data_list = [{'x': torch.ones(3), 'y': 1}] * 4\n    batch = next(iter(DataLoader(data_list, batch_size=4)))\n    assert torch.equal(batch['x'], torch.ones(4, 3))\n    assert torch.equal(batch['y'], torch.ones(4, dtype=torch.long))\n\n    # Test inputs of type List[Tuple]:\n    DataTuple = namedtuple('DataTuple', 'x y')\n    data_list = [DataTuple(0.0, 1)] * 4\n    batch = next(iter(DataLoader(data_list, batch_size=4)))\n    assert torch.equal(batch.x, torch.zeros(4))\n    assert torch.equal(batch[1], torch.ones(4, dtype=torch.long))\n\n    # Test inputs of type List[Sequence]:\n    data_list = [[0.0, 1]] * 4\n    batch = next(iter(DataLoader(data_list, batch_size=4)))\n    assert torch.equal(batch[0], torch.zeros(4))\n    assert torch.equal(batch[1], torch.ones(4, dtype=torch.long))\n\n    # Test that inputs of unsupported types raise an error:\n    class DummyClass:\n        pass\n\n    with pytest.raises(TypeError):\n        data_list = [DummyClass()] * 4\n        next(iter(DataLoader(data_list, batch_size=4)))\n\n\n@pytest.mark.skipif(not with_mp, reason='Multi-processing not available')\ndef test_multiprocessing():\n    queue = torch.multiprocessing.Manager().Queue()\n    data = Data(x=torch.randn(5, 16))\n    data_list = [data, data, data, data]\n    loader = DataLoader(data_list, batch_size=2)\n    for batch in loader:\n        queue.put(batch)\n\n    batch = queue.get()\n    assert batch.num_graphs == len(batch) == 2\n\n    batch = queue.get()\n    assert batch.num_graphs == len(batch) == 2\n\n\ndef test_pin_memory():\n    x = torch.randn(3, 16)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    data = Data(x=x, edge_index=edge_index)\n\n    loader = DataLoader([data] * 4, batch_size=2, pin_memory=True)\n    for batch in loader:\n        assert batch.x.is_pinned() or not torch.cuda.is_available()\n        assert batch.edge_index.is_pinned() or not torch.cuda.is_available()\n\n\n@pytest.mark.parametrize('num_workers', num_workers_list)\ndef test_heterogeneous_dataloader(num_workers):\n    data = HeteroData()\n    data['p'].x = torch.randn(100, 128)\n    data['a'].x = torch.randn(200, 128)\n    data['p', 'a'].edge_index = get_random_edge_index(100, 200, 500)\n    data['p'].edge_attr = torch.randn(500, 32)\n    data['a', 'p'].edge_index = get_random_edge_index(200, 100, 400)\n    data['a', 'p'].edge_attr = torch.randn(400, 32)\n\n    loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False,\n                        num_workers=num_workers)\n    assert len(loader) == 2\n\n    for batch in loader:\n        assert batch.num_graphs == len(batch) == 2\n        assert batch.num_nodes == 600\n\n        for store in batch.stores:\n            assert id(batch) == id(store._parent())\n\n\n@pytest.mark.parametrize('num_workers', num_workers_list)\ndef test_index_dataloader(num_workers):\n    index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)\n    index2 = Index([0, 1, 1, 2, 2, 3], dim_size=4, is_sorted=True)\n\n    data1 = Data(index=index1, num_nodes=3)\n    data2 = Data(index=index2, num_nodes=4)\n\n    loader = DataLoader(\n        [data1, data2, data1, data2],\n        batch_size=2,\n        num_workers=num_workers,\n    )\n    assert len(loader) == 2\n\n    for batch in loader:\n        assert isinstance(batch.index, Index)\n        assert batch.index.dtype == torch.long\n        assert batch.index.dim_size == 7\n        assert batch.index.is_sorted\n\n\n@pytest.mark.parametrize('num_workers', num_workers_list)\n@pytest.mark.parametrize('sort_order', [None, 'row', 'col'])\ndef test_edge_index_dataloader(num_workers, sort_order):\n    if sort_order == 'col':\n        edge_index = [[1, 0, 2, 1], [0, 1, 1, 2]]\n    else:\n        edge_index = [[0, 1, 1, 2], [1, 0, 2, 1]]\n\n    edge_index = EdgeIndex(\n        edge_index,\n        sparse_size=(3, 3),\n        sort_order=sort_order,\n        is_undirected=True,\n    )\n    data = Data(edge_index=edge_index)\n    assert data.num_nodes == 3\n\n    loader = DataLoader(\n        [data, data, data, data],\n        batch_size=2,\n        num_workers=num_workers,\n    )\n    assert len(loader) == 2\n\n    for batch in loader:\n        assert isinstance(batch.edge_index, EdgeIndex)\n        assert batch.edge_index.dtype == torch.long\n        assert batch.edge_index.sparse_size() == (6, 6)\n        assert batch.edge_index.sort_order == sort_order\n        assert batch.edge_index.is_undirected\n\n\n@withPackage('torch_frame')\ndef test_dataloader_tensor_frame():\n    tf = get_random_tensor_frame(num_rows=10)\n    loader = DataLoader([tf, tf, tf, tf], batch_size=2, shuffle=False)\n    assert len(loader) == 2\n\n    for batch in loader:\n        assert batch.num_rows == 20\n\n    data = Data(tf=tf, edge_index=get_random_edge_index(10, 10, 20))\n    loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False)\n    assert len(loader) == 2\n\n    for batch in loader:\n        assert batch.num_graphs == len(batch) == 2\n        assert batch.num_nodes == 20\n        assert batch.tf.num_rows == 20\n        assert batch.edge_index.max() >= 10\n\n\ndef test_dataloader_sparse():\n    adj_t = torch.sparse_coo_tensor(\n        indices=torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]),\n        values=torch.randn(4),\n        size=(3, 3),\n    )\n    data = Data(adj_t=adj_t)\n\n    loader = DataLoader([data, data], batch_size=2)\n    for batch in loader:\n        assert batch.adj_t.size() == (6, 6)\n\n\nif __name__ == '__main__':\n    import argparse\n    import time\n\n    from torch_geometric.datasets import QM9\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--num_workers', type=int, default=0)\n    args = parser.parse_args()\n\n    kwargs = dict(batch_size=128, shuffle=True, num_workers=args.num_workers)\n\n    in_memory_dataset = QM9('/tmp/QM9')\n    loader = DataLoader(in_memory_dataset, **kwargs)\n\n    print('In-Memory Dataset:')\n    for _ in range(2):\n        print(f'Start loading {len(loader)} mini-batches ... ', end='')\n        t = time.perf_counter()\n        for _ in loader:\n            pass\n        print(f'Done! [{time.perf_counter() - t:.4f}s]')\n\n    on_disk_dataset = in_memory_dataset.to_on_disk_dataset()\n    loader = DataLoader(on_disk_dataset, **kwargs)\n\n    print('On-Disk Dataset:')\n    for _ in range(2):\n        print(f'Start loading {len(loader)} mini-batches ... ', end='')\n        t = time.perf_counter()\n        for _ in loader:\n            pass\n        print(f'Done! [{time.perf_counter() - t:.4f}s]')\n\n    on_disk_dataset.close()\n"
  },
  {
    "path": "test/loader/test_dynamic_batch_sampler.py",
    "content": "from typing import List\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import DataLoader, DynamicBatchSampler\n\n\ndef test_dataloader_with_dynamic_batches():\n    data_list: List[Data] = []\n    for num_nodes in range(100, 110):\n        data_list.append(Data(num_nodes=num_nodes))\n\n    torch.manual_seed(12345)\n    batch_sampler = DynamicBatchSampler(data_list, 300, shuffle=True)\n    loader = DataLoader(data_list, batch_sampler=batch_sampler)\n\n    num_nodes_total = 0\n    for data in loader:\n        assert data.num_nodes <= 300\n        num_nodes_total += data.num_nodes\n    assert num_nodes_total == 1045\n\n    # Test skipping\n    data_list = [Data(num_nodes=400)] + data_list\n    batch_sampler = DynamicBatchSampler(data_list, 300, skip_too_big=True,\n                                        num_steps=2)\n    loader = DataLoader(data_list, batch_sampler=batch_sampler)\n\n    num_nodes_total = 0\n    for data in loader:\n        num_nodes_total += data.num_nodes\n    assert num_nodes_total == 404\n\n    with pytest.raises(ValueError, match=\"length of 'DynamicBatchSampler'\"):\n        len(DynamicBatchSampler(data_list, max_num=300))\n    assert len(DynamicBatchSampler(data_list, max_num=300, num_steps=2)) == 2\n"
  },
  {
    "path": "test/loader/test_graph_saint.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import (\n    GraphSAINTEdgeSampler,\n    GraphSAINTNodeSampler,\n    GraphSAINTRandomWalkSampler,\n)\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('torch_sparse')\ndef test_graph_saint():\n    adj = torch.tensor([\n        [+1, +2, +3, +0, +4, +0],\n        [+5, +6, +0, +7, +0, +8],\n        [+9, +0, 10, +0, 11, +0],\n        [+0, 12, +0, 13, +0, 14],\n        [15, +0, 16, +0, 17, +0],\n        [+0, 18, +0, 19, +0, 20],\n    ])\n\n    edge_index = adj.nonzero(as_tuple=False).t()\n    edge_id = adj[edge_index[0], edge_index[1]]\n    x = torch.tensor([\n        [0.0, 0.0],\n        [1.0, 1.0],\n        [2.0, 2.0],\n        [3.0, 3.0],\n        [4.0, 4.0],\n        [5.0, 5.0],\n    ])\n    n_id = torch.arange(6)\n    data = Data(edge_index=edge_index, x=x, n_id=n_id, edge_id=edge_id,\n                num_nodes=6)\n\n    loader = GraphSAINTNodeSampler(data, batch_size=3, num_steps=4,\n                                   sample_coverage=10, log=False)\n\n    assert len(loader) == 4\n    for sample in loader:\n        assert sample.num_nodes <= data.num_nodes\n        assert sample.n_id.min() >= 0 and sample.n_id.max() < 6\n        assert sample.num_nodes == sample.n_id.numel()\n        assert sample.x.tolist() == x[sample.n_id].tolist()\n        assert sample.edge_index.min() >= 0\n        assert sample.edge_index.max() < sample.num_nodes\n        assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21\n        assert sample.edge_id.numel() == sample.num_edges\n        assert sample.node_norm.numel() == sample.num_nodes\n        assert sample.edge_norm.numel() == sample.num_edges\n\n    loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4,\n                                   sample_coverage=10, log=False)\n\n    assert len(loader) == 4\n    for sample in loader:\n        assert sample.num_nodes <= data.num_nodes\n        assert sample.n_id.min() >= 0 and sample.n_id.max() < 6\n        assert sample.num_nodes == sample.n_id.numel()\n        assert sample.x.tolist() == x[sample.n_id].tolist()\n        assert sample.edge_index.min() >= 0\n        assert sample.edge_index.max() < sample.num_nodes\n        assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21\n        assert sample.edge_id.numel() == sample.num_edges\n        assert sample.node_norm.numel() == sample.num_nodes\n        assert sample.edge_norm.numel() == sample.num_edges\n\n    loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1,\n                                         num_steps=4, sample_coverage=10,\n                                         log=False)\n\n    assert len(loader) == 4\n    for sample in loader:\n        assert sample.num_nodes <= data.num_nodes\n        assert sample.n_id.min() >= 0 and sample.n_id.max() < 6\n        assert sample.num_nodes == sample.n_id.numel()\n        assert sample.x.tolist() == x[sample.n_id].tolist()\n        assert sample.edge_index.min() >= 0\n        assert sample.edge_index.max() < sample.num_nodes\n        assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21\n        assert sample.edge_id.numel() == sample.num_edges\n        assert sample.node_norm.numel() == sample.num_nodes\n        assert sample.edge_norm.numel() == sample.num_edges\n"
  },
  {
    "path": "test/loader/test_hgt_loader.py",
    "content": "import numpy as np\nimport torch\n\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.loader import HGTLoader\nfrom torch_geometric.nn import GraphConv, to_hetero\nfrom torch_geometric.testing import (\n    get_random_edge_index,\n    onlyOnline,\n    withPackage,\n)\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import k_hop_subgraph\n\n\ndef is_subset(subedge_index, edge_index, src_idx, dst_idx):\n    num_nodes = int(edge_index.max()) + 1\n    idx = num_nodes * edge_index[0] + edge_index[1]\n    subidx = num_nodes * src_idx[subedge_index[0]] + dst_idx[subedge_index[1]]\n    mask = torch.from_numpy(np.isin(subidx, idx))\n    return int(mask.sum()) == mask.numel()\n\n\n@withPackage('torch_sparse')\ndef test_hgt_loader():\n    torch.manual_seed(12345)\n\n    data = HeteroData()\n\n    data['paper'].x = torch.arange(100)\n    data['author'].x = torch.arange(100, 300)\n\n    data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500)\n    data['paper', 'paper'].edge_attr = torch.arange(500)\n    data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000)\n    data['paper', 'author'].edge_attr = torch.arange(500, 1500)\n    data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000)\n    data['author', 'paper'].edge_attr = torch.arange(1500, 2500)\n\n    r1, c1 = data['paper', 'paper'].edge_index\n    r2, c2 = data['paper', 'author'].edge_index + torch.tensor([[0], [100]])\n    r3, c3 = data['author', 'paper'].edge_index + torch.tensor([[100], [0]])\n    full_adj = SparseTensor(\n        row=torch.cat([r1, r2, r3]),\n        col=torch.cat([c1, c2, c3]),\n        value=torch.arange(2500),\n    )\n\n    batch_size = 20\n    loader = HGTLoader(data, num_samples=[5] * 4, batch_size=batch_size,\n                       input_nodes='paper')\n    assert str(loader) == 'HGTLoader()'\n    assert len(loader) == (100 + batch_size - 1) // batch_size\n\n    for batch in loader:\n        assert isinstance(batch, HeteroData)\n\n        # Test node and types:\n        assert set(batch.node_types) == {'paper', 'author'}\n        assert set(batch.edge_types) == set(data.edge_types)\n\n        assert len(batch['paper']) == 4\n        assert batch['paper'].n_id.size() == (batch['paper'].num_nodes, )\n        assert batch['paper'].x.size() == (40, )  # 20 + 4 * 5\n        assert batch['paper'].input_id.numel() == batch_size\n        assert batch['paper'].batch_size == batch_size\n        assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100\n\n        assert len(batch['author']) == 2\n        assert batch['author'].n_id.size() == (batch['author'].num_nodes, )\n        assert batch['author'].x.size() == (20, )  # 4 * 5\n        assert batch['author'].x.min() >= 100 and batch['author'].x.max() < 300\n\n        # Test edge type selection:\n        assert set(batch.edge_types) == {('paper', 'to', 'paper'),\n                                         ('paper', 'to', 'author'),\n                                         ('author', 'to', 'paper')}\n\n        assert len(batch['paper', 'paper']) == 3\n        num_edges = batch['paper', 'paper'].num_edges\n        assert batch['paper', 'paper'].e_id.size() == (num_edges, )\n        row, col = batch['paper', 'paper'].edge_index\n        value = batch['paper', 'paper'].edge_attr\n        adj = full_adj[batch['paper'].x, batch['paper'].x]\n        assert row.min() >= 0 and row.max() < 40\n        assert col.min() >= 0 and col.max() < 40\n        assert value.min() >= 0 and value.max() < 500\n        assert adj.nnz() == row.size(0)\n        assert torch.allclose(row.unique(), adj.storage.row().unique())\n        assert torch.allclose(col.unique(), adj.storage.col().unique())\n        assert torch.allclose(value.unique(), adj.storage.value().unique())\n\n        assert is_subset(batch['paper', 'paper'].edge_index,\n                         data['paper', 'paper'].edge_index, batch['paper'].x,\n                         batch['paper'].x)\n\n        assert len(batch['paper', 'author']) == 3\n        num_edges = batch['paper', 'author'].num_edges\n        assert batch['paper', 'author'].e_id.size() == (num_edges, )\n        row, col = batch['paper', 'author'].edge_index\n        value = batch['paper', 'author'].edge_attr\n        adj = full_adj[batch['paper'].x, batch['author'].x]\n        assert row.min() >= 0 and row.max() < 40\n        assert col.min() >= 0 and col.max() < 20\n        assert value.min() >= 500 and value.max() < 1500\n        assert adj.nnz() == row.size(0)\n        assert torch.allclose(row.unique(), adj.storage.row().unique())\n        assert torch.allclose(col.unique(), adj.storage.col().unique())\n        assert torch.allclose(value.unique(), adj.storage.value().unique())\n\n        assert is_subset(batch['paper', 'author'].edge_index,\n                         data['paper', 'author'].edge_index, batch['paper'].x,\n                         batch['author'].x - 100)\n\n        assert len(batch['author', 'paper']) == 3\n        num_edges = batch['author', 'paper'].num_edges\n        assert batch['author', 'paper'].e_id.size() == (num_edges, )\n        row, col = batch['author', 'paper'].edge_index\n        value = batch['author', 'paper'].edge_attr\n        adj = full_adj[batch['author'].x, batch['paper'].x]\n        assert row.min() >= 0 and row.max() < 20\n        assert col.min() >= 0 and col.max() < 40\n        assert value.min() >= 1500 and value.max() < 2500\n        assert adj.nnz() == row.size(0)\n        assert torch.allclose(row.unique(), adj.storage.row().unique())\n        assert torch.allclose(col.unique(), adj.storage.col().unique())\n        assert torch.allclose(value.unique(), adj.storage.value().unique())\n\n        assert is_subset(batch['author', 'paper'].edge_index,\n                         data['author', 'paper'].edge_index,\n                         batch['author'].x - 100, batch['paper'].x)\n\n        # Test for isolated nodes (there shouldn't exist any):\n        n_id = torch.cat([batch['paper'].x, batch['author'].x])\n        row, col, _ = full_adj[n_id, n_id].coo()\n        assert torch.cat([row, col]).unique().numel() >= 59\n\n\n@onlyOnline\n@withPackage('torch_sparse')\ndef test_hgt_loader_on_cora(get_dataset):\n    dataset = get_dataset(name='Cora')\n    data = dataset[0]\n    data.edge_weight = torch.rand(data.num_edges)\n\n    hetero_data = HeteroData()\n    hetero_data['paper'].x = data.x\n    hetero_data['paper'].n_id = torch.arange(data.num_nodes)\n    hetero_data['paper', 'paper'].edge_index = data.edge_index\n    hetero_data['paper', 'paper'].edge_weight = data.edge_weight\n\n    split_idx = torch.arange(5, 8)\n\n    # Sample the complete two-hop neighborhood:\n    loader = HGTLoader(hetero_data, num_samples=[data.num_nodes] * 2,\n                       batch_size=split_idx.numel(),\n                       input_nodes=('paper', split_idx))\n    assert len(loader) == 1\n\n    hetero_batch = next(iter(loader))\n    batch_size = hetero_batch['paper'].batch_size\n\n    n_id, _, _, e_mask = k_hop_subgraph(split_idx, num_hops=2,\n                                        edge_index=data.edge_index,\n                                        num_nodes=data.num_nodes)\n\n    n_id = n_id.sort()[0]\n    assert n_id.tolist() == hetero_batch['paper'].n_id.sort()[0].tolist()\n    assert hetero_batch['paper', 'paper'].num_edges == int(e_mask.sum())\n\n    class GNN(torch.nn.Module):\n        def __init__(self, in_channels, hidden_channels, out_channels):\n            super().__init__()\n            self.conv1 = GraphConv(in_channels, hidden_channels)\n            self.conv2 = GraphConv(hidden_channels, out_channels)\n\n        def forward(self, x, edge_index, edge_weight):\n            x = self.conv1(x, edge_index, edge_weight).relu()\n            x = self.conv2(x, edge_index, edge_weight).relu()\n            return x\n\n    model = GNN(dataset.num_features, 16, dataset.num_classes)\n    hetero_model = to_hetero(model, hetero_data.metadata())\n\n    out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx]\n    out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict,\n                        hetero_batch.edge_weight_dict)['paper'][:batch_size]\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n\n@withPackage('torch_sparse')\ndef test_hgt_loader_disconnected():\n    data = HeteroData()\n\n    data['paper'].x = torch.randn(10, 16)\n    data['author'].x = torch.randn(10, 16)\n\n    # Paper nodes are disconnected from author nodes:\n    data['paper', 'paper'].edge_index = get_random_edge_index(10, 10, 15)\n    data['paper', 'paper'].edge_attr = torch.randn(15, 8)\n    data['author', 'author'].edge_index = get_random_edge_index(10, 10, 15)\n    data['author', 'author'].edge_attr = torch.randn(15, 8)\n\n    loader = HGTLoader(data, num_samples=[2], batch_size=2,\n                       input_nodes='paper')\n\n    for batch in loader:\n        assert isinstance(batch, HeteroData)\n\n        # Test node and edge types:\n        assert set(batch.node_types) == set(data.node_types)\n        assert set(batch.edge_types) == set(data.edge_types)\n\n        assert batch['author'].num_nodes == 0\n        assert batch['author'].x.size() == (0, 16)\n        assert batch['author', 'author'].num_edges == 0\n        assert batch['author', 'author'].edge_index.size() == (2, 0)\n        assert batch['author', 'author'].edge_attr.size() == (0, 8)\n"
  },
  {
    "path": "test/loader/test_ibmb_loader.py",
    "content": "import pytest\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.datasets import KarateClub\nfrom torch_geometric.loader.ibmb_loader import IBMBBatchLoader, IBMBNodeLoader\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.typing import SparseTensor\n\n\n@withPackage('python_tsp')\n@pytest.mark.parametrize(\n    'use_sparse_tensor',\n    [False] + [True] if torch_geometric.typing.WITH_TORCH_SPARSE else [])\n@pytest.mark.parametrize('kwargs', [\n    dict(num_partitions=4, batch_size=1),\n    dict(num_partitions=8, batch_size=2),\n])\ndef test_ibmb_batch_loader(use_sparse_tensor, kwargs):\n    data = KarateClub()[0]\n\n    loader = IBMBBatchLoader(\n        data,\n        batch_order='order',\n        input_nodes=torch.randperm(data.num_nodes)[:20],\n        return_edge_index_type='adj' if use_sparse_tensor else 'edge_index',\n        **kwargs,\n    )\n    assert str(loader) == 'IBMBBatchLoader()'\n    assert len(loader) == 4\n    assert sum([batch.output_node_mask.sum() for batch in loader]) == 20\n\n    for batch in loader:\n        if use_sparse_tensor:\n            assert isinstance(batch.edge_index, SparseTensor)\n        else:\n            assert isinstance(batch.edge_index, Tensor)\n\n\n@withPackage('python_tsp', 'numba')\n@pytest.mark.parametrize(\n    'use_sparse_tensor',\n    [False] + [True] if torch_geometric.typing.WITH_TORCH_SPARSE else [])\n@pytest.mark.parametrize('kwargs', [\n    dict(num_nodes_per_batch=4, batch_size=1),\n    dict(num_nodes_per_batch=2, batch_size=2),\n])\ndef test_ibmb_node_loader(use_sparse_tensor, kwargs):\n    data = KarateClub()[0]\n\n    loader = IBMBNodeLoader(\n        data,\n        batch_order='order',\n        input_nodes=torch.randperm(data.num_nodes)[:20],\n        num_auxiliary_nodes=4,\n        return_edge_index_type='adj' if use_sparse_tensor else 'edge_index',\n        **kwargs,\n    )\n    assert str(loader) == 'IBMBNodeLoader()'\n    assert len(loader) == 5\n    assert sum([batch.output_node_mask.sum() for batch in loader]) == 20\n\n    for batch in loader:\n        if use_sparse_tensor:\n            assert isinstance(batch.edge_index, SparseTensor)\n        else:\n            assert isinstance(batch.edge_index, Tensor)\n"
  },
  {
    "path": "test/loader/test_imbalanced_sampler.py",
    "content": "from typing import List\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets import FakeDataset, FakeHeteroDataset\nfrom torch_geometric.loader import (\n    DataLoader,\n    ImbalancedSampler,\n    NeighborLoader,\n)\nfrom torch_geometric.testing import onlyNeighborSampler\n\n\ndef test_dataloader_with_imbalanced_sampler():\n    data_list: List[Data] = []\n    for _ in range(10):\n        data_list.append(Data(num_nodes=10, y=0))\n    for _ in range(90):\n        data_list.append(Data(num_nodes=10, y=1))\n\n    torch.manual_seed(12345)\n    sampler = ImbalancedSampler(data_list)\n    loader = DataLoader(data_list, batch_size=10, sampler=sampler)\n\n    y = torch.cat([batch.y for batch in loader])\n\n    histogram = y.bincount()\n    prob = histogram / histogram.sum()\n\n    assert histogram.sum() == len(data_list)\n    assert prob.min() > 0.4 and prob.max() < 0.6\n\n    # Test with label tensor as input:\n    torch.manual_seed(12345)\n    sampler = ImbalancedSampler(torch.tensor([data.y for data in data_list]))\n    loader = DataLoader(data_list, batch_size=10, sampler=sampler)\n\n    assert torch.allclose(y, torch.cat([batch.y for batch in loader]))\n\n    # Test with list of data objects as input where each y is a tensor:\n    torch.manual_seed(12345)\n    for data in data_list:\n        data.y = torch.tensor([data.y])\n    sampler = ImbalancedSampler(data_list)\n    loader = DataLoader(data_list, batch_size=100, sampler=sampler)\n\n    assert torch.allclose(y, torch.cat([batch.y for batch in loader]))\n\n\ndef test_in_memory_dataset_imbalanced_sampler():\n    torch.manual_seed(12345)\n    dataset = FakeDataset(num_graphs=100, avg_num_nodes=10, avg_degree=0,\n                          num_channels=0, num_classes=2)\n    sampler = ImbalancedSampler(dataset)\n    loader = DataLoader(dataset, batch_size=10, sampler=sampler)\n\n    y = torch.cat([batch.y for batch in loader])\n    histogram = y.bincount()\n    prob = histogram / histogram.sum()\n\n    assert histogram.sum() == len(dataset)\n    assert prob.min() > 0.4 and prob.max() < 0.6\n\n\n@onlyNeighborSampler\ndef test_neighbor_loader_with_imbalanced_sampler():\n    zeros = torch.zeros(10, dtype=torch.long)\n    ones = torch.ones(90, dtype=torch.long)\n\n    y = torch.cat([zeros, ones], dim=0)\n    edge_index = torch.empty((2, 0), dtype=torch.long)\n    data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0))\n\n    torch.manual_seed(12345)\n    sampler = ImbalancedSampler(data)\n    loader = NeighborLoader(data, batch_size=10, sampler=sampler,\n                            num_neighbors=[-1])\n\n    y = torch.cat([batch.y for batch in loader])\n\n    histogram = y.bincount()\n    prob = histogram / histogram.sum()\n\n    assert histogram.sum() == data.num_nodes\n    assert prob.min() > 0.4 and prob.max() < 0.6\n\n    # Test with label tensor as input:\n    torch.manual_seed(12345)\n    sampler = ImbalancedSampler(data.y)\n    loader = NeighborLoader(data, batch_size=10, sampler=sampler,\n                            num_neighbors=[-1])\n\n    assert torch.allclose(y, torch.cat([batch.y for batch in loader]))\n\n\n@onlyNeighborSampler\ndef test_hetero_neighbor_loader_with_imbalanced_sampler():\n    torch.manual_seed(12345)\n    data = FakeHeteroDataset(num_classes=2)[0]\n\n    loader = NeighborLoader(\n        data,\n        batch_size=100,\n        input_nodes='v0',\n        num_neighbors=[-1],\n        sampler=ImbalancedSampler(data['v0'].y),\n    )\n\n    y = torch.cat([batch['v0'].y[:batch['v0'].batch_size] for batch in loader])\n\n    histogram = y.bincount()\n    prob = histogram / histogram.sum()\n\n    assert histogram.sum() == data['v0'].num_nodes\n    assert prob.min() > 0.4 and prob.max() < 0.6\n"
  },
  {
    "path": "test/loader/test_link_neighbor_loader.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.loader import LinkNeighborLoader\nfrom torch_geometric.testing import (\n    MyFeatureStore,\n    MyGraphStore,\n    get_random_edge_index,\n    onlyNeighborSampler,\n    withCUDA,\n    withPackage,\n)\n\n\ndef unique_edge_pairs(edge_index):\n    return set(map(tuple, edge_index.t().tolist()))\n\n\n@withCUDA\n@onlyNeighborSampler\n@pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional'])\n@pytest.mark.parametrize('neg_sampling_ratio', [None, 1.0])\n@pytest.mark.parametrize('filter_per_worker', [None, True, False])\ndef test_homo_link_neighbor_loader_basic(device, subgraph_type,\n                                         neg_sampling_ratio,\n                                         filter_per_worker):\n    pos_edge_index = get_random_edge_index(50, 50, 500, device=device)\n    neg_edge_index = get_random_edge_index(50, 50, 500, device=device)\n    neg_edge_index += 50\n\n    input_edges = torch.cat([pos_edge_index, neg_edge_index], dim=-1)\n    edge_label = torch.cat([\n        torch.ones(500, device=device),\n        torch.zeros(500, device=device),\n    ], dim=0)\n\n    data = Data()\n\n    data.edge_index = pos_edge_index\n    data.x = torch.arange(100, device=device)\n    data.edge_attr = torch.arange(500, device=device)\n\n    loader = LinkNeighborLoader(\n        data,\n        num_neighbors=[-1] * 2,\n        batch_size=20,\n        edge_label_index=input_edges,\n        edge_label=edge_label if neg_sampling_ratio is None else None,\n        subgraph_type=subgraph_type,\n        neg_sampling_ratio=neg_sampling_ratio,\n        shuffle=True,\n        filter_per_worker=filter_per_worker,\n    )\n\n    assert str(loader) == 'LinkNeighborLoader()'\n    assert len(loader) == 1000 / 20\n\n    batch = loader([0])\n    assert isinstance(batch, Data)\n    assert int(input_edges[0, 0]) in batch.n_id.tolist()\n    assert int(input_edges[1, 0]) in batch.n_id.tolist()\n\n    for batch in loader:\n        assert isinstance(batch, Data)\n\n        assert batch.n_id.size() == (batch.num_nodes, )\n        assert batch.e_id.size() == (batch.num_edges, )\n        assert batch.x.device == device\n        assert batch.x.size(0) <= 100\n        assert batch.x.min() >= 0 and batch.x.max() < 100\n        assert batch.input_id.numel() == 20\n        assert batch.edge_index.device == device\n        assert batch.edge_index.min() >= 0\n        assert batch.edge_index.max() < batch.num_nodes\n        assert batch.edge_attr.device == device\n        assert batch.edge_attr.min() >= 0\n        assert batch.edge_attr.max() < 500\n\n        if neg_sampling_ratio is None:\n            assert batch.edge_label_index.size(1) == 20\n\n            # Assert positive samples are present in the original graph:\n            edge_index = unique_edge_pairs(batch.edge_index)\n            edge_label_index = batch.edge_label_index[:, batch.edge_label == 1]\n            edge_label_index = unique_edge_pairs(edge_label_index)\n            assert len(edge_index | edge_label_index) == len(edge_index)\n\n            # Assert negative samples are not present in the original graph:\n            edge_index = unique_edge_pairs(batch.edge_index)\n            edge_label_index = batch.edge_label_index[:, batch.edge_label == 0]\n            edge_label_index = unique_edge_pairs(edge_label_index)\n            assert len(edge_index & edge_label_index) == 0\n\n        else:\n            assert batch.edge_label_index.size(1) == 40\n            assert torch.all(batch.edge_label[:20] == 1)\n            assert torch.all(batch.edge_label[20:] == 0)\n\n        # Ensure local `edge_label_index` correctly maps to input edges.\n        global_edge_label_index = batch.n_id[batch.edge_label_index]\n        global_edge_label_index = (\n            global_edge_label_index[:, batch.edge_label >= 1])\n        global_edge_label_index = unique_edge_pairs(global_edge_label_index)\n        assert (len(global_edge_label_index & unique_edge_pairs(input_edges))\n                == len(global_edge_label_index))\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional'])\n@pytest.mark.parametrize('neg_sampling_ratio', [None, 1.0])\ndef test_hetero_link_neighbor_loader_basic(subgraph_type, neg_sampling_ratio):\n    data = HeteroData()\n\n    data['paper'].x = torch.arange(100)\n    data['author'].x = torch.arange(100, 300)\n\n    data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500)\n    data['paper', 'paper'].edge_attr = torch.arange(500)\n    data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000)\n    data['paper', 'author'].edge_attr = torch.arange(500, 1500)\n    data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000)\n    data['author', 'paper'].edge_attr = torch.arange(1500, 2500)\n\n    loader = LinkNeighborLoader(\n        data,\n        num_neighbors=[-1] * 2,\n        edge_label_index=('paper', 'author'),\n        batch_size=20,\n        subgraph_type=subgraph_type,\n        neg_sampling_ratio=neg_sampling_ratio,\n        shuffle=True,\n    )\n\n    assert str(loader) == 'LinkNeighborLoader()'\n    assert len(loader) == 1000 / 20\n\n    for batch in loader:\n        assert isinstance(batch, HeteroData)\n        assert batch.input_type == ('paper', 'to', 'author')\n\n        if neg_sampling_ratio is None:\n            # Assert only positive samples are present in the original graph:\n            edge_index = unique_edge_pairs(batch['paper', 'author'].edge_index)\n            edge_label_index = batch['paper', 'author'].edge_label_index\n            edge_label_index = unique_edge_pairs(edge_label_index)\n            assert len(edge_index | edge_label_index) == len(edge_index)\n\n        else:\n            assert batch['paper', 'author'].edge_label_index.size(1) == 40\n            assert torch.all(batch['paper', 'author'].edge_label[:20] == 1)\n            assert torch.all(batch['paper', 'author'].edge_label[20:] == 0)\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional'])\ndef test_hetero_link_neighbor_loader_loop(subgraph_type):\n    data = HeteroData()\n\n    data['paper'].x = torch.arange(100)\n    data['author'].x = torch.arange(100, 300)\n\n    data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500)\n    data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000)\n    data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000)\n\n    loader = LinkNeighborLoader(\n        data,\n        num_neighbors=[-1] * 2,\n        edge_label_index=('paper', 'paper'),\n        batch_size=20,\n        subgraph_type=subgraph_type,\n    )\n\n    for batch in loader:\n        assert batch['paper'].x.size(0) <= 100\n        assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100\n\n        # Assert positive samples are present in the original graph:\n        edge_index = unique_edge_pairs(batch['paper', 'paper'].edge_index)\n        edge_label_index = batch['paper', 'paper'].edge_label_index\n        edge_label_index = unique_edge_pairs(edge_label_index)\n        assert len(edge_index | edge_label_index) == len(edge_index)\n\n\n@onlyNeighborSampler\ndef test_link_neighbor_loader_edge_label():\n    edge_index = get_random_edge_index(100, 100, 500)\n    data = Data(edge_index=edge_index, x=torch.arange(100))\n\n    loader = LinkNeighborLoader(\n        data,\n        num_neighbors=[-1] * 2,\n        batch_size=10,\n        neg_sampling_ratio=1.0,\n    )\n\n    for batch in loader:\n        assert batch.edge_label.dtype == torch.float\n        assert torch.all(batch.edge_label[:10] == 1.0)\n        assert torch.all(batch.edge_label[10:] == 0.0)\n\n    loader = LinkNeighborLoader(\n        data,\n        num_neighbors=[-1] * 2,\n        batch_size=10,\n        edge_label=torch.ones(500, dtype=torch.long),\n        neg_sampling_ratio=1.0,\n    )\n\n    for batch in loader:\n        assert batch.edge_label.dtype == torch.long\n        assert torch.all(batch.edge_label[:10] == 1)\n        assert torch.all(batch.edge_label[10:] == 0)\n\n\n@withPackage('pyg_lib')\n@pytest.mark.parametrize('batch_size', [1])\ndef test_temporal_homo_link_neighbor_loader(batch_size):\n    data = Data(\n        x=torch.randn(10, 5),\n        edge_index=torch.randint(0, 10, (2, 123)),\n        time=torch.arange(10),\n    )\n\n    # Ensure that nodes exist at the time of the `edge_label_time`:\n    edge_label_time = torch.max(\n        data.time[data.edge_index[0]],\n        data.time[data.edge_index[1]],\n    )\n\n    loader = LinkNeighborLoader(\n        data,\n        num_neighbors=[-1],\n        time_attr='time',\n        edge_label=torch.ones(data.num_edges),\n        edge_label_time=edge_label_time,\n        batch_size=batch_size,\n        shuffle=True,\n    )\n\n    for batch in loader:\n        assert batch.edge_label_index.size() == (2, batch_size)\n        assert batch.edge_label_time.size() == (batch_size, )\n        assert batch.edge_label.size() == (batch_size, )\n        assert torch.all(batch.time <= batch.edge_label_time)\n\n\n@withPackage('pyg_lib')\ndef test_temporal_hetero_link_neighbor_loader():\n    data = HeteroData()\n\n    data['paper'].x = torch.arange(100)\n    data['paper'].time = torch.arange(data['paper'].num_nodes) - 200\n    data['author'].x = torch.arange(100, 300)\n    data['author'].time = torch.arange(data['author'].num_nodes)\n\n    data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500)\n    data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000)\n    data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000)\n\n    with pytest.raises(ValueError, match=r\"'edge_label_time' is not set\"):\n        loader = LinkNeighborLoader(\n            data,\n            num_neighbors=[-1] * 2,\n            edge_label_index=('paper', 'paper'),\n            batch_size=32,\n            time_attr='time',\n        )\n\n    # With edge_time:\n    edge_time = torch.arange(data['paper', 'paper'].edge_index.size(1))\n    loader = LinkNeighborLoader(\n        data,\n        num_neighbors=[-1] * 2,\n        edge_label_index=('paper', 'paper'),\n        edge_label_time=edge_time,\n        batch_size=32,\n        time_attr='time',\n        neg_sampling_ratio=0.5,\n        drop_last=True,\n    )\n    for batch in loader:\n        # Check if each seed edge has a different batch:\n        assert int(batch['paper'].batch.max()) + 1 == 32\n\n        author_max = batch['author'].time.max()\n        edge_max = batch['paper', 'paper'].edge_label_time.max()\n        assert edge_max >= author_max\n        author_min = batch['author'].time.min()\n        edge_min = batch['paper', 'paper'].edge_label_time.min()\n        assert edge_min >= author_min\n\n\n@onlyNeighborSampler\ndef test_custom_hetero_link_neighbor_loader():\n    data = HeteroData()\n    feature_store = MyFeatureStore()\n    graph_store = MyGraphStore()\n\n    # Set up node features:\n    x = torch.arange(100)\n    data['paper'].x = x\n    feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None)\n\n    x = torch.arange(100, 300)\n    data['author'].x = x\n    feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)\n\n    # Set up edge indices (GraphStore does not support `edge_attr` at the\n    # moment):\n    edge_index = get_random_edge_index(100, 100, 500)\n    data['paper', 'to', 'paper'].edge_index = edge_index\n    graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),\n                               edge_type=('paper', 'to', 'paper'),\n                               layout='coo', size=(100, 100))\n\n    edge_index = get_random_edge_index(100, 200, 1000)\n    data['paper', 'to', 'author'].edge_index = edge_index\n    graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),\n                               edge_type=('paper', 'to', 'author'),\n                               layout='coo', size=(100, 200))\n\n    edge_index = get_random_edge_index(200, 100, 1000)\n    data['author', 'to', 'paper'].edge_index = edge_index\n    graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]),\n                               edge_type=('author', 'to', 'paper'),\n                               layout='coo', size=(200, 100))\n\n    loader1 = LinkNeighborLoader(\n        data,\n        num_neighbors=[-1] * 2,\n        edge_label_index=('paper', 'to', 'author'),\n        batch_size=20,\n    )\n\n    loader2 = LinkNeighborLoader(\n        (feature_store, graph_store),\n        num_neighbors=[-1] * 2,\n        edge_label_index=('paper', 'to', 'author'),\n        batch_size=20,\n    )\n\n    assert str(loader1) == str(loader2)\n\n    for (batch1, batch2) in zip(loader1, loader2):\n        # Mapped indices of neighbors may be differently sorted:\n        assert torch.allclose(batch1['paper'].x.sort()[0],\n                              batch2['paper'].x.sort()[0])\n        assert torch.allclose(batch1['author'].x.sort()[0],\n                              batch2['author'].x.sort()[0])\n\n        # Assert that edge indices have the same size:\n        assert (batch1['paper', 'to', 'paper'].edge_index.size() == batch1[\n            'paper', 'to', 'paper'].edge_index.size())\n        assert (batch1['paper', 'to', 'author'].edge_index.size() == batch1[\n            'paper', 'to', 'author'].edge_index.size())\n        assert (batch1['author', 'to', 'paper'].edge_index.size() == batch1[\n            'author', 'to', 'paper'].edge_index.size())\n\n\n@onlyNeighborSampler\ndef test_homo_link_neighbor_loader_no_edges():\n    loader = LinkNeighborLoader(\n        Data(num_nodes=100),\n        num_neighbors=[],\n        batch_size=20,\n        edge_label_index=get_random_edge_index(100, 100, 100),\n    )\n\n    for batch in loader:\n        assert isinstance(batch, Data)\n        assert batch.input_id.numel() == 20\n        assert batch.edge_label_index.size(1) == 20\n        assert batch.num_nodes == batch.edge_label_index.unique().numel()\n\n\n@onlyNeighborSampler\ndef test_hetero_link_neighbor_loader_no_edges():\n    loader = LinkNeighborLoader(\n        HeteroData(paper=dict(num_nodes=100)),\n        num_neighbors=[],\n        edge_label_index=(\n            ('paper', 'paper'),\n            get_random_edge_index(100, 100, 100),\n        ),\n        batch_size=20,\n    )\n\n    for batch in loader:\n        assert isinstance(batch, HeteroData)\n        assert batch['paper', 'paper'].input_id.numel() == 20\n        assert batch['paper', 'paper'].edge_label_index.size(1) == 20\n        assert batch['paper'].num_nodes == batch[\n            'paper', 'paper'].edge_label_index.unique().numel()\n\n\n@withPackage('pyg_lib')\n@pytest.mark.parametrize('disjoint', [False, True])\n@pytest.mark.parametrize('temporal', [False, True])\n@pytest.mark.parametrize('amount', [1, 2])\ndef test_homo_link_neighbor_loader_triplet(disjoint, temporal, amount):\n    if not disjoint and temporal:\n        return\n\n    data = Data()\n    data.x = torch.arange(100)\n    data.edge_index = get_random_edge_index(100, 100, 400)\n    data.edge_label_index = get_random_edge_index(100, 100, 500)\n    data.edge_attr = torch.arange(data.num_edges)\n\n    time_attr = edge_label_time = None\n    if temporal:\n        time_attr = 'time'\n        data.time = torch.arange(data.num_nodes)\n\n        edge_label_time = torch.max(data.time[data.edge_label_index[0]],\n                                    data.time[data.edge_label_index[1]])\n        edge_label_time = edge_label_time + 50\n\n    batch_size = 20\n    loader = LinkNeighborLoader(\n        data,\n        num_neighbors=[-1] * 2,\n        batch_size=batch_size,\n        edge_label_index=data.edge_label_index,\n        edge_label_time=edge_label_time,\n        time_attr=time_attr,\n        disjoint=disjoint,\n        neg_sampling=dict(mode='triplet', amount=amount),\n        shuffle=True,\n    )\n\n    assert str(loader) == 'LinkNeighborLoader()'\n    assert len(loader) == 500 / batch_size\n\n    for batch in loader:\n        assert isinstance(batch, Data)\n\n        # Check that `src_index` and `dst_pos_index` point to valid edges:\n        assert torch.equal(batch.x[batch.src_index],\n                           data.edge_label_index[0, batch.input_id])\n        assert torch.equal(batch.x[batch.dst_pos_index],\n                           data.edge_label_index[1, batch.input_id])\n\n        # Check that `dst_neg_index` points to valid nodes in the batch:\n        if amount == 1:\n            assert batch.dst_neg_index.size() == (batch_size, )\n        else:\n            assert batch.dst_neg_index.size() == (batch_size, amount)\n        assert batch.dst_neg_index.min() >= 0\n        assert batch.dst_neg_index.max() < batch.num_nodes\n\n        if disjoint:\n            # In disjoint mode, seed nodes should always be placed first:\n            assert batch.src_index.min() == 0\n            assert batch.src_index.max() == batch_size - 1\n\n            assert batch.dst_pos_index.min() == batch_size\n            assert batch.dst_pos_index.max() == 2 * batch_size - 1\n\n            assert batch.dst_neg_index.min() == 2 * batch_size\n            max_seed_nodes = 2 * batch_size + batch_size * amount\n            assert batch.dst_neg_index.max() == max_seed_nodes - 1\n\n            assert batch.batch.min() == 0\n            assert batch.batch.max() == batch_size - 1\n\n            # Check that `batch` is always increasing:\n            for i in range(0, max_seed_nodes, batch_size):\n                batch_vector = batch.batch[i:i + batch_size]\n                assert torch.equal(batch_vector, torch.arange(batch_size))\n\n        if temporal:\n            for i in range(batch_size):\n                assert batch.time[batch.batch == i].max() <= batch.seed_time[i]\n\n\n@withPackage('pyg_lib')\n@pytest.mark.parametrize('disjoint', [False, True])\n@pytest.mark.parametrize('temporal', [False, True])\n@pytest.mark.parametrize('amount', [1, 2])\ndef test_hetero_link_neighbor_loader_triplet(disjoint, temporal, amount):\n    if not disjoint and temporal:\n        return\n\n    data = HeteroData()\n\n    data['paper'].x = torch.arange(100)\n    data['author'].x = torch.arange(100, 300)\n\n    data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 400)\n    edge_label_index = get_random_edge_index(100, 100, 500)\n    data['paper', 'paper'].edge_label_index = edge_label_index\n    data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000)\n    data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000)\n\n    time_attr = edge_label_time = None\n    if temporal:\n        time_attr = 'time'\n        data['paper'].time = torch.arange(data['paper'].num_nodes)\n        data['author'].time = torch.arange(data['author'].num_nodes)\n\n        edge_label_time = torch.max(\n            data['paper'].time[data['paper', 'paper'].edge_label_index[0]],\n            data['paper'].time[data['paper', 'paper'].edge_label_index[1]],\n        )\n        edge_label_time = edge_label_time + 50\n\n    weight = torch.rand(data['paper'].num_nodes) if not temporal else None\n\n    batch_size = 20\n    index = (('paper', 'paper'), data['paper', 'paper'].edge_label_index)\n    loader = LinkNeighborLoader(\n        data,\n        num_neighbors=[-1] * 2,\n        batch_size=batch_size,\n        edge_label_index=index,\n        edge_label_time=edge_label_time,\n        time_attr=time_attr,\n        disjoint=disjoint,\n        neg_sampling=dict(\n            mode='triplet',\n            amount=amount,\n            src_weight=weight,\n            dst_weight=weight,\n        ),\n        shuffle=True,\n    )\n\n    assert str(loader) == 'LinkNeighborLoader()'\n    assert len(loader) == 500 / batch_size\n\n    for batch in loader:\n        assert isinstance(batch, HeteroData)\n\n        node_store = batch['paper']\n        edge_store = batch['paper', 'paper']\n\n        # Check that `src_index` and `dst_pos_index` point to valid edges:\n        assert torch.equal(\n            node_store.x[node_store.src_index],\n            data['paper', 'paper'].edge_label_index[0, edge_store.input_id])\n        assert torch.equal(\n            node_store.x[node_store.dst_pos_index],\n            data['paper', 'paper'].edge_label_index[1, edge_store.input_id])\n\n        # Check that `dst_neg_index` points to valid nodes in the batch:\n        if amount == 1:\n            assert node_store.dst_neg_index.size() == (batch_size, )\n        else:\n            assert node_store.dst_neg_index.size() == (batch_size, amount)\n        assert node_store.dst_neg_index.min() >= 0\n        assert node_store.dst_neg_index.max() < node_store.num_nodes\n\n        if disjoint:\n            # In disjoint mode, seed nodes should always be placed first:\n            assert node_store.src_index.min() == 0\n            assert node_store.src_index.max() == batch_size - 1\n\n            assert node_store.dst_pos_index.min() == batch_size\n            assert node_store.dst_pos_index.max() == 2 * batch_size - 1\n\n            assert node_store.dst_neg_index.min() == 2 * batch_size\n            max_seed_nodes = 2 * batch_size + batch_size * amount\n            assert node_store.dst_neg_index.max() == max_seed_nodes - 1\n\n            assert node_store.batch.min() == 0\n            assert node_store.batch.max() == batch_size - 1\n\n            # Check that `batch` is always increasing:\n            for i in range(0, max_seed_nodes, batch_size):\n                batch_vector = node_store.batch[i:i + batch_size]\n                assert torch.equal(batch_vector, torch.arange(batch_size))\n\n        if temporal:\n            for i in range(batch_size):\n                assert (node_store.time[node_store.batch == i].max()\n                        <= node_store.seed_time[i])\n\n\n@withPackage('pyg_lib')\ndef test_link_neighbor_loader_mapping():\n    edge_index = torch.tensor([\n        [0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 5],\n        [1, 2, 3, 4, 5, 8, 6, 7, 9, 10, 6, 11],\n    ])\n    data = Data(edge_index=edge_index, num_nodes=12)\n\n    loader = LinkNeighborLoader(\n        data,\n        edge_label_index=data.edge_index,\n        num_neighbors=[1],\n        batch_size=2,\n        shuffle=True,\n    )\n\n    for batch in loader:\n        assert torch.equal(\n            batch.n_id[batch.edge_index],\n            data.edge_index[:, batch.e_id],\n        )\n"
  },
  {
    "path": "test/loader/test_mixin.py",
    "content": "import subprocess\nfrom time import sleep\n\nimport psutil\nimport pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.testing import onlyLinux, onlyNeighborSampler\n\n\n@pytest.mark.xfail(reason=\"TODO: Fix test\")\n@onlyLinux\n@onlyNeighborSampler\n@pytest.mark.skipif(\n    psutil.cpu_count(logical=False) == 1, reason=\"Requires multiple CPU cores\")\n@pytest.mark.parametrize('loader_cores', [None, [1, 2]])\ndef test_cpu_affinity_neighbor_loader(loader_cores, spawn_context):\n    data = Data(x=torch.randn(1, 1))\n    loader = NeighborLoader(data, num_neighbors=[-1], batch_size=1,\n                            num_workers=2)\n    out = []\n    with loader.enable_cpu_affinity(loader_cores):\n        iterator = loader._get_iterator()\n        workers = iterator._workers\n        sleep(3)  # Gives time for worker to initialize.\n        for worker in workers:\n            process = subprocess.Popen(\n                ['taskset', '-c', '-p', f'{worker.pid}'],\n                stdout=subprocess.PIPE)\n            stdout = process.communicate()[0].decode('utf-8')\n            # returns \"pid <pid>'s current affinity list <n>-<m>\"\n            out.append(stdout.split(':')[1].strip())\n        if loader_cores:\n            assert out == ['[1]', '[2]']\n        else:\n            assert out[0] != out[1]\n\n\ndef init_fn(worker_id):\n    assert torch.get_num_threads() == 2\n\n\n@onlyLinux\n@onlyNeighborSampler\n@pytest.mark.skipif(\n    psutil.cpu_count(logical=False) == 1, reason=\"Requires multiple CPU cores\")\ndef test_multithreading_neighbor_loader(spawn_context):\n    loader = NeighborLoader(\n        data=Data(x=torch.randn(1, 1)),\n        num_neighbors=[-1],\n        batch_size=1,\n        num_workers=2,\n        worker_init_fn=init_fn,\n    )\n\n    with loader.enable_multithreading(2):\n        loader._get_iterator()  # Runs assertion in `init_fn`.\n"
  },
  {
    "path": "test/loader/test_neighbor_loader.py",
    "content": "import os.path as osp\n\nimport numpy as np\nimport pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import GraphConv, to_hetero\nfrom torch_geometric.sampler.base import SubgraphType\nfrom torch_geometric.testing import (\n    MyFeatureStore,\n    MyGraphStore,\n    get_random_edge_index,\n    get_random_tensor_frame,\n    onlyLinux,\n    onlyNeighborSampler,\n    onlyOnline,\n    withCUDA,\n    withPackage,\n)\nfrom torch_geometric.typing import (\n    WITH_EDGE_TIME_NEIGHBOR_SAMPLE,\n    WITH_PYG_LIB,\n    WITH_TORCH_SPARSE,\n    WITH_WEIGHTED_NEIGHBOR_SAMPLE,\n    TensorFrame,\n)\nfrom torch_geometric.utils import (\n    is_undirected,\n    sort_edge_index,\n    to_torch_csr_tensor,\n    to_undirected,\n)\n\nDTYPES = [\n    pytest.param(torch.int64, id='int64'),\n    pytest.param(torch.int32, id='int32'),\n]\n\nSUBGRAPH_TYPES = [\n    pytest.param(SubgraphType.directional, id='directional'),\n    pytest.param(SubgraphType.bidirectional, id='bidirectional'),\n    pytest.param(SubgraphType.induced, id='induced'),\n]\n\nFILTER_PER_WORKERS = [\n    pytest.param(None, id='auto_filter'),\n    pytest.param(True, id='filter_per_worker'),\n    pytest.param(False, id='filter_in_main'),\n]\n\n\ndef is_subset(subedge_index, edge_index, src_idx, dst_idx):\n    num_nodes = int(edge_index.max()) + 1\n    idx = num_nodes * edge_index[0] + edge_index[1]\n    subidx = num_nodes * src_idx[subedge_index[0]] + dst_idx[subedge_index[1]]\n    mask = torch.from_numpy(np.isin(subidx.cpu().numpy(), idx.cpu().numpy()))\n    return int(mask.sum()) == mask.numel()\n\n\n@withCUDA\n@onlyNeighborSampler\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES)\n@pytest.mark.parametrize('filter_per_worker', FILTER_PER_WORKERS)\ndef test_homo_neighbor_loader_basic(\n    device,\n    subgraph_type,\n    dtype,\n    filter_per_worker,\n):\n    if dtype != torch.int64 and not torch_geometric.typing.WITH_PT20:\n        return\n    induced = SubgraphType.induced\n    if subgraph_type == SubgraphType.induced and not WITH_TORCH_SPARSE:\n        return\n    if dtype != torch.int64 and (not WITH_PYG_LIB or subgraph_type == induced):\n        return\n\n    torch.manual_seed(12345)\n\n    data = Data()\n\n    data.x = torch.arange(100, device=device)\n    data.edge_index = get_random_edge_index(100, 100, 500, dtype, device)\n    data.edge_attr = torch.arange(500, device=device)\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[5] * 2,\n        batch_size=20,\n        subgraph_type=subgraph_type,\n        filter_per_worker=filter_per_worker,\n    )\n\n    assert str(loader) == 'NeighborLoader()'\n    assert len(loader) == 5\n\n    batch = loader([0])\n    assert isinstance(batch, Data)\n    assert batch.n_id[:1].tolist() == [0]\n\n    for i, batch in enumerate(loader):\n        assert isinstance(batch, Data)\n        assert batch.x.device == device\n        assert batch.x.size(0) <= 100\n        assert batch.n_id.size() == (batch.num_nodes, )\n        assert batch.input_id.numel() == batch.batch_size == 20\n        assert batch.x.min() >= 0 and batch.x.max() < 100\n        # TODO Re-enable once `EdgeIndex` is stable.\n        assert not isinstance(batch.edge_index, EdgeIndex)\n        # batch.edge_index.validate()\n        # size = (batch.num_nodes, batch.num_nodes)\n        # assert batch.edge_index.sparse_size() == size\n        # assert batch.edge_index.sort_order == 'col'\n        assert batch.edge_index.device == device\n        assert batch.edge_index.min() >= 0\n        assert batch.edge_index.max() < batch.num_nodes\n        assert batch.edge_attr.device == device\n        assert batch.edge_attr.size(0) == batch.edge_index.size(1)\n\n        # Input nodes are always sampled first:\n        assert torch.equal(\n            batch.x[:batch.batch_size],\n            torch.arange(i * batch.batch_size, (i + 1) * batch.batch_size,\n                         device=device),\n        )\n\n        if subgraph_type != SubgraphType.bidirectional:\n            assert batch.e_id.size() == (batch.num_edges, )\n            assert batch.edge_attr.min() >= 0\n            assert batch.edge_attr.max() < 500\n\n            assert is_subset(\n                batch.edge_index.to(torch.int64),\n                data.edge_index.to(torch.int64),\n                batch.x,\n                batch.x,\n            )\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES)\ndef test_hetero_neighbor_loader_basic(subgraph_type, dtype):\n    if dtype != torch.int64 and not torch_geometric.typing.WITH_PT20:\n        return\n    induced = SubgraphType.induced\n    if subgraph_type == SubgraphType.induced and not WITH_TORCH_SPARSE:\n        return\n    if dtype != torch.int64 and (not WITH_PYG_LIB or subgraph_type == induced):\n        return\n\n    torch.manual_seed(12345)\n\n    data = HeteroData()\n\n    data['paper'].x = torch.arange(100)\n    data['author'].x = torch.arange(100, 300)\n\n    edge_index = get_random_edge_index(100, 100, 500, dtype)\n    data['paper', 'paper'].edge_index = edge_index\n    data['paper', 'paper'].edge_attr = torch.arange(500)\n    edge_index = get_random_edge_index(100, 200, 1000, dtype)\n    data['paper', 'author'].edge_index = edge_index\n    data['paper', 'author'].edge_attr = torch.arange(500, 1500)\n    edge_index = get_random_edge_index(200, 100, 1000, dtype)\n    data['author', 'paper'].edge_index = edge_index\n    data['author', 'paper'].edge_attr = torch.arange(1500, 2500)\n\n    r1, c1 = data['paper', 'paper'].edge_index\n    r2, c2 = data['paper', 'author'].edge_index + torch.tensor([[0], [100]])\n    r3, c3 = data['author', 'paper'].edge_index + torch.tensor([[100], [0]])\n\n    batch_size = 20\n\n    with pytest.raises(ValueError, match=\"hops must be the same across all\"):\n        loader = NeighborLoader(\n            data,\n            num_neighbors={\n                ('paper', 'to', 'paper'): [-1],\n                ('paper', 'to', 'author'): [-1, -1],\n                ('author', 'to', 'paper'): [-1, -1],\n            },\n            input_nodes='paper',\n            batch_size=batch_size,\n            subgraph_type=subgraph_type,\n        )\n        next(iter(loader))\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[10] * 2,\n        input_nodes='paper',\n        batch_size=batch_size,\n        subgraph_type=subgraph_type,\n    )\n\n    assert str(loader) == 'NeighborLoader()'\n    assert len(loader) == (100 + batch_size - 1) // batch_size\n\n    for batch in loader:\n        assert isinstance(batch, HeteroData)\n        assert batch.input_type == 'paper'\n\n        # Test node type selection:\n        assert set(batch.node_types) == {'paper', 'author'}\n\n        assert batch['paper'].n_id.size() == (batch['paper'].num_nodes, )\n        assert batch['paper'].x.size(0) <= 100\n        assert batch['paper'].input_id.numel() == batch_size\n        assert batch['paper'].batch_size == batch_size\n        assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100\n\n        assert batch['author'].n_id.size() == (batch['author'].num_nodes, )\n        assert batch['author'].x.size(0) <= 200\n        assert batch['author'].x.min() >= 100 and batch['author'].x.max() < 300\n\n        # Test edge type selection:\n        assert set(batch.edge_types) == {('paper', 'to', 'paper'),\n                                         ('paper', 'to', 'author'),\n                                         ('author', 'to', 'paper')}\n\n        for edge_type, edge_index in batch.edge_index_dict.items():\n            src, _, dst = edge_type\n            # TODO Re-enable once `EdgeIndex` is stable.\n            assert not isinstance(edge_index, EdgeIndex)\n            # edge_index.validate()\n            # size = (batch[src].num_nodes, batch[dst].num_nodes)\n            # assert edge_index.sparse_size() == size\n            # assert edge_index.sort_order == 'col'\n\n        row, col = batch['paper', 'paper'].edge_index\n        assert row.min() >= 0 and row.max() < batch['paper'].num_nodes\n        assert col.min() >= 0 and col.max() < batch['paper'].num_nodes\n\n        if subgraph_type != SubgraphType.bidirectional:\n            assert batch['paper', 'paper'].e_id.size() == (row.numel(), )\n            value = batch['paper', 'paper'].edge_attr\n            assert value.min() >= 0 and value.max() < 500\n\n            assert is_subset(\n                batch['paper', 'paper'].edge_index.to(torch.int64),\n                data['paper', 'paper'].edge_index.to(torch.int64),\n                batch['paper'].x,\n                batch['paper'].x,\n            )\n        elif subgraph_type != SubgraphType.directional:\n            assert 'e_id' not in batch['paper', 'paper']\n            assert 'edge_attr' not in batch['paper', 'paper']\n\n            assert is_undirected(batch['paper', 'paper'].edge_index)\n\n        row, col = batch['paper', 'author'].edge_index\n        assert row.min() >= 0 and row.max() < batch['paper'].num_nodes\n        assert col.min() >= 0 and col.max() < batch['author'].num_nodes\n\n        if subgraph_type != SubgraphType.bidirectional:\n            assert batch['paper', 'author'].e_id.size() == (row.numel(), )\n            value = batch['paper', 'author'].edge_attr\n            assert value.min() >= 500 and value.max() < 1500\n\n            assert is_subset(\n                batch['paper', 'author'].edge_index.to(torch.int64),\n                data['paper', 'author'].edge_index.to(torch.int64),\n                batch['paper'].x,\n                batch['author'].x - 100,\n            )\n        elif subgraph_type != SubgraphType.directional:\n            assert 'e_id' not in batch['paper', 'author']\n            assert 'edge_attr' not in batch['paper', 'author']\n\n            edge_index1 = batch['paper', 'author'].edge_index\n            edge_index2 = batch['author', 'paper'].edge_index\n            assert torch.equal(\n                edge_index1,\n                sort_edge_index(edge_index2.flip([0]), sort_by_row=False),\n            )\n\n        row, col = batch['author', 'paper'].edge_index\n        assert row.min() >= 0 and row.max() < batch['author'].num_nodes\n        assert col.min() >= 0 and col.max() < batch['paper'].num_nodes\n\n        if subgraph_type != SubgraphType.bidirectional:\n            assert batch['author', 'paper'].e_id.size() == (row.numel(), )\n            value = batch['author', 'paper'].edge_attr\n            assert value.min() >= 1500 and value.max() < 2500\n\n            assert is_subset(\n                batch['author', 'paper'].edge_index.to(torch.int64),\n                data['author', 'paper'].edge_index.to(torch.int64),\n                batch['author'].x - 100,\n                batch['paper'].x,\n            )\n        elif subgraph_type != SubgraphType.directional:\n            assert 'e_id' not in batch['author', 'paper']\n            assert 'edge_attr' not in batch['author', 'paper']\n\n            edge_index1 = batch['author', 'paper'].edge_index\n            edge_index2 = batch['paper', 'author'].edge_index\n            assert torch.equal(\n                edge_index1,\n                sort_edge_index(edge_index2.flip([0]), sort_by_row=False),\n            )\n\n        # Test for isolated nodes (there shouldn't exist any):\n        assert not batch.has_isolated_nodes()\n\n\n@onlyOnline\n@onlyNeighborSampler\n@pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES)\ndef test_homo_neighbor_loader_on_karate(get_dataset, subgraph_type):\n    if subgraph_type == SubgraphType.induced and not WITH_TORCH_SPARSE:\n        return\n    dataset = get_dataset(name='karate')\n    data = dataset[0]\n\n    mask = data.edge_index[0] < data.edge_index[1]\n    edge_index = data.edge_index[:, mask]\n    edge_weight = torch.rand(edge_index.size(1))\n    data.edge_index, data.edge_weight = to_undirected(edge_index, edge_weight)\n\n    split_idx = torch.arange(5, 8)\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[-1, -1],\n        batch_size=split_idx.numel(),\n        input_nodes=split_idx,\n        subgraph_type=subgraph_type,\n    )\n    assert len(loader) == 1\n\n    batch = next(iter(loader))\n    batch_size = batch.batch_size\n\n    class GNN(torch.nn.Module):\n        def __init__(self, in_channels, hidden_channels, out_channels):\n            super().__init__()\n            self.conv1 = GraphConv(in_channels, hidden_channels)\n            self.conv2 = GraphConv(hidden_channels, out_channels)\n\n        def forward(self, x, edge_index, edge_weight):\n            x = self.conv1(x, edge_index, edge_weight).relu()\n            x = self.conv2(x, edge_index, edge_weight)\n            return x\n\n    model = GNN(dataset.num_features, 16, dataset.num_classes)\n\n    out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx]\n    out2 = model(batch.x, batch.edge_index, batch.edge_weight)[:batch_size]\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n\n@onlyOnline\n@onlyNeighborSampler\n@pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES)\ndef test_hetero_neighbor_loader_on_karate(get_dataset, subgraph_type):\n    if subgraph_type == SubgraphType.induced and not WITH_TORCH_SPARSE:\n        return\n    dataset = get_dataset(name='karate')\n    data = dataset[0]\n\n    hetero_data = HeteroData()\n    hetero_data['v'].x = data.x\n    hetero_data['v', 'v'].edge_index = data.edge_index\n\n    split_idx = torch.arange(5, 8)\n\n    loader = NeighborLoader(\n        hetero_data,\n        num_neighbors=[-1, -1],\n        batch_size=split_idx.numel(),\n        input_nodes=('v', split_idx),\n        subgraph_type=subgraph_type,\n    )\n    assert len(loader) == 1\n\n    hetero_batch = next(iter(loader))\n    batch_size = hetero_batch['v'].batch_size\n\n    class GNN(torch.nn.Module):\n        def __init__(self, in_channels, hidden_channels, out_channels):\n            super().__init__()\n            self.conv1 = GraphConv(in_channels, hidden_channels)\n            self.conv2 = GraphConv(hidden_channels, out_channels)\n\n        def forward(self, x, edge_index):\n            x = self.conv1(x, edge_index).relu()\n            x = self.conv2(x, edge_index)\n            return x\n\n    model = GNN(dataset.num_features, 16, dataset.num_classes)\n    hetero_model = to_hetero(model, hetero_data.metadata())\n\n    out1 = model(data.x, data.edge_index)[split_idx]\n    out2 = hetero_model(hetero_batch.x_dict,\n                        hetero_batch.edge_index_dict)['v'][:batch_size]\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n\n@onlyOnline\n@withPackage('pyg_lib')\ndef test_temporal_hetero_neighbor_loader_on_karate(get_dataset):\n    dataset = get_dataset(name='karate')\n    data = dataset[0]\n\n    hetero_data = HeteroData()\n    hetero_data['v'].x = data.x\n    hetero_data['v'].time = torch.arange(data.num_nodes, 0, -1)\n    hetero_data['v', 'v'].edge_index = data.edge_index\n\n    loader = NeighborLoader(hetero_data, num_neighbors=[-1, -1],\n                            input_nodes='v', time_attr='time', batch_size=1)\n\n    for batch in loader:\n        mask = batch['v'].time[0] >= batch['v'].time[1:]\n        assert torch.all(mask)\n\n\n@onlyNeighborSampler\ndef test_custom_neighbor_loader():\n    # Initialize feature store, graph store, and reference:\n    feature_store = MyFeatureStore()\n    graph_store = MyGraphStore()\n\n    # Set up node features:\n    x = torch.arange(100, 300)\n    feature_store.put_tensor(x, group_name=None, attr_name='x', index=None)\n\n    y = torch.arange(100, 300)\n    feature_store.put_tensor(y, group_name=None, attr_name='y', index=None)\n\n    # COO:\n    edge_index = get_random_edge_index(100, 100, 500, coalesce=True)\n    edge_index = edge_index[:, torch.randperm(edge_index.size(1))]\n    coo = (edge_index[0], edge_index[1])\n    graph_store.put_edge_index(edge_index=coo, edge_type=None, layout='coo',\n                               size=(100, 100))\n\n    data = Data(x=x, edge_index=edge_index, y=y, num_nodes=200)\n\n    # Construct neighbor loaders:\n    loader1 = NeighborLoader(data, batch_size=20,\n                             input_nodes=torch.arange(100),\n                             num_neighbors=[-1] * 2)\n\n    loader2 = NeighborLoader((feature_store, graph_store), batch_size=20,\n                             input_nodes=torch.arange(100),\n                             num_neighbors=[-1] * 2)\n\n    assert str(loader1) == str(loader2)\n    assert len(loader1) == len(loader2)\n\n    for batch1, batch2 in zip(loader1, loader2):\n        assert len(batch1) == len(batch2)\n        assert batch1.num_nodes == batch2.num_nodes\n        assert batch1.num_edges == batch2.num_edges\n        assert batch1.batch_size == batch2.batch_size\n\n        # Mapped indices of neighbors may be differently sorted ...\n        assert torch.allclose(batch1.x.sort()[0], batch2.x.sort()[0])\n        assert torch.allclose(batch1.y.sort()[0], batch2.y.sort()[0])\n\n\n@onlyNeighborSampler\ndef test_custom_hetero_neighbor_loader():\n    # Initialize feature store, graph store, and reference:\n    feature_store = MyFeatureStore()\n    graph_store = MyGraphStore()\n    data = HeteroData()\n\n    # Set up node features:\n    x = torch.arange(100)\n    data['paper'].x = x\n    feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None)\n\n    x = torch.arange(100, 300)\n    data['author'].x = x\n    feature_store.put_tensor(x, group_name='author', attr_name='x', index=None)\n\n    # COO:\n    edge_index = get_random_edge_index(100, 100, 500, coalesce=True)\n    edge_index = edge_index[:, torch.randperm(edge_index.size(1))]\n    data['paper', 'to', 'paper'].edge_index = edge_index\n    coo = (edge_index[0], edge_index[1])\n    graph_store.put_edge_index(edge_index=coo,\n                               edge_type=('paper', 'to', 'paper'),\n                               layout='coo', size=(100, 100))\n\n    # CSR:\n    edge_index = get_random_edge_index(100, 200, 1000, coalesce=True)\n    data['paper', 'to', 'author'].edge_index = edge_index\n    adj = to_torch_csr_tensor(edge_index, size=(100, 200))\n    csr = (adj.crow_indices(), adj.col_indices())\n    graph_store.put_edge_index(edge_index=csr,\n                               edge_type=('paper', 'to', 'author'),\n                               layout='csr', size=(100, 200))\n\n    # CSC:\n    edge_index = get_random_edge_index(200, 100, 1000, coalesce=True)\n    data['author', 'to', 'paper'].edge_index = edge_index\n    adj = to_torch_csr_tensor(edge_index.flip([0]), size=(100, 200))\n    csc = (adj.col_indices(), adj.crow_indices())\n    graph_store.put_edge_index(edge_index=csc,\n                               edge_type=('author', 'to', 'paper'),\n                               layout='csc', size=(200, 100))\n\n    # COO (sorted):\n    edge_index = get_random_edge_index(200, 200, 100, coalesce=True)\n    edge_index = edge_index[:, edge_index[1].argsort()]\n    data['author', 'to', 'author'].edge_index = edge_index\n    coo = (edge_index[0], edge_index[1])\n    graph_store.put_edge_index(edge_index=coo,\n                               edge_type=('author', 'to', 'author'),\n                               layout='coo', size=(200, 200), is_sorted=True)\n\n    # Construct neighbor loaders:\n    loader1 = NeighborLoader(data, batch_size=20,\n                             input_nodes=('paper', range(100)),\n                             num_neighbors=[-1] * 2)\n\n    loader2 = NeighborLoader((feature_store, graph_store), batch_size=20,\n                             input_nodes=('paper', range(100)),\n                             num_neighbors=[-1] * 2)\n\n    assert str(loader1) == str(loader2)\n    assert len(loader1) == len(loader2)\n\n    for batch1, batch2 in zip(loader1, loader2):\n        # `loader2` explicitly adds `num_nodes` to the batch:\n        assert len(batch1) + 1 == len(batch2)\n        assert batch1['paper'].batch_size == batch2['paper'].batch_size\n\n        # Mapped indices of neighbors may be differently sorted ...\n        for node_type in data.node_types:\n            assert torch.allclose(\n                batch1[node_type].x.sort()[0],\n                batch2[node_type].x.sort()[0],\n            )\n\n        # ... but should sample the exact same number of edges:\n        for edge_type in data.edge_types:\n            assert batch1[edge_type].num_edges == batch2[edge_type].num_edges\n\n\n@onlyNeighborSampler\ndef test_custom_hetero_neighbor_loader_duplicate():\n    feature_store = MyFeatureStore()\n    graph_store = MyGraphStore()\n\n    x = torch.arange(10)\n    feature_store.put_tensor(x, group_name='user', attr_name='x', index=None)\n\n    edge_index = get_random_edge_index(10, 10, 20, coalesce=True)\n    graph_store.put_edge_index(\n        edge_index=(edge_index[0], edge_index[1]),\n        edge_type=('user', 'user', 'user'),\n        layout='coo',\n        size=(10, 10),\n    )\n\n    loader = NeighborLoader(\n        (feature_store, graph_store),\n        batch_size=10,\n        input_nodes=('user', range(10)),\n        num_neighbors=[-1] * 2,\n    )\n    batch = next(iter(loader))\n\n    assert batch.node_types == ['user']\n    assert batch['user'].num_nodes == 10\n    assert batch.edge_types == [('user', 'user', 'user')]\n    assert batch['user', 'user'].num_edges == edge_index.size(1)\n\n\n@onlyOnline\n@withPackage('pyg_lib')\ndef test_temporal_custom_neighbor_loader_on_karate(get_dataset):\n    dataset = get_dataset(name='karate')\n    data = dataset[0]\n    data.time = torch.arange(data.num_nodes, 0, -1)\n\n    # Initialize feature store, graph store, and reference:\n    feature_store = MyFeatureStore()\n    graph_store = MyGraphStore()\n    hetero_data = HeteroData()\n\n    feature_store.put_tensor(\n        data.x,\n        group_name='v',\n        attr_name='x',\n        index=None,\n    )\n    hetero_data['v'].x = data.x\n\n    feature_store.put_tensor(\n        data.time,\n        group_name='v',\n        attr_name='time',\n        index=None,\n    )\n    hetero_data['v'].time = data.time\n\n    # Sort according to time in local neighborhoods:\n    row, col = data.edge_index\n    perm = ((col * (data.num_nodes + 1)) + data.time[row]).argsort()\n    edge_index = data.edge_index[:, perm]\n\n    graph_store.put_edge_index(\n        edge_index,\n        edge_type=('v', 'to', 'v'),\n        layout='coo',\n        is_sorted=True,\n        size=(data.num_nodes, data.num_nodes),\n    )\n    hetero_data['v', 'to', 'v'].edge_index = data.edge_index\n\n    loader1 = NeighborLoader(\n        hetero_data,\n        num_neighbors=[-1, -1],\n        input_nodes='v',\n        time_attr='time',\n        batch_size=128,\n    )\n\n    loader2 = NeighborLoader(\n        (feature_store, graph_store),\n        num_neighbors=[-1, -1],\n        input_nodes='v',\n        time_attr='time',\n        batch_size=128,\n    )\n\n    for batch1, batch2 in zip(loader1, loader2):\n        assert torch.equal(batch1['v'].time, batch2['v'].time)\n\n\n@withPackage('pyg_lib', 'torch_sparse')\ndef test_pyg_lib_and_torch_sparse_homo_equality():\n    edge_index = get_random_edge_index(20, 20, 100)\n    adj = to_torch_csr_tensor(edge_index.flip([0]), size=(20, 20))\n    colptr, row = adj.crow_indices(), adj.col_indices()\n\n    seed = torch.arange(10)\n\n    sample = torch.ops.pyg.neighbor_sample\n    out1 = sample(colptr, row, seed, [-1, -1], None, None, None, None, True)\n    sample = torch.ops.torch_sparse.neighbor_sample\n    out2 = sample(colptr, row, seed, [-1, -1], False, True)\n\n    row1, col1, node_id1, edge_id1 = out1[:4]\n    node_id2, row2, col2, edge_id2 = out2\n    assert torch.equal(node_id1, node_id2)\n    assert torch.equal(row1, row2)\n    assert torch.equal(col1, col2)\n    assert torch.equal(edge_id1, edge_id2)\n\n\n@withPackage('pyg_lib', 'torch_sparse')\ndef test_pyg_lib_and_torch_sparse_hetero_equality():\n    edge_index = get_random_edge_index(20, 10, 50)\n    adj = to_torch_csr_tensor(edge_index.flip([0]), size=(10, 20))\n    colptr1, row1 = adj.crow_indices(), adj.col_indices()\n\n    edge_index = get_random_edge_index(10, 20, 50)\n    adj = to_torch_csr_tensor(edge_index.flip([0]), size=(20, 10))\n    colptr2, row2 = adj.crow_indices(), adj.col_indices()\n\n    node_types = ['paper', 'author']\n    edge_types = [('paper', 'to', 'author'), ('author', 'to', 'paper')]\n    colptr_dict = {\n        'paper__to__author': colptr1,\n        'author__to__paper': colptr2,\n    }\n    row_dict = {\n        'paper__to__author': row1,\n        'author__to__paper': row2,\n    }\n    seed_dict = {'paper': torch.arange(1)}\n    num_neighbors_dict = {\n        'paper__to__author': [-1, -1],\n        'author__to__paper': [-1, -1],\n    }\n\n    sample = torch.ops.pyg.hetero_neighbor_sample\n    out1 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict,\n                  num_neighbors_dict, None, None, None, None, True, False,\n                  True, False, \"uniform\", True)\n    sample = torch.ops.torch_sparse.hetero_neighbor_sample\n    out2 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict,\n                  num_neighbors_dict, 2, False, True)\n\n    row1_dict, col1_dict, node_id1_dict, edge_id1_dict = out1[:4]\n    node_id2_dict, row2_dict, col2_dict, edge_id2_dict = out2\n    assert len(node_id1_dict) == len(node_id2_dict)\n    for key in node_id1_dict.keys():\n        assert torch.equal(node_id1_dict[key], node_id2_dict[key])\n    assert len(row1_dict) == len(row2_dict)\n    for key in row1_dict.keys():\n        assert torch.equal(row1_dict[key], row2_dict[key])\n    assert len(col1_dict) == len(col2_dict)\n    for key in col1_dict.keys():\n        assert torch.equal(col1_dict[key], col2_dict[key])\n    assert len(edge_id1_dict) == len(edge_id2_dict)\n    for key in edge_id1_dict.keys():\n        assert torch.equal(edge_id1_dict[key], edge_id2_dict[key])\n\n\n@onlyLinux\n@onlyNeighborSampler\ndef test_memmap_neighbor_loader(tmp_path):\n    path = osp.join(tmp_path, 'x.npy')\n    x = np.memmap(path, dtype=np.float32, mode='w+', shape=(100, 32))\n    x[:] = np.random.randn(100, 32)\n\n    data = Data()\n    data.x = np.memmap(path, dtype=np.float32, mode='r', shape=(100, 32))\n    data.edge_index = get_random_edge_index(100, 100, 500)\n\n    assert str(data) == 'Data(x=[100, 32], edge_index=[2, 500])'\n    assert data.num_nodes == 100\n\n    loader = NeighborLoader(data, num_neighbors=[5] * 2, batch_size=20,\n                            num_workers=2)\n    batch = next(iter(loader))\n    assert batch.num_nodes <= 100\n    assert isinstance(batch.x, torch.Tensor)\n    assert batch.x.size() == (batch.num_nodes, 32)\n\n\n@withPackage('pyg_lib')\ndef test_homo_neighbor_loader_sampled_info():\n    edge_index = torch.tensor([\n        [2, 3, 4, 5, 7, 7, 10, 11, 12, 13],\n        [0, 1, 2, 3, 2, 3, 7, 7, 7, 7],\n    ])\n\n    data = Data(edge_index=edge_index, num_nodes=14)\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[1, 2, 4],\n        batch_size=2,\n        shuffle=False,\n    )\n    batch = next(iter(loader))\n\n    assert batch.num_sampled_nodes == [2, 2, 3, 4]\n    assert batch.num_sampled_edges == [2, 4, 4]\n\n\n@withPackage('pyg_lib')\ndef test_hetero_neighbor_loader_sampled_info():\n    edge_index = torch.tensor([\n        [2, 3, 4, 5, 7, 7, 10, 11, 12, 13],\n        [0, 1, 2, 3, 2, 3, 7, 7, 7, 7],\n    ])\n\n    data = HeteroData()\n    data['paper'].num_nodes = data['author'].num_nodes = 14\n    data['paper', 'paper'].edge_index = edge_index\n    data['paper', 'author'].edge_index = edge_index\n    data['author', 'paper'].edge_index = edge_index\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[1, 2, 4],\n        batch_size=2,\n        input_nodes='paper',\n        shuffle=False,\n    )\n    batch = next(iter(loader))\n\n    expected_num_sampled_nodes = {\n        'paper': [2, 2, 3, 4],\n        'author': [0, 2, 3, 4],\n    }\n    expected_num_sampled_edges = {\n        ('paper', 'to', 'paper'): [2, 4, 4],\n        ('paper', 'to', 'author'): [0, 4, 4],\n        ('author', 'to', 'paper'): [2, 4, 4],\n    }\n\n    for node_type in batch.node_types:\n        assert (batch[node_type].num_sampled_nodes ==\n                expected_num_sampled_nodes[node_type])\n    for edge_type in batch.edge_types:\n        assert (batch[edge_type].num_sampled_edges ==\n                expected_num_sampled_edges[edge_type])\n\n\n@withPackage('pyg_lib')\ndef test_neighbor_loader_mapping():\n    edge_index = torch.tensor([\n        [0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 5],\n        [1, 2, 3, 4, 5, 8, 6, 7, 9, 10, 6, 11],\n    ])\n    data = Data(edge_index=edge_index, num_nodes=12)\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[1],\n        batch_size=2,\n        shuffle=True,\n    )\n\n    for batch in loader:\n        assert torch.equal(\n            batch.n_id[batch.edge_index],\n            data.edge_index[:, batch.e_id],\n        )\n\n\n@pytest.mark.skipif(\n    not WITH_WEIGHTED_NEIGHBOR_SAMPLE,\n    reason=\"'pyg-lib' does not support weighted neighbor sampling\",\n)\ndef test_weighted_homo_neighbor_loader():\n    edge_index = torch.tensor([\n        [1, 3, 0, 4],\n        [2, 2, 1, 3],\n    ])\n    edge_weight = torch.tensor([0.0, 1.0, 0.0, 1.0])\n\n    data = Data(num_nodes=5, edge_index=edge_index, edge_weight=edge_weight)\n\n    loader = NeighborLoader(\n        data,\n        input_nodes=torch.tensor([2]),\n        num_neighbors=[1] * 2,\n        batch_size=1,\n        weight_attr='edge_weight',\n    )\n    assert len(loader) == 1\n\n    batch = next(iter(loader))\n\n    assert batch.num_nodes == 3\n    assert batch.n_id.tolist() == [2, 3, 4]\n    assert batch.num_edges == 2\n    assert batch.n_id[batch.edge_index].tolist() == [[3, 4], [2, 3]]\n\n\n@pytest.mark.skipif(\n    not WITH_WEIGHTED_NEIGHBOR_SAMPLE,\n    reason=\"'pyg-lib' does not support weighted neighbor sampling\",\n)\ndef test_weighted_hetero_neighbor_loader():\n    edge_index = torch.tensor([\n        [1, 3, 0, 4],\n        [2, 2, 1, 3],\n    ])\n    edge_weight = torch.tensor([0.0, 1.0, 0.0, 1.0])\n\n    data = HeteroData()\n    data['paper'].num_nodes = 5\n    data['paper', 'to', 'paper'].edge_index = edge_index\n    data['paper', 'to', 'paper'].edge_weight = edge_weight\n\n    loader = NeighborLoader(\n        data,\n        input_nodes=('paper', torch.tensor([2])),\n        num_neighbors=[1] * 2,\n        batch_size=1,\n        weight_attr='edge_weight',\n    )\n    assert len(loader) == 1\n\n    batch = next(iter(loader))\n\n    assert batch['paper'].num_nodes == 3\n    assert batch['paper'].n_id.tolist() == [2, 3, 4]\n    assert batch['paper', 'paper'].num_edges == 2\n    global_edge_index = batch['paper'].n_id[batch['paper', 'paper'].edge_index]\n    assert global_edge_index.tolist() == [[3, 4], [2, 3]]\n\n\n@pytest.mark.skipif(\n    not WITH_EDGE_TIME_NEIGHBOR_SAMPLE,\n    reason=\"'pyg-lib' does not support weighted neighbor sampling\",\n)\ndef test_edge_level_temporal_homo_neighbor_loader():\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4],\n        [1, 0, 2, 1, 3, 2, 4, 3],\n    ])\n    edge_time = torch.arange(edge_index.size(1))\n\n    data = Data(edge_index=edge_index, edge_time=edge_time, num_nodes=5)\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[-1, -1],\n        input_time=torch.tensor([4, 4, 4, 4, 4]),\n        time_attr='edge_time',\n        batch_size=1,\n    )\n\n    for batch in loader:\n        assert batch.edge_time.numel() == batch.num_edges\n        if batch.edge_time.numel() > 0:\n            assert batch.edge_time.max() <= 4\n\n\n@pytest.mark.skipif(\n    not WITH_EDGE_TIME_NEIGHBOR_SAMPLE,\n    reason=\"'pyg-lib' does not support weighted neighbor sampling\",\n)\ndef test_edge_level_temporal_hetero_neighbor_loader():\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4],\n        [1, 0, 2, 1, 3, 2, 4, 3],\n    ])\n    edge_time = torch.arange(edge_index.size(1))\n\n    data = HeteroData()\n    data['A'].num_nodes = 5\n    data['A', 'A'].edge_index = edge_index\n    data['A', 'A'].edge_time = edge_time\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[-1, -1],\n        input_nodes='A',\n        input_time=torch.tensor([4, 4, 4, 4, 4]),\n        time_attr='edge_time',\n        batch_size=1,\n    )\n\n    for batch in loader:\n        assert batch['A', 'A'].edge_time.numel() == batch['A', 'A'].num_edges\n        if batch['A', 'A'].edge_time.numel() > 0:\n            assert batch['A', 'A'].edge_time.max() <= 4\n\n\n@withCUDA\n@onlyNeighborSampler\n@withPackage('torch_frame')\ndef test_neighbor_loader_with_tensor_frame(device):\n    data = Data()\n    data.tf = get_random_tensor_frame(num_rows=100, device=device)\n    data.edge_index = get_random_edge_index(100, 100, 500, device=device)\n    data.edge_attr = get_random_tensor_frame(500, device=device)\n    data.global_tf = get_random_tensor_frame(num_rows=1, device=device)\n\n    loader = NeighborLoader(data, num_neighbors=[5] * 2, batch_size=20)\n    assert len(loader) == 5\n\n    for batch in loader:\n        assert isinstance(batch.tf, TensorFrame)\n        assert batch.tf.device == device\n        assert batch.tf.num_rows == batch.n_id.numel()\n        assert batch.tf == data.tf[batch.n_id]\n\n        assert isinstance(batch.edge_attr, TensorFrame)\n        assert batch.edge_attr.device == device\n        assert batch.edge_attr.num_rows == batch.e_id.numel()\n        assert batch.edge_attr == data.edge_attr[batch.e_id]\n\n        assert isinstance(batch.global_tf, TensorFrame)\n        assert batch.global_tf.device == device\n        assert batch.global_tf.num_rows == 1\n        assert batch.global_tf == data.global_tf\n\n\n@onlyNeighborSampler\ndef test_neighbor_loader_input_id():\n    data = HeteroData()\n    data['a'].num_nodes = 10\n    data['b'].num_nodes = 12\n\n    row = torch.randint(0, data['a'].num_nodes, (40, ))\n    col = torch.randint(0, data['b'].num_nodes, (40, ))\n    data['a', 'b'].edge_index = torch.stack([row, col], dim=0)\n    data['b', 'a'].edge_index = torch.stack([col, row], dim=0)\n\n    mask = torch.ones(data['a'].num_nodes, dtype=torch.bool)\n    mask[0] = False\n\n    loader = NeighborLoader(\n        data,\n        input_nodes=('a', mask),\n        batch_size=2,\n        num_neighbors=[2, 2],\n    )\n    for i, batch in enumerate(loader):\n        if i < 4:\n            expected = [(2 * i) + 1, (2 * i) + 2]\n        else:\n            expected = [(2 * i) + 1]\n\n        assert batch['a'].input_id.tolist() == expected\n\n\n@withPackage('pyg_lib')\ndef test_temporal_neighbor_loader_single_link():\n    data = HeteroData()\n    data['a'].x = torch.arange(10)\n    data['b'].x = torch.arange(10)\n    data['c'].x = torch.arange(10)\n\n    data['b'].time = torch.arange(0, 10)\n    data['c'].time = torch.arange(1, 11)\n\n    data['a', 'b'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1)\n    data['b', 'a'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1)\n    data['a', 'c'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1)\n    data['c', 'a'].edge_index = torch.arange(10).view(1, -1).repeat(2, 1)\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[-1],\n        input_nodes='a',\n        time_attr='time',\n        input_time=torch.arange(0, 10),\n        batch_size=10,\n    )\n    batch = next(iter(loader))\n    assert batch['a'].num_nodes == 10\n    assert batch['b'].num_nodes == 10\n    assert batch['c'].num_nodes == 0\n"
  },
  {
    "path": "test/loader/test_neighbor_sampler.py",
    "content": "import numpy as np\nimport torch\n\nfrom torch_geometric.loader import NeighborSampler\nfrom torch_geometric.nn.conv import GATConv, SAGEConv\nfrom torch_geometric.testing import onlyOnline, withPackage\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import erdos_renyi_graph\n\n\n@withPackage('torch_sparse')\ndef test_neighbor_sampler_basic():\n    edge_index = erdos_renyi_graph(num_nodes=10, edge_prob=0.5)\n    adj_t = SparseTensor.from_edge_index(edge_index, sparse_sizes=(10, 10)).t()\n    E = edge_index.size(1)\n\n    loader = NeighborSampler(edge_index, sizes=[2, 4], batch_size=2)\n    assert str(loader) == 'NeighborSampler(sizes=[2, 4])'\n    assert len(loader) == 5\n\n    for batch_size, n_id, adjs in loader:\n        assert batch_size == 2\n        assert all(np.isin(n_id, torch.arange(10)).tolist())\n        assert n_id.unique().size(0) == n_id.size(0)\n        for (edge_index, e_id, size) in adjs:\n            assert int(edge_index[0].max() + 1) <= size[0]\n            assert int(edge_index[1].max() + 1) <= size[1]\n            assert all(np.isin(e_id, torch.arange(E)).tolist())\n            assert e_id.unique().size(0) == e_id.size(0)\n            assert size[0] >= size[1]\n\n    out = loader.sample([1, 2])\n    assert len(out) == 3\n\n    loader = NeighborSampler(adj_t, sizes=[2, 4], batch_size=2)\n\n    for _, _, adjs in loader:\n        for adj_t, _, size in adjs:\n            assert adj_t.size(0) == size[1]\n            assert adj_t.size(1) == size[0]\n\n\n@withPackage('torch_sparse')\ndef test_neighbor_sampler_invalid_kwargs():\n    # Ignore `collate_fn` and `dataset` arguments:\n    edge_index = torch.tensor([[0, 1], [1, 0]])\n    NeighborSampler(edge_index, sizes=[-1], collate_fn=None, dataset=None)\n\n\n@onlyOnline\n@withPackage('torch_sparse')\ndef test_neighbor_sampler_on_cora(get_dataset):\n    dataset = get_dataset(name='Cora')\n    data = dataset[0]\n\n    batch = torch.arange(10)\n    loader = NeighborSampler(data.edge_index, sizes=[-1, -1, -1],\n                             node_idx=batch, batch_size=10)\n\n    class SAGE(torch.nn.Module):\n        def __init__(self, in_channels, out_channels):\n            super().__init__()\n\n            self.convs = torch.nn.ModuleList()\n            self.convs.append(SAGEConv(in_channels, 16))\n            self.convs.append(SAGEConv(16, 16))\n            self.convs.append(SAGEConv(16, out_channels))\n\n        def batch(self, x, adjs):\n            for i, (edge_index, _, size) in enumerate(adjs):\n                x_target = x[:size[1]]  # Target nodes are always placed first.\n                x = self.convs[i]((x, x_target), edge_index)\n            return x\n\n        def full(self, x, edge_index):\n            for conv in self.convs:\n                x = conv(x, edge_index)\n            return x\n\n    model = SAGE(dataset.num_features, dataset.num_classes)\n\n    _, n_id, adjs = next(iter(loader))\n    out1 = model.batch(data.x[n_id], adjs)\n    out2 = model.full(data.x, data.edge_index)[batch]\n    assert torch.allclose(out1, out2, atol=1e-7)\n\n    class GAT(torch.nn.Module):\n        def __init__(self, in_channels, out_channels):\n            super().__init__()\n\n            self.convs = torch.nn.ModuleList()\n            self.convs.append(GATConv(in_channels, 16, heads=2))\n            self.convs.append(GATConv(32, 16, heads=2))\n            self.convs.append(GATConv(32, out_channels, heads=2, concat=False))\n\n        def batch(self, x, adjs):\n            for i, (edge_index, _, size) in enumerate(adjs):\n                x_target = x[:size[1]]  # Target nodes are always placed first.\n                x = self.convs[i]((x, x_target), edge_index)\n            return x\n\n        def full(self, x, edge_index):\n            for conv in self.convs:\n                x = conv(x, edge_index)\n            return x\n\n    _, n_id, adjs = next(iter(loader))\n    out1 = model.batch(data.x[n_id], adjs)\n    out2 = model.full(data.x, data.edge_index)[batch]\n    assert torch.allclose(out1, out2, atol=1e-7)\n"
  },
  {
    "path": "test/loader/test_prefetch.py",
    "content": "import torch\n\nfrom torch_geometric.loader import NeighborLoader, PrefetchLoader\nfrom torch_geometric.nn import GraphSAGE\nfrom torch_geometric.testing import withCUDA\n\n\n@withCUDA\ndef test_prefetch_loader(device):\n    data = [torch.randn(5, 5) for _ in range(10)]\n\n    loader = PrefetchLoader(data, device=device)\n    assert str(loader).startswith('PrefetchLoader')\n    assert len(loader) == 10\n\n    for i, batch in enumerate(loader):\n        assert batch.device == device\n        assert torch.equal(batch.cpu(), data[i])\n\n\nif __name__ == '__main__':\n    import argparse\n\n    from ogb.nodeproppred import PygNodePropPredDataset\n    from tqdm import tqdm\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--num_workers', type=int, default=0)\n    args = parser.parse_args()\n\n    data = PygNodePropPredDataset('ogbn-products', root='/tmp/ogb')[0]\n\n    model = GraphSAGE(\n        in_channels=data.x.size(-1),\n        hidden_channels=64,\n        num_layers=2,\n    ).cuda()\n\n    loader = NeighborLoader(\n        data,\n        input_nodes=torch.arange(1024 * 200),\n        batch_size=1024,\n        num_neighbors=[10, 10],\n        num_workers=args.num_workers,\n        persistent_workers=args.num_workers > 0,\n    )\n\n    print('Forward pass without prefetching...')\n    for batch in tqdm(loader):\n        with torch.no_grad():\n            batch = batch.cuda()\n            model(batch.x, batch.edge_index)\n\n    print('Forward pass with prefetching...')\n    for batch in tqdm(PrefetchLoader(loader)):\n        with torch.no_grad():\n            model(batch.x, batch.edge_index)\n"
  },
  {
    "path": "test/loader/test_random_node_loader.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.loader import RandomNodeLoader\nfrom torch_geometric.testing import get_random_edge_index\n\n\ndef test_random_node_loader():\n    data = Data()\n    data.x = torch.randn(100, 128)\n    data.node_id = torch.arange(100)\n    data.edge_index = get_random_edge_index(100, 100, 500)\n    data.edge_attr = torch.randn(500, 32)\n\n    loader = RandomNodeLoader(data, num_parts=4, shuffle=True)\n    assert len(loader) == 4\n\n    for batch in loader:\n        assert len(batch) == 4\n        assert batch.node_id.min() >= 0\n        assert batch.node_id.max() < 100\n        assert batch.edge_index.size(1) == batch.edge_attr.size(0)\n        assert torch.allclose(batch.x, data.x[batch.node_id])\n        batch.validate()\n\n\ndef test_heterogeneous_random_node_loader():\n    data = HeteroData()\n    data['paper'].x = torch.randn(100, 128)\n    data['paper'].node_id = torch.arange(100)\n    data['author'].x = torch.randn(200, 128)\n    data['author'].node_id = torch.arange(200)\n    data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 500)\n    data['paper', 'author'].edge_attr = torch.randn(500, 32)\n    data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 400)\n    data['author', 'paper'].edge_attr = torch.randn(400, 32)\n    data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 600)\n    data['paper', 'paper'].edge_attr = torch.randn(600, 32)\n\n    loader = RandomNodeLoader(data, num_parts=4, shuffle=True)\n    assert len(loader) == 4\n\n    for batch in loader:\n        assert len(batch) == 4\n        assert batch.node_types == data.node_types\n        assert batch.edge_types == data.edge_types\n        batch.validate()\n"
  },
  {
    "path": "test/loader/test_shadow.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import ShaDowKHopSampler\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.typing import SparseTensor\n\n\n@withPackage('torch_sparse')\ndef test_shadow_k_hop_sampler():\n    row = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5])\n    col = torch.tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4])\n    edge_index = torch.stack([row, col], dim=0)\n    edge_weight = torch.arange(row.size(0))\n    x = torch.randn(6, 16)\n    y = torch.randint(3, (6, ), dtype=torch.long)\n    data = Data(edge_index=edge_index, edge_weight=edge_weight, x=x, y=y)\n\n    train_mask = torch.tensor([1, 1, 0, 0, 0, 0], dtype=torch.bool)\n    loader = ShaDowKHopSampler(data, depth=1, num_neighbors=3,\n                               node_idx=train_mask, batch_size=2)\n    assert len(loader) == 1\n\n    batch1 = next(iter(loader))\n    assert batch1.num_graphs == len(batch1) == 2\n\n    assert batch1.batch.tolist() == [0, 0, 0, 0, 1, 1, 1]\n    assert batch1.ptr.tolist() == [0, 4, 7]\n    assert batch1.root_n_id.tolist() == [0, 5]\n    assert batch1.x.tolist() == x[torch.tensor([0, 1, 2, 3, 0, 1, 2])].tolist()\n    assert batch1.y.tolist() == y[train_mask].tolist()\n    row, col = batch1.edge_index\n    assert row.tolist() == [0, 0, 0, 1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6]\n    assert col.tolist() == [1, 2, 3, 0, 2, 0, 1, 0, 5, 6, 4, 6, 4, 5]\n    e_id = torch.tensor([0, 1, 2, 3, 4, 5, 6, 9, 0, 1, 3, 4, 5, 6])\n    assert batch1.edge_weight.tolist() == edge_weight[e_id].tolist()\n\n    adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],\n                         value=edge_weight).t()\n    data = Data(adj_t=adj_t, x=x, y=y)\n\n    loader = ShaDowKHopSampler(data, depth=1, num_neighbors=3,\n                               node_idx=train_mask, batch_size=2)\n    assert len(loader) == 1\n\n    batch2 = next(iter(loader))\n    assert batch2.num_graphs == len(batch2) == 2\n\n    assert batch1.batch.tolist() == batch2.batch.tolist()\n    assert batch1.ptr.tolist() == batch2.ptr.tolist()\n    assert batch1.root_n_id.tolist() == batch2.root_n_id.tolist()\n    assert batch1.x.tolist() == batch2.x.tolist()\n    assert batch1.y.tolist() == batch2.y.tolist()\n    row, col, value = batch2.adj_t.t().coo()\n    assert batch1.edge_index[0].tolist() == row.tolist()\n    assert batch1.edge_index[1].tolist() == col.tolist()\n    assert batch1.edge_weight.tolist() == value.tolist()\n"
  },
  {
    "path": "test/loader/test_temporal_dataloader.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import TemporalData\nfrom torch_geometric.loader import TemporalDataLoader\n\n\n@pytest.mark.parametrize('batch_size,drop_last', [(4, True), (2, False)])\ndef test_temporal_dataloader(batch_size, drop_last):\n    src = dst = t = torch.arange(10)\n    msg = torch.randn(10, 16)\n\n    data = TemporalData(src=src, dst=dst, t=t, msg=msg)\n\n    loader = TemporalDataLoader(\n        data,\n        batch_size=batch_size,\n        drop_last=drop_last,\n    )\n    assert len(loader) == 10 // batch_size\n\n    for i, batch in enumerate(loader):\n        assert len(batch) == batch_size\n        arange = range(len(batch) * i, len(batch) * i + len(batch))\n        assert batch.src.tolist() == data.src[arange].tolist()\n        assert batch.dst.tolist() == data.dst[arange].tolist()\n        assert batch.t.tolist() == data.t[arange].tolist()\n        assert batch.msg.tolist() == data.msg[arange].tolist()\n"
  },
  {
    "path": "test/loader/test_utils.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.loader.utils import index_select\n\n\ndef test_index_select():\n    x = torch.randn(3, 5)\n    index = torch.tensor([0, 2])\n    assert torch.equal(index_select(x, index), x[index])\n    assert torch.equal(index_select(x, index, dim=-1), x[..., index])\n\n\ndef test_index_select_out_of_range():\n    with pytest.raises(IndexError, match=\"out of range\"):\n        index_select(torch.randn(3, 5), torch.tensor([0, 2, 3]))\n"
  },
  {
    "path": "test/loader/test_zip_loader.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import NeighborLoader, ZipLoader\nfrom torch_geometric.testing import onlyNeighborSampler\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('filter_per_worker', [True, False])\ndef test_zip_loader(filter_per_worker):\n    x = torch.arange(100)\n    edge_index = torch.randint(0, 100, (2, 1000))\n    data = Data(x=x, edge_index=edge_index)\n\n    loaders = [\n        NeighborLoader(data, [5], input_nodes=torch.arange(0, 50)),\n        NeighborLoader(data, [5], input_nodes=torch.arange(50, 95)),\n    ]\n\n    loader = ZipLoader(loaders, batch_size=10,\n                       filter_per_worker=filter_per_worker)\n\n    batches = loader(torch.arange(5))\n    assert isinstance(batches, tuple)\n    assert len(batches) == 2\n\n    assert str(loader) == ('ZipLoader(loaders=[NeighborLoader(), '\n                           'NeighborLoader()])')\n    assert len(loader) == 5\n    assert loader.dataset == range(0, 45)\n\n    for i, (batch1, batch2) in enumerate(loader):\n        n_id1 = batch1.n_id[:batch1.batch_size]\n        n_id2 = batch2.n_id[:batch2.batch_size]\n\n        if i < 4:\n            assert batch1.batch_size == 10\n            assert batch2.batch_size == 10\n            assert torch.equal(n_id1, torch.arange(0 + i * 10, 10 + i * 10))\n            assert torch.equal(n_id2, torch.arange(50 + i * 10, 60 + i * 10))\n        else:\n            assert batch1.batch_size == 5\n            assert batch2.batch_size == 5\n            assert torch.equal(n_id1, torch.arange(0 + i * 10, 5 + i * 10))\n            assert torch.equal(n_id2, torch.arange(50 + i * 10, 55 + i * 10))\n"
  },
  {
    "path": "test/metrics/test_link_pred_metric.py",
    "content": "from typing import List\n\nimport pytest\nimport torch\n\nfrom torch_geometric.metrics import (\n    LinkPredAveragePopularity,\n    LinkPredCoverage,\n    LinkPredDiversity,\n    LinkPredF1,\n    LinkPredHitRatio,\n    LinkPredMAP,\n    LinkPredMetricCollection,\n    LinkPredMRR,\n    LinkPredNDCG,\n    LinkPredPersonalization,\n    LinkPredPrecision,\n    LinkPredRecall,\n)\nfrom torch_geometric.testing import withCUDA\n\n\n@pytest.mark.parametrize('num_src_nodes', [100])\n@pytest.mark.parametrize('num_dst_nodes', [1000])\n@pytest.mark.parametrize('num_edges', [3000])\n@pytest.mark.parametrize('batch_size', [32])\n@pytest.mark.parametrize('k', [1, 10, 100])\ndef test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k):\n    row = torch.randint(0, num_src_nodes, (num_edges, ))\n    col = torch.randint(0, num_dst_nodes, (num_edges, ))\n    edge_label_index = torch.stack([row, col], dim=0)\n\n    pred = torch.rand(num_src_nodes, num_dst_nodes)\n    pred[row, col] += 0.3  # Offset positive links by a little.\n    pred_index_mat = pred.topk(k, dim=1)[1]\n\n    metric = LinkPredPrecision(k)\n    assert str(metric) == f'LinkPredPrecision(k={k})'\n\n    for node_id in torch.split(torch.randperm(num_src_nodes), batch_size):\n        mask = torch.isin(edge_label_index[0], node_id)\n\n        y_batch, y_index = edge_label_index[:, mask]\n        # Remap `y_batch` back to `[0, batch_size - 1]` range:\n        arange = torch.empty(num_src_nodes, dtype=node_id.dtype)\n        arange[node_id] = torch.arange(node_id.numel())\n        y_batch = arange[y_batch]\n\n        metric.update(pred_index_mat[node_id], (y_batch, y_index))\n\n    out = metric.compute()\n    metric.reset()\n\n    values: List[float] = []\n    for i in range(num_src_nodes):  # Naive computation per node:\n        y_index = col[row == i]\n        if y_index.numel() > 0:\n            mask = torch.isin(pred_index_mat[i], y_index)\n            precision = float(mask.sum() / k)\n            values.append(precision)\n    expected = torch.tensor(values).mean()\n    assert torch.allclose(out, expected)\n\n    # Test with `k > pred_index_mat.size(1)`:\n    metric.update(pred_index_mat[:, :k - 1], edge_label_index)\n    metric.compute()\n    metric.reset()\n\n\ndef test_recall():\n    pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])\n    edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]])\n    edge_label_weight = torch.tensor([4.0, 1.0, 2.0, 3.0, 0.5])\n\n    metric = LinkPredRecall(k=2)\n    assert str(metric) == 'LinkPredRecall(k=2)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n    assert float(result) == pytest.approx(0.5 * (2 / 3 + 0.5))\n\n    # Test with `k > pred_index_mat.size(1)`:\n    metric.update(pred_index_mat[:, :1], edge_label_index)\n    metric.compute()\n    metric.reset()\n\n    metric = LinkPredRecall(k=2, weighted=True)\n    assert str(metric) == 'LinkPredRecall(k=2, weighted=True)'\n    with pytest.raises(ValueError, match=\"'edge_label_weight'\"):\n        metric.update(pred_index_mat, edge_label_index)\n\n    metric.update(pred_index_mat, edge_label_index, edge_label_weight)\n    result = metric.compute()\n    metric.reset()\n    assert float(result) == pytest.approx(0.5 * (5.0 / 7.0 + 3.0 / 3.5))\n\n    edge_label_weight[0] = -2\n    metric.update(pred_index_mat, edge_label_index, edge_label_weight)\n    result = metric.compute()\n    metric.reset()\n    assert float(result) == pytest.approx(0.5 * (1.0 / 3.0 + 3.0 / 3.5))\n\n    # Test with `k > pred_index_mat.size(1)`:\n    metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight)\n    metric.compute()\n    metric.reset()\n\n\ndef test_f1():\n    pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])\n    edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]])\n\n    metric = LinkPredF1(k=2)\n    assert str(metric) == 'LinkPredF1(k=2)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n    assert float(result) == pytest.approx(0.6500)\n\n    # Test with `k > pred_index_mat.size(1)`:\n    metric.update(pred_index_mat[:, :1], edge_label_index)\n    metric.compute()\n    metric.reset()\n\n\ndef test_map():\n    pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])\n    edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]])\n\n    metric = LinkPredMAP(k=2)\n    assert str(metric) == 'LinkPredMAP(k=2)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n    assert float(result) == pytest.approx(0.6250)\n\n    # Test with `k > pred_index_mat.size(1)`:\n    metric.update(pred_index_mat[:, :1], edge_label_index)\n    metric.compute()\n    metric.reset()\n\n\ndef test_ndcg():\n    pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2]])\n    edge_label_index = torch.tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]])\n    edge_label_weight = torch.tensor([1.0, 2.0, 0.1, 3.0, 0.5])\n\n    metric = LinkPredNDCG(k=2)\n    assert str(metric) == 'LinkPredNDCG(k=2)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n    assert float(result) == pytest.approx(0.6934264)\n\n    # Test with `k > pred_index_mat.size(1)`:\n    metric.update(pred_index_mat[:, :1], edge_label_index)\n    metric.compute()\n    metric.reset()\n\n    metric = LinkPredNDCG(k=2, weighted=True)\n    assert str(metric) == 'LinkPredNDCG(k=2, weighted=True)'\n    with pytest.raises(ValueError, match=\"'edge_label_weight'\"):\n        metric.update(pred_index_mat, edge_label_index)\n\n    metric.update(pred_index_mat, edge_label_index, edge_label_weight)\n    result = metric.compute()\n    metric.reset()\n    assert float(result) == pytest.approx(0.7854486)\n\n    perm = torch.randperm(edge_label_weight.size(0))\n    metric.update(pred_index_mat, edge_label_index[:, perm],\n                  edge_label_weight[perm])\n    assert metric.compute() == result\n\n    # Test with `k > pred_index_mat.size(1)`:\n    metric.update(pred_index_mat[:, :1], edge_label_index, edge_label_weight)\n    metric.compute()\n    metric.reset()\n\n\ndef test_mrr():\n    pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]])\n    edge_label_index = torch.tensor([[0, 0, 2, 2, 3], [0, 1, 2, 1, 2]])\n\n    metric = LinkPredMRR(k=2)\n    assert str(metric) == 'LinkPredMRR(k=2)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n\n    assert float(result) == pytest.approx((1 + 0.5 + 0) / 3)\n\n    # Test with `k > pred_index_mat.size(1)`:\n    metric.update(pred_index_mat[:, :1], edge_label_index)\n    metric.compute()\n    metric.reset()\n\n\ndef test_hit_ratio():\n    pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]])\n    edge_label_index = torch.tensor([[0, 0, 2, 2, 3], [0, 1, 2, 1, 2]])\n\n    metric = LinkPredHitRatio(k=2)\n    assert str(metric) == 'LinkPredHitRatio(k=2)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n\n    assert float(result) == pytest.approx(2 / 3)\n\n    # Test with `k > pred_index_mat.size(1)`:\n    metric.update(pred_index_mat[:, :1], edge_label_index)\n    metric.compute()\n    metric.reset()\n\n\ndef test_coverage():\n    pred_index_mat = torch.tensor([[1, 0], [1, 2], [0, 2], [0, 1]])\n    edge_label_index = torch.empty(2, 0, dtype=torch.long)\n\n    metric = LinkPredCoverage(k=2, num_dst_nodes=3)\n    assert str(metric) == 'LinkPredCoverage(k=2, num_dst_nodes=3)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n    metric.reset()\n    assert metric.mask.sum() == 0\n\n    assert float(result) == 1.0\n\n    metric = LinkPredCoverage(k=1, num_dst_nodes=4)\n    assert str(metric) == 'LinkPredCoverage(k=1, num_dst_nodes=4)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n    metric.reset()\n    assert metric.mask.sum() == 0\n\n    assert float(result) == 2 / 4\n\n\ndef test_diversity():\n    pred_index_mat = torch.tensor([[0, 1, 2], [3, 1, 0]])\n    category = torch.tensor([0, 1, 2, 0])\n    edge_label_index = torch.empty(2, 0, dtype=torch.long)\n\n    metric = LinkPredDiversity(k=3, category=category)\n    assert str(metric) == 'LinkPredDiversity(k=3)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n    metric.reset()\n\n    assert pytest.approx(float(result)) == (1 + 2 / 3) / 2\n\n\n@withCUDA\ndef test_personalization(device):\n    pred_index_mat = torch.tensor([[0, 1, 2, 3], [2, 1, 0, 4], [1, 0, 2, 5]],\n                                  device=device)\n    edge_label_index = torch.empty(2, 0, dtype=torch.long, device=device)\n\n    metric = LinkPredPersonalization(k=4).to(device)\n    assert str(metric) == 'LinkPredPersonalization(k=4)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n    assert result.device == device\n    assert float(result) == 0.25\n    metric.reset()\n    assert metric.preds == []\n\n    metric.update(pred_index_mat[:0], edge_label_index)\n    result = metric.compute()\n    assert result.device == device\n    assert float(result) == 0.0\n    metric.reset()\n\n\ndef test_average_popularity():\n    pred_index_mat = torch.tensor([[0, 1, 2], [3, 1, 0]])\n    popularity = torch.tensor([10, 5, 2, 1])\n    edge_label_index = torch.empty(2, 0, dtype=torch.long)\n\n    metric = LinkPredAveragePopularity(k=3, popularity=popularity)\n    assert str(metric) == 'LinkPredAveragePopularity(k=3)'\n    metric.update(pred_index_mat, edge_label_index)\n    result = metric.compute()\n    metric.reset()\n\n    assert pytest.approx(float(result)) == (10 + 5 + 2 + 1 + 5 + 10) / 6\n\n\n@pytest.mark.parametrize('num_src_nodes', [10])\n@pytest.mark.parametrize('num_dst_nodes', [50])\n@pytest.mark.parametrize('num_edges', [200])\ndef test_metric_collection(num_src_nodes, num_dst_nodes, num_edges):\n    metrics = [\n        LinkPredMAP(k=10),\n        LinkPredPrecision(k=100),\n        LinkPredRecall(k=50),\n        LinkPredF1(k=20),\n        LinkPredMRR(k=40),\n        LinkPredNDCG(k=80),\n        LinkPredCoverage(k=5, num_dst_nodes=num_dst_nodes),\n    ]\n\n    row = torch.randint(0, num_src_nodes, (num_edges, ))\n    col = torch.randint(0, num_dst_nodes, (num_edges, ))\n    edge_label_index = torch.stack([row, col], dim=0)\n\n    pred = torch.rand(num_src_nodes, num_dst_nodes)\n    pred[row, col] += 0.3  # Offset positive links by a little.\n    pred_index_mat = pred.argsort(dim=1)\n\n    metric_collection = LinkPredMetricCollection(metrics)\n    assert str(metric_collection) == (\n        'LinkPredMetricCollection([\\n'\n        '  LinkPredMAP@10: LinkPredMAP(k=10),\\n'\n        '  LinkPredPrecision@100: LinkPredPrecision(k=100),\\n'\n        '  LinkPredRecall@50: LinkPredRecall(k=50),\\n'\n        '  LinkPredF1@20: LinkPredF1(k=20),\\n'\n        '  LinkPredMRR@40: LinkPredMRR(k=40),\\n'\n        '  LinkPredNDCG@80: LinkPredNDCG(k=80),\\n'\n        '  LinkPredCoverage@5: LinkPredCoverage(k=5, num_dst_nodes=50),\\n'\n        '])')\n    assert metric_collection.max_k == 100\n\n    expected = {}\n    for metric in metrics:\n        metric.update(pred_index_mat[:, :metric.k], edge_label_index)\n        out = metric.compute()\n        expected[f'{metric.__class__.__name__}@{metric.k}'] = out\n        metric.reset()\n\n    metric_collection.update(pred_index_mat, edge_label_index)\n    assert metric_collection.compute() == expected\n    metric_collection.reset()\n\n\ndef test_empty_ground_truth():\n    pred = torch.rand(10, 5)\n    pred_index_mat = pred.argsort(dim=1)\n    edge_label_index = torch.empty(2, 0, dtype=torch.long)\n    edge_label_weight = torch.empty(0)\n\n    metric = LinkPredMAP(k=5)\n    metric.update(pred_index_mat, edge_label_index)\n    assert metric.compute() == 0\n    metric.reset()\n\n    metric = LinkPredNDCG(k=5, weighted=True)\n    metric.update(pred_index_mat, edge_label_index, edge_label_weight)\n    assert metric.compute() == 0\n    metric.reset()\n"
  },
  {
    "path": "test/my_config.yaml",
    "content": "defaults:\n  - dataset: KarateClub\n  - transform@dataset.transform:\n      - NormalizeFeatures\n      - AddSelfLoops\n  - model: GCN\n  - optimizer: Adam\n  - lr_scheduler: ReduceLROnPlateau\n  - _self_\n\nmodel:\n  in_channels: 34\n  out_channels: 4\n  hidden_channels: 16\n  num_layers: 2\n"
  },
  {
    "path": "test/nn/aggr/test_aggr_utils.py",
    "content": "import torch\n\nfrom torch_geometric.nn.aggr.utils import (\n    InducedSetAttentionBlock,\n    MultiheadAttentionBlock,\n    PoolingByMultiheadAttention,\n    SetAttentionBlock,\n)\nfrom torch_geometric.testing import withCUDA\n\n\n@withCUDA\ndef test_multihead_attention_block(device: torch.device):\n    x = torch.randn(2, 4, 8, device=device)\n    y = torch.randn(2, 3, 8, device=device)\n    x_mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=torch.bool,\n                          device=device)\n    y_mask = torch.tensor([[1, 1, 0], [1, 1, 1]], dtype=torch.bool,\n                          device=device)\n\n    block = MultiheadAttentionBlock(8, heads=2, device=device)\n    block.reset_parameters()\n    assert str(block) == ('MultiheadAttentionBlock(8, heads=2, '\n                          'layer_norm=True, dropout=0.0)')\n\n    out = block(x, y, x_mask, y_mask)\n    assert out.size() == (2, 4, 8)\n\n    jit = torch.jit.script(block)\n    assert torch.allclose(jit(x, y, x_mask, y_mask), out)\n\n\n@withCUDA\ndef test_multihead_attention_block_dropout(device: torch.device):\n    x = torch.randn(2, 4, 8, device=device)\n\n    block = MultiheadAttentionBlock(8, dropout=0.5, device=device)\n    assert not torch.allclose(block(x, x), block(x, x))\n\n\ndef test_set_attention_block():\n    x = torch.randn(2, 4, 8)\n    mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=torch.bool)\n\n    block = SetAttentionBlock(8, heads=2)\n    block.reset_parameters()\n    assert str(block) == ('SetAttentionBlock(8, heads=2, layer_norm=True, '\n                          'dropout=0.0)')\n\n    out = block(x, mask)\n    assert out.size() == (2, 4, 8)\n\n    jit = torch.jit.script(block)\n    assert torch.allclose(jit(x, mask), out)\n\n\ndef test_induced_set_attention_block():\n    x = torch.randn(2, 4, 8)\n    mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=torch.bool)\n\n    block = InducedSetAttentionBlock(8, num_induced_points=2, heads=2)\n    assert str(block) == ('InducedSetAttentionBlock(8, num_induced_points=2, '\n                          'heads=2, layer_norm=True, dropout=0.0)')\n\n    out = block(x, mask)\n    assert out.size() == (2, 4, 8)\n\n    jit = torch.jit.script(block)\n    assert torch.allclose(jit(x, mask), out)\n\n\ndef test_pooling_by_multihead_attention():\n    x = torch.randn(2, 4, 8)\n    mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]], dtype=torch.bool)\n\n    block = PoolingByMultiheadAttention(8, num_seed_points=2, heads=2)\n    assert str(block) == ('PoolingByMultiheadAttention(8, num_seed_points=2, '\n                          'heads=2, layer_norm=True, dropout=0.0)')\n\n    out = block(x, mask)\n    assert out.size() == (2, 2, 8)\n\n    jit = torch.jit.script(block)\n    assert torch.allclose(jit(x, mask), out)\n"
  },
  {
    "path": "test/nn/aggr/test_attention.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import MLP\nfrom torch_geometric.nn.aggr import AttentionalAggregation\n\n\n@pytest.mark.parametrize('dim', [2, 3])\ndef test_attentional_aggregation(dim):\n    channels = 16\n    x = torch.randn(6, channels) if dim == 2 else torch.randn(2, 6, channels)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n    ptr = torch.tensor([0, 2, 5, 6])\n\n    gate_nn = MLP([channels, 1], act='relu')\n    nn = MLP([channels, channels], act='relu')\n    aggr = AttentionalAggregation(gate_nn, nn)\n    aggr.reset_parameters()\n    assert str(aggr) == (f'AttentionalAggregation(gate_nn=MLP({channels}, 1), '\n                         f'nn=MLP({channels}, {channels}))')\n\n    out = aggr(x, index)\n    assert out.size() == (3, channels) if dim == 2 else (2, 3, channels)\n\n    if (not torch_geometric.typing.WITH_TORCH_SCATTER\n            and (dim == 3 or not torch_geometric.typing.WITH_PT20)):\n        with pytest.raises(ImportError, match=\"requires the 'torch-scatter'\"):\n            aggr(x, ptr=ptr)\n    else:\n        assert torch.allclose(out, aggr(x, ptr=ptr))\n"
  },
  {
    "path": "test/nn/aggr/test_basic.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import (\n    MaxAggregation,\n    MeanAggregation,\n    MinAggregation,\n    MulAggregation,\n    PowerMeanAggregation,\n    SoftmaxAggregation,\n    StdAggregation,\n    SumAggregation,\n    VarAggregation,\n)\n\n\ndef test_validate():\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n    ptr = torch.tensor([0, 2, 5, 6])\n\n    aggr = MeanAggregation()\n\n    with pytest.raises(ValueError, match=\"invalid dimension\"):\n        aggr(x, index, dim=-3)\n\n    with pytest.raises(ValueError, match=\"invalid 'dim_size'\"):\n        aggr(x, ptr=ptr, dim_size=2)\n\n    with pytest.raises(ValueError, match=\"invalid 'dim_size'\"):\n        aggr(x, index, dim_size=2)\n\n\n@pytest.mark.parametrize('Aggregation', [\n    MeanAggregation,\n    SumAggregation,\n    MaxAggregation,\n    MinAggregation,\n    MulAggregation,\n    VarAggregation,\n    StdAggregation,\n])\ndef test_basic_aggregation(Aggregation):\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n    ptr = torch.tensor([0, 2, 5, 6])\n\n    aggr = Aggregation()\n    assert str(aggr) == f'{Aggregation.__name__}()'\n\n    out = aggr(x, index)\n    assert out.size() == (3, x.size(1))\n\n    if isinstance(aggr, MulAggregation):\n        with pytest.raises(RuntimeError, match=\"requires 'index'\"):\n            aggr(x, ptr=ptr)\n    elif (not torch_geometric.typing.WITH_TORCH_SCATTER\n          and not torch_geometric.typing.WITH_PT20):\n        with pytest.raises(ImportError, match=\"requires the 'torch-scatter'\"):\n            aggr(x, ptr=ptr)\n    else:\n        assert torch.allclose(out, aggr(x, ptr=ptr))\n\n\ndef test_var_aggregation():\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n\n    var_aggr = VarAggregation()\n    out = var_aggr(x, index)\n\n    mean_aggr = MeanAggregation()\n    expected = mean_aggr((x - mean_aggr(x, index)[index]).pow(2), index)\n    assert torch.allclose(out, expected, atol=1e-6)\n\n\ndef test_empty_std_aggregation():\n    aggr = StdAggregation()\n\n    x = torch.empty(0, 6).reshape(0, 6)\n    index = torch.empty(0, dtype=torch.long)\n\n    out = aggr(x, index, dim_size=5)\n    assert out.size() == (5, 6)\n    assert float(out.abs().sum()) == 0.0\n\n\n@pytest.mark.parametrize('Aggregation', [\n    SoftmaxAggregation,\n    PowerMeanAggregation,\n])\n@pytest.mark.parametrize('learn', [True, False])\ndef test_learnable_aggregation(Aggregation, learn):\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n    ptr = torch.tensor([0, 2, 5, 6])\n\n    aggr = Aggregation(learn=learn)\n    assert str(aggr) == f'{Aggregation.__name__}(learn={learn})'\n\n    out = aggr(x, index)\n    assert out.size() == (3, x.size(1))\n\n    if (not torch_geometric.typing.WITH_TORCH_SCATTER\n            and not torch_geometric.typing.WITH_PT20):\n        with pytest.raises(ImportError, match=\"requires the 'torch-scatter'\"):\n            aggr(x, ptr=ptr)\n    else:\n        assert torch.allclose(out, aggr(x, ptr=ptr))\n\n    if learn:\n        out.mean().backward()\n        for param in aggr.parameters():\n            assert not torch.isnan(param.grad).any()\n\n\n@pytest.mark.parametrize('Aggregation', [\n    SoftmaxAggregation,\n    PowerMeanAggregation,\n])\ndef test_learnable_channels_aggregation(Aggregation):\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n    ptr = torch.tensor([0, 2, 5, 6])\n\n    aggr = Aggregation(learn=True, channels=16)\n    assert str(aggr) == f'{Aggregation.__name__}(learn=True)'\n\n    out = aggr(x, index)\n    assert out.size() == (3, x.size(1))\n\n    if (not torch_geometric.typing.WITH_TORCH_SCATTER\n            and not torch_geometric.typing.WITH_PT20):\n        with pytest.raises(ImportError, match=\"requires the 'torch-scatter'\"):\n            aggr(x, ptr=ptr)\n    else:\n        assert torch.allclose(out, aggr(x, ptr=ptr))\n\n    out.mean().backward()\n    for param in aggr.parameters():\n        assert not torch.isnan(param.grad).any()\n"
  },
  {
    "path": "test/nn/aggr/test_deep_sets.py",
    "content": "import torch\n\nfrom torch_geometric.nn import DeepSetsAggregation, Linear\n\n\ndef test_deep_sets_aggregation():\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n\n    aggr = DeepSetsAggregation(\n        local_nn=Linear(16, 32),\n        global_nn=Linear(32, 64),\n    )\n    aggr.reset_parameters()\n    assert str(aggr) == ('DeepSetsAggregation('\n                         'local_nn=Linear(16, 32, bias=True), '\n                         'global_nn=Linear(32, 64, bias=True))')\n\n    out = aggr(x, index)\n    assert out.size() == (3, 64)\n"
  },
  {
    "path": "test/nn/aggr/test_equilibrium.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn.aggr import EquilibriumAggregation\n\n\n@pytest.mark.parametrize('iter', [0, 1, 5])\n@pytest.mark.parametrize('alpha', [0, .1, 5])\ndef test_equilibrium(iter, alpha):\n    batch_size = 10\n    feature_channels = 3\n    output_channels = 2\n    x = torch.randn(batch_size, feature_channels)\n    model = EquilibriumAggregation(feature_channels, output_channels,\n                                   num_layers=[10, 10], grad_iter=iter)\n\n    assert str(model) == 'EquilibriumAggregation()'\n    out = model(x)\n    assert out.size() == (1, 2)\n\n    out = model(x, dim_size=3)\n    assert out.size() == (3, 2)\n    assert torch.all(out[1:, :] == 0)\n\n\n@pytest.mark.parametrize('iter', [0, 1, 5])\n@pytest.mark.parametrize('alpha', [0, .1, 5])\ndef test_equilibrium_batch(iter, alpha):\n    batch_1, batch_2 = 4, 6\n    feature_channels = 3\n    output_channels = 2\n    x = torch.randn(batch_1 + batch_2, feature_channels)\n    batch = torch.tensor([0 for _ in range(batch_1)] +\n                         [1 for _ in range(batch_2)])\n\n    model = EquilibriumAggregation(feature_channels, output_channels,\n                                   num_layers=[10, 10], grad_iter=iter)\n\n    assert str(model) == 'EquilibriumAggregation()'\n    out = model(x, batch)\n    assert out.size() == (2, 2)\n\n    out = model(x, dim_size=3)\n    assert out.size() == (3, 2)\n    assert torch.all(out[1:, :] == 0)\n"
  },
  {
    "path": "test/nn/aggr/test_fused.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn.aggr.fused import FusedAggregation\nfrom torch_geometric.nn.resolver import aggregation_resolver\nfrom torch_geometric.profile import benchmark\n\n\n@pytest.mark.parametrize('aggrs', [\n    ['sum', 'mean', 'min', 'max', 'mul', 'var', 'std'],\n    ['sum', 'min', 'max', 'mul', 'var', 'std'],\n    ['min', 'max', 'mul', 'var', 'std'],\n    ['mean', 'min', 'max', 'mul', 'var', 'std'],\n    ['sum', 'min', 'max', 'mul', 'std'],\n    ['mean', 'min', 'max', 'mul', 'std'],\n    ['min', 'max', 'mul', 'std'],\n])\ndef test_fused_aggregation(aggrs):\n    aggrs = [aggregation_resolver(aggr) for aggr in aggrs]\n\n    x = torch.randn(6, 1)\n    y = x.clone()\n    index = torch.tensor([0, 0, 1, 1, 1, 3])\n\n    x.requires_grad_(True)\n    y.requires_grad_(True)\n\n    aggr = FusedAggregation(aggrs)\n    assert str(aggr) == 'FusedAggregation()'\n    out = torch.cat(aggr(x, index), dim=-1)\n\n    expected = torch.cat([aggr(y, index) for aggr in aggrs], dim=-1)\n    assert torch.allclose(out, expected, atol=1e-5)\n\n    jit = torch.jit.script(aggr)\n    assert torch.allclose(torch.cat(jit(x, index), dim=-1), out, atol=1e-5)\n\n    out.mean().backward()\n    assert x.grad is not None\n    expected.mean().backward()\n    assert y.grad is not None\n    assert torch.allclose(x.grad, y.grad, atol=1e-5)\n\n\ndef test_empty_fused_std_aggregation():\n    aggrs = [aggregation_resolver(aggr) for aggr in ['mean', 'var', 'std']]\n    aggr = FusedAggregation(aggrs)\n\n    x = torch.empty(0, 6).reshape(0, 6)\n    index = torch.empty(0, dtype=torch.long)\n\n    out = torch.cat(aggr(x, index, dim_size=5), dim=-1)\n    assert out.size() == (5, 18)\n    assert float(out.abs().sum()) == 0.0\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    args = parser.parse_args()\n\n    num_nodes, num_edges = 1_000, 50_000\n    x = torch.randn(num_edges, 64, device=args.device)\n    index = torch.randint(num_nodes, (num_edges, ), device=args.device)\n\n    aggrs = ['sum', 'mean', 'max', 'std']\n    print(f'Aggregators: {\", \".join(aggrs)}')\n\n    aggrs = [aggregation_resolver(aggr) for aggr in aggrs]\n    fused_aggregation = FusedAggregation(aggrs)\n\n    def naive_aggr(x, index, dim_size):\n        outs = [aggr(x, index, dim_size=dim_size) for aggr in aggrs]\n        return torch.cat(outs, dim=-1)\n\n    def fused_aggr(x, index, dim_size):\n        outs = fused_aggregation(x, index, dim_size=dim_size)\n        return torch.cat(outs, dim=-1)\n\n    benchmark(\n        funcs=[naive_aggr, fused_aggr],\n        func_names=['Naive', 'Fused'],\n        args=(x, index, num_nodes),\n        num_steps=100 if args.device == 'cpu' else 1000,\n        num_warmups=50 if args.device == 'cpu' else 500,\n        backward=args.backward,\n    )\n"
  },
  {
    "path": "test/nn/aggr/test_gmt.py",
    "content": "import torch\n\nfrom torch_geometric.nn.aggr import GraphMultisetTransformer\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_graph_multiset_transformer():\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n\n    aggr = GraphMultisetTransformer(16, k=2, heads=2)\n    aggr.reset_parameters()\n    assert str(aggr) == ('GraphMultisetTransformer(16, k=2, heads=2, '\n                         'layer_norm=False, dropout=0.0)')\n\n    out = aggr(x, index)\n    assert out.size() == (3, 16)\n\n    if is_full_test():\n        jit = torch.jit.script(aggr)\n        assert torch.allclose(jit(x, index), out)\n"
  },
  {
    "path": "test/nn/aggr/test_gru.py",
    "content": "import torch\n\nfrom torch_geometric.nn import GRUAggregation\n\n\ndef test_gru_aggregation():\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n\n    aggr = GRUAggregation(16, 32)\n    assert str(aggr) == 'GRUAggregation(16, 32)'\n\n    out = aggr(x, index)\n    assert out.size() == (3, 32)\n"
  },
  {
    "path": "test/nn/aggr/test_lcm.py",
    "content": "from itertools import product\n\nimport pytest\nimport torch\n\nfrom torch_geometric.nn import LCMAggregation\nfrom torch_geometric.profile import benchmark\n\n\ndef test_lcm_aggregation_with_project():\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n\n    aggr = LCMAggregation(16, 32)\n    assert str(aggr) == 'LCMAggregation(16, 32, project=True)'\n\n    out = aggr(x, index)\n    assert out.size() == (3, 32)\n\n\ndef test_lcm_aggregation_without_project():\n    x = torch.randn(5, 16)\n    index = torch.tensor([0, 1, 1, 2, 2])\n\n    aggr = LCMAggregation(16, 16, project=False)\n    assert str(aggr) == 'LCMAggregation(16, 16, project=False)'\n\n    out = aggr(x, index)\n    assert out.size() == (3, 16)\n\n\ndef test_lcm_aggregation_error_handling():\n    with pytest.raises(ValueError, match=\"must be projected\"):\n        LCMAggregation(16, 32, project=False)\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    args = parser.parse_args()\n\n    channels = 128\n    batch_size_list = [2**i for i in range(10, 12)]\n    num_nodes_list = [2**i for i in range(15, 18)]\n\n    aggr = LCMAggregation(channels, channels, project=False)\n    aggr = aggr.to(args.device)\n\n    funcs = []\n    func_names = []\n    args_list = []\n    for batch_size, num_nodes in product(batch_size_list, num_nodes_list):\n        x = torch.randn((num_nodes, channels), device=args.device)\n        index = torch.randint(0, batch_size, (num_nodes, ), device=args.device)\n        index = index.sort()[0]\n\n        funcs.append(aggr)\n        func_names.append(f'B={batch_size}, N={num_nodes}')\n        args_list.append((x, index))\n\n    benchmark(\n        funcs=funcs,\n        func_names=func_names,\n        args=args_list,\n        num_steps=10 if args.device == 'cpu' else 100,\n        num_warmups=5 if args.device == 'cpu' else 50,\n        backward=args.backward,\n        progress_bar=True,\n    )\n"
  },
  {
    "path": "test/nn/aggr/test_lstm.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import LSTMAggregation\n\n\ndef test_lstm_aggregation():\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n\n    aggr = LSTMAggregation(16, 32)\n    assert str(aggr) == 'LSTMAggregation(16, 32)'\n\n    with pytest.raises(ValueError, match=\"is not sorted\"):\n        aggr(x, torch.tensor([0, 1, 0, 1, 2, 1]))\n\n    out = aggr(x, index)\n    assert out.size() == (3, 32)\n"
  },
  {
    "path": "test/nn/aggr/test_mlp_aggr.py",
    "content": "import torch\n\nfrom torch_geometric.nn import MLPAggregation\n\n\ndef test_mlp_aggregation():\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n\n    aggr = MLPAggregation(\n        in_channels=16,\n        out_channels=32,\n        max_num_elements=3,\n        num_layers=1,\n    )\n    assert str(aggr) == 'MLPAggregation(16, 32, max_num_elements=3)'\n\n    out = aggr(x, index)\n    assert out.size() == (3, 32)\n"
  },
  {
    "path": "test/nn/aggr/test_multi.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import MultiAggregation\n\n\n@pytest.mark.parametrize('multi_aggr_tuple', [\n    (dict(mode='cat'), 3),\n    (dict(mode='proj', mode_kwargs=dict(in_channels=16, out_channels=16)), 1),\n    (dict(mode='attn', mode_kwargs=dict(in_channels=16, out_channels=16,\n                                        num_heads=4)), 1),\n    (dict(mode='sum'), 1),\n    (dict(mode='mean'), 1),\n    (dict(mode='max'), 1),\n    (dict(mode='min'), 1),\n    (dict(mode='logsumexp'), 1),\n    (dict(mode='std'), 1),\n    (dict(mode='var'), 1),\n])\ndef test_multi_aggr(multi_aggr_tuple):\n    # The 'cat' combine mode will expand the output dimensions by\n    # the number of aggregators which is 3 here, while the other\n    # modes keep output dimensions unchanged.\n    aggr_kwargs, expand = multi_aggr_tuple\n    x = torch.randn(7, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2, 3])\n    ptr = torch.tensor([0, 2, 5, 6, 7])\n\n    aggrs = ['mean', 'sum', 'max']\n    aggr = MultiAggregation(aggrs, **aggr_kwargs)\n    aggr.reset_parameters()\n    assert str(aggr) == ('MultiAggregation([\\n'\n                         '  MeanAggregation(),\\n'\n                         '  SumAggregation(),\\n'\n                         '  MaxAggregation(),\\n'\n                         f\"], mode={aggr_kwargs['mode']})\")\n\n    out = aggr(x, index)\n    assert out.size() == (4, expand * x.size(1))\n\n    if (not torch_geometric.typing.WITH_TORCH_SCATTER\n            and not torch_geometric.typing.WITH_PT20):\n        with pytest.raises(ImportError, match=\"requires the 'torch-scatter'\"):\n            aggr(x, ptr=ptr)\n    else:\n        assert torch.allclose(out, aggr(x, ptr=ptr))\n\n    jit = torch.jit.script(aggr)\n    assert torch.allclose(out, jit(x, index))\n"
  },
  {
    "path": "test/nn/aggr/test_patch_transformer.py",
    "content": "import torch\n\nfrom torch_geometric.nn import PatchTransformerAggregation\nfrom torch_geometric.testing import withCUDA\n\n\n@withCUDA\ndef test_patch_transformer_aggregation(device: torch.device) -> None:\n    aggr = PatchTransformerAggregation(\n        in_channels=16,\n        out_channels=32,\n        patch_size=2,\n        hidden_channels=8,\n        num_transformer_blocks=1,\n        heads=2,\n        dropout=0.2,\n        aggr=['sum', 'mean', 'min', 'max', 'var', 'std'],\n        device=device,\n    )\n    aggr.reset_parameters()\n    assert str(aggr) == 'PatchTransformerAggregation(16, 32, patch_size=2)'\n\n    index = torch.tensor([0, 0, 1, 1, 1, 2], device=device)\n    x = torch.randn(index.size(0), 16, device=device)\n\n    out = aggr(x, index)\n    assert out.device == device\n    assert out.size() == (3, aggr.out_channels)\n"
  },
  {
    "path": "test/nn/aggr/test_quantile.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import MedianAggregation, QuantileAggregation\n\n\n@pytest.mark.parametrize('q', [0., .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.])\n@pytest.mark.parametrize('interpolation', QuantileAggregation.interpolations)\n@pytest.mark.parametrize('dim', [0, 1])\n@pytest.mark.parametrize('dim_size', [None, 15])\n@pytest.mark.parametrize('fill_value', [0.0, 10.0])\ndef test_quantile_aggregation(q, interpolation, dim, dim_size, fill_value):\n    x = torch.tensor([\n        [0.0, 1.0, 2.0],\n        [3.0, 4.0, 5.0],\n        [6.0, 7.0, 8.0],\n        [9.0, 0.0, 1.0],\n        [2.0, 3.0, 4.0],\n        [5.0, 6.0, 7.0],\n        [8.0, 9.0, 0.0],\n        [1.0, 2.0, 3.0],\n        [4.0, 5.0, 6.0],\n        [7.0, 8.0, 9.0],\n    ])\n    index = torch.zeros(x.size(dim), dtype=torch.long)\n\n    aggr = QuantileAggregation(q=q, interpolation=interpolation,\n                               fill_value=fill_value)\n    assert str(aggr) == f\"QuantileAggregation(q={q})\"\n\n    out = aggr(x, index, dim=dim, dim_size=dim_size)\n    expected = x.quantile(q, dim, interpolation=interpolation, keepdim=True)\n\n    assert torch.allclose(out.narrow(dim, 0, 1), expected)\n\n    if out.size(0) > index.max() + 1:\n        padding = out.narrow(dim, 1, out.size(dim) - 1)\n        assert torch.allclose(padding, torch.tensor(fill_value))\n\n\ndef test_median_aggregation():\n    x = torch.tensor([\n        [0.0, 1.0, 2.0],\n        [3.0, 4.0, 5.0],\n        [6.0, 7.0, 8.0],\n        [9.0, 0.0, 1.0],\n        [2.0, 3.0, 4.0],\n        [5.0, 6.0, 7.0],\n        [8.0, 9.0, 0.0],\n        [1.0, 2.0, 3.0],\n        [4.0, 5.0, 6.0],\n        [7.0, 8.0, 9.0],\n    ])\n\n    aggr = MedianAggregation()\n    assert str(aggr) == \"MedianAggregation()\"\n\n    index = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2])\n    assert aggr(x, index).tolist() == [\n        [3.0, 1.0, 2.0],\n        [5.0, 6.0, 4.0],\n        [4.0, 5.0, 6.0],\n    ]\n\n    index = torch.tensor([0, 1, 0])\n    assert aggr(x, index, dim=1).tolist() == [\n        [0.0, 1.0],\n        [3.0, 4.0],\n        [6.0, 7.0],\n        [1.0, 0.0],\n        [2.0, 3.0],\n        [5.0, 6.0],\n        [0.0, 9.0],\n        [1.0, 2.0],\n        [4.0, 5.0],\n        [7.0, 8.0],\n    ]\n\n\ndef test_quantile_aggregation_multi():\n    x = torch.tensor([\n        [0.0, 1.0, 2.0],\n        [3.0, 4.0, 5.0],\n        [6.0, 7.0, 8.0],\n        [9.0, 0.0, 1.0],\n        [2.0, 3.0, 4.0],\n        [5.0, 6.0, 7.0],\n        [8.0, 9.0, 0.0],\n        [1.0, 2.0, 3.0],\n        [4.0, 5.0, 6.0],\n        [7.0, 8.0, 9.0],\n    ])\n    index = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2])\n\n    qs = [0.25, 0.5, 0.75]\n\n    assert torch.allclose(\n        QuantileAggregation(qs)(x, index),\n        torch.cat([QuantileAggregation(q)(x, index) for q in qs], dim=-1),\n    )\n\n\ndef test_quantile_aggregation_validate():\n    with pytest.raises(ValueError, match=\"at least one quantile\"):\n        QuantileAggregation(q=[])\n\n    with pytest.raises(ValueError, match=\"must be in the range\"):\n        QuantileAggregation(q=-1)\n\n    with pytest.raises(ValueError, match=\"Invalid interpolation method\"):\n        QuantileAggregation(q=0.5, interpolation=None)\n"
  },
  {
    "path": "test/nn/aggr/test_scaler.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import DegreeScalerAggregation\n\n\n@pytest.mark.parametrize('train_norm', [True, False])\ndef test_degree_scaler_aggregation(train_norm):\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 2])\n    ptr = torch.tensor([0, 2, 5, 6])\n    deg = torch.tensor([0, 3, 0, 1, 1, 0])\n\n    aggr = ['mean', 'sum', 'max']\n    scaler = [\n        'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear'\n    ]\n    aggr = DegreeScalerAggregation(aggr, scaler, deg, train_norm=train_norm)\n    assert str(aggr) == 'DegreeScalerAggregation()'\n\n    out = aggr(x, index)\n    assert out.size() == (3, 240)\n    assert torch.allclose(torch.jit.script(aggr)(x, index), out)\n\n    with pytest.raises(NotImplementedError, match=\"requires 'index'\"):\n        aggr(x, ptr=ptr)\n"
  },
  {
    "path": "test/nn/aggr/test_set2set.py",
    "content": "import torch\n\nfrom torch_geometric.nn.aggr import Set2Set\n\n\ndef test_set2set():\n    set2set = Set2Set(in_channels=2, processing_steps=1)\n    assert str(set2set) == 'Set2Set(2, 4)'\n\n    N = 4\n    x_1, batch_1 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long)\n    out_1 = set2set(x_1, batch_1).view(-1)\n\n    N = 6\n    x_2, batch_2 = torch.randn(N, 2), torch.zeros(N, dtype=torch.long)\n    out_2 = set2set(x_2, batch_2).view(-1)\n\n    x, batch = torch.cat([x_1, x_2]), torch.cat([batch_1, batch_2 + 1])\n    out = set2set(x, batch)\n    assert out.size() == (2, 4)\n    assert torch.allclose(out_1, out[0])\n    assert torch.allclose(out_2, out[1])\n\n    x, batch = torch.cat([x_2, x_1]), torch.cat([batch_2, batch_1 + 1])\n    out = set2set(x, batch)\n    assert out.size() == (2, 4)\n    assert torch.allclose(out_1, out[1])\n    assert torch.allclose(out_2, out[0])\n"
  },
  {
    "path": "test/nn/aggr/test_set_transformer.py",
    "content": "import warnings\n\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn.aggr import SetTransformerAggregation\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_set_transformer_aggregation():\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 3])\n\n    aggr = SetTransformerAggregation(16, num_seed_points=2, heads=2)\n    aggr.reset_parameters()\n    assert str(aggr) == ('SetTransformerAggregation(16, num_seed_points=2, '\n                         'heads=2, layer_norm=False, dropout=0.0)')\n\n    out = aggr(x, index)\n    assert out.size() == (4, 2 * 16)\n    assert out.isnan().sum() == 0\n    if torch_geometric.typing.WITH_PT25:\n        if not out[2].abs().sum() != 0:\n            warnings.warn(\"'SetTransformerAggregation' broken on PyTorch>2.4\",\n                          stacklevel=2)\n    else:\n        assert out[2].abs().sum() == 0\n\n    if is_full_test():\n        jit = torch.jit.script(aggr)\n        assert torch.allclose(jit(x, index), out)\n"
  },
  {
    "path": "test/nn/aggr/test_sort.py",
    "content": "import torch\n\nfrom torch_geometric.nn.aggr import SortAggregation\n\n\ndef test_sort_aggregation():\n    N_1, N_2 = 4, 6\n    x = torch.randn(N_1 + N_2, 4)\n    index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])\n\n    aggr = SortAggregation(k=5)\n    assert str(aggr) == 'SortAggregation(k=5)'\n\n    out = aggr(x, index)\n    assert out.size() == (2, 5 * 4)\n\n    out_dim = out = aggr(x, index, dim=0)\n    assert torch.allclose(out_dim, out)\n\n    out = out.view(2, 5, 4)\n\n    # First graph output has been filled up with zeros.\n    assert out[0, -1].tolist() == [0, 0, 0, 0]\n\n    # Nodes are sorted.\n    assert torch.equal(out[0, :4, -1].argsort(), 3 - torch.arange(4))\n    assert torch.equal(out[1, :, -1].argsort(), 4 - torch.arange(5))\n\n\ndef test_sort_aggregation_smaller_than_k():\n    N_1, N_2 = 4, 6\n    x = torch.randn(N_1 + N_2, 4)\n    index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])\n\n    # Set k which is bigger than both N_1=4 and N_2=6.\n    aggr = SortAggregation(k=10)\n    assert str(aggr) == 'SortAggregation(k=10)'\n\n    out = aggr(x, index)\n    assert out.size() == (2, 10 * 4)\n\n    out_dim = out = aggr(x, index, dim=0)\n    assert torch.allclose(out_dim, out)\n\n    out = out.view(2, 10, 4)\n\n    # Both graph outputs have been filled up with zeros.\n    assert out[0, -1].tolist() == [0, 0, 0, 0]\n    assert out[1, -1].tolist() == [0, 0, 0, 0]\n\n    # Nodes are sorted.\n    assert torch.equal(out[0, :4, -1].argsort(), 3 - torch.arange(4))\n    assert torch.equal(out[1, :6, -1].argsort(), 5 - torch.arange(6))\n\n\ndef test_sort_aggregation_dim_size():\n    N_1, N_2 = 4, 6\n    x = torch.randn(N_1 + N_2, 4)\n    index = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])\n\n    aggr = SortAggregation(k=5)\n    assert str(aggr) == 'SortAggregation(k=5)'\n\n    # expand batch output by 1\n    out = aggr(x, index, dim_size=3)\n    assert out.size() == (3, 5 * 4)\n\n    out = out.view(3, 5, 4)\n\n    # Both first and last graph outputs have been filled up with zeros.\n    assert out[0, -1].tolist() == [0, 0, 0, 0]\n    assert out[2, -1].tolist() == [0, 0, 0, 0]\n"
  },
  {
    "path": "test/nn/aggr/test_variance_preserving.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import (\n    MeanAggregation,\n    SumAggregation,\n    VariancePreservingAggregation,\n)\n\n\ndef test_variance_preserving():\n    x = torch.randn(6, 16)\n    index = torch.tensor([0, 0, 1, 1, 1, 3])\n    ptr = torch.tensor([0, 2, 5, 5, 6])\n\n    vpa_aggr = VariancePreservingAggregation()\n    mean_aggr = MeanAggregation()\n    sum_aggr = SumAggregation()\n\n    out_vpa = vpa_aggr(x, index)\n    out_mean = mean_aggr(x, index)\n    out_sum = sum_aggr(x, index)\n\n    # Equivalent formulation:\n    expected = torch.sqrt(out_mean.abs() * out_sum.abs()) * out_sum.sign()\n\n    assert out_vpa.size() == (4, 16)\n    assert torch.allclose(out_vpa, expected)\n\n    if (not torch_geometric.typing.WITH_TORCH_SCATTER\n            and not torch_geometric.typing.WITH_PT20):\n        with pytest.raises(ImportError, match=\"requires the 'torch-scatter'\"):\n            vpa_aggr(x, ptr=ptr)\n    else:\n        assert torch.allclose(out_vpa, vpa_aggr(x, ptr=ptr))\n"
  },
  {
    "path": "test/nn/attention/test_performer_attention.py",
    "content": "import torch\n\nfrom torch_geometric.nn.attention import PerformerAttention\n\n\ndef test_performer_attention():\n    x = torch.randn(1, 4, 16)\n    mask = torch.ones([1, 4], dtype=torch.bool)\n    attn = PerformerAttention(channels=16, heads=4)\n    out = attn(x, mask)\n    assert out.shape == (1, 4, 16)\n    assert str(attn) == ('PerformerAttention(heads=4, '\n                         'head_channels=64 kernel=ReLU())')\n"
  },
  {
    "path": "test/nn/attention/test_polynormer_attention.py",
    "content": "import torch\n\nfrom torch_geometric.nn.attention import PolynormerAttention\n\n\ndef test_performer_attention():\n    x = torch.randn(1, 4, 16)\n    mask = torch.ones([1, 4], dtype=torch.bool)\n    attn = PolynormerAttention(channels=16, heads=4)\n    out = attn(x, mask)\n    assert out.shape == (1, 4, 256)\n    assert str(attn) == 'PolynormerAttention(heads=4, head_channels=64)'\n"
  },
  {
    "path": "test/nn/attention/test_qformer.py",
    "content": "import torch\n\nfrom torch_geometric.nn.attention import QFormer\n\n\ndef test_qformer():\n    x = torch.randn(1, 4, 16)\n    attn = QFormer(input_dim=16, hidden_dim=16, output_dim=32, num_heads=4,\n                   num_layers=2)\n    out = attn(x)\n\n    assert out.shape == (1, 4, 32)\n    assert str(attn) == ('QFormer(num_heads=4, num_layers=2)')\n"
  },
  {
    "path": "test/nn/conv/cugraph/test_cugraph_gat_conv.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.nn import CuGraphGATConv, GATConv\nfrom torch_geometric.testing import onlyCUDA, withPackage\n\n\n@onlyCUDA\n@withPackage('pylibcugraphops>=23.02')\n@pytest.mark.parametrize('bias', [True, False])\n@pytest.mark.parametrize('bipartite', [True, False])\n@pytest.mark.parametrize('concat', [True, False])\n@pytest.mark.parametrize('edge_attr', [True, False])\n@pytest.mark.parametrize('heads', [1, 2, 3])\n@pytest.mark.parametrize('max_num_neighbors', [8, None])\ndef test_gat_conv_equality(bias, bipartite, concat, edge_attr, heads,\n                           max_num_neighbors):\n    in_channels, out_channels = 5, 2\n    kwargs = dict(bias=bias, concat=concat)\n\n    size = (10, 8) if bipartite else (10, 10)\n    x = torch.rand(size[0], in_channels, device='cuda')\n    edge_index = torch.tensor([\n        [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9],\n        [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7],\n    ], device='cuda')\n\n    conv1 = GATConv(in_channels, out_channels, heads, add_self_loops=False,\n                    **kwargs).cuda()\n    conv2 = CuGraphGATConv(in_channels, out_channels, heads, **kwargs).cuda()\n\n    with torch.no_grad():\n        conv2.lin.weight.data[:, :] = conv1.lin.weight.data\n        conv2.att.data[:heads * out_channels] = conv1.att_src.data.flatten()\n        conv2.att.data[heads * out_channels:] = conv1.att_dst.data.flatten()\n    if edge_attr and not bipartite:\n        e_attrs = torch.randn(size=(edge_index.size(1), 10))\n        out1 = conv1(x, edge_index, edge_attr=e_attrs)\n\n        out2 = conv2(\n            x,\n            EdgeIndex(edge_index, sparse_size=size),\n            max_num_neighbors=max_num_neighbors,\n            edge_attr=e_attrs,\n        )\n    else:\n        if bipartite:\n            out1 = conv1((x, x[:size[1]]), edge_index)\n        else:\n            out1 = conv1(x, edge_index)\n\n        out2 = conv2(\n            x,\n            EdgeIndex(edge_index, sparse_size=size),\n            max_num_neighbors=max_num_neighbors,\n        )\n    assert torch.allclose(out1, out2, atol=1e-3)\n\n    grad_output = torch.rand_like(out1)\n    out1.backward(grad_output)\n    out2.backward(grad_output)\n\n    assert torch.allclose(conv1.lin.weight.grad, conv2.lin.weight.grad,\n                          atol=1e-3)\n    assert torch.allclose(conv1.att_src.grad.flatten(),\n                          conv2.att.grad[:heads * out_channels], atol=1e-3)\n    assert torch.allclose(conv1.att_dst.grad.flatten(),\n                          conv2.att.grad[heads * out_channels:], atol=1e-3)\n    if bias:\n        assert torch.allclose(conv1.bias.grad, conv2.bias.grad, atol=1e-3)\n"
  },
  {
    "path": "test/nn/conv/cugraph/test_cugraph_rgcn_conv.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.nn import CuGraphRGCNConv\nfrom torch_geometric.nn import FastRGCNConv as RGCNConv\nfrom torch_geometric.testing import onlyCUDA, withPackage\n\n\n@onlyCUDA\n@withPackage('pylibcugraphops>=23.02')\n@pytest.mark.parametrize('aggr', ['add', 'sum', 'mean'])\n@pytest.mark.parametrize('bias', [True, False])\n@pytest.mark.parametrize('bipartite', [True, False])\n@pytest.mark.parametrize('max_num_neighbors', [8, None])\n@pytest.mark.parametrize('num_bases', [1, 2, None])\n@pytest.mark.parametrize('root_weight', [True, False])\ndef test_rgcn_conv_equality(aggr, bias, bipartite, max_num_neighbors,\n                            num_bases, root_weight):\n\n    in_channels, out_channels, num_relations = (4, 2, 3)\n    kwargs = dict(aggr=aggr, bias=bias, num_bases=num_bases,\n                  root_weight=root_weight)\n\n    size = (10, 8) if bipartite else (10, 10)\n    x = torch.rand(size[0], in_channels, device='cuda')\n    edge_index = torch.tensor([\n        [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9],\n        [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7],\n    ], device='cuda')\n    edge_type = torch.tensor([1, 2, 1, 0, 2, 1, 2, 0, 2, 2, 1, 1, 1, 2,\n                              2]).cuda()\n\n    torch.manual_seed(12345)\n    conv1 = RGCNConv(in_channels, out_channels, num_relations, **kwargs).cuda()\n    torch.manual_seed(12345)\n    conv2 = CuGraphRGCNConv(in_channels, out_channels, num_relations,\n                            **kwargs).cuda()\n\n    if bipartite:\n        out1 = conv1((x, x[:size[1]]), edge_index, edge_type)\n    else:\n        out1 = conv1(x, edge_index, edge_type)\n\n    out2 = conv2(\n        x,\n        EdgeIndex(edge_index, sparse_size=size),\n        edge_type,\n        max_num_neighbors=max_num_neighbors,\n    )\n    assert torch.allclose(out1, out2, atol=1e-3)\n\n    grad_out = torch.rand_like(out1)\n    out1.backward(grad_out)\n    out2.backward(grad_out)\n\n    end = -1 if root_weight else None\n    assert torch.allclose(conv1.weight.grad, conv2.weight.grad[:end],\n                          atol=1e-3)\n\n    if root_weight:\n        assert torch.allclose(conv1.root.grad, conv2.weight.grad[-1],\n                              atol=1e-3)\n\n    if num_bases is not None:\n        assert torch.allclose(conv1.comp.grad, conv2.comp.grad, atol=1e-3)\n"
  },
  {
    "path": "test/nn/conv/cugraph/test_cugraph_sage_conv.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.nn import CuGraphSAGEConv, SAGEConv\nfrom torch_geometric.testing import onlyCUDA, withPackage\n\n\n@onlyCUDA\n@withPackage('pylibcugraphops>=23.02')\n@pytest.mark.parametrize('aggr', ['sum', 'mean', 'min', 'max'])\n@pytest.mark.parametrize('bias', [True, False])\n@pytest.mark.parametrize('bipartite', [True, False])\n@pytest.mark.parametrize('max_num_neighbors', [8, None])\n@pytest.mark.parametrize('normalize', [True, False])\n@pytest.mark.parametrize('root_weight', [True, False])\ndef test_sage_conv_equality(aggr, bias, bipartite, max_num_neighbors,\n                            normalize, root_weight):\n\n    in_channels, out_channels = (8, 16)\n    kwargs = dict(aggr=aggr, bias=bias, normalize=normalize,\n                  root_weight=root_weight)\n\n    size = (10, 8) if bipartite else (10, 10)\n    x = torch.rand(size[0], in_channels, device='cuda')\n    edge_index = torch.tensor([\n        [7, 0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 8, 9],\n        [0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7],\n    ], device='cuda')\n\n    conv1 = SAGEConv(in_channels, out_channels, **kwargs).cuda()\n    conv2 = CuGraphSAGEConv(in_channels, out_channels, **kwargs).cuda()\n\n    with torch.no_grad():\n        conv2.lin.weight.data[:, :in_channels] = conv1.lin_l.weight.data\n        if root_weight:\n            conv2.lin.weight.data[:, in_channels:] = conv1.lin_r.weight.data\n        if bias:\n            conv2.lin.bias.data[:] = conv1.lin_l.bias.data\n\n    if bipartite:\n        out1 = conv1((x, x[:size[1]]), edge_index)\n    else:\n        out1 = conv1(x, edge_index)\n\n    out2 = conv2(\n        x,\n        EdgeIndex(edge_index, sparse_size=size),\n        max_num_neighbors=max_num_neighbors,\n    )\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n    grad_out = torch.rand_like(out1)\n    out1.backward(grad_out)\n    out2.backward(grad_out)\n\n    assert torch.allclose(\n        conv1.lin_l.weight.grad,\n        conv2.lin.weight.grad[:, :in_channels],\n        atol=1e-6,\n    )\n\n    if root_weight:\n        assert torch.allclose(\n            conv1.lin_r.weight.grad,\n            conv2.lin.weight.grad[:, in_channels:],\n            atol=1e-6,\n        )\n\n    if bias:\n        assert torch.allclose(\n            conv1.lin_l.bias.grad,\n            conv2.lin.bias.grad,\n            atol=1e-6,\n        )\n"
  },
  {
    "path": "test/nn/conv/test_agnn_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import AGNNConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\n@pytest.mark.parametrize('requires_grad', [True, False])\ndef test_agnn_conv(requires_grad):\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = AGNNConv(requires_grad=requires_grad)\n    assert str(conv) == 'AGNNConv()'\n    out = conv(x, edge_index)\n    assert out.size() == (4, 16)\n    assert torch.allclose(conv(x, adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj2.t()), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_antisymmetric_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import AntiSymmetricConv\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_antisymmetric_conv():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    value = torch.rand(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = AntiSymmetricConv(8)\n    assert str(conv) == ('AntiSymmetricConv(8, phi=GCNConv(8, 8), '\n                         'num_iters=1, epsilon=0.1, gamma=0.1)')\n\n    out1 = conv(x, edge_index)\n    assert out1.size() == (4, 8)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n\n    out2 = conv(x, edge_index, value)\n    assert out2.size() == (4, 8)\n    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n        assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_appnp.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import APPNP\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_appnp():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = APPNP(K=3, alpha=0.1, cached=True)\n    assert str(conv) == 'APPNP(K=3, alpha=0.1)'\n    out = conv(x, edge_index)\n    assert out.size() == (4, 16)\n    assert torch.allclose(conv(x, adj1.t()), out, rtol=1e-5, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj2.t()), out, rtol=1e-5, atol=1e-6)\n\n    # Run again to test the cached functionality:\n    assert conv._cached_edge_index is not None\n    assert torch.allclose(conv(x, edge_index), conv(x, adj1.t()), rtol=1e-5,\n                          atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert conv._cached_adj_t is not None\n        assert torch.allclose(conv(x, edge_index), conv(x, adj2.t()),\n                              rtol=1e-5, atol=1e-6)\n\n    conv.reset_parameters()\n    assert conv._cached_edge_index is None\n    assert conv._cached_adj_t is None\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out, rtol=1e-5, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj2.t()), out, rtol=1e-5, atol=1e-6)\n\n\ndef test_appnp_dropout():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    # With dropout probability of 1.0, the final output equals to alpha * x:\n    conv = APPNP(K=2, alpha=0.1, dropout=1.0)\n    assert torch.allclose(0.1 * x, conv(x, edge_index), rtol=1e-5, atol=1e-6)\n    assert torch.allclose(0.1 * x, conv(x, adj1.t()), rtol=1e-5, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(0.1 * x, conv(x, adj2.t()), rtol=1e-5, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_arma_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import ARMAConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_arma_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = ARMAConv(16, 32, num_stacks=8, num_layers=4)\n    assert str(conv) == 'ARMAConv(16, 32, num_stacks=8, num_layers=4)'\n    out = conv(x, edge_index)\n    assert out.size() == (4, 32)\n    with pytest.raises(RuntimeError):  # No 3D feature tensor support.\n        assert torch.allclose(conv(x, adj1.t()), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj2.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj2.t()), out, atol=1e-6)\n\n\ndef test_lazy_arma_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    conv = ARMAConv(-1, 32, num_stacks=8, num_layers=4)\n    assert str(conv) == 'ARMAConv(-1, 32, num_stacks=8, num_layers=4)'\n    out = conv(x, edge_index)\n    assert out.size() == (4, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj2.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj2.t()), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_cg_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import CGConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\n@pytest.mark.parametrize('batch_norm', [False, True])\ndef test_cg_conv(batch_norm):\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = CGConv(8, batch_norm=batch_norm)\n    assert str(conv) == 'CGConv(8, dim=0)'\n    out = conv(x1, edge_index)\n    assert out.size() == (4, 8)\n    assert torch.allclose(conv(x1, adj1.t()), out)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    conv = CGConv((8, 16))\n    assert str(conv) == 'CGConv((8, 16), dim=0)'\n    out = conv((x1, x2), edge_index)\n    assert out.size() == (2, 16)\n    assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6)\n\n\ndef test_cg_conv_with_edge_features():\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.rand(edge_index.size(1), 3)\n\n    conv = CGConv(8, dim=3)\n    assert str(conv) == 'CGConv(8, dim=3)'\n    out = conv(x1, edge_index, value)\n    assert out.size() == (4, 8)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x1, adj.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index, value), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj.t()), out)\n\n    # Test bipartite message passing:\n    conv = CGConv((8, 16), dim=3)\n    assert str(conv) == 'CGConv((8, 16), dim=3)'\n    out = conv((x1, x2), edge_index, value)\n    assert out.size() == (2, 16)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 2))\n        assert torch.allclose(conv((x1, x2), adj.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index, value), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj.t()), out)\n"
  },
  {
    "path": "test/nn/conv/test_cheb_conv.py",
    "content": "import torch\n\nfrom torch_geometric.data import Batch, Data\nfrom torch_geometric.nn import ChebConv\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_cheb_conv():\n    in_channels, out_channels = (16, 32)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    num_nodes = edge_index.max().item() + 1\n    edge_weight = torch.rand(edge_index.size(1))\n    x = torch.randn((num_nodes, in_channels))\n\n    conv = ChebConv(in_channels, out_channels, K=3)\n    assert str(conv) == 'ChebConv(16, 32, K=3, normalization=sym)'\n    out1 = conv(x, edge_index)\n    assert out1.size() == (num_nodes, out_channels)\n    out2 = conv(x, edge_index, edge_weight)\n    assert out2.size() == (num_nodes, out_channels)\n    out3 = conv(x, edge_index, edge_weight, lambda_max=3.0)\n    assert out3.size() == (num_nodes, out_channels)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out1)\n        assert torch.allclose(jit(x, edge_index, edge_weight), out2)\n        assert torch.allclose(\n            jit(x, edge_index, edge_weight, lambda_max=torch.tensor(3.0)),\n            out3)\n\n    batch = torch.tensor([0, 0, 1, 1])\n    edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n    num_nodes = edge_index.max().item() + 1\n    edge_weight = torch.rand(edge_index.size(1))\n    x = torch.randn((num_nodes, in_channels))\n    lambda_max = torch.tensor([2.0, 3.0])\n\n    out4 = conv(x, edge_index, edge_weight, batch)\n    assert out4.size() == (num_nodes, out_channels)\n    out5 = conv(x, edge_index, edge_weight, batch, lambda_max)\n    assert out5.size() == (num_nodes, out_channels)\n\n    if is_full_test():\n        assert torch.allclose(jit(x, edge_index, edge_weight, batch), out4)\n        assert torch.allclose(\n            jit(x, edge_index, edge_weight, batch, lambda_max), out5)\n\n\ndef test_cheb_conv_batch():\n    x1 = torch.randn(4, 8)\n    edge_index1 = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])\n    edge_weight1 = torch.rand(edge_index1.size(1))\n    data1 = Data(x=x1, edge_index=edge_index1, edge_weight=edge_weight1)\n\n    x2 = torch.randn(3, 8)\n    edge_index2 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_weight2 = torch.rand(edge_index2.size(1))\n    data2 = Data(x=x2, edge_index=edge_index2, edge_weight=edge_weight2)\n\n    conv = ChebConv(8, 16, K=2)\n\n    out1 = conv(x1, edge_index1, edge_weight1)\n    out2 = conv(x2, edge_index2, edge_weight2)\n\n    batch = Batch.from_data_list([data1, data2])\n    out = conv(batch.x, batch.edge_index, batch.edge_weight, batch.batch)\n\n    assert out.size() == (7, 16)\n    assert torch.allclose(out1, out[:4], atol=1e-6)\n    assert torch.allclose(out2, out[4:], atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_cluster_gcn_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import ClusterGCNConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_cluster_gcn_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = ClusterGCNConv(16, 32, diag_lambda=1.)\n    assert str(conv) == 'ClusterGCNConv(16, 32, diag_lambda=1.0)'\n    out = conv(x, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x, adj1.t()), out, atol=1e-5)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj2.t()), out, atol=1e-5)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj2.t()), out, atol=1e-5)\n"
  },
  {
    "path": "test/nn/conv/test_create_gnn.py",
    "content": "import torch\n\nfrom torch_geometric.nn import MessagePassing\nfrom torch_geometric.utils import add_self_loops, degree\n\n\nclass GCNConv(MessagePassing):\n    def __init__(self, in_channels, out_channels):\n        super().__init__(aggr='add')\n        self.lin = torch.nn.Linear(in_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))\n\n        row, col = edge_index\n        deg = degree(row, x.size(0), dtype=x.dtype)\n        deg_inv_sqrt = deg.pow(-0.5)\n        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]\n\n        x = self.lin(x)\n        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x,\n                              norm=norm)\n\n    def message(self, x_j, norm):\n        return norm.view(-1, 1) * x_j\n\n    def update(self, aggr_out):\n        return aggr_out\n\n\ndef test_create_gnn():\n    conv = GCNConv(16, 32)\n    x = torch.randn(5, 16)\n    edge_index = torch.randint(5, (2, 64), dtype=torch.long)\n    out = conv(x, edge_index)\n    assert out.size() == (5, 32)\n"
  },
  {
    "path": "test/nn/conv/test_dir_gnn_conv.py",
    "content": "import torch\n\nfrom torch_geometric.nn import DirGNNConv, SAGEConv\n\n\ndef test_dir_gnn_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])\n\n    conv = DirGNNConv(SAGEConv(16, 32))\n    assert str(conv) == 'DirGNNConv(SAGEConv(16, 32, aggr=mean), alpha=0.5)'\n\n    out = conv(x, edge_index)\n    assert out.size() == (4, 32)\n\n\ndef test_static_dir_gnn_conv():\n    x = torch.randn(3, 4, 16)\n    edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])\n\n    conv = DirGNNConv(SAGEConv(16, 32))\n\n    out = conv(x, edge_index)\n    assert out.size() == (3, 4, 32)\n"
  },
  {
    "path": "test/nn/conv/test_dna_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import DNAConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\n@pytest.mark.parametrize('channels', [32])\n@pytest.mark.parametrize('num_layers', [3])\ndef test_dna_conv(channels, num_layers):\n    x = torch.randn((4, num_layers, channels))\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    conv = DNAConv(channels, heads=4, groups=8, dropout=0.0)\n    assert str(conv) == 'DNAConv(32, heads=4, groups=8)'\n    out = conv(x, edge_index)\n    assert out.size() == (4, channels)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out, atol=1e-6)\n\n    conv = DNAConv(channels, heads=1, groups=1, dropout=0.0)\n    assert str(conv) == 'DNAConv(32, heads=1, groups=1)'\n    out = conv(x, edge_index)\n    assert out.size() == (4, channels)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out, atol=1e-6)\n\n    conv = DNAConv(channels, heads=1, groups=1, dropout=0.0, cached=True)\n    out = conv(x, edge_index)\n    assert conv._cached_edge_index is not None\n    out = conv(x, edge_index)\n    assert out.size() == (4, channels)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out, atol=1e-6)\n\n\n@pytest.mark.parametrize('channels', [32])\n@pytest.mark.parametrize('num_layers', [3])\ndef test_dna_conv_sparse_tensor(channels, num_layers):\n    x = torch.randn((4, num_layers, channels))\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    value = torch.rand(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = DNAConv(32, heads=4, groups=8, dropout=0.0)\n    assert str(conv) == 'DNAConv(32, heads=4, groups=8)'\n    out1 = conv(x, edge_index)\n    assert out1.size() == (4, 32)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n    out2 = conv(x, edge_index, value)\n    assert out2.size() == (4, 32)\n    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n        assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)\n        assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)\n\n    conv = DNAConv(channels, heads=1, groups=1, dropout=0.0, cached=True)\n\n    out1 = conv(x, adj1.t())\n    assert conv._cached_edge_index is not None\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n        assert conv._cached_adj_t is not None\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_edge_conv.py",
    "content": "import torch\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import DynamicEdgeConv, EdgeConv\nfrom torch_geometric.testing import is_full_test, withPackage\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_edge_conv_conv():\n    x1 = torch.randn(4, 16)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    nn = Seq(Lin(32, 16), ReLU(), Lin(16, 32))\n    conv = EdgeConv(nn)\n    assert str(conv) == (\n        'EdgeConv(nn=Sequential(\\n'\n        '  (0): Linear(in_features=32, out_features=16, bias=True)\\n'\n        '  (1): ReLU()\\n'\n        '  (2): Linear(in_features=16, out_features=32, bias=True)\\n'\n        '))')\n    out = conv(x1, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv((x1, x1), edge_index), out, atol=1e-6)\n    assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)\n    assert torch.allclose(conv((x1, x1), adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)\n        assert torch.allclose(conv((x1, x1), adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)\n        assert torch.allclose(jit((x1, x1), edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)\n            assert torch.allclose(jit((x1, x1), adj2.t()), out, atol=1e-6)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    out = conv((x1, x2), edge_index)\n    assert out.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6)\n\n\n@withPackage('torch_cluster')\ndef test_dynamic_edge_conv():\n    x1 = torch.randn(8, 16)\n    x2 = torch.randn(4, 16)\n    batch1 = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])\n    batch2 = torch.tensor([0, 0, 1, 1])\n\n    nn = Seq(Lin(32, 16), ReLU(), Lin(16, 32))\n    conv = DynamicEdgeConv(nn, k=2)\n    assert str(conv) == (\n        'DynamicEdgeConv(nn=Sequential(\\n'\n        '  (0): Linear(in_features=32, out_features=16, bias=True)\\n'\n        '  (1): ReLU()\\n'\n        '  (2): Linear(in_features=16, out_features=32, bias=True)\\n'\n        '), k=2)')\n    out11 = conv(x1)\n    assert out11.size() == (8, 32)\n\n    out12 = conv(x1, batch1)\n    assert out12.size() == (8, 32)\n\n    out21 = conv((x1, x2))\n    assert out21.size() == (4, 32)\n\n    out22 = conv((x1, x2), (batch1, batch2))\n    assert out22.size() == (4, 32)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1), out11)\n        assert torch.allclose(jit(x1, batch1), out12)\n        assert torch.allclose(jit((x1, x2)), out21)\n        assert torch.allclose(jit((x1, x2), (batch1, batch2)), out22)\n"
  },
  {
    "path": "test/nn/conv/test_eg_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import EGConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_eg_conv_with_error():\n    with pytest.raises(ValueError, match=\"must be divisible by the number of\"):\n        EGConv(16, 30, num_heads=8)\n\n    with pytest.raises(ValueError, match=\"Unsupported aggregator\"):\n        EGConv(16, 32, aggregators=['xxx'])\n\n\n@pytest.mark.parametrize('aggregators', [\n    ['symnorm'],\n    ['sum', 'symnorm', 'std'],\n])\n@pytest.mark.parametrize('add_self_loops', [True, False])\ndef test_eg_conv(aggregators, add_self_loops):\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = EGConv(\n        in_channels=16,\n        out_channels=32,\n        aggregators=aggregators,\n        add_self_loops=add_self_loops,\n    )\n    assert str(conv) == f\"EGConv(16, 32, aggregators={aggregators})\"\n    out = conv(x, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x, adj1.t()), out, atol=1e-2)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj2.t()), out, atol=1e-2)\n\n    conv.cached = True\n    assert torch.allclose(conv(x, edge_index), out, atol=1e-2)\n    assert conv._cached_edge_index is not None\n    assert torch.allclose(conv(x, edge_index), out, atol=1e-2)\n    assert torch.allclose(conv(x, adj1.t()), out, atol=1e-2)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x, adj2.t()), out, atol=1e-2)\n        assert conv._cached_adj_t is not None\n        assert torch.allclose(conv(x, adj2.t()), out, atol=1e-2)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out, atol=1e-2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj2.t()), out, atol=1e-2)\n\n\ndef test_eg_conv_with_sparse_input_feature():\n    x = torch.randn(4, 16).to_sparse_coo()\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    conv = EGConv(16, 32)\n    assert conv(x, edge_index).size() == (4, 32)\n"
  },
  {
    "path": "test/nn/conv/test_fa_conv.py",
    "content": "from typing import Tuple\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import FAConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import Adj, SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_fa_conv():\n    x = torch.randn(4, 16)\n    x_0 = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = FAConv(16, eps=1.0, cached=True)\n    assert str(conv) == 'FAConv(16, eps=1.0)'\n    out = conv(x, x_0, edge_index)\n    assert conv._cached_edge_index is not None\n    assert out.size() == (4, 16)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, x_0, adj2.t()), out, atol=1e-6)\n        assert conv._cached_adj_t is not None\n        assert torch.allclose(conv(x, x_0, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tensor,\n                x_0: Tensor,\n                edge_index: Adj,\n            ) -> Tensor:\n                return self.conv(x, x_0, edge_index)\n\n        jit = torch.jit.script(MyModule())\n        assert torch.allclose(jit(x, x_0, edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, x_0, adj2.t()), out)\n\n    conv.reset_parameters()\n    assert conv._cached_edge_index is None\n    assert conv._cached_adj_t is None\n\n    # Test without caching:\n    conv.cached = False\n    out = conv(x, x_0, edge_index)\n    assert torch.allclose(conv(x, x_0, adj1.t()), out, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x, x_0, adj2.t()), out, atol=1e-6)\n\n    # Test `return_attention_weights`:\n    result = conv(x, x_0, edge_index, return_attention_weights=True)\n    assert torch.allclose(result[0], out, atol=1e-6)\n    assert result[1][0].size() == (2, 10)\n    assert result[1][1].size() == (10, )\n    assert conv._alpha is None\n\n    result = conv(x, x_0, adj1.t(), return_attention_weights=True)\n    assert torch.allclose(result[0], out, atol=1e-6)\n    assert result[1][0].size() == torch.Size([4, 4])\n    assert result[1][0]._nnz() == 10\n    assert conv._alpha is None\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        result = conv(x, x_0, adj2.t(), return_attention_weights=True)\n        assert torch.allclose(result[0], out, atol=1e-6)\n        assert result[1].sizes() == [4, 4] and result[1].nnz() == 10\n        assert conv._alpha is None\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tensor,\n                x_0: Tensor,\n                edge_index: Tensor,\n            ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:\n                return self.conv(x, x_0, edge_index,\n                                 return_attention_weights=True)\n\n        jit = torch.jit.script(MyModule())\n        result = jit(x, x_0, edge_index)\n        assert torch.allclose(result[0], out, atol=1e-6)\n        assert result[1][0].size() == (2, 10)\n        assert result[1][1].size() == (10, )\n        assert conv._alpha is None\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n\n            class MyModule(torch.nn.Module):\n                def __init__(self):\n                    super().__init__()\n                    self.conv = conv\n\n                def forward(\n                    self,\n                    x: Tensor,\n                    x_0: Tensor,\n                    edge_index: SparseTensor,\n                ) -> Tuple[Tensor, SparseTensor]:\n                    return self.conv(x, x_0, edge_index,\n                                     return_attention_weights=True)\n\n            jit = torch.jit.script(MyModule())\n            result = jit(x, x_0, adj2.t())\n            assert torch.allclose(result[0], out, atol=1e-6)\n            assert result[1].sizes() == [4, 4] and result[1].nnz() == 10\n            assert conv._alpha is None\n"
  },
  {
    "path": "test/nn/conv/test_feast_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import FeaStConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_feast_conv():\n    x1 = torch.randn(4, 16)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = FeaStConv(16, 32, heads=2)\n    assert str(conv) == 'FeaStConv(16, 32, heads=2)'\n\n    out = conv(x1, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    out = conv((x1, x2), edge_index)\n    assert out.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_film_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import FiLMConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\n\n\ndef test_film_conv():\n    x1 = torch.randn(4, 4)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1]])\n    edge_type = torch.tensor([0, 1, 1, 0, 0, 1])\n\n    conv = FiLMConv(4, 32)\n    assert str(conv) == 'FiLMConv(4, 32, num_relations=1)'\n    out = conv(x1, edge_index)\n    assert out.size() == (4, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)\n\n    conv = FiLMConv(4, 32, num_relations=2)\n    assert str(conv) == 'FiLMConv(4, 32, num_relations=2)'\n    out = conv(x1, edge_index, edge_type)\n    assert out.size() == (4, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 4))\n        assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index, edge_type), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)\n\n    # Test bipartite message passing:\n    conv = FiLMConv((4, 16), 32)\n    assert str(conv) == 'FiLMConv((4, 16), 32, num_relations=1)'\n    out = conv((x1, x2), edge_index)\n    assert out.size() == (2, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6)\n\n    conv = FiLMConv((4, 16), 32, num_relations=2)\n    assert str(conv) == 'FiLMConv((4, 16), 32, num_relations=2)'\n    out = conv((x1, x2), edge_index, edge_type)\n    assert out.size() == (2, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 2))\n        assert torch.allclose(conv((x1, x2), adj.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index, edge_type), out,\n                              atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj.t()), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_fused_gat_conv.py",
    "content": "import torch\n\nfrom torch_geometric.nn import FusedGATConv\nfrom torch_geometric.testing import onlyCUDA, withPackage\n\n\ndef test_to_graph_format() -> None:\n    edge_index = torch.tensor([[1, 0, 2, 3], [0, 0, 1, 1]])\n\n    csr, csc, perm = FusedGATConv.to_graph_format(edge_index, size=(4, 4))\n\n    assert csr[0].dtype == torch.int\n    assert torch.equal(csr[0], torch.tensor([0, 1, 2, 3, 4], dtype=torch.int))\n    assert csr[1].dtype == torch.int\n    assert torch.equal(csr[1], torch.tensor([0, 0, 1, 1], dtype=torch.int))\n    assert csc[0].dtype == torch.int\n    assert torch.equal(csc[0], torch.tensor([0, 1, 2, 3], dtype=torch.int))\n    assert csc[1].dtype == torch.int\n    assert torch.equal(csc[1], torch.tensor([0, 2, 4, 4, 4], dtype=torch.int))\n    assert perm.dtype == torch.int\n    assert torch.equal(perm, torch.tensor([0, 1, 2, 3], dtype=torch.int))\n\n\n@onlyCUDA\n@withPackage('dgNN')\ndef test_fused_gat_conv() -> None:\n    device = torch.device('cuda')\n\n    x = torch.randn(4, 8, device=device)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], device=device)\n\n    csr, csc, perm = FusedGATConv.to_graph_format(edge_index, size=(4, 4))\n\n    conv = FusedGATConv(8, 32, heads=2, add_self_loops=False).to(device)\n    assert str(conv) == 'FusedGATConv(8, 32, heads=2)'\n\n    out = conv(x, csr, csc, perm)\n    assert out.size() == (4, 64)\n"
  },
  {
    "path": "test/nn/conv/test_gat_conv.py",
    "content": "from typing import Optional, Tuple\n\nimport pytest\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GATConv\nfrom torch_geometric.testing import is_full_test, withDevice\nfrom torch_geometric.typing import Adj, Size, SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\n@pytest.mark.parametrize('residual', [False, True])\ndef test_gat_conv(residual):\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = GATConv(8, 32, heads=2, residual=residual)\n    assert str(conv) == 'GATConv(8, 32, heads=2)'\n    out = conv(x1, edge_index)\n    assert out.size() == (4, 64)\n    assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out)\n    assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tensor,\n                edge_index: Adj,\n                size: Size = None,\n            ) -> Tensor:\n                return self.conv(x, edge_index, size=size)\n\n        jit = torch.jit.script(MyModule())\n        assert torch.allclose(jit(x1, edge_index), out)\n        assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)\n\n    # Test `return_attention_weights`.\n    result = conv(x1, edge_index, return_attention_weights=True)\n    assert torch.allclose(result[0], out)\n    assert result[1][0].size() == (2, 7)\n    assert result[1][1].size() == (7, 2)\n    assert result[1][1].min() >= 0 and result[1][1].max() <= 1\n\n    result = conv(x1, adj1.t(), return_attention_weights=True)\n    assert torch.allclose(result[0], out, atol=1e-6)\n    assert result[1][0].size() == torch.Size([4, 4, 2])\n    assert result[1][0]._nnz() == 7\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        result = conv(x1, adj2.t(), return_attention_weights=True)\n        assert torch.allclose(result[0], out, atol=1e-6)\n        assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tensor,\n                edge_index: Tensor,\n            ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:\n                return self.conv(x, edge_index, return_attention_weights=True)\n\n        jit = torch.jit.script(MyModule())\n        result = jit(x1, edge_index)\n        assert torch.allclose(result[0], out)\n        assert result[1][0].size() == (2, 7)\n        assert result[1][1].size() == (7, 2)\n        assert result[1][1].min() >= 0 and result[1][1].max() <= 1\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n\n            class MyModule(torch.nn.Module):\n                def __init__(self):\n                    super().__init__()\n                    self.conv = conv\n\n                def forward(\n                    self,\n                    x: Tensor,\n                    edge_index: SparseTensor,\n                ) -> Tuple[Tensor, SparseTensor]:\n                    return self.conv(x, edge_index,\n                                     return_attention_weights=True)\n\n            jit = torch.jit.script(MyModule())\n            result = jit(x1, adj2.t())\n            assert torch.allclose(result[0], out, atol=1e-6)\n            assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    conv = GATConv((8, 16), 32, heads=2, residual=residual)\n    assert str(conv) == 'GATConv((8, 16), 32, heads=2)'\n\n    out1 = conv((x1, x2), edge_index)\n    assert out1.size() == (2, 64)\n    assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1)\n    assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6)\n\n    out2 = conv((x1, None), edge_index, size=(4, 2))\n    assert out2.size() == (2, 64)\n    assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)\n        assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tuple[Tensor, Optional[Tensor]],\n                edge_index: Adj,\n                size: Size = None,\n            ) -> Tensor:\n                return self.conv(x, edge_index, size=size)\n\n        jit = torch.jit.script(MyModule())\n        assert torch.allclose(jit((x1, x2), edge_index), out1)\n        assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1)\n        assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj2.t()), out1, atol=1e-6)\n            assert torch.allclose(jit((x1, None), adj2.t()), out2, atol=1e-6)\n\n\ndef test_gat_conv_with_edge_attr():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 1]])\n    edge_weight = torch.randn(edge_index.size(1))\n    edge_attr = torch.randn(edge_index.size(1), 4)\n\n    conv = GATConv(8, 32, heads=2, edge_dim=1, fill_value=0.5)\n    out = conv(x, edge_index, edge_weight)\n    assert out.size() == (4, 64)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj1 = SparseTensor.from_edge_index(edge_index, edge_weight, (4, 4))\n        with pytest.raises(NotImplementedError):\n            assert torch.allclose(conv(x, adj1.t()), out)\n\n    conv = GATConv(8, 32, heads=2, edge_dim=1, fill_value='mean')\n    out = conv(x, edge_index, edge_weight)\n    assert out.size() == (4, 64)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        with pytest.raises(NotImplementedError):\n            assert torch.allclose(conv(x, adj1.t()), out)\n\n    conv = GATConv(8, 32, heads=2, edge_dim=4, fill_value=0.5)\n    out = conv(x, edge_index, edge_attr)\n    assert out.size() == (4, 64)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4))\n        with pytest.raises(NotImplementedError):\n            assert torch.allclose(conv(x, adj2.t()), out)\n\n    conv = GATConv(8, 32, heads=2, edge_dim=4, fill_value='mean')\n    out = conv(x, edge_index, edge_attr)\n    assert out.size() == (4, 64)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        with pytest.raises(NotImplementedError):\n            assert torch.allclose(conv(x, adj2.t()), out)\n\n\n@withDevice\ndef test_gat_conv_empty_edge_index(device):\n    x = torch.randn(0, 8, device=device)\n    edge_index = torch.empty(2, 0, dtype=torch.long, device=device)\n\n    conv = GATConv(8, 32, heads=2).to(device)\n    out = conv(x, edge_index)\n    assert out.size() == (0, 64)\n"
  },
  {
    "path": "test/nn/conv/test_gated_graph_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GatedGraphConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_gated_graph_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    value = torch.rand(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = GatedGraphConv(32, num_layers=3)\n    assert str(conv) == 'GatedGraphConv(32, num_layers=3)'\n    out1 = conv(x, edge_index)\n    assert out1.size() == (4, 32)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n    out2 = conv(x, edge_index, value)\n    assert out2.size() == (4, 32)\n    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n        assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)\n        assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_gatv2_conv.py",
    "content": "from typing import Tuple\n\nimport pytest\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GATv2Conv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import Adj, SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\n@pytest.mark.parametrize('residual', [False, True])\ndef test_gatv2_conv(residual):\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = GATv2Conv(8, 32, heads=2, residual=residual)\n    assert str(conv) == 'GATv2Conv(8, 32, heads=2)'\n    out = conv(x1, edge_index)\n    assert out.size() == (4, 64)\n    assert torch.allclose(conv(x1, edge_index), out)\n    assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tensor,\n                edge_index: Adj,\n            ) -> Tensor:\n                return self.conv(x, edge_index)\n\n        jit = torch.jit.script(MyModule())\n        assert torch.allclose(jit(x1, edge_index), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)\n\n    # Test `return_attention_weights`.\n    result = conv(x1, edge_index, return_attention_weights=True)\n    assert torch.allclose(result[0], out)\n    assert result[1][0].size() == (2, 7)\n    assert result[1][1].size() == (7, 2)\n    assert result[1][1].min() >= 0 and result[1][1].max() <= 1\n\n    result = conv(x1, adj1.t(), return_attention_weights=True)\n    assert torch.allclose(result[0], out, atol=1e-6)\n    assert result[1][0].size() == torch.Size([4, 4, 2])\n    assert result[1][0]._nnz() == 7\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        result = conv(x1, adj2.t(), return_attention_weights=True)\n        assert torch.allclose(result[0], out, atol=1e-6)\n        assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tensor,\n                edge_index: Tensor,\n            ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:\n                return self.conv(x, edge_index, return_attention_weights=True)\n\n        jit = torch.jit.script(MyModule())\n        result = jit(x1, edge_index)\n        assert torch.allclose(result[0], out)\n        assert result[1][0].size() == (2, 7)\n        assert result[1][1].size() == (7, 2)\n        assert result[1][1].min() >= 0 and result[1][1].max() <= 1\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n\n            class MyModule(torch.nn.Module):\n                def __init__(self):\n                    super().__init__()\n                    self.conv = conv\n\n                def forward(\n                    self,\n                    x: Tensor,\n                    edge_index: SparseTensor,\n                ) -> Tuple[Tensor, SparseTensor]:\n                    return self.conv(x, edge_index,\n                                     return_attention_weights=True)\n\n            jit = torch.jit.script(MyModule())\n            result = jit(x1, adj2.t())\n            assert torch.allclose(result[0], out, atol=1e-6)\n            assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    out = conv((x1, x2), edge_index)\n    assert out.size() == (2, 64)\n    assert torch.allclose(conv((x1, x2), edge_index), out)\n    assert torch.allclose(conv((x1, x2), adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tuple[Tensor, Tensor],\n                edge_index: Adj,\n            ) -> Tensor:\n                return self.conv(x, edge_index)\n\n        jit = torch.jit.script(MyModule())\n        assert torch.allclose(jit((x1, x2), edge_index), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6)\n\n\ndef test_gatv2_conv_with_edge_attr():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 1]])\n    edge_weight = torch.randn(edge_index.size(1))\n    edge_attr = torch.randn(edge_index.size(1), 4)\n\n    conv = GATv2Conv(8, 32, heads=2, edge_dim=1, fill_value=0.5)\n    out = conv(x, edge_index, edge_weight)\n    assert out.size() == (4, 64)\n\n    conv = GATv2Conv(8, 32, heads=2, edge_dim=1, fill_value='mean')\n    out = conv(x, edge_index, edge_weight)\n    assert out.size() == (4, 64)\n\n    conv = GATv2Conv(8, 32, heads=2, edge_dim=4, fill_value=0.5)\n    out = conv(x, edge_index, edge_attr)\n    assert out.size() == (4, 64)\n\n    conv = GATv2Conv(8, 32, heads=2, edge_dim=4, fill_value='mean')\n    out = conv(x, edge_index, edge_attr)\n    assert out.size() == (4, 64)\n"
  },
  {
    "path": "test/nn/conv/test_gcn2_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GCN2Conv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_gcn2_conv():\n    x = torch.randn(4, 16)\n    x_0 = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    value = torch.rand(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = GCN2Conv(16, alpha=0.2)\n    assert str(conv) == 'GCN2Conv(16, alpha=0.2, beta=1.0)'\n    out1 = conv(x, x_0, edge_index)\n    assert out1.size() == (4, 16)\n    assert torch.allclose(conv(x, x_0, adj1.t()), out1, atol=1e-6)\n    out2 = conv(x, x_0, edge_index, value)\n    assert out2.size() == (4, 16)\n    assert torch.allclose(conv(x, x_0, adj2.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, x_0, adj3.t()), out1, atol=1e-6)\n        assert torch.allclose(conv(x, x_0, adj4.t()), out2, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, x_0, edge_index), out1, atol=1e-6)\n        assert torch.allclose(jit(x, x_0, edge_index, value), out2, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, x_0, adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit(x, x_0, adj4.t()), out2, atol=1e-6)\n\n    conv.cached = True\n    conv(x, x_0, edge_index)\n    assert conv._cached_edge_index is not None\n    assert torch.allclose(conv(x, x_0, edge_index), out1, atol=1e-6)\n    assert torch.allclose(conv(x, x_0, adj1.t()), out1, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        conv(x, x_0, adj3.t())\n        assert conv._cached_adj_t is not None\n        assert torch.allclose(conv(x, x_0, adj3.t()), out1, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_gcn_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GCNConv\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import WITH_PT21, SparseTensor\nfrom torch_geometric.utils import to_torch_coo_tensor, to_torch_csc_tensor\n\n\ndef test_gcn_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    value = torch.rand(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = GCNConv(16, 32)\n    assert str(conv) == 'GCNConv(16, 32)'\n\n    out1 = conv(x, edge_index)\n    assert out1.size() == (4, 32)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n\n    out2 = conv(x, edge_index, value)\n    assert out2.size() == (4, 32)\n    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n        assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)\n        assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)\n\n    conv.cached = True\n    conv(x, edge_index)\n    assert conv._cached_edge_index is not None\n    assert torch.allclose(conv(x, edge_index), out1, atol=1e-6)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        conv(x, adj3.t())\n        assert conv._cached_adj_t is not None\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n\n\ndef test_gcn_conv_with_decomposed_layers():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    def hook(module, inputs):\n        assert inputs[0]['x_j'].size() == (10, 32 // module.decomposed_layers)\n\n    conv = GCNConv(16, 32)\n    conv.register_message_forward_pre_hook(hook)\n    out1 = conv(x, edge_index)\n\n    conv.decomposed_layers = 2\n    assert conv.propagate.__module__.endswith('message_passing')\n    out2 = conv(x, edge_index)\n    assert torch.allclose(out1, out2)\n\n    # TorchScript should still work since it relies on class methods\n    # (but without decomposition).\n    torch.jit.script(conv)\n\n    conv.decomposed_layers = 1\n    assert conv.propagate.__module__.endswith('GCNConv_propagate')\n\n\ndef test_gcn_conv_with_sparse_input_feature():\n    x = torch.sparse_coo_tensor(\n        indices=torch.tensor([[0, 0], [0, 1]]),\n        values=torch.tensor([1., 1.]),\n        size=torch.Size([4, 16]),\n    )\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    conv = GCNConv(16, 32)\n    assert conv(x, edge_index).size() == (4, 32)\n\n\ndef test_static_gcn_conv():\n    x = torch.randn(3, 4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    conv = GCNConv(16, 32)\n    out = conv(x, edge_index)\n    assert out.size() == (3, 4, 32)\n\n\ndef test_gcn_conv_error():\n    with pytest.raises(ValueError, match=\"does not support adding self-loops\"):\n        GCNConv(16, 32, normalize=False, add_self_loops=True)\n\n\ndef test_gcn_conv_flow():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0], [1, 2, 3]])\n\n    conv = GCNConv(16, 32, flow=\"source_to_target\")\n    out1 = conv(x, edge_index)\n    conv.flow = \"target_to_source\"\n    out2 = conv(x, edge_index.flip(0))\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n\n@pytest.mark.parametrize('requires_grad', [False, True])\n@pytest.mark.parametrize('layout', [torch.sparse_coo, torch.sparse_csr])\ndef test_gcn_norm_gradient(requires_grad, layout):\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    edge_weight = torch.ones(edge_index.size(1), requires_grad=requires_grad)\n    adj = to_torch_coo_tensor(edge_index, edge_weight)\n    if layout == torch.sparse_csr:\n        adj = adj.to_sparse_csr()\n\n    # TODO Sparse CSR tensor doesn't inherit `requires_grad` for PyTorch < 2.1.\n    if layout == torch.sparse_csr and not WITH_PT21:\n        assert not gcn_norm(adj)[0].requires_grad\n    else:\n        assert adj.requires_grad == gcn_norm(adj)[0].requires_grad\n"
  },
  {
    "path": "test/nn/conv/test_gen_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GENConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_coo_tensor\n\n\n@pytest.mark.parametrize('aggr', [\n    'softmax',\n    'powermean',\n    ['softmax', 'powermean'],\n])\ndef test_gen_conv(aggr):\n    x1 = torch.randn(4, 16)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.randn(edge_index.size(1), 16)\n    adj1 = to_torch_coo_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_coo_tensor(edge_index, value, size=(4, 4))\n\n    conv = GENConv(16, 32, aggr, edge_dim=16, msg_norm=True)\n    assert str(conv) == f'GENConv(16, 32, aggr={aggr})'\n    out1 = conv(x1, edge_index)\n    assert out1.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out1)\n    assert torch.allclose(conv(x1, adj1.t().coalesce()), out1)\n\n    out2 = conv(x1, edge_index, value)\n    assert out2.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out2)\n    # t() expects a tensor with <= 2 sparse and 0 dense dimensions\n    assert torch.allclose(conv(x1, adj2.transpose(1, 0).coalesce()), out2)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x1, adj3.t()), out1, atol=1e-4)\n        assert torch.allclose(conv(x1, adj4.t()), out2, atol=1e-4)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index), out1, atol=1e-4)\n        assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out1,\n                              atol=1e-4)\n        assert torch.allclose(jit(x1, edge_index, value), out2, atol=1e-4)\n        assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out2,\n                              atol=1e-4)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj3.t()), out1, atol=1e-4)\n            assert torch.allclose(jit(x1, adj4.t()), out2, atol=1e-4)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_coo_tensor(edge_index, size=(4, 2))\n    adj2 = to_torch_coo_tensor(edge_index, value, size=(4, 2))\n\n    out1 = conv((x1, x2), edge_index)\n    assert out1.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1)\n    assert torch.allclose(conv((x1, x2), adj1.t().coalesce()), out1)\n\n    out2 = conv((x1, x2), edge_index, value)\n    assert out2.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out2)\n    assert torch.allclose(conv((x1, x2),\n                               adj2.transpose(1, 0).coalesce()), out2)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 2))\n        assert torch.allclose(conv((x1, x2), adj3.t()), out1, atol=1e-4)\n        assert torch.allclose(conv((x1, x2), adj4.t()), out2, atol=1e-4)\n\n    if is_full_test():\n        assert torch.allclose(jit((x1, x2), edge_index), out1, atol=1e-4)\n        assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1,\n                              atol=1e-4)\n        assert torch.allclose(jit((x1, x2), edge_index, value), out2,\n                              atol=1e-4)\n        assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out2,\n                              atol=1e-4)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj3.t()), out1, atol=1e-4)\n            assert torch.allclose(jit((x1, x2), adj4.t()), out2, atol=1e-4)\n\n    # Test bipartite message passing with unequal feature dimensions:\n    conv.reset_parameters()\n    assert float(conv.msg_norm.scale) == 1\n\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n\n    conv = GENConv((8, 16), 32, aggr)\n    assert str(conv) == f'GENConv((8, 16), 32, aggr={aggr})'\n\n    out1 = conv((x1, x2), edge_index)\n    assert out1.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1)\n    assert torch.allclose(conv((x1, x2), adj1.t().coalesce()), out1)\n\n    out2 = conv((x1, None), edge_index, size=(4, 2))\n    assert out2.size() == (2, 32)\n    assert torch.allclose(conv((x1, None), adj1.t().coalesce()), out2)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv((x1, x2), adj3.t()), out1, atol=1e-4)\n        assert torch.allclose(conv((x1, None), adj3.t()), out2, atol=1e-4)\n\n    # Test lazy initialization:\n    conv = GENConv((-1, -1), 32, aggr, edge_dim=-1)\n    assert str(conv) == f'GENConv((-1, -1), 32, aggr={aggr})'\n    out1 = conv((x1, x2), edge_index, value)\n    assert out1.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, value, size=(4, 2)), out1)\n    assert torch.allclose(conv((x1, x2),\n                               adj2.transpose(1, 0).coalesce()), out1)\n\n    out2 = conv((x1, None), edge_index, value, size=(4, 2))\n    assert out2.size() == (2, 32)\n    assert torch.allclose(conv((x1, None),\n                               adj2.transpose(1, 0).coalesce()), out2)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv((x1, x2), adj4.t()), out1, atol=1e-4)\n        assert torch.allclose(conv((x1, None), adj4.t()), out2, atol=1e-4)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index, value), out1,\n                              atol=1e-4)\n        assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),\n                              out1, atol=1e-4)\n        assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),\n                              out2, atol=1e-4)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj4.t()), out1, atol=1e-4)\n            assert torch.allclose(jit((x1, None), adj4.t()), out2, atol=1e-4)\n"
  },
  {
    "path": "test/nn/conv/test_general_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GeneralConv\nfrom torch_geometric.typing import SparseTensor\n\n\n@pytest.mark.parametrize('kwargs', [\n    dict(),\n    dict(skip_linear=True),\n    dict(directed_msg=False),\n    dict(heads=3),\n    dict(attention=True),\n    dict(heads=3, attention=True),\n    dict(heads=3, attention=True, attention_type='dot_product'),\n    dict(l2_normalize=True),\n])\ndef test_general_conv(kwargs):\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_attr = torch.randn(edge_index.size(1), 16)\n\n    conv = GeneralConv(8, 32, **kwargs)\n    assert str(conv) == 'GeneralConv(8, 32)'\n\n    out = conv(x, edge_index)\n    assert out.size() == (4, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)\n\n    conv = GeneralConv(8, 32, in_edge_channels=16, **kwargs)\n    assert str(conv) == 'GeneralConv(8, 32)'\n\n    out = conv(x, edge_index, edge_attr)\n    assert out.size() == (4, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4))\n        assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_gin_conv.py",
    "content": "import torch\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GINConv, GINEConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_gin_conv():\n    x1 = torch.randn(4, 16)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32))\n    conv = GINConv(nn, train_eps=True)\n    assert str(conv) == (\n        'GINConv(nn=Sequential(\\n'\n        '  (0): Linear(in_features=16, out_features=32, bias=True)\\n'\n        '  (1): ReLU()\\n'\n        '  (2): Linear(in_features=32, out_features=32, bias=True)\\n'\n        '))')\n    out = conv(x1, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out, atol=1e-6)\n    assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index), out, atol=1e-6)\n        assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    out1 = conv((x1, x2), edge_index)\n    assert out1.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6)\n    assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6)\n\n    out2 = conv((x1, None), edge_index, (4, 2))\n    assert out2.size() == (2, 32)\n    assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)\n        assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)\n\n    if is_full_test():\n        assert torch.allclose(jit((x1, x2), edge_index), out1)\n        assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1)\n        assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj2.t()), out1)\n            assert torch.allclose(jit((x1, None), adj2.t()), out2)\n\n\ndef test_gine_conv():\n    x1 = torch.randn(4, 16)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.randn(edge_index.size(1), 16)\n\n    nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32))\n    conv = GINEConv(nn, train_eps=True)\n    assert str(conv) == (\n        'GINEConv(nn=Sequential(\\n'\n        '  (0): Linear(in_features=16, out_features=32, bias=True)\\n'\n        '  (1): ReLU()\\n'\n        '  (2): Linear(in_features=32, out_features=32, bias=True)\\n'\n        '))')\n    out = conv(x1, edge_index, value)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x1, adj.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index, value), out)\n        assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj.t()), out)\n\n    # Test bipartite message passing:\n    out1 = conv((x1, x2), edge_index, value)\n    assert out1.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)\n\n    out2 = conv((x1, None), edge_index, value, (4, 2))\n    assert out2.size() == (2, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 2))\n        assert torch.allclose(conv((x1, x2), adj.t()), out1)\n        assert torch.allclose(conv((x1, None), adj.t()), out2)\n\n    if is_full_test():\n        assert torch.allclose(jit((x1, x2), edge_index, value), out1)\n        assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),\n                              out1)\n        assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),\n                              out2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj.t()), out1)\n            assert torch.allclose(jit((x1, None), adj.t()), out2)\n\n\ndef test_gine_conv_edge_dim():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_attr = torch.randn(edge_index.size(1), 8)\n\n    nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32))\n    conv = GINEConv(nn, train_eps=True, edge_dim=8)\n    out = conv(x, edge_index, edge_attr)\n    assert out.size() == (4, 32)\n\n    nn = Lin(16, 32)\n    conv = GINEConv(nn, train_eps=True, edge_dim=8)\n    out = conv(x, edge_index, edge_attr)\n    assert out.size() == (4, 32)\n\n\ndef test_static_gin_conv():\n    x = torch.randn(3, 4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32))\n    conv = GINConv(nn, train_eps=True)\n    out = conv(x, edge_index)\n    assert out.size() == (3, 4, 32)\n\n\ndef test_static_gine_conv():\n    x = torch.randn(3, 4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    edge_attr = torch.randn(edge_index.size(1), 16)\n\n    nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32))\n    conv = GINEConv(nn, train_eps=True)\n    out = conv(x, edge_index, edge_attr)\n    assert out.size() == (3, 4, 32)\n"
  },
  {
    "path": "test/nn/conv/test_gmm_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GMMConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_coo_tensor\n\n\n@pytest.mark.parametrize('separate_gaussians', [True, False])\ndef test_gmm_conv(separate_gaussians):\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.rand(edge_index.size(1), 3)\n    adj1 = to_torch_coo_tensor(edge_index, value, size=(4, 4))\n\n    conv = GMMConv(8, 32, dim=3, kernel_size=25,\n                   separate_gaussians=separate_gaussians)\n    assert str(conv) == 'GMMConv(8, 32, dim=3)'\n    out = conv(x1, edge_index, value)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out)\n    # t() expects a tensor with <= 2 sparse and 0 dense dimensions\n    assert torch.allclose(conv(x1, adj1.transpose(0, 1).coalesce()), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index, value), out)\n        assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_coo_tensor(edge_index, value, size=(4, 2))\n\n    conv = GMMConv((8, 16), 32, dim=3, kernel_size=5,\n                   separate_gaussians=separate_gaussians)\n    assert str(conv) == 'GMMConv((8, 16), 32, dim=3)'\n\n    out1 = conv((x1, x2), edge_index, value)\n    assert out1.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)\n    assert torch.allclose(conv((x1, x2),\n                               adj1.transpose(0, 1).coalesce()), out1)\n\n    out2 = conv((x1, None), edge_index, value, (4, 2))\n    assert out2.size() == (2, 32)\n    assert torch.allclose(conv((x1, None),\n                               adj1.transpose(0, 1).coalesce()), out2)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 2))\n        assert torch.allclose(conv((x1, x2), adj2.t()), out1)\n        assert torch.allclose(conv((x1, None), adj2.t()), out2)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index, value), out1)\n        assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),\n                              out1)\n        assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),\n                              out2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj2.t()), out1)\n            assert torch.allclose(jit((x1, None), adj2.t()), out2)\n\n\n@pytest.mark.parametrize('separate_gaussians', [True, False])\ndef test_lazy_gmm_conv(separate_gaussians):\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.rand(edge_index.size(1), 3)\n\n    conv = GMMConv(-1, 32, dim=3, kernel_size=25,\n                   separate_gaussians=separate_gaussians)\n    assert str(conv) == 'GMMConv(-1, 32, dim=3)'\n    out = conv(x1, edge_index, value)\n    assert out.size() == (4, 32)\n\n    conv = GMMConv((-1, -1), 32, dim=3, kernel_size=25,\n                   separate_gaussians=separate_gaussians)\n    assert str(conv) == 'GMMConv((-1, -1), 32, dim=3)'\n    out = conv((x1, x2), edge_index, value)\n    assert out.size() == (2, 32)\n"
  },
  {
    "path": "test/nn/conv/test_gps_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GPSConv, SAGEConv\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\n@pytest.mark.parametrize('attn_type', ['multihead', 'performer'])\n@pytest.mark.parametrize('norm', [None, 'batch_norm', 'layer_norm'])\ndef test_gps_conv(norm, attn_type):\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]])\n    batch = torch.tensor([0, 0, 1, 1])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = GPSConv(16, conv=SAGEConv(16, 16), heads=4, norm=norm,\n                   attn_type=attn_type)\n    conv.reset_parameters()\n    assert str(conv) == (f'GPSConv(16, conv=SAGEConv(16, 16, aggr=mean), '\n                         f'heads=4, attn_type={attn_type})')\n\n    out = conv(x, edge_index)\n    assert out.size() == (4, 16)\n    assert torch.allclose(conv(x, adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj2.t()), out, atol=1e-6)\n\n    out = conv(x, edge_index, batch)\n    assert out.size() == (4, 16)\n    assert torch.allclose(conv(x, adj1.t(), batch), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x, adj2.t(), batch), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_graph_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.nn import GraphConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_graph_conv():\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.randn(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = GraphConv(8, 32)\n    assert str(conv) == 'GraphConv(8, 32)'\n    out1 = conv(x1, edge_index)\n    assert out1.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out1, atol=1e-6)\n    assert torch.allclose(conv(x1, adj1.t()), out1, atol=1e-6)\n\n    assert conv(\n        x1,\n        EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 4)),\n    ).allclose(out1, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj3.t()), out1, atol=1e-6)\n\n    out2 = conv(x1, edge_index, value)\n    assert out2.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out2,\n                          atol=1e-6)\n    assert torch.allclose(conv(x1, adj2.t()), out2, atol=1e-6)\n\n    assert conv(\n        x1,\n        EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 4)),\n        value,\n    ).allclose(out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x1, adj4.t()), out2, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index), out1)\n        assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out1)\n        assert torch.allclose(jit(x1, edge_index, value), out2)\n        assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit(x1, adj4.t()), out2, atol=1e-6)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 2))\n\n    conv = GraphConv((8, 16), 32)\n    assert str(conv) == 'GraphConv((8, 16), 32)'\n    out1 = conv((x1, x2), edge_index)\n    assert out1.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out1)\n    assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6)\n\n    assert conv(\n        (x1, x2),\n        EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)),\n    ).allclose(out1, atol=1e-6)\n\n    out2 = conv((x1, None), edge_index, size=(4, 2))\n    assert out2.size() == (2, 32)\n    assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6)\n\n    assert conv(\n        (x1, None),\n        EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)),\n    ).allclose(out2, atol=1e-6)\n\n    out3 = conv((x1, x2), edge_index, value)\n    assert out3.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out3)\n    assert torch.allclose(conv((x1, x2), adj2.t()), out3, atol=1e-6)\n\n    assert conv(\n        (x1, x2),\n        EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)),\n        value,\n    ).allclose(out3, atol=1e-6)\n\n    out4 = conv((x1, None), edge_index, value, size=(4, 2))\n    assert out4.size() == (2, 32)\n    assert torch.allclose(conv((x1, None), adj2.t()), out4, atol=1e-6)\n\n    assert conv(\n        (x1, None),\n        EdgeIndex(edge_index, sort_order='col', sparse_size=(4, 2)),\n        value,\n    ).allclose(out4, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 2))\n        assert torch.allclose(conv((x1, x2), adj3.t()), out1, atol=1e-6)\n        assert torch.allclose(conv((x1, None), adj3.t()), out2, atol=1e-6)\n        assert torch.allclose(conv((x1, x2), adj3.t()), out1, atol=1e-6)\n        assert torch.allclose(conv((x1, None), adj4.t()), out4, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index), out1)\n        assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1)\n        assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2)\n        assert torch.allclose(jit((x1, x2), edge_index, value), out3)\n        assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out3)\n        assert torch.allclose(jit((x1, None), edge_index, value, (4, 2)), out4)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit((x1, None), adj3.t()), out2, atol=1e-6)\n            assert torch.allclose(jit((x1, x2), adj4.t()), out3, atol=1e-6)\n            assert torch.allclose(jit((x1, None), adj4.t()), out4, atol=1e-6)\n\n\nclass EdgeGraphConv(GraphConv):\n    def message(self, x_j, edge_weight):\n        return edge_weight.view(-1, 1) * x_j\n\n\ndef test_inheritance():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_weight = torch.rand(4)\n\n    conv = EdgeGraphConv(8, 16)\n    assert conv(x, edge_index, edge_weight).size() == (4, 16)\n"
  },
  {
    "path": "test/nn/conv/test_gravnet_conv.py",
    "content": "import torch\n\nfrom torch_geometric.nn import GravNetConv\nfrom torch_geometric.testing import is_full_test, withPackage\n\n\n@withPackage('torch_cluster')\ndef test_gravnet_conv():\n    x1 = torch.randn(8, 16)\n    x2 = torch.randn(4, 16)\n    batch1 = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])\n    batch2 = torch.tensor([0, 0, 1, 1])\n\n    conv = GravNetConv(16, 32, space_dimensions=4, propagate_dimensions=8, k=2)\n    assert str(conv) == 'GravNetConv(16, 32, k=2)'\n\n    out11 = conv(x1)\n    assert out11.size() == (8, 32)\n\n    out12 = conv(x1, batch1)\n    assert out12.size() == (8, 32)\n\n    out21 = conv((x1, x2))\n    assert out21.size() == (4, 32)\n\n    out22 = conv((x1, x2), (batch1, batch2))\n    assert out22.size() == (4, 32)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1), out11)\n        assert torch.allclose(jit(x1, batch1), out12)\n\n        assert torch.allclose(jit((x1, x2)), out21)\n        assert torch.allclose(jit((x1, x2), (batch1, batch2)), out22)\n"
  },
  {
    "path": "test/nn/conv/test_han_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import HANConv\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import coalesce, to_torch_csc_tensor\n\n\ndef test_han_conv():\n    x_dict = {\n        'author': torch.randn(6, 16),\n        'paper': torch.randn(5, 12),\n        'term': torch.randn(4, 3)\n    }\n    edge_index1 = coalesce(torch.randint(0, 6, (2, 7)))\n    edge_index2 = coalesce(torch.randint(0, 5, (2, 4)))\n    edge_index3 = coalesce(torch.randint(0, 3, (2, 5)))\n    edge_index_dict = {\n        ('author', 'metapath0', 'author'): edge_index1,\n        ('paper', 'metapath1', 'paper'): edge_index2,\n        ('paper', 'metapath2', 'paper'): edge_index3,\n    }\n\n    adj_t_dict1 = {}\n    for edge_type, edge_index in edge_index_dict.items():\n        src_type, _, dst_type = edge_type\n        adj_t_dict1[edge_type] = to_torch_csc_tensor(\n            edge_index,\n            size=(x_dict[src_type].size(0), x_dict[dst_type].size(0)),\n        ).t()\n\n    metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))\n    in_channels = {'author': 16, 'paper': 12, 'term': 3}\n\n    conv = HANConv(in_channels, 16, metadata, heads=2)\n    assert str(conv) == 'HANConv(16, heads=2)'\n    out_dict1 = conv(x_dict, edge_index_dict)\n    assert len(out_dict1) == 3\n    assert out_dict1['author'].size() == (6, 16)\n    assert out_dict1['paper'].size() == (5, 16)\n    assert out_dict1['term'] is None\n    del out_dict1['term']\n    del x_dict['term']\n\n    out_dict2 = conv(x_dict, adj_t_dict1)\n    assert len(out_dict1) == len(out_dict2)\n    for key in out_dict1.keys():\n        assert torch.allclose(out_dict1[key], out_dict2[key], atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj_t_dict2 = {}\n        for edge_type, edge_index in edge_index_dict.items():\n            adj_t_dict2[edge_type] = SparseTensor.from_edge_index(\n                edge_index,\n                sparse_sizes=adj_t_dict1[edge_type].size()[::-1],\n            ).t()\n        out_dict3 = conv(x_dict, adj_t_dict2)\n        assert len(out_dict1) == len(out_dict3)\n        for key in out_dict3.keys():\n            assert torch.allclose(out_dict1[key], out_dict3[key], atol=1e-6)\n\n    # Test non-zero dropout:\n    conv = HANConv(in_channels, 16, metadata, heads=2, dropout=0.1)\n    assert str(conv) == 'HANConv(16, heads=2)'\n    out_dict1 = conv(x_dict, edge_index_dict)\n    assert len(out_dict1) == 2\n    assert out_dict1['author'].size() == (6, 16)\n    assert out_dict1['paper'].size() == (5, 16)\n\n\ndef test_han_conv_lazy():\n    x_dict = {\n        'author': torch.randn(6, 16),\n        'paper': torch.randn(5, 12),\n    }\n    edge_index1 = coalesce(torch.randint(0, 6, (2, 8)))\n    edge_index2 = coalesce(torch.randint(0, 5, (2, 6)))\n    edge_index_dict = {\n        ('author', 'to', 'author'): edge_index1,\n        ('paper', 'to', 'paper'): edge_index2,\n    }\n\n    adj_t_dict1 = {}\n    for edge_type, edge_index in edge_index_dict.items():\n        src_type, _, dst_type = edge_type\n        adj_t_dict1[edge_type] = to_torch_csc_tensor(\n            edge_index,\n            size=(x_dict[src_type].size(0), x_dict[dst_type].size(0)),\n        ).t()\n\n    metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))\n    conv = HANConv(-1, 16, metadata, heads=2)\n    assert str(conv) == 'HANConv(16, heads=2)'\n    out_dict1 = conv(x_dict, edge_index_dict)\n    assert len(out_dict1) == 2\n    assert out_dict1['author'].size() == (6, 16)\n    assert out_dict1['paper'].size() == (5, 16)\n\n    out_dict2 = conv(x_dict, adj_t_dict1)\n    assert len(out_dict1) == len(out_dict2)\n    for key in out_dict1.keys():\n        assert torch.allclose(out_dict1[key], out_dict2[key], atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj_t_dict2 = {}\n        for edge_type, edge_index in edge_index_dict.items():\n            adj_t_dict2[edge_type] = SparseTensor.from_edge_index(\n                edge_index,\n                sparse_sizes=adj_t_dict1[edge_type].size()[::-1],\n            ).t()\n        out_dict3 = conv(x_dict, adj_t_dict2)\n        assert len(out_dict1) == len(out_dict3)\n        for key in out_dict1.keys():\n            assert torch.allclose(out_dict1[key], out_dict3[key], atol=1e-6)\n\n\ndef test_han_conv_empty_tensor():\n    x_dict = {\n        'author': torch.randn(6, 16),\n        'paper': torch.empty(0, 12),\n    }\n    edge_index_dict = {\n        ('paper', 'to', 'author'): torch.empty((2, 0), dtype=torch.long),\n        ('author', 'to', 'paper'): torch.empty((2, 0), dtype=torch.long),\n        ('paper', 'to', 'paper'): torch.empty((2, 0), dtype=torch.long),\n    }\n\n    metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))\n    in_channels = {'author': 16, 'paper': 12}\n    conv = HANConv(in_channels, 16, metadata, heads=2)\n\n    out_dict = conv(x_dict, edge_index_dict)\n    assert len(out_dict) == 2\n    assert out_dict['author'].size() == (6, 16)\n    assert torch.all(out_dict['author'] == 0)\n    assert out_dict['paper'].size() == (0, 16)\n"
  },
  {
    "path": "test/nn/conv/test_heat_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import HEATConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\n\n\n@pytest.mark.parametrize('concat', [True, False])\ndef test_heat_conv(concat):\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_attr = torch.randn((4, 2))\n    node_type = torch.tensor([0, 0, 1, 2])\n    edge_type = torch.tensor([0, 2, 1, 2])\n\n    conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3,\n                    num_edge_types=3, edge_type_emb_dim=5, edge_dim=2,\n                    edge_attr_emb_dim=6, heads=2, concat=concat)\n    assert str(conv) == 'HEATConv(8, 16, heads=2)'\n\n    out = conv(x, edge_index, node_type, edge_type, edge_attr)\n    assert out.size() == (4, 32 if concat else 16)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4))\n        assert torch.allclose(conv(x, adj.t(), node_type, edge_type), out,\n                              atol=1e-5)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(\n            jit(x, edge_index, node_type, edge_type, edge_attr), out,\n            atol=1e-5)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj.t(), node_type, edge_type), out,\n                                  atol=1e-5)\n"
  },
  {
    "path": "test/nn/conv/test_hetero_conv.py",
    "content": "import random\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.datasets import FakeHeteroDataset\nfrom torch_geometric.nn import (\n    GATConv,\n    GCN2Conv,\n    GCNConv,\n    HeteroConv,\n    Linear,\n    MessagePassing,\n    SAGEConv,\n)\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import (\n    get_random_edge_index,\n    onlyLinux,\n    withDevice,\n    withPackage,\n)\n\n\n@pytest.mark.parametrize('aggr', ['sum', 'mean', 'min', 'max', 'cat', None])\ndef test_hetero_conv(aggr):\n    data = HeteroData()\n    data['paper'].x = torch.randn(50, 32)\n    data['author'].x = torch.randn(30, 64)\n    data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200)\n    data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100)\n    data['paper', 'author'].edge_attr = torch.randn(100, 3)\n    data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100)\n    data['paper', 'paper'].edge_weight = torch.rand(200)\n\n    # Unspecified edge types should be ignored:\n    data['author', 'author'].edge_index = get_random_edge_index(30, 30, 100)\n\n    conv = HeteroConv(\n        {\n            ('paper', 'to', 'paper'):\n            GCNConv(-1, 64),\n            ('author', 'to', 'paper'):\n            SAGEConv((-1, -1), 64),\n            ('paper', 'to', 'author'):\n            GATConv((-1, -1), 64, edge_dim=3, add_self_loops=False),\n        },\n        aggr=aggr,\n    )\n\n    assert len(list(conv.parameters())) > 0\n    assert str(conv) == 'HeteroConv(num_relations=3)'\n\n    out_dict = conv(\n        data.x_dict,\n        data.edge_index_dict,\n        data.edge_attr_dict,\n        edge_weight_dict=data.edge_weight_dict,\n    )\n\n    assert len(out_dict) == 2\n    if aggr == 'cat':\n        assert out_dict['paper'].size() == (50, 128)\n        assert out_dict['author'].size() == (30, 64)\n    elif aggr is not None:\n        assert out_dict['paper'].size() == (50, 64)\n        assert out_dict['author'].size() == (30, 64)\n    else:\n        assert out_dict['paper'].size() == (50, 2, 64)\n        assert out_dict['author'].size() == (30, 1, 64)\n\n\ndef test_gcn2_hetero_conv():\n    data = HeteroData()\n    data['paper'].x = torch.randn(50, 32)\n    data['author'].x = torch.randn(30, 64)\n    data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200)\n    data['author', 'author'].edge_index = get_random_edge_index(30, 30, 100)\n    data['paper', 'paper'].edge_weight = torch.rand(200)\n\n    conv = HeteroConv({\n        ('paper', 'to', 'paper'): GCN2Conv(32, alpha=0.1),\n        ('author', 'to', 'author'): GCN2Conv(64, alpha=0.2),\n    })\n\n    out_dict = conv(\n        data.x_dict,\n        data.x_dict,\n        data.edge_index_dict,\n        edge_weight_dict=data.edge_weight_dict,\n    )\n\n    assert len(out_dict) == 2\n    assert out_dict['paper'].size() == (50, 32)\n    assert out_dict['author'].size() == (30, 64)\n\n\nclass CustomConv(MessagePassing):\n    def __init__(self, out_channels):\n        super().__init__(aggr='add')\n        self.lin = Linear(-1, out_channels)\n\n    def forward(self, x, edge_index, y, z):\n        return self.propagate(edge_index, x=x, y=y, z=z)\n\n    def message(self, x_j, y_j, z_j):\n        return self.lin(torch.cat([x_j, y_j, z_j], dim=-1))\n\n\ndef test_hetero_conv_with_custom_conv():\n    data = HeteroData()\n    data['paper'].x = torch.randn(50, 32)\n    data['paper'].y = torch.randn(50, 3)\n    data['paper'].z = torch.randn(50, 3)\n    data['author'].x = torch.randn(30, 64)\n    data['author'].y = torch.randn(30, 3)\n    data['author'].z = torch.randn(30, 3)\n    data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200)\n    data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100)\n    data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100)\n\n    conv = HeteroConv({key: CustomConv(64) for key in data.edge_types})\n    # Test node `args_dict` and `kwargs_dict` with `y_dict` and `z_dict`:\n    out_dict = conv(\n        data.x_dict,\n        data.edge_index_dict,\n        data.y_dict,\n        z_dict=data.z_dict,\n    )\n    assert len(out_dict) == 2\n    assert out_dict['paper'].size() == (50, 64)\n    assert out_dict['author'].size() == (30, 64)\n\n\nclass MessagePassingLoops(MessagePassing):\n    def __init__(self):\n        super().__init__()\n        self.add_self_loops = True\n\n\ndef test_hetero_conv_self_loop_error():\n    HeteroConv({('a', 'to', 'a'): MessagePassingLoops()})\n    with pytest.raises(ValueError, match=\"incorrect message passing\"):\n        HeteroConv({('a', 'to', 'b'): MessagePassingLoops()})\n\n\ndef test_hetero_conv_with_dot_syntax_node_types():\n    data = HeteroData()\n    data['src.paper'].x = torch.randn(50, 32)\n    data['author'].x = torch.randn(30, 64)\n    edge_index = get_random_edge_index(50, 50, 200)\n    data['src.paper', 'src.paper'].edge_index = edge_index\n    data['src.paper', 'author'].edge_index = get_random_edge_index(50, 30, 100)\n    data['author', 'src.paper'].edge_index = get_random_edge_index(30, 50, 100)\n    data['src.paper', 'src.paper'].edge_weight = torch.rand(200)\n\n    conv = HeteroConv({\n        ('src.paper', 'to', 'src.paper'):\n        GCNConv(-1, 64),\n        ('author', 'to', 'src.paper'):\n        SAGEConv((-1, -1), 64),\n        ('src.paper', 'to', 'author'):\n        GATConv((-1, -1), 64, add_self_loops=False),\n    })\n\n    assert len(list(conv.parameters())) > 0\n    assert str(conv) == 'HeteroConv(num_relations=3)'\n\n    out_dict = conv(\n        data.x_dict,\n        data.edge_index_dict,\n        edge_weight_dict=data.edge_weight_dict,\n    )\n\n    assert len(out_dict) == 2\n    assert out_dict['src.paper'].size() == (50, 64)\n    assert out_dict['author'].size() == (30, 64)\n\n\n@withDevice\n@onlyLinux\n@withPackage('torch>=2.1.0')\ndef test_compile_hetero_conv_graph_breaks(device):\n    import torch._dynamo as dynamo\n\n    data = HeteroData()\n    data['a'].x = torch.randn(50, 16, device=device)\n    data['b'].x = torch.randn(50, 16, device=device)\n    edge_index = get_random_edge_index(50, 50, 100, device=device)\n    data['a', 'to', 'b'].edge_index = edge_index\n    data['b', 'to', 'a'].edge_index = edge_index.flip([0])\n\n    conv = HeteroConv({\n        ('a', 'to', 'b'): SAGEConv(16, 32),\n        ('b', 'to', 'a'): SAGEConv(16, 32),\n    }).to(device)\n\n    explanation = dynamo.explain(conv)(data.x_dict, data.edge_index_dict)\n    assert explanation.graph_break_count == 0\n\n    compiled_conv = torch.compile(conv)\n\n    expected = conv(data.x_dict, data.edge_index_dict)\n    out = compiled_conv(data.x_dict, data.edge_index_dict)\n    assert len(out) == len(expected)\n    for key in expected.keys():\n        assert torch.allclose(out[key], expected[key], atol=1e-6)\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    args = parser.parse_args()\n\n    dataset = FakeHeteroDataset(num_graphs=10).to(args.device)\n\n    def gen_args():\n        data = dataset[random.randrange(len(dataset))]\n        return data.x_dict, data.edge_index_dict\n\n    class HeteroGNN(torch.nn.Module):\n        def __init__(self, channels: int = 32, num_layers: int = 2):\n            super().__init__()\n            self.convs = torch.nn.ModuleList()\n\n            conv = HeteroConv({\n                edge_type:\n                SAGEConv(\n                    in_channels=(\n                        dataset.num_features[edge_type[0]],\n                        dataset.num_features[edge_type[-1]],\n                    ),\n                    out_channels=channels,\n                )\n                for edge_type in dataset[0].edge_types\n            })\n            self.convs.append(conv)\n\n            for _ in range(num_layers - 1):\n                conv = HeteroConv({\n                    edge_type:\n                    SAGEConv((channels, channels), channels)\n                    for edge_type in dataset[0].edge_types\n                })\n                self.convs.append(conv)\n\n            self.lin = Linear(channels, 1)\n\n        def forward(self, x_dict, edge_index_dict):\n            for conv in self.convs:\n                x_dict = conv(x_dict, edge_index_dict)\n                x_dict = {key: x.relu() for key, x in x_dict.items()}\n            return self.lin(x_dict['v0'])\n\n    model = HeteroGNN().to(args.device)\n    compiled_model = torch.compile(model)\n\n    benchmark(\n        funcs=[model, compiled_model],\n        func_names=['Vanilla', 'Compiled'],\n        args=gen_args,\n        num_steps=50 if args.device == 'cpu' else 500,\n        num_warmups=10 if args.device == 'cpu' else 100,\n        backward=args.backward,\n    )\n"
  },
  {
    "path": "test/nn/conv/test_hgt_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.nn import HGTConv\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import get_random_edge_index\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import coalesce, to_torch_csc_tensor\n\n\ndef test_hgt_conv_same_dimensions():\n    x_dict = {\n        'author': torch.randn(4, 16),\n        'paper': torch.randn(6, 16),\n    }\n    edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20))\n\n    edge_index_dict = {\n        ('author', 'writes', 'paper'): edge_index,\n        ('paper', 'written_by', 'author'): edge_index.flip([0]),\n    }\n\n    adj_t_dict1 = {}\n    for edge_type, edge_index in edge_index_dict.items():\n        src_type, _, dst_type = edge_type\n        adj_t_dict1[edge_type] = to_torch_csc_tensor(\n            edge_index,\n            size=(x_dict[src_type].size(0), x_dict[dst_type].size(0)),\n        ).t()\n\n    metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))\n\n    conv = HGTConv(16, 16, metadata, heads=2)\n    assert str(conv) == 'HGTConv(-1, 16, heads=2)'\n    out_dict1 = conv(x_dict, edge_index_dict)\n    assert len(out_dict1) == 2\n    assert out_dict1['author'].size() == (4, 16)\n    assert out_dict1['paper'].size() == (6, 16)\n\n    out_dict2 = conv(x_dict, adj_t_dict1)\n    assert len(out_dict1) == len(out_dict2)\n    for key in out_dict1.keys():\n        assert torch.allclose(out_dict1[key], out_dict2[key], atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj_t_dict2 = {}\n        for edge_type, edge_index in edge_index_dict.items():\n            adj_t_dict2[edge_type] = SparseTensor.from_edge_index(\n                edge_index,\n                sparse_sizes=adj_t_dict1[edge_type].size()[::-1],\n            ).t()\n        out_dict3 = conv(x_dict, adj_t_dict2)\n        assert len(out_dict1) == len(out_dict3)\n        for key in out_dict1.keys():\n            assert torch.allclose(out_dict1[key], out_dict3[key], atol=1e-6)\n\n    # TODO: Test JIT functionality. We need to wait on this one until PyTorch\n    # allows indexing `ParameterDict` mappings :(\n\n\ndef test_hgt_conv_different_dimensions():\n    x_dict = {\n        'author': torch.randn(4, 16),\n        'paper': torch.randn(6, 32),\n    }\n    edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20))\n\n    edge_index_dict = {\n        ('author', 'writes', 'paper'): edge_index,\n        ('paper', 'written_by', 'author'): edge_index.flip([0]),\n    }\n\n    adj_t_dict1 = {}\n    for edge_type, edge_index in edge_index_dict.items():\n        src_type, _, dst_type = edge_type\n        adj_t_dict1[edge_type] = to_torch_csc_tensor(\n            edge_index,\n            size=(x_dict[src_type].size(0), x_dict[dst_type].size(0)),\n        ).t()\n\n    metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))\n\n    conv = HGTConv(in_channels={\n        'author': 16,\n        'paper': 32\n    }, out_channels=32, metadata=metadata, heads=2)\n    assert str(conv) == 'HGTConv(-1, 32, heads=2)'\n    out_dict1 = conv(x_dict, edge_index_dict)\n    assert len(out_dict1) == 2\n    assert out_dict1['author'].size() == (4, 32)\n    assert out_dict1['paper'].size() == (6, 32)\n\n    out_dict2 = conv(x_dict, adj_t_dict1)\n    assert len(out_dict1) == len(out_dict2)\n    for key in out_dict1.keys():\n        assert torch.allclose(out_dict1[key], out_dict2[key], atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj_t_dict2 = {}\n        for edge_type, edge_index in edge_index_dict.items():\n            adj_t_dict2[edge_type] = SparseTensor.from_edge_index(\n                edge_index,\n                sparse_sizes=adj_t_dict1[edge_type].size()[::-1],\n            ).t()\n        out_dict3 = conv(x_dict, adj_t_dict2)\n        assert len(out_dict1) == len(out_dict3)\n        for key in out_dict1.keys():\n            assert torch.allclose(out_dict1[key], out_dict3[key], atol=1e-6)\n\n\ndef test_hgt_conv_lazy():\n    x_dict = {\n        'author': torch.randn(4, 16),\n        'paper': torch.randn(6, 32),\n    }\n    edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20))\n\n    edge_index_dict = {\n        ('author', 'writes', 'paper'): edge_index,\n        ('paper', 'written_by', 'author'): edge_index.flip([0]),\n    }\n\n    adj_t_dict1 = {}\n    for edge_type, edge_index in edge_index_dict.items():\n        src_type, _, dst_type = edge_type\n        adj_t_dict1[edge_type] = to_torch_csc_tensor(\n            edge_index,\n            size=(x_dict[src_type].size(0), x_dict[dst_type].size(0)),\n        ).t()\n\n    metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))\n\n    conv = HGTConv(-1, 32, metadata, heads=2)\n    assert str(conv) == 'HGTConv(-1, 32, heads=2)'\n    out_dict1 = conv(x_dict, edge_index_dict)\n    assert len(out_dict1) == 2\n    assert out_dict1['author'].size() == (4, 32)\n    assert out_dict1['paper'].size() == (6, 32)\n\n    out_dict2 = conv(x_dict, adj_t_dict1)\n    assert len(out_dict1) == len(out_dict2)\n    for key in out_dict1.keys():\n        assert torch.allclose(out_dict1[key], out_dict2[key], atol=1e-6)\n\n    if False and torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj_t_dict2 = {}\n        for edge_type, edge_index in edge_index_dict.items():\n            adj_t_dict2[edge_type] = SparseTensor.from_edge_index(\n                edge_index,\n                sparse_sizes=adj_t_dict1[edge_type].size()[::-1],\n            ).t()\n        out_dict3 = conv(x_dict, adj_t_dict2)\n        assert len(out_dict1) == len(out_dict3)\n        for key in out_dict1.keys():\n            assert torch.allclose(out_dict1[key], out_dict3[key], atol=1e-6)\n\n\ndef test_hgt_conv_out_of_place():\n    data = HeteroData()\n    data['author'].x = torch.randn(4, 16)\n    data['paper'].x = torch.randn(6, 32)\n\n    edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20))\n\n    data['author', 'paper'].edge_index = edge_index\n    data['paper', 'author'].edge_index = edge_index.flip([0])\n\n    conv = HGTConv(-1, 64, data.metadata(), heads=1)\n\n    x_dict, edge_index_dict = data.x_dict, data.edge_index_dict\n    assert x_dict['author'].size() == (4, 16)\n    assert x_dict['paper'].size() == (6, 32)\n\n    _ = conv(x_dict, edge_index_dict)\n\n    assert x_dict['author'].size() == (4, 16)\n    assert x_dict['paper'].size() == (6, 32)\n\n\ndef test_hgt_conv_missing_dst_node_type():\n    data = HeteroData()\n    data['author'].x = torch.randn(4, 16)\n    data['paper'].x = torch.randn(6, 32)\n    data['university'].x = torch.randn(10, 32)\n\n    data['author', 'paper'].edge_index = get_random_edge_index(4, 6, 20)\n    data['paper', 'author'].edge_index = get_random_edge_index(6, 4, 20)\n    data['university', 'author'].edge_index = get_random_edge_index(10, 4, 10)\n\n    conv = HGTConv(-1, 64, data.metadata(), heads=1)\n\n    out_dict = conv(data.x_dict, data.edge_index_dict)\n    assert out_dict['author'].size() == (4, 64)\n    assert out_dict['paper'].size() == (6, 64)\n    assert 'university' not in out_dict\n\n\ndef test_hgt_conv_missing_input_node_type():\n    data = HeteroData()\n    data['author'].x = torch.randn(4, 16)\n    data['paper'].x = torch.randn(6, 32)\n    data['author', 'writes',\n         'paper'].edge_index = get_random_edge_index(4, 6, 20)\n\n    # Some nodes from metadata are missing in data.\n    # This might happen while using NeighborLoader.\n    metadata = (['author', 'paper',\n                 'university'], [('author', 'writes', 'paper')])\n    conv = HGTConv(-1, 64, metadata, heads=1)\n\n    out_dict = conv(data.x_dict, data.edge_index_dict)\n    assert out_dict['paper'].size() == (6, 64)\n    assert 'university' not in out_dict\n\n\ndef test_hgt_conv_missing_edge_type():\n    data = HeteroData()\n    data['author'].x = torch.randn(4, 16)\n    data['paper'].x = torch.randn(6, 32)\n    data['university'].x = torch.randn(10, 32)\n\n    data['author', 'writes',\n         'paper'].edge_index = get_random_edge_index(4, 6, 20)\n\n    metadata = (['author', 'paper',\n                 'university'], [('author', 'writes', 'paper'),\n                                 ('university', 'employs', 'author')])\n    conv = HGTConv(-1, 64, metadata, heads=1)\n\n    out_dict = conv(data.x_dict, data.edge_index_dict)\n    assert out_dict['author'].size() == (4, 64)\n    assert out_dict['paper'].size() == (6, 64)\n    assert 'university' not in out_dict\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    args = parser.parse_args()\n\n    num_nodes, num_edges = 30_000, 300_000\n    x_dict = {\n        'paper': torch.randn(num_nodes, 64, device=args.device),\n        'author': torch.randn(num_nodes, 64, device=args.device),\n    }\n    edge_index_dict = {\n        ('paper', 'to', 'paper'):\n        torch.randint(num_nodes, (2, num_edges), device=args.device),\n        ('author', 'to', 'paper'):\n        torch.randint(num_nodes, (2, num_edges), device=args.device),\n        ('paper', 'to', 'author'):\n        torch.randint(num_nodes, (2, num_edges), device=args.device),\n    }\n\n    conv = HGTConv(\n        in_channels=64,\n        out_channels=64,\n        metadata=(list(x_dict.keys()), list(edge_index_dict.keys())),\n        heads=4,\n    ).to(args.device)\n\n    benchmark(\n        funcs=[conv],\n        args=(x_dict, edge_index_dict),\n        num_steps=10 if args.device == 'cpu' else 100,\n        num_warmups=5 if args.device == 'cpu' else 50,\n        backward=False,\n    )\n"
  },
  {
    "path": "test/nn/conv/test_hypergraph_conv.py",
    "content": "import torch\n\nfrom torch_geometric.nn import HypergraphConv\n\n\ndef test_hypergraph_conv_with_more_nodes_than_edges():\n    in_channels, out_channels = (16, 32)\n    hyperedge_index = torch.tensor([[0, 0, 1, 1, 2, 3], [0, 1, 0, 1, 0, 1]])\n    num_nodes = hyperedge_index[0].max().item() + 1\n    num_edges = hyperedge_index[1].max().item() + 1\n    x = torch.randn((num_nodes, in_channels))\n    hyperedge_weight = torch.tensor([1.0, 0.5])\n    hyperedge_attr = torch.randn((num_edges, in_channels))\n\n    conv = HypergraphConv(in_channels, out_channels)\n    assert str(conv) == 'HypergraphConv(16, 32)'\n    out = conv(x, hyperedge_index)\n    assert out.size() == (num_nodes, out_channels)\n    out = conv(x, hyperedge_index, hyperedge_weight)\n    assert out.size() == (num_nodes, out_channels)\n\n    conv = HypergraphConv(in_channels, out_channels, use_attention=True,\n                          heads=2)\n    out = conv(x, hyperedge_index, hyperedge_attr=hyperedge_attr)\n    assert out.size() == (num_nodes, 2 * out_channels)\n    out = conv(x, hyperedge_index, hyperedge_weight, hyperedge_attr)\n    assert out.size() == (num_nodes, 2 * out_channels)\n\n    conv = HypergraphConv(in_channels, out_channels, use_attention=True,\n                          heads=2, concat=False, dropout=0.5)\n    out = conv(x, hyperedge_index, hyperedge_weight, hyperedge_attr)\n    assert out.size() == (num_nodes, out_channels)\n\n\ndef test_hypergraph_conv_with_more_edges_than_nodes():\n    in_channels, out_channels = (16, 32)\n    hyperedge_index = torch.tensor([[0, 0, 1, 1, 2, 3, 3, 3, 2, 1, 2],\n                                    [0, 1, 2, 1, 2, 1, 0, 3, 3, 4, 4]])\n    hyperedge_weight = torch.tensor([1.0, 0.5, 0.8, 0.2, 0.7])\n    num_nodes = hyperedge_index[0].max().item() + 1\n    x = torch.randn((num_nodes, in_channels))\n\n    conv = HypergraphConv(in_channels, out_channels)\n    assert str(conv) == 'HypergraphConv(16, 32)'\n    out = conv(x, hyperedge_index)\n    assert out.size() == (num_nodes, out_channels)\n    out = conv(x, hyperedge_index, hyperedge_weight)\n    assert out.size() == (num_nodes, out_channels)\n"
  },
  {
    "path": "test/nn/conv/test_le_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import LEConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_le_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = LEConv(16, 32)\n    assert str(conv) == 'LEConv(16, 32)'\n    out = conv(x, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x, adj1.t()), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj2.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        torch.allclose(jit(x, edge_index), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj2.t()), out)\n"
  },
  {
    "path": "test/nn/conv/test_lg_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import LGConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_lg_conv():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.rand(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = LGConv()\n    assert str(conv) == 'LGConv()'\n    out1 = conv(x, edge_index)\n    assert out1.size() == (4, 8)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n\n    out2 = conv(x, edge_index, value)\n    assert out2.size() == (4, 8)\n    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n        assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)\n        assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_meshcnn_conv.py",
    "content": "import pytest\nimport torch\nfrom torch.nn import Linear, ModuleList, Sequential, Sigmoid\n\nfrom torch_geometric.nn import MeshCNNConv\n\n\n@pytest.mark.parametrize('in_channels, out_channels', [\n    (1, 1),\n    (1, 2),\n    (8, 3),\n    (8, 3),\n    (42, 40),\n])\ndef test_meshcnn_conv(in_channels: int, out_channels: int):\n    # m = (V, F), shape [|V| x 3, 3 * |F|]\n    # The simplest manifold triangular mesh is a tetrahedron\n    E_cardinality = 6  # |E|, the number of edges\n    x0 = torch.randn(E_cardinality, in_channels)  # X^(k), the prior layer\n    edge_index = torch.tensor([[\n        0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5\n    ], [\n        1, 2, 3, 4, 2, 0, 4, 5, 5, 3, 0, 1, 2, 5, 4, 0, 0, 3, 5, 1, 1, 4, 3, 2\n    ]], dtype=torch.int64)\n\n    # in_channels is the `Dim-Out(k)` in torch.nn.conv.MeshCNNConv\n    # out_channels is the `Dim-Out(k+1)` in torch.nn.conv.MeshCNNConv\n    conv = MeshCNNConv(in_channels, out_channels)\n\n    # Assert right representation (defined by the class's __repr__ method)\n    # WARN: For now we do not account for the 5 default kernels in the\n    # representation.\n    assert str(conv) == f\"MeshCNNConv({in_channels}, {out_channels})\"\n\n    x1 = conv(x0, edge_index)\n    assert x1.size() == (E_cardinality, out_channels)\n    # assert determinism\n    assert torch.allclose(conv(x0, edge_index), x1)\n\n    # kernels MUST be a ModuleList of length 5.\n    # Where kernels[0] is known as W_0^{(k+1)} in MeshCNNConv etc\n    kernels = ModuleList([\n        Sequential(Linear(in_channels, out_channels), Sigmoid())\n        for _ in range(5)\n    ])\n    with pytest.warns(UserWarning, match=\"does not have attribute\"):\n        conv = MeshCNNConv(in_channels, out_channels, kernels)\n    # WARN: For now we do not account for the 5 kernels in the\n    # representation\n    assert str(conv) == f\"MeshCNNConv({in_channels}, {out_channels})\"\n    x1 = conv(x0, edge_index)\n    assert x1.size() == (E_cardinality, out_channels)\n"
  },
  {
    "path": "test/nn/conv/test_message_passing.py",
    "content": "import copy\nimport os.path as osp\nfrom typing import Optional, Tuple, Union\n\nimport pytest\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Linear\n\nimport torch_geometric.typing\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.nn import MessagePassing, aggr\nfrom torch_geometric.typing import (\n    Adj,\n    OptPairTensor,\n    OptTensor,\n    Size,\n    SparseTensor,\n)\nfrom torch_geometric.utils import (\n    add_self_loops,\n    scatter,\n    spmm,\n    to_torch_csc_tensor,\n)\n\n\nclass MyConv(MessagePassing):\n    def __init__(self, in_channels: Union[int, Tuple[int, int]],\n                 out_channels: int, aggr: str = 'add'):\n        super().__init__(aggr=aggr)\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        self.lin_l = Linear(in_channels[0], out_channels)\n        self.lin_r = Linear(in_channels[1], out_channels)\n\n    def forward(\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n        size: Size = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,\n                             size=size)\n        out = self.lin_l(out)\n\n        x_r = x[1]\n        if x_r is not None:\n            out += self.lin_r(x_r)\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:\n        return spmm(adj_t, x[0], reduce=self.aggr)\n\n\nclass MyConvWithSelfLoops(MessagePassing):\n    def __init__(self, aggr: str = 'add'):\n        super().__init__(aggr=aggr)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        edge_index, _ = add_self_loops(edge_index)\n\n        # propagate_type: (x: Tensor)\n        return self.propagate(edge_index, x=x)\n\n\ndef test_my_conv_basic():\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.randn(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n\n    conv = MyConv(8, 32)\n    out = conv(x1, edge_index, value)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out, atol=1e-6)\n    assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)\n    conv.fuse = False\n    assert torch.allclose(conv(x1, adj1.t()), out, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)\n    conv.fuse = True\n\n    # Bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, value, size=(4, 2))\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 2))\n\n    conv = MyConv((8, 16), 32)\n    out1 = conv((x1, x2), edge_index, value)\n    out2 = conv((x1, None), edge_index, value, (4, 2))\n    assert out1.size() == (2, 32)\n    assert out2.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)\n    assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6)\n    assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)\n        assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)\n    conv.fuse = False\n    assert torch.allclose(conv((x1, x2), adj1.t()), out1, atol=1e-6)\n    assert torch.allclose(conv((x1, None), adj1.t()), out2, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv((x1, x2), adj2.t()), out1, atol=1e-6)\n        assert torch.allclose(conv((x1, None), adj2.t()), out2, atol=1e-6)\n\n    # Test gradient computation for `torch.sparse` tensors:\n    conv.fuse = True\n    torch_adj_t = adj1.t().requires_grad_()\n    out = conv((x1, x2), torch_adj_t)\n    out.sum().backward()\n    assert torch_adj_t.grad is not None\n\n\ndef test_my_conv_save(tmp_path):\n    conv = MyConv(8, 32)\n    assert conv._jinja_propagate is not None\n    assert conv.__class__._jinja_propagate is not None\n    assert conv._orig_propagate is not None\n    assert conv.__class__._orig_propagate is not None\n\n    path = osp.join(tmp_path, 'model.pt')\n    torch.save(conv, path)\n    conv = torch.load(path, weights_only=False)\n    assert conv._jinja_propagate is not None\n    assert conv.__class__._jinja_propagate is not None\n    assert conv._orig_propagate is not None\n    assert conv.__class__._orig_propagate is not None\n\n\ndef test_my_conv_edge_index():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_index = EdgeIndex(edge_index, sparse_size=(4, 4), sort_order='col')\n\n    conv = MyConv(8, 32)\n\n    out = conv(x, edge_index)\n    assert out.size() == (4, 32)\n\n\nclass MyCommentedConv(MessagePassing):\n    r\"\"\"This layer calls `self.propagate()` internally.\"\"\"\n    def __init__(self) -> None:\n        super().__init__()\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        # `self.propagate()` is used here to propagate messages.\n        return self.propagate(edge_index, x=x)\n\n\ndef test_my_commented_conv():\n    # Check that `self.propagate` occurrences in comments are correctly\n    # ignored.\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    conv = MyCommentedConv()\n    conv(x, edge_index)\n\n    jit = torch.jit.script(conv)\n    jit(x, edge_index)\n\n\nclass MyKwargsConv(MessagePassing):\n    def __init__(self) -> None:\n        super().__init__()\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        return self.propagate(x=x, edge_index=edge_index)\n\n\ndef test_my_kwargs_conv():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    conv = MyKwargsConv()\n    conv(x, edge_index)\n\n    jit = torch.jit.script(conv)\n    jit(x, edge_index)\n\n\ndef test_my_conv_out_of_bounds():\n    x = torch.randn(3, 8)\n    value = torch.randn(4)\n\n    conv = MyConv(8, 32)\n\n    with pytest.raises(IndexError, match=\"valid indices\"):\n        edge_index = torch.tensor([[-1, 1, 2, 2], [0, 0, 1, 1]])\n        conv(x, edge_index, value)\n\n    with pytest.raises(IndexError, match=\"valid indices\"):\n        edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n        conv(x, edge_index, value)\n\n\ndef test_my_conv_jit():\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.randn(edge_index.size(1))\n\n    conv = MyConv(8, 32)\n    out = conv(x1, edge_index, value)\n\n    jit = torch.jit.script(conv)\n    assert torch.allclose(jit(x1, edge_index, value), out, atol=1e-6)\n    assert torch.allclose(jit(x1, edge_index, value, (4, 4)), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n\n        assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)\n        jit.fuse = False\n        assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)\n        jit.fuse = True\n\n    conv = MyConv((8, 16), 32)\n    out1 = conv((x1, x2), edge_index, value)\n    out2 = conv((x1, None), edge_index, value, (4, 2))\n\n    jit = torch.jit.script(conv)\n    assert torch.allclose(jit((x1, x2), edge_index, value), out1)\n    assert torch.allclose(jit((x1, x2), edge_index, value, (4, 2)), out1)\n    assert torch.allclose(jit((x1, None), edge_index, value, (4, 2)), out2)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 2))\n\n        assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)\n        assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6)\n        jit.fuse = False\n        assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)\n        assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6)\n        jit.fuse = True\n\n\ndef test_my_conv_jit_save(tmp_path):\n    path = osp.join(tmp_path, 'model.pt')\n\n    conv = MyConv(8, 32)\n    conv = torch.jit.script(conv)\n    torch.jit.save(conv, path)\n    conv = torch.jit.load(path)\n\n\n@pytest.mark.parametrize('aggr', ['add', 'sum', 'mean', 'min', 'max', 'mul'])\ndef test_my_conv_aggr(aggr):\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_weight = torch.randn(edge_index.size(1))\n\n    conv = MyConv(8, 32, aggr=aggr)\n    out = conv(x, edge_index, edge_weight)\n    assert out.size() == (4, 32)\n\n\ndef test_my_static_graph_conv():\n    x1 = torch.randn(3, 4, 8)\n    x2 = torch.randn(3, 2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.randn(edge_index.size(1))\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n\n    conv = MyConv(8, 32)\n    out = conv(x1, edge_index, value)\n    assert out.size() == (3, 4, 32)\n    assert torch.allclose(conv(x1, edge_index, value, (4, 4)), out)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x1, adj.t()), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 2))\n\n    conv = MyConv((8, 16), 32)\n    out1 = conv((x1, x2), edge_index, value)\n    out2 = conv((x1, None), edge_index, value, (4, 2))\n    assert out1.size() == (3, 2, 32)\n    assert out2.size() == (3, 2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv((x1, x2), adj.t()), out1)\n        assert torch.allclose(conv((x1, None), adj.t()), out2)\n\n\nclass MyMultipleAggrConv(MessagePassing):\n    def __init__(self, **kwargs):\n        super().__init__(aggr=['add', 'mean', 'max'], **kwargs)\n\n    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n        # propagate_type: (x: Tensor)\n        return self.propagate(edge_index, x=x)\n\n\n@pytest.mark.parametrize('multi_aggr_tuple', [\n    (dict(mode='cat'), 3),\n    (dict(mode='proj', mode_kwargs=dict(in_channels=16, out_channels=16)), 1)\n])\ndef test_my_multiple_aggr_conv(multi_aggr_tuple):\n    # The 'cat' combine mode will expand the output dimensions by\n    # the number of aggregators which is 3 here, while the 'proj'\n    # mode keeps output dimensions unchanged.\n    aggr_kwargs, expand = multi_aggr_tuple\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n\n    conv = MyMultipleAggrConv(aggr_kwargs=aggr_kwargs)\n    out = conv(x, edge_index)\n    assert out.size() == (4, 16 * expand)\n    assert torch.allclose(conv(x, adj1.t()), out)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x, adj2.t()), out)\n\n\ndef test_my_multiple_aggr_conv_jit():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    conv = MyMultipleAggrConv()\n    out = conv(x, edge_index)\n\n    jit = torch.jit.script(conv)\n    assert torch.allclose(jit(x, edge_index), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(jit(x, adj.t()), out)\n\n\ndef test_copy():\n    conv = MyConv(8, 32)\n    conv2 = copy.copy(conv)\n\n    assert conv != conv2\n    assert torch.equal(conv.lin_l.weight, conv2.lin_l.weight)\n    assert torch.equal(conv.lin_r.weight, conv2.lin_r.weight)\n    assert conv.lin_l.weight.data_ptr == conv2.lin_l.weight.data_ptr\n    assert conv.lin_r.weight.data_ptr == conv2.lin_r.weight.data_ptr\n\n    conv = copy.deepcopy(conv)\n    assert conv != conv2\n    assert torch.equal(conv.lin_l.weight, conv2.lin_l.weight)\n    assert torch.equal(conv.lin_r.weight, conv2.lin_r.weight)\n    assert conv.lin_l.weight.data_ptr != conv2.lin_l.weight.data_ptr\n    assert conv.lin_r.weight.data_ptr != conv2.lin_r.weight.data_ptr\n\n\nclass MyEdgeConv(MessagePassing):\n    def __init__(self):\n        super().__init__(aggr='add')\n\n    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n        # edge_updater_type: (x: Tensor)\n        edge_attr = self.edge_updater(edge_index, x=x)\n\n        # propagate_type: (edge_attr: Tensor)\n        return self.propagate(edge_index, edge_attr=edge_attr,\n                              size=(x.size(0), x.size(0)))\n\n    def edge_update(self, x_j: Tensor, x_i: Tensor) -> Tensor:\n        return x_j - x_i\n\n    def message(self, edge_attr: Tensor) -> Tensor:\n        return edge_attr\n\n\ndef test_my_edge_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    row, col = edge_index\n    expected = scatter(x[row] - x[col], col, dim=0, dim_size=4, reduce='sum')\n\n    conv = MyEdgeConv()\n    out = conv(x, edge_index)\n    assert out.size() == (4, 16)\n    assert torch.allclose(out, expected)\n    assert torch.allclose(conv(x, adj1.t()), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj2.t()), out)\n\n\ndef test_my_edge_conv_jit():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    conv = MyEdgeConv()\n    out = conv(x, edge_index)\n\n    jit = torch.jit.script(conv)\n    assert torch.allclose(jit(x, edge_index), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(jit(x, adj.t()), out)\n\n\nnum_pre_hook_calls = 0\nnum_hook_calls = 0\n\n\ndef test_message_passing_hooks():\n    conv = MyConv(8, 32)\n\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.randn(edge_index.size(1))\n    adj = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    def pre_hook(module, inputs):\n        assert module == conv\n        global num_pre_hook_calls\n        num_pre_hook_calls += 1\n        return inputs\n\n    def hook(module, inputs, output):\n        assert module == conv\n        global num_hook_calls\n        num_hook_calls += 1\n        return output\n\n    handle1 = conv.register_propagate_forward_pre_hook(pre_hook)\n    assert len(conv._propagate_forward_pre_hooks) == 1\n    handle2 = conv.register_propagate_forward_hook(hook)\n    assert len(conv._propagate_forward_hooks) == 1\n\n    handle3 = conv.register_message_forward_pre_hook(pre_hook)\n    assert len(conv._message_forward_pre_hooks) == 1\n    handle4 = conv.register_message_forward_hook(hook)\n    assert len(conv._message_forward_hooks) == 1\n\n    handle5 = conv.register_aggregate_forward_pre_hook(pre_hook)\n    assert len(conv._aggregate_forward_pre_hooks) == 1\n    handle6 = conv.register_aggregate_forward_hook(hook)\n    assert len(conv._aggregate_forward_hooks) == 1\n\n    handle7 = conv.register_message_and_aggregate_forward_pre_hook(pre_hook)\n    assert len(conv._message_and_aggregate_forward_pre_hooks) == 1\n    handle8 = conv.register_message_and_aggregate_forward_hook(hook)\n    assert len(conv._message_and_aggregate_forward_hooks) == 1\n\n    out1 = conv(x, edge_index, value)\n    assert num_pre_hook_calls == 3\n    assert num_hook_calls == 3\n    out2 = conv(x, adj.t())\n    assert num_pre_hook_calls == 5\n    assert num_hook_calls == 5\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n    handle1.remove()\n    assert len(conv._propagate_forward_pre_hooks) == 0\n    handle2.remove()\n    assert len(conv._propagate_forward_hooks) == 0\n\n    handle3.remove()\n    assert len(conv._message_forward_pre_hooks) == 0\n    handle4.remove()\n    assert len(conv._message_forward_hooks) == 0\n\n    handle5.remove()\n    assert len(conv._aggregate_forward_pre_hooks) == 0\n    handle6.remove()\n    assert len(conv._aggregate_forward_hooks) == 0\n\n    handle7.remove()\n    assert len(conv._message_and_aggregate_forward_pre_hooks) == 0\n    handle8.remove()\n    assert len(conv._message_and_aggregate_forward_hooks) == 0\n\n    conv = MyEdgeConv()\n\n    handle1 = conv.register_edge_update_forward_pre_hook(pre_hook)\n    assert len(conv._edge_update_forward_pre_hooks) == 1\n    handle2 = conv.register_edge_update_forward_hook(hook)\n    assert len(conv._edge_update_forward_hooks) == 1\n\n    out1 = conv(x, edge_index)\n    assert num_pre_hook_calls == 6\n    assert num_hook_calls == 6\n    out2 = conv(x, adj.t())\n    assert num_pre_hook_calls == 7\n    assert num_hook_calls == 7\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n    handle1.remove()\n    assert len(conv._propagate_forward_pre_hooks) == 0\n    handle2.remove()\n    assert len(conv._propagate_forward_hooks) == 0\n\n\ndef test_modified_message_passing_hook():\n    conv = MyConv(8, 32)\n\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_weight = torch.randn(edge_index.size(1))\n\n    out1 = conv(x, edge_index, edge_weight)\n\n    def hook(module, inputs, output):\n        assert len(inputs) == 1\n        assert len(inputs[-1]) == 2\n        assert 'x_j' in inputs[-1]\n        assert 'edge_weight' in inputs[-1]\n        return output + 1.\n\n    conv.register_message_forward_hook(hook)\n\n    out2 = conv(x, edge_index, edge_weight)\n    assert not torch.allclose(out1, out2, atol=1e-6)\n\n\nclass MyDefaultArgConv(MessagePassing):\n    def __init__(self):\n        super().__init__(aggr='mean')\n\n    # propagate_type: (x: Tensor)\n    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n        return self.propagate(edge_index, x=x)\n\n    def message(self, x_j, zeros: bool = True):\n        return x_j * 0 if zeros else x_j\n\n\ndef test_my_default_arg_conv():\n    x = torch.randn(4, 1)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = MyDefaultArgConv()\n    assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0]\n    assert conv(x, adj1.t()).view(-1).tolist() == [0, 0, 0, 0]\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert conv(x, adj2.t()).view(-1).tolist() == [0, 0, 0, 0]\n\n    jit = torch.jit.script(conv)\n    assert jit(x, edge_index).view(-1).tolist() == [0, 0, 0, 0]\n    assert jit(x, adj1.t()).view(-1).tolist() == [0, 0, 0, 0]\n\n\nclass MyMultipleOutputConv(MessagePassing):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor]:\n        # propagate_type: (x: Tensor)\n        return self.propagate(edge_index, x=x)\n\n    def message(self, x_j: Tensor) -> Tuple[Tensor, Tensor]:\n        return x_j, x_j\n\n    def aggregate(self, inputs: Tuple[Tensor, Tensor],\n                  index: Tensor) -> Tuple[Tensor, Tensor]:\n        return (scatter(inputs[0], index, dim=0, reduce='sum'),\n                scatter(inputs[0], index, dim=0, reduce='mean'))\n\n    def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:\n        return inputs\n\n\ndef test_tuple_output():\n    conv = MyMultipleOutputConv()\n\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    out1 = conv(x, edge_index)\n    assert isinstance(out1, tuple) and len(out1) == 2\n\n\ndef test_tuple_output_jit():\n    conv = MyMultipleOutputConv()\n\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    out1 = conv(x, edge_index)\n    assert isinstance(out1, tuple) and len(out1) == 2\n\n    jit = torch.jit.script(conv)\n    out2 = jit(x, edge_index)\n    assert isinstance(out2, tuple) and len(out2) == 2\n    assert torch.allclose(out1[0], out2[0])\n    assert torch.allclose(out1[1], out2[1])\n\n\nclass MyExplainConv(MessagePassing):\n    def __init__(self):\n        super().__init__(aggr='add')\n\n    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n        return self.propagate(edge_index, x=x)\n\n\ndef test_explain_message():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    conv = MyExplainConv()\n    conv.explain = True\n    assert conv.propagate.__module__.endswith('message_passing')\n\n    with pytest.raises(ValueError, match=\"pre-defined 'edge_mask'\"):\n        conv(x, edge_index)\n\n    conv._edge_mask = torch.tensor([0.0, 0.0, 0.0, 0.0])\n    conv._apply_sigmoid = False\n    assert conv(x, edge_index).abs().sum() == 0.\n\n    conv._edge_mask = torch.tensor([1.0, 1.0, 1.0, 1.0])\n    conv._apply_sigmoid = False\n    out1 = conv(x, edge_index)\n\n    # TorchScript should still work since it relies on class methods\n    # (but without explainability).\n    torch.jit.script(conv)\n\n    conv.explain = False\n    assert conv.propagate.__module__.endswith('MyExplainConv_propagate')\n    out2 = conv(x, edge_index)\n    assert torch.allclose(out1, out2)\n\n\nclass MyAggregatorConv(MessagePassing):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n\n    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n        # propagate_type: (x: Tensor)\n        return self.propagate(edge_index, x=x)\n\n\n@pytest.mark.parametrize('aggr_module', [\n    aggr.MeanAggregation(),\n    aggr.SumAggregation(),\n    aggr.MaxAggregation(),\n    aggr.SoftmaxAggregation(),\n    aggr.PowerMeanAggregation(),\n    aggr.MultiAggregation(['mean', 'max'])\n])\ndef test_message_passing_with_aggr_module(aggr_module):\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    row, col = edge_index\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = MyAggregatorConv(aggr=aggr_module)\n    assert isinstance(conv.aggr_module, aggr.Aggregation)\n    out = conv(x, edge_index)\n    assert out.size(0) == 4 and out.size(1) in {8, 16}\n    assert torch.allclose(conv(x, adj1.t()), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj2.t()), out)\n\n\ndef test_message_passing_int32_edge_index():\n    # Check that we can dispatch an int32 edge_index up to aggregation\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=torch.int32)\n    edge_weight = torch.randn(edge_index.shape[1])\n\n    # Use a hook to promote the edge_index to long to workaround PyTorch CPU\n    # backend restriction to int64 for the index.\n    def cast_index_hook(module, inputs):\n        input_dict = inputs[-1]\n        input_dict['index'] = input_dict['index'].long()\n        return (input_dict, )\n\n    conv = MyConv(8, 32)\n    conv.register_aggregate_forward_pre_hook(cast_index_hook)\n\n    assert conv(x, edge_index, edge_weight).size() == (4, 32)\n\n\n@pytest.mark.parametrize('num_nodes', [4, 8, 2, 0])\ndef test_traceable_my_conv_with_self_loops(num_nodes):\n    # `torch.jit.trace` a `MessagePassing` layer that adds self loops and test\n    # it across different input sizes.\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])\n\n    conv = MyConvWithSelfLoops()\n    traced_conv = torch.jit.trace(conv, ((x, edge_index)))\n    scripted_conv = torch.jit.script(conv)\n\n    x = torch.randn(num_nodes, 16)\n    if num_nodes > 0:\n        edge_index = torch.stack([\n            torch.arange(0, num_nodes - 1),\n            torch.arange(1, num_nodes),\n        ], dim=0)\n    else:\n        edge_index = torch.empty((2, 0), dtype=torch.long)\n\n    out = conv(x, edge_index)\n    traced_out = traced_conv(x, edge_index)\n    scripted_out = scripted_conv(x, edge_index)\n\n    assert torch.allclose(out, traced_out)\n    assert torch.allclose(out, scripted_out)\n\n\ndef test_pickle(tmp_path):\n    path = osp.join(tmp_path, 'model.pt')\n    model = MyConv(16, 32)\n    torch.save(model, path)\n\n    MyConv.propagate = MyConv._orig_propagate\n\n    model = torch.load(path, weights_only=False)\n    torch.jit.script(model)\n\n\nclass MyOptionalEdgeAttrConv(MessagePassing):\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x, edge_index, edge_attr=None):\n        return self.propagate(edge_index, x=x, edge_attr=edge_attr)\n\n    def message(self, x_j, edge_attr=None):\n        return x_j if edge_attr is None else x_j * edge_attr.view(-1, 1)\n\n\ndef test_my_optional_edge_attr_conv():\n    conv = MyOptionalEdgeAttrConv()\n\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    out = conv(x, edge_index)\n    assert out.size() == (4, 8)\n"
  },
  {
    "path": "test/nn/conv/test_mf_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import MFConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\n\n\ndef test_mf_conv():\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    conv = MFConv(8, 32)\n    assert str(conv) == 'MFConv(8, 32)'\n    out = conv(x1, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index), out)\n        assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj.t()), out)\n\n    # Test bipartite message passing:\n    conv = MFConv((8, 16), 32)\n    assert str(conv) == 'MFConv((8, 16), 32)'\n\n    out1 = conv((x1, x2), edge_index)\n    assert out1.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, (4, 2)), out1)\n\n    out2 = conv((x1, None), edge_index, (4, 2))\n    assert out2.size() == (2, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv((x1, x2), adj.t()), out1)\n        assert torch.allclose(conv((x1, None), adj.t()), out2)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index), out1)\n        assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1)\n        assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj.t()), out1)\n            assert torch.allclose(jit((x1, None), adj.t()), out2)\n"
  },
  {
    "path": "test/nn/conv/test_mixhop_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import MixHopConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_mixhop_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    value = torch.rand(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = MixHopConv(16, 32, powers=[0, 1, 2, 4])\n    assert str(conv) == 'MixHopConv(16, 32, powers=[0, 1, 2, 4])'\n\n    out1 = conv(x, edge_index)\n    assert out1.size() == (4, 128)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n\n    out2 = conv(x, edge_index, value)\n    assert out2.size() == (4, 128)\n    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)\n        assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_nn_conv.py",
    "content": "import torch\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import NNConv\nfrom torch_geometric.testing import is_full_test, withCUDA\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_coo_tensor\n\n\n@withCUDA\ndef test_nn_conv(device):\n    x1 = torch.randn(4, 8, device=device)\n    x2 = torch.randn(2, 16, device=device)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], device=device)\n    value = torch.rand(edge_index.size(1), 3, device=device)\n    adj1 = to_torch_coo_tensor(edge_index, value, size=(4, 4))\n\n    nn = Seq(Lin(3, 32), ReLU(), Lin(32, 8 * 32))\n    conv = NNConv(8, 32, nn=nn).to(device)\n    assert str(conv) == (\n        'NNConv(8, 32, aggr=add, nn=Sequential(\\n'\n        '  (0): Linear(in_features=3, out_features=32, bias=True)\\n'\n        '  (1): ReLU()\\n'\n        '  (2): Linear(in_features=32, out_features=256, bias=True)\\n'\n        '))')\n\n    out = conv(x1, edge_index, value)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out)\n    assert torch.allclose(conv(x1, adj1.transpose(0, 1).coalesce()), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index, value), out)\n        assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_coo_tensor(edge_index, value, size=(4, 2))\n\n    conv = NNConv((8, 16), 32, nn=nn).to(device)\n    assert str(conv) == (\n        'NNConv((8, 16), 32, aggr=add, nn=Sequential(\\n'\n        '  (0): Linear(in_features=3, out_features=32, bias=True)\\n'\n        '  (1): ReLU()\\n'\n        '  (2): Linear(in_features=32, out_features=256, bias=True)\\n'\n        '))')\n\n    out1 = conv((x1, x2), edge_index, value)\n    assert out1.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)\n    assert torch.allclose(conv((x1, x2),\n                               adj1.transpose(0, 1).coalesce()), out1)\n\n    out2 = conv((x1, None), edge_index, value, (4, 2))\n    assert out2.size() == (2, 32)\n    assert torch.allclose(conv((x1, None),\n                               adj1.transpose(0, 1).coalesce()), out2)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 2))\n        assert torch.allclose(conv((x1, x2), adj2.t()), out1)\n        assert torch.allclose(conv((x1, None), adj2.t()), out2)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index, value), out1)\n        assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),\n                              out1)\n        assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),\n                              out2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj2.t()), out1)\n            assert torch.allclose(jit((x1, None), adj2.t()), out2)\n"
  },
  {
    "path": "test/nn/conv/test_pan_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import PANConv\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\n@withPackage('torch_sparse')  # TODO `PANConv` returns a `SparseTensor`.\ndef test_pan_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n\n    conv = PANConv(16, 32, filter_size=2)\n    assert str(conv) == 'PANConv(16, 32, filter_size=2)'\n    out1, M1 = conv(x, edge_index)\n    assert out1.size() == (4, 32)\n\n    out2, M2 = conv(x, adj1.t())\n    assert torch.allclose(out1, out2, atol=1e-6)\n    assert torch.allclose(M1.to_dense(), M2.to_dense())\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        out3, M3 = conv(x, adj2.t())\n        assert torch.allclose(out1, out3, atol=1e-6)\n        assert torch.allclose(M1.to_dense(), M3.to_dense())\n"
  },
  {
    "path": "test/nn/conv/test_pdn_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import PDNConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\n\n\ndef test_pdn_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    edge_attr = torch.randn(edge_index.size(1), 8)\n\n    conv = PDNConv(16, 32, edge_dim=8, hidden_channels=128)\n    assert str(conv) == \"PDNConv(16, 32)\"\n\n    out = conv(x, edge_index, edge_attr)\n    assert out.size() == (4, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4))\n        assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index, edge_attr), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)\n\n\ndef test_pdn_conv_with_sparse_node_input_feature():\n    x = torch.sparse_coo_tensor(\n        indices=torch.tensor([[0, 0], [0, 1]]),\n        values=torch.tensor([1.0, 1.0]),\n        size=torch.Size([4, 16]),\n    )\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    edge_attr = torch.randn(edge_index.size(1), 8)\n\n    conv = PDNConv(16, 32, edge_dim=8, hidden_channels=128)\n\n    out = conv(x, edge_index, edge_attr)\n    assert out.size() == (4, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4))\n        assert torch.allclose(conv(x, adj.t(), edge_attr), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index, edge_attr), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj.t(), edge_attr), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_pna_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import DataLoader, NeighborLoader\nfrom torch_geometric.nn import PNAConv\nfrom torch_geometric.testing import is_full_test, onlyNeighborSampler\nfrom torch_geometric.typing import SparseTensor\n\naggregators = ['sum', 'mean', 'min', 'max', 'var', 'std']\nscalers = [\n    'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear'\n]\n\n\n@pytest.mark.parametrize('divide_input', [True, False])\ndef test_pna_conv(divide_input):\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    deg = torch.tensor([0, 3, 0, 1])\n    value = torch.rand(edge_index.size(1), 3)\n\n    conv = PNAConv(16, 32, aggregators, scalers, deg=deg, edge_dim=3, towers=4,\n                   pre_layers=2, post_layers=2, divide_input=divide_input)\n    assert str(conv) == 'PNAConv(16, 32, towers=4, edge_dim=3)'\n\n    out = conv(x, edge_index, value)\n    assert out.size() == (4, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index, value), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)\n\n\n@onlyNeighborSampler\ndef test_pna_conv_get_degree_histogram_neighbor_loader():\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]])\n    data = Data(num_nodes=5, edge_index=edge_index)\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[-1],\n        input_nodes=None,\n        batch_size=5,\n        shuffle=False,\n    )\n    deg_hist = PNAConv.get_degree_histogram(loader)\n    assert torch.equal(deg_hist, torch.tensor([1, 2, 1, 1]))\n\n\ndef test_pna_conv_get_degree_histogram_dataloader():\n    edge_index_1 = torch.tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]])\n    edge_index_2 = torch.tensor([[1, 1, 2, 2, 0, 3, 3], [2, 3, 3, 1, 1, 0, 2]])\n    edge_index_3 = torch.tensor([[1, 3, 2, 0, 0, 4, 2], [2, 0, 4, 1, 1, 0, 3]])\n    edge_index_4 = torch.tensor([[0, 1, 2, 4, 0, 1, 3], [2, 3, 3, 1, 1, 0, 2]])\n\n    data_1 = Data(num_nodes=5, edge_index=edge_index_1)  # hist = [1, 2 ,1 ,1]\n    data_2 = Data(num_nodes=5, edge_index=edge_index_2)  # hist = [1, 1, 3]\n    data_3 = Data(num_nodes=5, edge_index=edge_index_3)  # hist = [0, 3, 2]\n    data_4 = Data(num_nodes=5, edge_index=edge_index_4)  # hist = [1, 1, 3]\n\n    loader = DataLoader(\n        [data_1, data_2, data_3, data_4],\n        batch_size=1,\n        shuffle=False,\n    )\n    deg_hist = PNAConv.get_degree_histogram(loader)\n    assert torch.equal(deg_hist, torch.tensor([3, 7, 9, 1]))\n"
  },
  {
    "path": "test/nn/conv/test_point_conv.py",
    "content": "import torch\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import PointNetConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_point_net_conv():\n    x1 = torch.randn(4, 16)\n    pos1 = torch.randn(4, 3)\n    pos2 = torch.randn(2, 3)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    local_nn = Seq(Lin(16 + 3, 32), ReLU(), Lin(32, 32))\n    global_nn = Seq(Lin(32, 32))\n    conv = PointNetConv(local_nn, global_nn)\n    assert str(conv) == (\n        'PointNetConv(local_nn=Sequential(\\n'\n        '  (0): Linear(in_features=19, out_features=32, bias=True)\\n'\n        '  (1): ReLU()\\n'\n        '  (2): Linear(in_features=32, out_features=32, bias=True)\\n'\n        '), global_nn=Sequential(\\n'\n        '  (0): Linear(in_features=32, out_features=32, bias=True)\\n'\n        '))')\n    out = conv(x1, pos1, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, pos1, edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, pos1, adj2.t()), out, atol=1e-6)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    out = conv(x1, (pos1, pos2), edge_index)\n    assert out.size() == (2, 32)\n    assert torch.allclose(conv((x1, None), (pos1, pos2), edge_index), out)\n    assert torch.allclose(conv(x1, (pos1, pos2), adj1.t()), out, atol=1e-6)\n    assert torch.allclose(conv((x1, None), (pos1, pos2), adj1.t()), out,\n                          atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv(x1, (pos1, pos2), adj2.t()), out, atol=1e-6)\n        assert torch.allclose(conv((x1, None), (pos1, pos2), adj2.t()), out,\n                              atol=1e-6)\n\n    if is_full_test():\n        assert torch.allclose(jit((x1, None), (pos1, pos2), edge_index), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, None), (pos1, pos2), adj2.t()), out,\n                                  atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_point_gnn_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import MLP, PointGNNConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_point_gnn_conv():\n    x = torch.randn(6, 8)\n    pos = torch.randn(6, 3)\n    edge_index = torch.tensor([[0, 1, 1, 1, 2, 5], [1, 2, 3, 4, 3, 4]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(6, 6))\n\n    conv = PointGNNConv(\n        mlp_h=MLP([8, 16, 3]),\n        mlp_f=MLP([3 + 8, 16, 8]),\n        mlp_g=MLP([8, 16, 8]),\n    )\n    assert str(conv) == ('PointGNNConv(\\n'\n                         '  mlp_h=MLP(8, 16, 3),\\n'\n                         '  mlp_f=MLP(11, 16, 8),\\n'\n                         '  mlp_g=MLP(8, 16, 8),\\n'\n                         ')')\n\n    out = conv(x, pos, edge_index)\n    assert out.size() == (6, 8)\n    assert torch.allclose(conv(x, pos, adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(6, 6))\n        assert torch.allclose(conv(x, pos, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, pos, edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, pos, adj2.t()), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_point_transformer_conv.py",
    "content": "import torch\nfrom torch.nn import Linear, ReLU, Sequential\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import PointTransformerConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_point_transformer_conv():\n    x1 = torch.rand(4, 16)\n    x2 = torch.randn(2, 8)\n    pos1 = torch.rand(4, 3)\n    pos2 = torch.randn(2, 3)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = PointTransformerConv(in_channels=16, out_channels=32)\n    assert str(conv) == 'PointTransformerConv(16, 32)'\n\n    out = conv(x1, pos1, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, pos1, edge_index), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, pos1, adj2.t()), out, atol=1e-6)\n\n    pos_nn = Sequential(Linear(3, 16), ReLU(), Linear(16, 32))\n    attn_nn = Sequential(Linear(32, 32), ReLU(), Linear(32, 32))\n    conv = PointTransformerConv(16, 32, pos_nn, attn_nn)\n\n    out = conv(x1, pos1, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, pos1, adj1.t()), out, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x1, pos1, adj2.t()), out, atol=1e-6)\n\n    # Test biparitite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    conv = PointTransformerConv((16, 8), 32)\n    assert str(conv) == 'PointTransformerConv((16, 8), 32)'\n\n    out = conv((x1, x2), (pos1, pos2), edge_index)\n    assert out.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), (pos1, pos2), adj1.t()), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv((x1, x2), (pos1, pos2), adj2.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), (pos1, pos2), edge_index), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), (pos1, pos2), adj2.t()), out)\n"
  },
  {
    "path": "test/nn/conv/test_ppf_conv.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import PPFConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_ppf_conv():\n    x1 = torch.randn(4, 16)\n    pos1 = torch.randn(4, 3)\n    pos2 = torch.randn(2, 3)\n    n1 = F.normalize(torch.rand(4, 3), dim=-1)\n    n2 = F.normalize(torch.rand(2, 3), dim=-1)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32))\n    global_nn = Seq(Lin(32, 32))\n    conv = PPFConv(local_nn, global_nn)\n    assert str(conv) == (\n        'PPFConv(local_nn=Sequential(\\n'\n        '  (0): Linear(in_features=20, out_features=32, bias=True)\\n'\n        '  (1): ReLU()\\n'\n        '  (2): Linear(in_features=32, out_features=32, bias=True)\\n'\n        '), global_nn=Sequential(\\n'\n        '  (0): Linear(in_features=32, out_features=32, bias=True)\\n'\n        '))')\n\n    out = conv(x1, pos1, n1, edge_index)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, pos1, n1, adj1.t()), out, atol=1e-3)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, pos1, n1, adj2.t()), out, atol=1e-3)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, pos1, n1, edge_index), out, atol=1e-3)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, pos1, n1, adj2.t()), out, atol=1e-3)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    out = conv(x1, (pos1, pos2), (n1, n2), edge_index)\n    assert out.size() == (2, 32)\n    assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), edge_index),\n                          out, atol=1e-3)\n    assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj1.t()), out,\n                          atol=1e-3)\n    assert torch.allclose(conv((x1, None), (pos1, pos2), (n1, n2), adj1.t()),\n                          out, atol=1e-3)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv(x1, (pos1, pos2), (n1, n2), adj2.t()), out,\n                              atol=1e-3)\n        assert torch.allclose(\n            conv((x1, None), (pos1, pos2), (n1, n2), adj2.t()), out, atol=1e-3)\n\n    if is_full_test():\n        assert torch.allclose(\n            jit((x1, None), (pos1, pos2), (n1, n2), edge_index), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(\n                jit((x1, None), (pos1, pos2), (n1, n2), adj2.t()), out,\n                atol=1e-3)\n"
  },
  {
    "path": "test/nn/conv/test_res_gated_graph_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import ResGatedGraphConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\n@pytest.mark.parametrize('edge_dim', [None, 4])\ndef test_res_gated_graph_conv(edge_dim):\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 32)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_attr = torch.randn(edge_index.size(1), edge_dim) if edge_dim else None\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = ResGatedGraphConv(8, 32, edge_dim=edge_dim)\n    assert str(conv) == 'ResGatedGraphConv(8, 32)'\n\n    out = conv(x1, edge_index, edge_attr)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, adj1.t(), edge_attr), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index, edge_attr), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    conv = ResGatedGraphConv((8, 32), 32, edge_dim=edge_dim)\n    assert str(conv) == 'ResGatedGraphConv((8, 32), 32)'\n\n    out = conv((x1, x2), edge_index, edge_attr)\n    assert out.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), adj1.t(), edge_attr), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 2))\n        assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index, edge_attr), out,\n                              atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_rgat_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import RGATConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_coo_tensor\n\n\n@pytest.mark.parametrize('mod', [\n    'additive',\n    'scaled',\n    'f-additive',\n    'f-scaled',\n])\n@pytest.mark.parametrize('attention_mechanism', [\n    'within-relation',\n    'across-relation',\n])\n@pytest.mark.parametrize('attention_mode', [\n    'additive-self-attention',\n    'multiplicative-self-attention',\n])\n@pytest.mark.parametrize('concat', [True, False])\n@pytest.mark.parametrize('edge_dim', [8, None])\ndef test_rgat_conv(mod, attention_mechanism, attention_mode, concat, edge_dim):\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_type = torch.tensor([0, 2, 1, 2])\n    edge_attr = torch.randn((4, edge_dim)) if edge_dim else None\n\n    conv1 = RGATConv(  # `num_bases` is not None:\n        in_channels=8,\n        out_channels=16,\n        num_relations=4,\n        num_bases=4,\n        mod=mod,\n        attention_mechanism=attention_mechanism,\n        attention_mode=attention_mode,\n        heads=2,\n        dim=1,\n        concat=concat,\n        edge_dim=edge_dim,\n    )\n\n    conv2 = RGATConv(  # `num_blocks` is not `None`\n        in_channels=8,\n        out_channels=16,\n        num_relations=4,\n        num_blocks=4,\n        mod=mod,\n        attention_mechanism=attention_mechanism,\n        attention_mode=attention_mode,\n        heads=2,\n        dim=1,\n        concat=concat,\n        edge_dim=edge_dim,\n    )\n\n    conv3 = RGATConv(  # Both `num_bases` and `num_blocks` are `None`:\n        in_channels=8,\n        out_channels=16,\n        num_relations=4,\n        mod=mod,\n        attention_mechanism=attention_mechanism,\n        attention_mode=attention_mode,\n        heads=2,\n        dim=1,\n        concat=concat,\n        edge_dim=edge_dim,\n    )\n\n    conv4 = RGATConv(  # `dropout > 0` and `mod` is `None`:\n        in_channels=8,\n        out_channels=16,\n        num_relations=4,\n        mod=None,\n        attention_mechanism=attention_mechanism,\n        attention_mode=attention_mode,\n        heads=2,\n        dim=1,\n        concat=concat,\n        edge_dim=edge_dim,\n        dropout=0.5,\n    )\n\n    for conv in [conv1, conv2, conv3, conv4]:\n        assert str(conv) == 'RGATConv(8, 16, heads=2)'\n\n        out = conv(x, edge_index, edge_type, edge_attr)\n        assert out.size() == (4, 16 * (2 if concat else 1))\n\n        out, (adj, alpha) = conv(x, edge_index, edge_type, edge_attr,\n                                 return_attention_weights=True)\n        assert out.size() == (4, 16 * (2 if concat else 1))\n        assert adj.size() == edge_index.size()\n        assert alpha.size() == (4, 2)\n\n\ndef test_rgat_conv_jit():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_attr = torch.randn((edge_index.size(1), 8))\n    edge_type = torch.tensor([0, 2, 1, 2])\n    adj1 = to_torch_coo_tensor(edge_index, edge_attr, size=(4, 4))\n\n    conv = RGATConv(8, 20, num_relations=4, num_bases=4, mod='additive',\n                    attention_mechanism='across-relation',\n                    attention_mode='additive-self-attention', heads=2, dim=1,\n                    edge_dim=8, bias=False)\n\n    out = conv(x, edge_index, edge_type, edge_attr)\n    assert out.size() == (4, 40)\n    # t() expects a tensor with <= 2 sparse and 0 dense dimensions\n    adj1_t = adj1.transpose(0, 1).coalesce()\n    assert torch.allclose(conv(x, adj1_t, edge_type), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4))\n        assert torch.allclose(conv(x, adj2.t(), edge_type), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index, edge_type),\n                              conv(x, edge_index, edge_type))\n"
  },
  {
    "path": "test/nn/conv/test_rgcn_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import FastRGCNConv, RGCNConv\nfrom torch_geometric.testing import is_full_test, withCUDA, withDevice\nfrom torch_geometric.typing import SparseTensor\n\nclasses = [RGCNConv, FastRGCNConv]\nconfs = [(None, None), (2, None), (None, 2)]\n\n\n@withDevice\n@pytest.mark.parametrize('conf', confs)\ndef test_rgcn_conv_equality(conf, device):\n    num_bases, num_blocks = conf\n\n    x1 = torch.randn(4, 4, device=device)\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3],\n        [0, 0, 1, 0, 1, 1],\n    ], device=device)\n    edge_type = torch.tensor([0, 1, 1, 0, 0, 1], device=device)\n\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3],\n        [0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1],\n    ], device=device)\n    edge_type = torch.tensor([0, 1, 1, 0, 0, 1, 2, 3, 3, 2, 2, 3],\n                             device=device)\n\n    torch.manual_seed(12345)\n    conv1 = RGCNConv(4, 32, 4, num_bases, num_blocks, aggr='sum').to(device)\n\n    torch.manual_seed(12345)\n    conv2 = FastRGCNConv(4, 32, 4, num_bases, num_blocks,\n                         aggr='sum').to(device)\n\n    out1 = conv1(x1, edge_index, edge_type)\n    out2 = conv2(x1, edge_index, edge_type)\n    assert torch.allclose(out1, out2, atol=1e-2)\n\n    if num_blocks is None:\n        out1 = conv1(None, edge_index, edge_type)\n        out2 = conv2(None, edge_index, edge_type)\n        assert torch.allclose(out1, out2, atol=1e-2)\n\n\n@withCUDA\n@pytest.mark.parametrize('cls', classes)\n@pytest.mark.parametrize('conf', confs)\ndef test_rgcn_conv_basic(cls, conf, device):\n    num_bases, num_blocks = conf\n\n    x1 = torch.randn(4, 4, device=device)\n    x2 = torch.randn(2, 16, device=device)\n    idx1 = torch.arange(4, device=device)\n    idx2 = torch.arange(2, device=device)\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3],\n        [0, 0, 1, 0, 1, 1],\n    ], device=device)\n    edge_type = torch.tensor([0, 1, 1, 0, 0, 1], device=device)\n\n    conv = cls(4, 32, 2, num_bases, num_blocks, aggr='sum').to(device)\n    assert str(conv) == f'{cls.__name__}(4, 32, num_relations=2)'\n\n    out1 = conv(x1, edge_index, edge_type)\n    assert out1.size() == (4, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 4))\n        assert torch.allclose(conv(x1, adj.t()), out1, atol=1e-3)\n\n    if num_blocks is None:\n        out2 = conv(None, edge_index, edge_type)\n        assert torch.allclose(conv(idx1, edge_index, edge_type), out2, 1e-3)\n        assert out2.size() == (4, 32)\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(conv(None, adj.t()), out2, atol=1e-3)\n            assert torch.allclose(conv(idx1, adj.t()), out2, atol=1e-3)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index, edge_type), out1, atol=1e-3)\n        if num_blocks is None:\n            assert torch.allclose(jit(idx1, edge_index, edge_type), out2,\n                                  atol=1e-3)\n            assert torch.allclose(jit(None, edge_index, edge_type), out2,\n                                  atol=1e-3)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj.t()), out1)\n            if num_blocks is None:\n                assert torch.allclose(jit(idx1, adj.t()), out2, atol=1e-3)\n                assert torch.allclose(jit(None, adj.t()), out2, atol=1e-3)\n\n    # Test bipartite message passing:\n    conv = cls((4, 16), 32, 2, num_bases, num_blocks, aggr='sum').to(device)\n    assert str(conv) == f'{cls.__name__}((4, 16), 32, num_relations=2)'\n\n    out1 = conv((x1, x2), edge_index, edge_type)\n    assert out1.size() == (2, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 2))\n        assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-3)\n\n    if num_blocks is None:\n        out2 = conv((None, idx2), edge_index, edge_type)\n        assert out2.size() == (2, 32)\n        assert torch.allclose(conv((idx1, idx2), edge_index, edge_type), out2,\n                              atol=1e-3)\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(conv((None, idx2), adj.t()), out2, atol=1e-3)\n            assert torch.allclose(conv((idx1, idx2), adj.t()), out2, atol=1e-3)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index, edge_type), out1,\n                              atol=1e-3)\n        if num_blocks is None:\n            assert torch.allclose(jit((None, idx2), edge_index, edge_type),\n                                  out2, atol=1e-3)\n            assert torch.allclose(jit((idx1, idx2), edge_index, edge_type),\n                                  out2, atol=1e-3)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-3)\n            if num_blocks is None:\n                assert torch.allclose(jit((None, idx2), adj.t()), out2,\n                                      atol=1e-3)\n                assert torch.allclose(jit((idx1, idx2), adj.t()), out2,\n                                      atol=1e-3)\n"
  },
  {
    "path": "test/nn/conv/test_sage_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import MLPAggregation, SAGEConv\nfrom torch_geometric.testing import (\n    assert_module,\n    is_full_test,\n    onlyLinux,\n    withDevice,\n    withPackage,\n)\nfrom torch_geometric.typing import SparseTensor\n\n\n@pytest.mark.parametrize('project', [False, True])\n@pytest.mark.parametrize('aggr', ['mean', 'sum'])\ndef test_sage_conv(project, aggr):\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    conv = SAGEConv(8, 32, project=project, aggr=aggr)\n    assert str(conv) == f'SAGEConv(8, 32, aggr={aggr})'\n\n    out = assert_module(conv, x, edge_index, expected_size=(4, 32))\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out, atol=1e-6)\n        assert torch.allclose(jit(x, edge_index, size=(4, 4)), out, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n            assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)\n\n    # Test bipartite message passing:\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n\n    conv = SAGEConv((8, 16), 32, project=project, aggr=aggr)\n    assert str(conv) == f'SAGEConv((8, 16), 32, aggr={aggr})'\n\n    out1 = assert_module(conv, (x1, x2), edge_index, expected_size=(2, 32))\n    out2 = assert_module(conv, (x1, None), edge_index, size=(4, 2),\n                         expected_size=(2, 32))\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index), out1, atol=1e-6)\n        assert torch.allclose(jit((x1, x2), edge_index, size=(4, 2)), out1)\n        assert torch.allclose(jit((x1, None), edge_index, size=(4, 2)), out2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n            assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)\n            assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6)\n\n\n@pytest.mark.parametrize('project', [False, True])\ndef test_lazy_sage_conv(project):\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    if project:\n        with pytest.raises(ValueError, match=\"does not support lazy\"):\n            SAGEConv(-1, 32, project=project)\n    else:\n        conv = SAGEConv(-1, 32, project=project)\n        assert str(conv) == 'SAGEConv(-1, 32, aggr=mean)'\n\n        out = conv(x, edge_index)\n        assert out.size() == (4, 32)\n\n\ndef test_lstm_aggr_sage_conv():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    conv = SAGEConv(8, 32, aggr='lstm')\n    assert str(conv) == 'SAGEConv(8, 32, aggr=lstm)'\n\n    assert_module(conv, x, edge_index, expected_size=(4, 32),\n                  test_edge_permutation=False)\n\n    edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 0]])\n    with pytest.raises(ValueError, match=\"'index' tensor is not sorted\"):\n        conv(x, edge_index)\n\n\ndef test_mlp_sage_conv():\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    conv = SAGEConv(\n        in_channels=8,\n        out_channels=32,\n        aggr=MLPAggregation(\n            in_channels=8,\n            out_channels=8,\n            max_num_elements=2,\n            num_layers=1,\n        ),\n    )\n\n    out = conv(x, edge_index)\n    assert out.size() == (4, 32)\n\n\n@pytest.mark.parametrize('aggr_kwargs', [\n    dict(mode='cat'),\n    dict(mode='proj', mode_kwargs=dict(in_channels=8, out_channels=16)),\n    dict(mode='attn', mode_kwargs=dict(in_channels=8, out_channels=16,\n                                       num_heads=4)),\n    dict(mode='sum'),\n])\ndef test_multi_aggr_sage_conv(aggr_kwargs):\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    aggr_kwargs['aggrs_kwargs'] = [{}, {}, {}, dict(learn=True, t=1)]\n    conv = SAGEConv(8, 32, aggr=['mean', 'max', 'sum', 'softmax'],\n                    aggr_kwargs=aggr_kwargs)\n\n    assert_module(conv, x, edge_index, expected_size=(4, 32))\n\n\n@withDevice\n@onlyLinux\n@withPackage('torch>=2.1.0')\ndef test_compile_multi_aggr_sage_conv(device):\n    import torch._dynamo as dynamo\n\n    x = torch.randn(4, 8, device=device)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]], device=device)\n\n    conv = SAGEConv(\n        in_channels=8,\n        out_channels=32,\n        aggr=['mean', 'sum', 'min', 'max', 'std'],\n    ).to(device)\n\n    explanation = dynamo.explain(conv)(x, edge_index)\n    assert explanation.graph_break_count == 0\n\n    compiled_conv = torch.compile(conv)\n\n    expected = conv(x, edge_index)\n    out = compiled_conv(x, edge_index)\n    assert torch.allclose(out, expected, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_sg_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import SGConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_sg_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    value = torch.rand(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = SGConv(16, 32, K=10)\n    assert str(conv) == 'SGConv(16, 32, K=10)'\n\n    out1 = conv(x, edge_index)\n    assert out1.size() == (4, 32)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n\n    out2 = conv(x, edge_index, value)\n    assert out2.size() == (4, 32)\n    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)\n        assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)\n\n    conv.cached = True\n    conv(x, edge_index)\n    assert conv._cached_x is not None\n    assert torch.allclose(conv(x, edge_index), out1, atol=1e-6)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_signed_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import SignedConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_signed_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv1 = SignedConv(16, 32, first_aggr=True)\n    assert str(conv1) == 'SignedConv(16, 32, first_aggr=True)'\n\n    conv2 = SignedConv(32, 48, first_aggr=False)\n    assert str(conv2) == 'SignedConv(32, 48, first_aggr=False)'\n\n    out1 = conv1(x, edge_index, edge_index)\n    assert out1.size() == (4, 64)\n    assert torch.allclose(conv1(x, adj1.t(), adj1.t()), out1)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv1(x, adj2.t(), adj2.t()), out1)\n\n    out2 = conv2(out1, edge_index, edge_index)\n    assert out2.size() == (4, 96)\n    assert torch.allclose(conv2(out1, adj1.t(), adj1.t()), out2)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv2(out1, adj2.t(), adj2.t()), out2)\n\n    if is_full_test():\n        jit1 = torch.jit.script(conv1)\n        jit2 = torch.jit.script(conv2)\n        assert torch.allclose(jit1(x, edge_index, edge_index), out1)\n        assert torch.allclose(jit2(out1, edge_index, edge_index), out2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit1(x, adj2.t(), adj2.t()), out1)\n            assert torch.allclose(jit2(out1, adj2.t(), adj2.t()), out2)\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    assert torch.allclose(conv1((x, x[:2]), edge_index, edge_index), out1[:2],\n                          atol=1e-6)\n    assert torch.allclose(conv1((x, x[:2]), adj1.t(), adj1.t()), out1[:2],\n                          atol=1e-6)\n    assert torch.allclose(conv2((out1, out1[:2]), edge_index, edge_index),\n                          out2[:2], atol=1e-6)\n    assert torch.allclose(conv2((out1, out1[:2]), adj1.t(), adj1.t()),\n                          out2[:2], atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2))\n        assert torch.allclose(conv1((x, x[:2]), adj2.t(), adj2.t()), out1[:2],\n                              atol=1e-6)\n        assert torch.allclose(conv2((out1, out1[:2]), adj2.t(), adj2.t()),\n                              out2[:2], atol=1e-6)\n\n    if is_full_test():\n        assert torch.allclose(jit1((x, x[:2]), edge_index, edge_index),\n                              out1[:2], atol=1e-6)\n        assert torch.allclose(jit2((out1, out1[:2]), edge_index, edge_index),\n                              out2[:2], atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit1((x, x[:2]), adj2.t(), adj2.t()),\n                                  out1[:2], atol=1e-6)\n            assert torch.allclose(jit2((out1, out1[:2]), adj2.t(), adj2.t()),\n                                  out2[:2], atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_simple_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import SimpleConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\n@pytest.mark.parametrize('aggr, combine_root', [\n    ('mean', None),\n    ('sum', 'sum'),\n    (['mean', 'max'], 'cat'),\n    ('mean', 'self_loop'),\n])\ndef test_simple_conv(aggr, combine_root):\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 1, 1]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = SimpleConv(aggr, combine_root)\n    assert str(conv) == 'SimpleConv()'\n\n    num_aggrs = 1 if isinstance(aggr, str) else len(aggr)\n    output_size = sum([8] * num_aggrs) + (8 if combine_root == 'cat' else 0)\n\n    out = conv(x1, edge_index)\n    assert out.size() == (4, output_size)\n    assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out)\n    assert torch.allclose(conv(x1, adj1.t()), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index), out)\n        assert torch.allclose(jit(x1, edge_index, size=(4, 4)), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out)\n\n    # Test bipartite message passing:\n    if combine_root != 'self_loop':\n        adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n        out = conv((x1, x2), edge_index)\n        assert out.size() == (2, output_size)\n        assert torch.allclose(conv((x1, x2), edge_index, size=(4, 2)), out)\n        assert torch.allclose(conv((x1, x2), adj1.t()), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            adj2 = SparseTensor.from_edge_index(edge_index,\n                                                sparse_sizes=(4, 2))\n            assert torch.allclose(conv((x1, x2), adj2.t()), out)\n"
  },
  {
    "path": "test/nn/conv/test_spline_conv.py",
    "content": "import warnings\n\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import SplineConv\nfrom torch_geometric.testing import is_full_test, withPackage\nfrom torch_geometric.typing import SparseTensor\n\n\n@withPackage('pyg_lib')\ndef test_spline_conv():\n    warnings.filterwarnings('ignore', '.*non-optimized CPU version.*')\n\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.rand(edge_index.size(1), 3)\n\n    conv = SplineConv(8, 32, dim=3, kernel_size=5)\n    assert str(conv) == 'SplineConv(8, 32, dim=3)'\n    out = conv(x1, edge_index, value)\n    assert out.size() == (4, 32)\n    assert torch.allclose(conv(x1, edge_index, value, size=(4, 4)), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x1, adj.t()), out, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x1, edge_index, value), out, atol=1e-6)\n        assert torch.allclose(jit(x1, edge_index, value, size=(4, 4)), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj.t()), out, atol=1e-6)\n\n    # Test bipartite message passing:\n    conv = SplineConv((8, 16), 32, dim=3, kernel_size=5)\n    assert str(conv) == 'SplineConv((8, 16), 32, dim=3)'\n\n    out1 = conv((x1, x2), edge_index, value)\n    assert out1.size() == (2, 32)\n    assert torch.allclose(conv((x1, x2), edge_index, value, (4, 2)), out1)\n\n    out2 = conv((x1, None), edge_index, value, (4, 2))\n    assert out2.size() == (2, 32)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, value, (4, 2))\n        assert torch.allclose(conv((x1, x2), adj.t()), out1, atol=1e-6)\n        assert torch.allclose(conv((x1, None), adj.t()), out2, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit((x1, x2), edge_index, value), out1)\n        assert torch.allclose(jit((x1, x2), edge_index, value, size=(4, 2)),\n                              out1, atol=1e-6)\n        assert torch.allclose(jit((x1, None), edge_index, value, size=(4, 2)),\n                              out2, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj.t()), out1, atol=1e-6)\n            assert torch.allclose(jit((x1, None), adj.t()), out2, atol=1e-6)\n\n\n@withPackage('pyg_lib')\ndef test_lazy_spline_conv():\n    warnings.filterwarnings('ignore', '.*non-optimized CPU version.*')\n\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    value = torch.rand(edge_index.size(1), 3)\n\n    conv = SplineConv(-1, 32, dim=3, kernel_size=5)\n    assert str(conv) == 'SplineConv(-1, 32, dim=3)'\n    out = conv(x1, edge_index, value)\n    assert out.size() == (4, 32)\n\n    conv = SplineConv((-1, -1), 32, dim=3, kernel_size=5)\n    assert str(conv) == 'SplineConv((-1, -1), 32, dim=3)'\n    out = conv((x1, x2), edge_index, value)\n    assert out.size() == (2, 32)\n"
  },
  {
    "path": "test/nn/conv/test_ssg_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import SSGConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_ssg_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    value = torch.rand(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = SSGConv(16, 32, alpha=0.1, K=10)\n    assert str(conv) == 'SSGConv(16, 32, K=10, alpha=0.1)'\n\n    out1 = conv(x, edge_index)\n    assert out1.size() == (4, 32)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n\n    out2 = conv(x, edge_index, value)\n    assert out2.size() == (4, 32)\n    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n        assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)\n        assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)\n\n    conv.cached = True\n    conv(x, edge_index)\n    assert conv._cached_h is not None\n    assert torch.allclose(conv(x, edge_index), out1, atol=1e-6)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_static_graph.py",
    "content": "import torch\n\nfrom torch_geometric.data import Batch, Data\nfrom torch_geometric.nn import ChebConv, GCNConv, MessagePassing\n\n\nclass MyConv(MessagePassing):\n    def forward(self, x, edge_index):\n        return self.propagate(edge_index, x=x)\n\n\ndef test_static_graph():\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    x1, x2 = torch.randn(3, 8), torch.randn(3, 8)\n\n    data1 = Data(edge_index=edge_index, x=x1)\n    data2 = Data(edge_index=edge_index, x=x2)\n    batch = Batch.from_data_list([data1, data2])\n\n    x = torch.stack([x1, x2], dim=0)\n    for conv in [MyConv(), GCNConv(8, 16), ChebConv(8, 16, K=2)]:\n        out1 = conv(batch.x, batch.edge_index)\n        assert out1.size(0) == 6\n        conv.node_dim = 1\n        out2 = conv(x, edge_index)\n        assert out2.size()[:2] == (2, 3)\n        assert torch.allclose(out1, out2.view(-1, out2.size(-1)))\n"
  },
  {
    "path": "test/nn/conv/test_supergat_conv.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import SuperGATConv\nfrom torch_geometric.typing import SparseTensor\n\n\n@pytest.mark.parametrize('att_type', ['MX', 'SD'])\ndef test_supergat_conv(att_type):\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    conv = SuperGATConv(8, 32, heads=2, attention_type=att_type,\n                        neg_sample_ratio=1.0, edge_sample_ratio=1.0)\n    assert str(conv) == f'SuperGATConv(8, 32, heads=2, type={att_type})'\n\n    out = conv(x, edge_index)\n    assert out.size() == (4, 64)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x, adj.t()), out, atol=1e-6)\n\n    # Negative samples are given:\n    neg_edge_index = conv.negative_sampling(edge_index, x.size(0))\n    assert torch.allclose(conv(x, edge_index, neg_edge_index), out)\n    att_loss = conv.get_attention_loss()\n    assert isinstance(att_loss, torch.Tensor) and att_loss > 0\n\n    # Batch of graphs:\n    x = torch.randn(8, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7],\n                               [0, 0, 1, 1, 4, 4, 5, 5]])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])\n    out = conv(x, edge_index, batch=batch)\n    assert out.size() == (8, 64)\n\n    # Batch of graphs and negative samples are given:\n    neg_edge_index = conv.negative_sampling(edge_index, x.size(0), batch)\n    assert torch.allclose(conv(x, edge_index, neg_edge_index), out)\n    att_loss = conv.get_attention_loss()\n    assert isinstance(att_loss, torch.Tensor) and att_loss > 0\n"
  },
  {
    "path": "test/nn/conv/test_tag_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import TAGConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\ndef test_tag_conv():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    value = torch.rand(edge_index.size(1))\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n    adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))\n\n    conv = TAGConv(16, 32)\n    assert str(conv) == 'TAGConv(16, 32, K=3)'\n\n    out1 = conv(x, edge_index)\n    assert out1.size() == (4, 32)\n    assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)\n\n    out2 = conv(x, edge_index, value)\n    assert out2.size() == (4, 32)\n    assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4))\n        assert torch.allclose(conv(x, adj3.t()), out1, atol=1e-6)\n        assert torch.allclose(conv(x, adj4.t()), out2, atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out1, atol=1e-6)\n        assert torch.allclose(jit(x, edge_index, value), out2, atol=1e-6)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj3.t()), out1, atol=1e-6)\n            assert torch.allclose(jit(x, adj4.t()), out2, atol=1e-6)\n\n\ndef test_static_tag_conv():\n    x = torch.randn(3, 4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    conv = TAGConv(16, 32)\n    out = conv(x, edge_index)\n    assert out.size() == (3, 4, 32)\n"
  },
  {
    "path": "test/nn/conv/test_transformer_conv.py",
    "content": "from typing import Optional, Tuple\n\nimport pytest\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import TransformerConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import Adj, SparseTensor\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\n@pytest.mark.parametrize('edge_dim', [None, 8])\n@pytest.mark.parametrize('concat', [True, False])\ndef test_transformer_conv(edge_dim, concat):\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 16)\n    out_channels = 32\n    heads = 2\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_attr = torch.randn(edge_index.size(1), edge_dim) if edge_dim else None\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = TransformerConv(8, out_channels, heads, beta=True,\n                           edge_dim=edge_dim, concat=concat)\n    assert str(conv) == f'TransformerConv(8, {out_channels}, heads={heads})'\n\n    out = conv(x1, edge_index, edge_attr)\n    assert out.size() == (4, out_channels * (heads if concat else 1))\n    assert torch.allclose(conv(x1, adj1.t(), edge_attr), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, edge_attr,\n                                            sparse_sizes=(4, 4))\n        assert torch.allclose(conv(x1, adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tensor,\n                edge_index: Adj,\n                edge_attr: Optional[Tensor] = None,\n            ) -> Tensor:\n                return self.conv(x, edge_index, edge_attr)\n\n        jit = torch.jit.script(MyModule())\n        assert torch.allclose(jit(x1, edge_index, edge_attr), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x1, adj2.t()), out, atol=1e-6)\n\n    # Test `return_attention_weights`.\n    result = conv(x1, edge_index, edge_attr, return_attention_weights=True)\n    assert torch.allclose(result[0], out)\n    assert result[1][0].size() == (2, 4)\n    assert result[1][1].size() == (4, 2)\n    assert result[1][1].min() >= 0 and result[1][1].max() <= 1\n    assert conv._alpha is None\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        result = conv(x1, adj2.t(), return_attention_weights=True)\n        assert torch.allclose(result[0], out, atol=1e-6)\n        assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 4\n        assert conv._alpha is None\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tensor,\n                edge_index: Tensor,\n                edge_attr: Optional[Tensor],\n            ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:\n                return self.conv(x, edge_index, edge_attr,\n                                 return_attention_weights=True)\n\n        jit = torch.jit.script(MyModule())\n        result = jit(x1, edge_index, edge_attr)\n        assert torch.allclose(result[0], out)\n        assert result[1][0].size() == (2, 4)\n        assert result[1][1].size() == (4, 2)\n        assert result[1][1].min() >= 0 and result[1][1].max() <= 1\n        assert conv._alpha is None\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n\n            class MyModule(torch.nn.Module):\n                def __init__(self):\n                    super().__init__()\n                    self.conv = conv\n\n                def forward(\n                    self,\n                    x: Tensor,\n                    edge_index: SparseTensor,\n                ) -> Tuple[Tensor, SparseTensor]:\n                    return self.conv(x, edge_index,\n                                     return_attention_weights=True)\n\n            jit = torch.jit.script(MyModule())\n            result = jit(x1, adj2.t())\n            assert torch.allclose(result[0], out, atol=1e-6)\n            assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 4\n            assert conv._alpha is None\n\n    # Test bipartite message passing:\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 2))\n\n    conv = TransformerConv((8, 16), out_channels, heads=heads, beta=True,\n                           edge_dim=edge_dim, concat=concat)\n    assert str(conv) == (f'TransformerConv((8, 16), {out_channels}, '\n                         f'heads={heads})')\n\n    out = conv((x1, x2), edge_index, edge_attr)\n    assert out.size() == (2, out_channels * (heads if concat else 1))\n    assert torch.allclose(conv((x1, x2), adj1.t(), edge_attr), out, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, edge_attr,\n                                            sparse_sizes=(4, 2))\n        assert torch.allclose(conv((x1, x2), adj2.t()), out, atol=1e-6)\n\n    if is_full_test():\n\n        class MyModule(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv = conv\n\n            def forward(\n                self,\n                x: Tuple[Tensor, Tensor],\n                edge_index: Adj,\n                edge_attr: Optional[Tensor] = None,\n            ) -> Tensor:\n                return self.conv(x, edge_index, edge_attr)\n\n        jit = torch.jit.script(MyModule())\n        assert torch.allclose(jit((x1, x2), edge_index, edge_attr), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, x2), adj2.t()), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_wl_conv.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import WLConv\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import one_hot, to_torch_csc_tensor\n\n\ndef test_wl_conv():\n    x1 = torch.tensor([1, 0, 0, 1])\n    x2 = one_hot(x1)\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])\n    adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))\n\n    conv = WLConv()\n    assert str(conv) == 'WLConv()'\n\n    out = conv(x1, edge_index)\n    assert out.tolist() == [0, 1, 1, 0]\n    assert torch.equal(conv(x2, edge_index), out)\n    assert torch.equal(conv(x1, adj1.t()), out)\n    assert torch.equal(conv(x2, adj1.t()), out)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.equal(conv(x1, adj2.t()), out)\n        assert torch.equal(conv(x2, adj2.t()), out)\n\n    assert conv.histogram(out).tolist() == [[2, 2]]\n    assert torch.allclose(conv.histogram(out, norm=True),\n                          torch.tensor([[0.7071, 0.7071]]))\n"
  },
  {
    "path": "test/nn/conv/test_wl_conv_continuous.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import WLConvContinuous\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\n\n\ndef test_wl_conv():\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)\n    x = torch.tensor([[-1], [0], [1]], dtype=torch.float)\n\n    conv = WLConvContinuous()\n    assert str(conv) == 'WLConvContinuous()'\n\n    out = conv(x, edge_index)\n    assert out.tolist() == [[-0.5], [0.0], [0.5]]\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(3, 3))\n        assert torch.allclose(conv(x, adj.t()), out)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n        assert torch.allclose(jit(x, edge_index), out)\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)\n\n    # Test bipartite message passing:\n    x1 = torch.randn(4, 8)\n    x2 = torch.randn(2, 8)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_weight = torch.randn(edge_index.size(1))\n\n    out1 = conv((x1, None), edge_index, edge_weight, size=(4, 2))\n    assert out1.size() == (2, 8)\n\n    out2 = conv((x1, x2), edge_index, edge_weight)\n    assert out2.size() == (2, 8)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, edge_weight, (4, 2))\n        assert torch.allclose(conv((x1, None), adj.t()), out1)\n        assert torch.allclose(conv((x1, x2), adj.t()), out2)\n\n    if is_full_test():\n        assert torch.allclose(\n            jit((x1, None), edge_index, edge_weight, size=(4, 2)), out1)\n        assert torch.allclose(jit((x1, x2), edge_index, edge_weight), out2)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit((x1, None), adj.t()), out1, atol=1e-6)\n            assert torch.allclose(jit((x1, x2), adj.t()), out2, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/test_x_conv.py",
    "content": "import torch\n\nfrom torch_geometric.nn import XConv\nfrom torch_geometric.testing import is_full_test, withPackage\n\n\n@withPackage('torch_cluster')\ndef test_x_conv():\n    x = torch.randn(8, 16)\n    pos = torch.rand(8, 3)\n    batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])\n\n    conv = XConv(16, 32, dim=3, kernel_size=2, dilation=2)\n    assert str(conv) == 'XConv(16, 32)'\n\n    torch.manual_seed(12345)\n    out1 = conv(x, pos)\n    assert out1.size() == (8, 32)\n\n    torch.manual_seed(12345)\n    out2 = conv(x, pos, batch)\n    assert out2.size() == (8, 32)\n\n    if is_full_test():\n        jit = torch.jit.script(conv)\n\n        torch.manual_seed(12345)\n        assert torch.allclose(jit(x, pos), out1, atol=1e-6)\n\n        torch.manual_seed(12345)\n        assert torch.allclose(jit(x, pos, batch), out2, atol=1e-6)\n"
  },
  {
    "path": "test/nn/conv/utils/test_gnn_cheatsheet.py",
    "content": "from torch_geometric.nn.conv import utils\n\n\ndef test_gnn_cheatsheet():\n    assert utils.paper_title('GCNConv') == ('Semi-supervised Classification '\n                                            'with Graph Convolutional '\n                                            'Networks')\n    assert utils.paper_link('GCNConv') == 'https://arxiv.org/abs/1609.02907'\n\n    assert utils.supports_sparse_tensor('GCNConv')\n    assert not utils.supports_sparse_tensor('ChebConv')\n\n    assert utils.supports_edge_weights('GraphConv')\n    assert not utils.supports_edge_weights('SAGEConv')\n\n    assert utils.supports_edge_features('GATConv')\n    assert not utils.supports_edge_features('SimpleConv')\n\n    assert utils.supports_bipartite_graphs('SAGEConv')\n    assert not utils.supports_bipartite_graphs('GCNConv')\n\n    assert utils.supports_static_graphs('GCNConv')\n    assert not utils.supports_static_graphs('GATConv')\n\n    assert utils.supports_lazy_initialization('SAGEConv')\n    assert not utils.supports_lazy_initialization('GatedGraphConv')\n\n    assert utils.processes_heterogeneous_graphs('RGCNConv')\n    assert utils.processes_heterogeneous_graphs('HeteroConv')\n    assert not utils.processes_heterogeneous_graphs('GCNConv')\n\n    assert utils.processes_hypergraphs('HypergraphConv')\n    assert not utils.processes_hypergraphs('SAGEConv')\n\n    assert utils.processes_point_clouds('DynamicEdgeConv')\n    assert utils.processes_point_clouds('XConv')\n    assert not utils.processes_point_clouds('CuGraphSAGEConv')\n"
  },
  {
    "path": "test/nn/dense/test_dense_gat_conv.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import DenseGATConv, GATConv\nfrom torch_geometric.testing import is_full_test\n\n\n@pytest.mark.parametrize('heads', [1, 4])\n@pytest.mark.parametrize('concat', [True, False])\ndef test_dense_gat_conv(heads, concat):\n    channels = 16\n    sparse_conv = GATConv(channels, channels, heads=heads, concat=concat)\n    dense_conv = DenseGATConv(channels, channels, heads=heads, concat=concat)\n    assert str(dense_conv) == f'DenseGATConv(16, 16, heads={heads})'\n\n    # Ensure same weights and bias:\n    dense_conv.lin = sparse_conv.lin\n    dense_conv.att_src = sparse_conv.att_src\n    dense_conv.att_dst = sparse_conv.att_dst\n    dense_conv.bias = sparse_conv.bias\n\n    x = torch.randn((5, channels))\n    edge_index = torch.tensor([[0, 1, 1, 2, 3, 4], [1, 0, 2, 1, 4, 3]])\n\n    sparse_out = sparse_conv(x, edge_index)\n\n    x = torch.cat([x, x.new_zeros(1, channels)], dim=0).view(2, 3, channels)\n    adj = torch.tensor([\n        [\n            [0.0, 1.0, 0.0],\n            [1.0, 0.0, 1.0],\n            [0.0, 1.0, 0.0],\n        ],\n        [\n            [0.0, 1.0, 0.0],\n            [1.0, 0.0, 0.0],\n            [0.0, 0.0, 0.0],\n        ],\n    ])\n    mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool)\n\n    dense_out = dense_conv(x, adj, mask)\n\n    if is_full_test():\n        jit = torch.jit.script(dense_conv)\n        assert torch.allclose(jit(x, adj, mask), dense_out)\n\n    assert dense_out[1, 2].abs().sum() == 0\n    dense_out = dense_out.view(6, dense_out.size(-1))[:-1]\n    assert torch.allclose(sparse_out, dense_out, atol=1e-4)\n\n\ndef test_dense_gat_conv_with_broadcasting():\n    batch_size, num_nodes, channels = 8, 3, 16\n    conv = DenseGATConv(channels, channels, heads=4)\n\n    x = torch.randn(batch_size, num_nodes, channels)\n    adj = torch.tensor([\n        [0.0, 1.0, 1.0],\n        [1.0, 0.0, 1.0],\n        [1.0, 1.0, 0.0],\n    ])\n\n    assert conv(x, adj).size() == (batch_size, num_nodes, 64)\n    mask = torch.tensor([1, 1, 1], dtype=torch.bool)\n    assert conv(x, adj, mask).size() == (batch_size, num_nodes, 64)\n"
  },
  {
    "path": "test/nn/dense/test_dense_gcn_conv.py",
    "content": "import torch\n\nfrom torch_geometric.nn import DenseGCNConv, GCNConv\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_dense_gcn_conv():\n    channels = 16\n    sparse_conv = GCNConv(channels, channels)\n    dense_conv = DenseGCNConv(channels, channels)\n    assert str(dense_conv) == 'DenseGCNConv(16, 16)'\n\n    # Ensure same weights and bias:\n    dense_conv.lin.weight = sparse_conv.lin.weight\n    dense_conv.bias = sparse_conv.bias\n\n    x = torch.randn((5, channels))\n    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4],\n                               [1, 2, 0, 2, 0, 1, 4, 3]])\n\n    sparse_out = sparse_conv(x, edge_index)\n    assert sparse_out.size() == (5, channels)\n\n    x = torch.cat([x, x.new_zeros(1, channels)], dim=0).view(2, 3, channels)\n    adj = torch.tensor([\n        [\n            [0.0, 1.0, 1.0],\n            [1.0, 0.0, 1.0],\n            [1.0, 1.0, 0.0],\n        ],\n        [\n            [0.0, 1.0, 0.0],\n            [1.0, 0.0, 0.0],\n            [0.0, 0.0, 0.0],\n        ],\n    ])\n    mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool)\n\n    dense_out = dense_conv(x, adj, mask)\n    assert dense_out.size() == (2, 3, channels)\n\n    if is_full_test():\n        jit = torch.jit.script(dense_conv)\n        assert torch.allclose(jit(x, adj, mask), dense_out)\n\n    assert dense_out[1, 2].abs().sum() == 0\n    dense_out = dense_out.view(6, channels)[:-1]\n    assert torch.allclose(sparse_out, dense_out, atol=1e-4)\n\n\ndef test_dense_gcn_conv_with_broadcasting():\n    batch_size, num_nodes, channels = 8, 3, 16\n    conv = DenseGCNConv(channels, channels)\n\n    x = torch.randn(batch_size, num_nodes, channels)\n    adj = torch.tensor([\n        [0.0, 1.0, 1.0],\n        [1.0, 0.0, 1.0],\n        [1.0, 1.0, 0.0],\n    ])\n\n    assert conv(x, adj).size() == (batch_size, num_nodes, channels)\n    mask = torch.tensor([1, 1, 1], dtype=torch.bool)\n    assert conv(x, adj, mask).size() == (batch_size, num_nodes, channels)\n"
  },
  {
    "path": "test/nn/dense/test_dense_gin_conv.py",
    "content": "import torch\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nfrom torch_geometric.nn import DenseGINConv, GINConv\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_dense_gin_conv():\n    channels = 16\n    nn = Seq(Lin(channels, channels), ReLU(), Lin(channels, channels))\n    sparse_conv = GINConv(nn)\n    dense_conv = DenseGINConv(nn)\n    dense_conv = DenseGINConv(nn, train_eps=True)\n    assert str(dense_conv) == (\n        'DenseGINConv(nn=Sequential(\\n'\n        '  (0): Linear(in_features=16, out_features=16, bias=True)\\n'\n        '  (1): ReLU()\\n'\n        '  (2): Linear(in_features=16, out_features=16, bias=True)\\n'\n        '))')\n\n    x = torch.randn((5, channels))\n    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4],\n                               [1, 2, 0, 2, 0, 1, 4, 3]])\n\n    sparse_out = sparse_conv(x, edge_index)\n    assert sparse_out.size() == (5, channels)\n\n    x = torch.cat([x, x.new_zeros(1, channels)], dim=0).view(2, 3, channels)\n    adj = torch.tensor([\n        [\n            [0.0, 1.0, 1.0],\n            [1.0, 0.0, 1.0],\n            [1.0, 1.0, 0.0],\n        ],\n        [\n            [0.0, 1.0, 0.0],\n            [1.0, 0.0, 0.0],\n            [0.0, 0.0, 0.0],\n        ],\n    ])\n    mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool)\n\n    dense_out = dense_conv(x, adj, mask)\n    assert dense_out.size() == (2, 3, channels)\n\n    if is_full_test():\n        jit = torch.jit.script(dense_conv)\n        assert torch.allclose(jit(x, adj, mask), dense_out)\n\n    assert dense_out[1, 2].abs().sum().item() == 0\n    dense_out = dense_out.view(6, channels)[:-1]\n    assert torch.allclose(sparse_out, dense_out, atol=1e-04)\n\n\ndef test_dense_gin_conv_with_broadcasting():\n    batch_size, num_nodes, channels = 8, 3, 16\n    nn = Seq(Lin(channels, channels), ReLU(), Lin(channels, channels))\n    conv = DenseGINConv(nn)\n\n    x = torch.randn(batch_size, num_nodes, channels)\n    adj = torch.tensor([\n        [0.0, 1.0, 1.0],\n        [1.0, 0.0, 1.0],\n        [1.0, 1.0, 0.0],\n    ])\n\n    assert conv(x, adj).size() == (batch_size, num_nodes, channels)\n    mask = torch.tensor([1, 1, 1], dtype=torch.bool)\n    assert conv(x, adj, mask).size() == (batch_size, num_nodes, channels)\n"
  },
  {
    "path": "test/nn/dense/test_dense_graph_conv.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import DenseGraphConv, GraphConv\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.utils import to_dense_adj\n\n\n@pytest.mark.parametrize('aggr', ['add', 'mean', 'max'])\ndef test_dense_graph_conv(aggr):\n    channels = 16\n    sparse_conv = GraphConv(channels, channels, aggr=aggr)\n    dense_conv = DenseGraphConv(channels, channels, aggr=aggr)\n    assert str(dense_conv) == 'DenseGraphConv(16, 16)'\n\n    # Ensure same weights and bias.\n    dense_conv.lin_rel = sparse_conv.lin_rel\n    dense_conv.lin_root = sparse_conv.lin_root\n\n    x = torch.randn((5, channels))\n    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4],\n                               [1, 2, 0, 2, 0, 1, 4, 3]])\n\n    sparse_out = sparse_conv(x, edge_index)\n    assert sparse_out.size() == (5, channels)\n\n    adj = to_dense_adj(edge_index)\n    mask = torch.ones(5, dtype=torch.bool)\n\n    dense_out = dense_conv(x, adj, mask)[0]\n\n    assert dense_out.size() == (5, channels)\n    assert torch.allclose(sparse_out, dense_out, atol=1e-04)\n\n    if is_full_test():\n        jit = torch.jit.script(dense_conv)\n        assert torch.allclose(jit(x, adj, mask), dense_out)\n\n\n@pytest.mark.parametrize('aggr', ['add', 'mean', 'max'])\ndef test_dense_graph_conv_batch(aggr):\n    channels = 16\n    sparse_conv = GraphConv(channels, channels, aggr=aggr)\n    dense_conv = DenseGraphConv(channels, channels, aggr=aggr)\n\n    # Ensure same weights and bias.\n    dense_conv.lin_rel = sparse_conv.lin_rel\n    dense_conv.lin_root = sparse_conv.lin_root\n\n    x = torch.randn((5, channels))\n    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4],\n                               [1, 2, 0, 2, 0, 1, 4, 3]])\n\n    sparse_out = sparse_conv(x, edge_index)\n    assert sparse_out.size() == (5, channels)\n\n    x = torch.cat([x, x.new_zeros(1, channels)], dim=0).view(2, 3, channels)\n    adj = torch.tensor([\n        [\n            [0.0, 1.0, 1.0],\n            [1.0, 0.0, 1.0],\n            [1.0, 1.0, 0.0],\n        ],\n        [\n            [0.0, 1.0, 0.0],\n            [1.0, 0.0, 0.0],\n            [0.0, 0.0, 0.0],\n        ],\n    ])\n    mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool)\n\n    dense_out = dense_conv(x, adj, mask)\n    assert dense_out.size() == (2, 3, channels)\n    dense_out = dense_out.view(-1, channels)\n\n    assert torch.allclose(sparse_out, dense_out[:5], atol=1e-04)\n    assert dense_out[-1].abs().sum() == 0\n\n\n@pytest.mark.parametrize('aggr', ['add', 'mean', 'max'])\ndef test_dense_graph_conv_with_broadcasting(aggr):\n    batch_size, num_nodes, channels = 8, 3, 16\n    conv = DenseGraphConv(channels, channels, aggr=aggr)\n\n    x = torch.randn(batch_size, num_nodes, channels)\n    adj = torch.tensor([\n        [0.0, 1.0, 1.0],\n        [1.0, 0.0, 1.0],\n        [1.0, 1.0, 0.0],\n    ])\n\n    assert conv(x, adj).size() == (batch_size, num_nodes, channels)\n    mask = torch.tensor([1, 1, 1], dtype=torch.bool)\n    assert conv(x, adj, mask).size() == (batch_size, num_nodes, channels)\n"
  },
  {
    "path": "test/nn/dense/test_dense_sage_conv.py",
    "content": "import torch\n\nfrom torch_geometric.nn import DenseSAGEConv, SAGEConv\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_dense_sage_conv():\n    channels = 16\n    sparse_conv = SAGEConv(channels, channels, normalize=True)\n    dense_conv = DenseSAGEConv(channels, channels, normalize=True)\n    assert str(dense_conv) == 'DenseSAGEConv(16, 16)'\n\n    # Ensure same weights and bias.\n    dense_conv.lin_rel = sparse_conv.lin_l\n    dense_conv.lin_root = sparse_conv.lin_r\n\n    x = torch.randn((5, channels))\n    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 4],\n                               [1, 2, 0, 2, 0, 1, 4, 3]])\n\n    sparse_out = sparse_conv(x, edge_index)\n    assert sparse_out.size() == (5, channels)\n\n    x = torch.cat([x, x.new_zeros(1, channels)], dim=0).view(2, 3, channels)\n    adj = torch.tensor([\n        [\n            [0.0, 1.0, 1.0],\n            [1.0, 0.0, 1.0],\n            [1.0, 1.0, 0.0],\n        ],\n        [\n            [0.0, 1.0, 0.0],\n            [1.0, 0.0, 0.0],\n            [0.0, 0.0, 0.0],\n        ],\n    ])\n    mask = torch.tensor([[1, 1, 1], [1, 1, 0]], dtype=torch.bool)\n\n    dense_out = dense_conv(x, adj, mask)\n    assert dense_out.size() == (2, 3, channels)\n\n    if is_full_test():\n        jit = torch.jit.script(dense_conv)\n        assert torch.allclose(jit(x, adj, mask), dense_out)\n\n    assert dense_out[1, 2].abs().sum().item() == 0\n    dense_out = dense_out.view(6, channels)[:-1]\n    assert torch.allclose(sparse_out, dense_out, atol=1e-04)\n\n\ndef test_dense_sage_conv_with_broadcasting():\n    batch_size, num_nodes, channels = 8, 3, 16\n    conv = DenseSAGEConv(channels, channels)\n\n    x = torch.randn(batch_size, num_nodes, channels)\n    adj = torch.tensor([\n        [0.0, 1.0, 1.0],\n        [1.0, 0.0, 1.0],\n        [1.0, 1.0, 0.0],\n    ])\n\n    assert conv(x, adj).size() == (batch_size, num_nodes, channels)\n    mask = torch.tensor([1, 1, 1], dtype=torch.bool)\n    assert conv(x, adj, mask).size() == (batch_size, num_nodes, channels)\n"
  },
  {
    "path": "test/nn/dense/test_diff_pool.py",
    "content": "from itertools import product\n\nimport torch\n\nfrom torch_geometric.nn import dense_diff_pool\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_dense_diff_pool():\n    batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10)\n    x = torch.randn((batch_size, num_nodes, channels))\n    adj = torch.rand((batch_size, num_nodes, num_nodes))\n    s = torch.randn((batch_size, num_nodes, num_clusters))\n    mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool)\n\n    x_out, adj_out, link_loss, ent_loss = dense_diff_pool(x, adj, s, mask)\n    assert x_out.size() == (2, 10, 16)\n    assert adj_out.size() == (2, 10, 10)\n    assert link_loss.item() >= 0\n    assert ent_loss.item() >= 0\n\n    if is_full_test():\n        jit = torch.jit.script(dense_diff_pool)\n        x_jit, adj_jit, link_loss, ent_loss = jit(x, adj, s, mask)\n        assert torch.allclose(x_jit, x_out)\n        assert torch.allclose(adj_jit, adj_out)\n        assert link_loss.item() >= 0\n        assert ent_loss.item() >= 0\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    args = parser.parse_args()\n\n    BS = [2**i for i in range(4, 8)]\n    NS = [2**i for i in range(4, 8)]\n    FS = [2**i for i in range(5, 9)]\n    CS = [2**i for i in range(5, 9)]\n\n    funcs = []\n    func_names = []\n    args_list = []\n    for B, N, F, C in product(BS, NS, FS, CS):\n        x = torch.randn(B, N, F, device=args.device)\n        adj = torch.randint(0, 2, (B, N, N), dtype=x.dtype, device=args.device)\n        s = torch.randn(B, N, C, device=args.device)\n\n        funcs.append(dense_diff_pool)\n        func_names.append(f'B={B}, N={N}, F={F}, C={C}')\n        args_list.append((x, adj, s))\n\n    benchmark(\n        funcs=funcs,\n        func_names=func_names,\n        args=args_list,\n        num_steps=50 if args.device == 'cpu' else 500,\n        num_warmups=10 if args.device == 'cpu' else 100,\n        backward=args.backward,\n        progress_bar=True,\n    )\n"
  },
  {
    "path": "test/nn/dense/test_dmon_pool.py",
    "content": "import math\n\nimport torch\n\nfrom torch_geometric.nn import DMoNPooling\n\n\ndef test_dmon_pooling():\n    batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10)\n    x = torch.randn((batch_size, num_nodes, channels))\n    adj = torch.ones((batch_size, num_nodes, num_nodes))\n    mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool)\n\n    pool = DMoNPooling([channels, channels], num_clusters)\n    assert str(pool) == 'DMoNPooling(16, num_clusters=10)'\n\n    s, x, adj, spectral_loss, ortho_loss, cluster_loss = pool(x, adj, mask)\n    assert s.size() == (2, 20, 10)\n    assert x.size() == (2, 10, 16)\n    assert adj.size() == (2, 10, 10)\n    assert -1 <= spectral_loss <= 0.5\n    assert 0 <= ortho_loss <= math.sqrt(2)\n    assert 0 <= cluster_loss <= math.sqrt(num_clusters) - 1\n"
  },
  {
    "path": "test/nn/dense/test_linear.py",
    "content": "import copy\nimport warnings\nfrom typing import List\n\nimport pytest\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Linear as PTLinear\nfrom torch.nn.parameter import UninitializedParameter\n\nimport torch_geometric.backend\nfrom torch_geometric.nn import HeteroDictLinear, HeteroLinear, Linear\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import withCUDA, withDevice, withPackage\nfrom torch_geometric.typing import pyg_lib\nfrom torch_geometric.utils import cumsum\n\nweight_inits = ['glorot', 'kaiming_uniform', None]\nbias_inits = ['zeros', None]\n\n\n@withDevice\n@pytest.mark.parametrize('weight', weight_inits)\n@pytest.mark.parametrize('bias', bias_inits)\ndef test_linear(weight, bias, device):\n    x = torch.randn(3, 4, 16, device=device)\n    lin = Linear(16, 32, weight_initializer=weight, bias_initializer=bias)\n    lin = lin.to(device)\n    assert str(lin) == 'Linear(16, 32, bias=True)'\n    assert lin(x).size() == (3, 4, 32)\n\n\n@withDevice\n@pytest.mark.parametrize('weight', weight_inits)\n@pytest.mark.parametrize('bias', bias_inits)\ndef test_lazy_linear(weight, bias, device):\n    x = torch.randn(3, 4, 16, device=device)\n    lin = Linear(-1, 32, weight_initializer=weight, bias_initializer=bias)\n    lin = lin.to(device)\n    copied_lin = copy.deepcopy(lin)\n\n    assert lin.weight.device == device\n    assert lin.bias.device == device\n    assert str(lin) == 'Linear(-1, 32, bias=True)'\n    assert lin(x).size() == (3, 4, 32)\n    assert str(lin) == 'Linear(16, 32, bias=True)'\n\n    assert copied_lin.weight.device == device\n    assert copied_lin.bias.device == device\n    assert copied_lin(x).size() == (3, 4, 32)\n\n\n@withDevice\n@pytest.mark.parametrize('dim1', [-1, 16])\n@pytest.mark.parametrize('dim2', [-1, 16])\n@pytest.mark.parametrize('bias', [True, False])\ndef test_load_lazy_linear(dim1, dim2, bias, device):\n    lin1 = Linear(dim1, 32, bias=bias).to(device)\n    lin2 = Linear(dim2, 32, bias=bias).to(device)\n    lin2.load_state_dict(lin1.state_dict())\n\n    if dim1 != -1:\n        assert isinstance(lin1.weight, torch.nn.Parameter)\n        assert isinstance(lin2.weight, torch.nn.Parameter)\n        assert torch.allclose(lin1.weight, lin2.weight)\n        assert not hasattr(lin1, '_hook')\n        assert not hasattr(lin2, '_hook')\n    else:\n        assert isinstance(lin1.weight, UninitializedParameter)\n        assert isinstance(lin2.weight, UninitializedParameter)\n        assert hasattr(lin1, '_hook')\n        assert hasattr(lin2, '_hook')\n\n    if bias:\n        assert isinstance(lin1.bias, torch.nn.Parameter)\n        assert isinstance(lin2.bias, torch.nn.Parameter)\n        if dim1 != -1:  # Only check for equality on materialized bias:\n            assert torch.allclose(lin1.bias, lin2.bias)\n    else:\n        assert lin1.bias is None\n        assert lin2.bias is None\n\n    with pytest.raises(RuntimeError, match=\"in state_dict\"):\n        lin1.load_state_dict({}, strict=True)\n    lin1.load_state_dict({}, strict=False)\n\n\n@pytest.mark.parametrize('lazy', [True, False])\ndef test_identical_linear_default_initialization(lazy):\n    x = torch.randn(3, 4, 16)\n\n    torch.manual_seed(12345)\n    lin1 = Linear(-1 if lazy else 16, 32)\n    lin1(x)\n\n    torch.manual_seed(12345)\n    lin2 = PTLinear(16, 32)\n\n    assert torch.equal(lin1.weight, lin2.weight)\n    assert torch.equal(lin1.bias, lin2.bias)\n    assert torch.allclose(lin1(x), lin2(x))\n\n\ndef test_copy_unintialized_parameter():\n    weight = UninitializedParameter()\n    copy.deepcopy(weight)\n\n\n@withDevice\n@pytest.mark.parametrize('lazy', [True, False])\ndef test_copy_linear(lazy, device):\n    lin = Linear(-1 if lazy else 16, 32).to(device)\n\n    copied_lin = copy.copy(lin).to(device)\n    assert id(copied_lin) != id(lin)\n    assert id(copied_lin.weight) == id(lin.weight)\n    if not isinstance(copied_lin.weight, UninitializedParameter):\n        assert copied_lin.weight.data_ptr() == lin.weight.data_ptr()\n    assert id(copied_lin.bias) == id(lin.bias)\n    assert copied_lin.bias.data_ptr() == lin.bias.data_ptr()\n\n    copied_lin = copy.deepcopy(lin).to(device)\n    assert id(copied_lin) != id(lin)\n    assert id(copied_lin.weight) != id(lin.weight)\n    if not isinstance(copied_lin.weight, UninitializedParameter):\n        assert copied_lin.weight.data_ptr() != lin.weight.data_ptr()\n        assert torch.allclose(copied_lin.weight, lin.weight)\n    assert id(copied_lin.bias) != id(lin.bias)\n    assert copied_lin.bias.data_ptr() != lin.bias.data_ptr()\n    if int(torch.isnan(lin.bias).sum()) == 0:\n        assert torch.allclose(copied_lin.bias, lin.bias)\n\n\n@withCUDA\ndef test_hetero_linear_basic(device):\n    x = torch.randn(3, 16, device=device)\n    type_vec = torch.tensor([0, 1, 2], device=device)\n\n    lin = HeteroLinear(16, 32, num_types=3).to(device)\n    assert str(lin) == 'HeteroLinear(16, 32, num_types=3, bias=True)'\n\n    out = lin(x, type_vec)\n    assert out.size() == (3, 32)\n\n    jit = torch.jit.script(lin)\n    assert torch.allclose(jit(x, type_vec), out, atol=1e-3)\n\n\ndef test_hetero_linear_initializer():\n    lin = HeteroLinear(\n        16,\n        32,\n        num_types=3,\n        weight_initializer='glorot',\n        bias_initializer='zeros',\n    )\n    assert torch.equal(lin.bias, torch.zeros_like(lin.bias))\n\n\n@withCUDA\n@pytest.mark.parametrize('use_segment_matmul', [None, True, False])\ndef test_hetero_linear_amp(device, use_segment_matmul):\n    warnings.filterwarnings('ignore', '.*but CUDA is not available.*')\n\n    old_state = torch_geometric.backend.use_segment_matmul\n    torch_geometric.backend.use_segment_matmul = use_segment_matmul\n\n    x = torch.randn(3, 16, device=device)\n    type_vec = torch.tensor([0, 1, 2], device=device)\n\n    lin = HeteroLinear(16, 32, num_types=3).to(device)\n\n    with torch.amp.autocast('cuda'):\n        assert lin(x, type_vec).size() == (3, 32)\n\n    torch_geometric.backend.use_segment_matmul = old_state\n\n\n@withCUDA\ndef test_lazy_hetero_linear(device):\n    x = torch.randn(3, 16, device=device)\n    type_vec = torch.tensor([0, 1, 2], device=device)\n\n    lin = HeteroLinear(-1, 32, num_types=3).to(device)\n    assert str(lin) == 'HeteroLinear(-1, 32, num_types=3, bias=True)'\n\n    out = lin(x, type_vec)\n    assert out.size() == (3, 32)\n\n\n@withDevice\n@pytest.mark.parametrize('bias', [True, False])\ndef test_hetero_dict_linear(bias, device):\n    x_dict = {\n        'v': torch.randn(3, 16, device=device),\n        'w': torch.randn(2, 8, device=device),\n    }\n\n    lin = HeteroDictLinear({'v': 16, 'w': 8}, 32, bias=bias).to(device)\n    assert str(lin) == (f\"HeteroDictLinear({{'v': 16, 'w': 8}}, 32, \"\n                        f\"bias={bias})\")\n\n    out_dict = lin(x_dict)\n    assert len(out_dict) == 2\n    assert out_dict['v'].size() == (3, 32)\n    assert out_dict['w'].size() == (2, 32)\n\n    x_dict = {\n        'v': torch.randn(3, 16, device=device),\n        'w': torch.randn(2, 16, device=device),\n    }\n\n    lin = HeteroDictLinear(16, 32, types=['v', 'w'], bias=bias).to(device)\n    assert str(lin) == (f\"HeteroDictLinear({{'v': 16, 'w': 16}}, 32, \"\n                        f\"bias={bias})\")\n\n    out_dict = lin(x_dict)\n    assert len(out_dict) == 2\n    assert out_dict['v'].size() == (3, 32)\n    assert out_dict['w'].size() == (2, 32)\n\n\ndef test_hetero_dict_linear_jit():\n    x_dict = {\n        'v': torch.randn(3, 16),\n        'w': torch.randn(2, 8),\n    }\n\n    lin = HeteroDictLinear({'v': 16, 'w': 8}, 32)\n\n    jit = torch.jit.script(lin)\n    assert len(jit(x_dict)) == 2\n\n\n@withDevice\ndef test_lazy_hetero_dict_linear(device):\n    x_dict = {\n        'v': torch.randn(3, 16, device=device),\n        'w': torch.randn(2, 8, device=device),\n    }\n\n    lin = HeteroDictLinear(-1, 32, types=['v', 'w']).to(device)\n    assert str(lin) == \"HeteroDictLinear({'v': -1, 'w': -1}, 32, bias=True)\"\n\n    out_dict = lin(x_dict)\n    assert len(out_dict) == 2\n    assert out_dict['v'].size() == (3, 32)\n    assert out_dict['w'].size() == (2, 32)\n\n\n@withCUDA\n@withPackage('pyg_lib')\n@pytest.mark.parametrize('type_vec', [\n    torch.tensor([0, 0, 1, 1, 2, 2]),\n    torch.tensor([0, 1, 2, 0, 1, 2]),\n])\ndef test_hetero_linear_sort(type_vec, device):\n    x = torch.randn(type_vec.numel(), 16, device=device)\n\n    lin = HeteroLinear(16, 32, num_types=3).to(device)\n    out = lin(x, type_vec)\n\n    for i in range(type_vec.numel()):\n        node_type = int(type_vec[i])\n        expected = x[i] @ lin.weight[node_type] + lin.bias[node_type]\n        assert torch.allclose(out[i], expected, atol=1e-3)\n\n\nif __name__ == '__main__':\n    import argparse\n    try:\n        import dgl\n        WITH_DLG = True\n    except Exception:\n        WITH_DGL = False\n\n    warnings.filterwarnings('ignore', '.*API of nested tensors.*')\n    warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*')\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    args = parser.parse_args()\n\n    torch.manual_seed(12345)\n\n    def get_xs(mean: float, std: float, num_types: int,\n               channels: int) -> List[Tensor]:\n        num_nodes_list = torch.normal(\n            mean=torch.tensor([mean] * num_types, dtype=torch.float),\n            std=torch.tensor([std] * num_types, dtype=torch.float),\n        ).round().to(torch.long).tolist()\n\n        return [\n            torch.randn(num_nodes, channels, device=args.device)\n            for num_nodes in num_nodes_list\n        ]\n\n    def sequential(xs: List[Tensor], weights: List[Tensor]) -> List[Tensor]:\n        return [x @ weight for x, weight in zip(xs, weights)]\n\n    def nested(xs: List[Tensor], weights: List[Tensor]) -> List[Tensor]:\n        x = torch.nested.nested_tensor(xs)\n        weight = torch.nested.nested_tensor(weights)\n        return list(torch.matmul(x, weight).unbind(0))\n\n    def grouped(x: Tensor, ptr: Tensor, weight: Tensor) -> Tensor:\n        return pyg_lib.ops.segment_matmul(x, ptr, weight)\n\n    def padded(x: Tensor, weight: Tensor) -> Tensor:\n        return torch.matmul(x, weight)\n\n    def dgl_mm(x: Tensor, count: Tensor, weight: Tensor) -> Tensor:\n        return dgl.ops.segment_mm(x, weight, count)\n\n    num_nodes, channels = 1_000_000, 64\n\n    for num_types in [3, 5, 10, 50, 100, 200, 500, 1000]:\n        print(f'Number of types: {num_types}')\n        mean = num_nodes // num_types\n        std = mean // 4\n\n        xs = get_xs(mean, std, num_types, channels)\n        count = torch.tensor([x.size(0) for x in xs])\n        ptr = cumsum(torch.tensor([x.size(0) for x in xs]))\n        x = torch.cat(xs, dim=0)\n        padded_x = torch.nested.nested_tensor(xs).to_padded_tensor(padding=0.0)\n        weight = torch.randn(num_types, channels, channels, device=args.device)\n        weights = list(weight.unbind(0))\n\n        funcs = [sequential, grouped, padded]\n        func_names = ['Sequential', 'Grouped', 'Padded']\n        args_list = [(xs, weights), (x, ptr, weight), (padded_x, weight)]\n\n        if WITH_DGL:\n            funcs.append(dgl_mm)\n            func_names.append('DGL')\n            args_list.append((x, count, weight))\n\n        benchmark(\n            funcs=funcs,\n            func_names=func_names,\n            args=args_list,\n            num_steps=50 if args.device == 'cpu' else 500,\n            num_warmups=10 if args.device == 'cpu' else 100,\n            backward=args.backward,\n        )\n"
  },
  {
    "path": "test/nn/dense/test_mincut_pool.py",
    "content": "import math\n\nimport torch\n\nfrom torch_geometric.nn import dense_mincut_pool\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_dense_mincut_pool():\n    batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10)\n    x = torch.randn((batch_size, num_nodes, channels))\n    adj = torch.ones((batch_size, num_nodes, num_nodes))\n    s = torch.randn((batch_size, num_nodes, num_clusters))\n    mask = torch.randint(0, 2, (batch_size, num_nodes), dtype=torch.bool)\n\n    x_out, adj_out, mincut_loss, ortho_loss = dense_mincut_pool(\n        x, adj, s, mask)\n    assert x_out.size() == (2, 10, 16)\n    assert adj_out.size() == (2, 10, 10)\n    assert -1 <= mincut_loss <= 0\n    assert 0 <= ortho_loss <= 2\n\n    if is_full_test():\n        jit = torch.jit.script(dense_mincut_pool)\n\n        x_jit, adj_jit, mincut_loss, ortho_loss = jit(x, adj, s, mask)\n        assert x_jit.size() == (2, 10, 16)\n        assert adj_jit.size() == (2, 10, 10)\n        assert -1 <= mincut_loss <= 0\n        assert 0 <= ortho_loss <= math.sqrt(2)\n"
  },
  {
    "path": "test/nn/functional/test_bro.py",
    "content": "import torch\n\nfrom torch_geometric.nn.functional import bro\n\n\ndef test_bro():\n    batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 2, 2])\n\n    g1 = torch.tensor([\n        [0.2, 0.2, 0.2, 0.2],\n        [0.0, 0.2, 0.2, 0.2],\n        [0.2, 0.0, 0.2, 0.2],\n        [0.2, 0.2, 0.0, 0.2],\n    ])\n\n    g2 = torch.tensor([\n        [0.2, 0.2, 0.2, 0.2],\n        [0.0, 0.2, 0.2, 0.2],\n        [0.2, 0.0, 0.2, 0.2],\n    ])\n\n    g3 = torch.tensor([\n        [0.2, 0.2, 0.2, 0.2],\n        [0.2, 0.0, 0.2, 0.2],\n    ])\n\n    s = 0.\n    for g in [g1, g2, g3]:\n        s += torch.norm(g @ g.t() - torch.eye(g.shape[0]), p=2)\n\n    assert torch.isclose(s / 3., bro(torch.cat([g1, g2, g3], dim=0), batch))\n"
  },
  {
    "path": "test/nn/functional/test_gini.py",
    "content": "import torch\n\nfrom torch_geometric.nn.functional import gini\n\n\ndef test_gini():\n    w = torch.tensor([[0., 0., 0., 0.], [0., 0., 0., 1000.0]])\n    assert torch.isclose(gini(w), torch.tensor(0.5))\n"
  },
  {
    "path": "test/nn/kge/test_complex.py",
    "content": "import torch\n\nfrom torch_geometric.nn import ComplEx\n\n\ndef test_complex_scoring():\n    model = ComplEx(num_nodes=5, num_relations=2, hidden_channels=1)\n\n    model.node_emb.weight.data = torch.tensor([\n        [2.],\n        [3.],\n        [5.],\n        [1.],\n        [2.],\n    ])\n    model.node_emb_im.weight.data = torch.tensor([\n        [4.],\n        [1.],\n        [3.],\n        [1.],\n        [2.],\n    ])\n    model.rel_emb.weight.data = torch.tensor([\n        [2.],\n        [3.],\n    ])\n    model.rel_emb_im.weight.data = torch.tensor([\n        [3.],\n        [1.],\n    ])\n\n    score = model(\n        head_index=torch.tensor([1, 3]),\n        rel_type=torch.tensor([1, 0]),\n        tail_index=torch.tensor([2, 4]),\n    )\n    assert score.tolist() == [58., 8.]\n\n\ndef test_complex():\n    model = ComplEx(num_nodes=10, num_relations=5, hidden_channels=32)\n    assert str(model) == 'ComplEx(10, num_relations=5, hidden_channels=32)'\n\n    head_index = torch.tensor([0, 2, 4, 6, 8])\n    rel_type = torch.tensor([0, 1, 2, 3, 4])\n    tail_index = torch.tensor([1, 3, 5, 7, 9])\n\n    loader = model.loader(head_index, rel_type, tail_index, batch_size=5)\n    for h, r, t in loader:\n        out = model(h, r, t)\n        assert out.size() == (5, )\n\n        loss = model.loss(h, r, t)\n        assert loss >= 0.\n\n        mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)\n        assert 0 <= mean_rank <= 10\n        assert 0 < mrr <= 1\n        assert hits == 1.0\n"
  },
  {
    "path": "test/nn/kge/test_distmult.py",
    "content": "import torch\n\nfrom torch_geometric.nn import DistMult\n\n\ndef test_distmult():\n    model = DistMult(num_nodes=10, num_relations=5, hidden_channels=32)\n    assert str(model) == 'DistMult(10, num_relations=5, hidden_channels=32)'\n\n    head_index = torch.tensor([0, 2, 4, 6, 8])\n    rel_type = torch.tensor([0, 1, 2, 3, 4])\n    tail_index = torch.tensor([1, 3, 5, 7, 9])\n\n    loader = model.loader(head_index, rel_type, tail_index, batch_size=5)\n    for h, r, t in loader:\n        out = model(h, r, t)\n        assert out.size() == (5, )\n\n        loss = model.loss(h, r, t)\n        assert loss >= 0.\n\n        mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)\n        assert 0 <= mean_rank <= 10\n        assert 0 < mrr <= 1\n        assert hits == 1.0\n"
  },
  {
    "path": "test/nn/kge/test_rotate.py",
    "content": "import torch\n\nfrom torch_geometric.nn import RotatE\n\n\ndef test_rotate():\n    model = RotatE(num_nodes=10, num_relations=5, hidden_channels=32)\n    assert str(model) == 'RotatE(10, num_relations=5, hidden_channels=32)'\n\n    head_index = torch.tensor([0, 2, 4, 6, 8])\n    rel_type = torch.tensor([0, 1, 2, 3, 4])\n    tail_index = torch.tensor([1, 3, 5, 7, 9])\n\n    loader = model.loader(head_index, rel_type, tail_index, batch_size=5)\n    for h, r, t in loader:\n        out = model(h, r, t)\n        assert out.size() == (5, )\n\n        loss = model.loss(h, r, t)\n        assert loss >= 0.\n\n        mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)\n        assert 0 <= mean_rank <= 10\n        assert 0 < mrr <= 1\n        assert hits == 1.0\n"
  },
  {
    "path": "test/nn/kge/test_transe.py",
    "content": "import torch\n\nfrom torch_geometric.nn import TransE\n\n\ndef test_transe():\n    model = TransE(num_nodes=10, num_relations=5, hidden_channels=32)\n    assert str(model) == 'TransE(10, num_relations=5, hidden_channels=32)'\n\n    head_index = torch.tensor([0, 2, 4, 6, 8])\n    rel_type = torch.tensor([0, 1, 2, 3, 4])\n    tail_index = torch.tensor([1, 3, 5, 7, 9])\n\n    loader = model.loader(head_index, rel_type, tail_index, batch_size=5)\n    for h, r, t in loader:\n        out = model(h, r, t)\n        assert out.size() == (5, )\n\n        loss = model.loss(h, r, t)\n        assert loss >= 0.\n\n        mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False)\n        assert 0 <= mean_rank <= 10\n        assert 0 < mrr <= 1\n        assert hits == 1.0\n"
  },
  {
    "path": "test/nn/models/test_attentive_fp.py",
    "content": "import torch\n\nfrom torch_geometric.nn import AttentiveFP\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_attentive_fp():\n    model = AttentiveFP(8, 16, 32, edge_dim=3, num_layers=2, num_timesteps=2)\n    assert str(model) == ('AttentiveFP(in_channels=8, hidden_channels=16, '\n                          'out_channels=32, edge_dim=3, num_layers=2, '\n                          'num_timesteps=2)')\n\n    x = torch.randn(4, 8)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    edge_attr = torch.randn(edge_index.size(1), 3)\n    batch = torch.tensor([0, 0, 0, 0])\n\n    out = model(x, edge_index, edge_attr, batch)\n    assert out.size() == (1, 32)\n\n    if is_full_test():\n        jit = torch.jit.script(model)\n        assert torch.allclose(jit(x, edge_index, edge_attr, batch), out)\n"
  },
  {
    "path": "test/nn/models/test_attract_repel.py",
    "content": "import torch\n\nfrom torch_geometric.nn.models import ARLinkPredictor\n\n\ndef test_ar_link_predictor():\n    model = ARLinkPredictor(in_channels=16, hidden_channels=32, num_layers=2)\n    x = torch.randn(4, 16)  # 4 nodes with 16 features each\n    edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])  # 3 edges\n\n    # Test forward pass\n    pred = model(x, edge_index)\n    assert pred.size(0) == edge_index.size(1)\n    assert torch.all(pred >= 0) and torch.all(pred <= 1)\n\n    # Test encode function\n    attract_z, repel_z = model.encode(x)\n    assert attract_z.size() == (\n        4, 16)  # Default attract_ratio=0.5, so half of hidden_channels\n    assert repel_z.size() == (4, 16)\n\n    # Test decode function\n    raw_scores = model.decode(attract_z, repel_z, edge_index)\n    assert raw_scores.size(0) == edge_index.size(1)\n\n    # Test R-fraction calculation\n    r_fraction = model.calculate_r_fraction(attract_z, repel_z)\n    assert 0 <= r_fraction <= 1\n\n\ndef test_ar_link_predictor_with_custom_ratio():\n    # Test with custom attract_ratio\n    model = ARLinkPredictor(in_channels=8, hidden_channels=20,\n                            attract_ratio=0.7)\n    x = torch.randn(5, 8)\n\n    # Check dimensions\n    attract_z, repel_z = model.encode(x)\n    assert attract_z.size() == (5, 14)  # 70% of 20 = 14\n    assert repel_z.size() == (5, 6)  # 30% of 20 = 6\n"
  },
  {
    "path": "test/nn/models/test_autoencoder.py",
    "content": "import torch\nfrom torch import Tensor as T\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.nn import ARGA, ARGVA, GAE, VGAE\nfrom torch_geometric.testing import has_package, is_full_test\nfrom torch_geometric.transforms import RandomLinkSplit\n\n\ndef test_gae():\n    model = GAE(encoder=lambda x: x)\n    model.reset_parameters()\n\n    x = torch.tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]])\n    z = model.encode(x)\n    assert torch.allclose(z, x)\n\n    adj = model.decoder.forward_all(z)\n    expected = torch.tensor([\n        [2.0, -1.0, 1.0],\n        [-1.0, 5.0, 4.0],\n        [1.0, 4.0, 5.0],\n    ]).sigmoid()\n    assert torch.allclose(adj, expected)\n\n    edge_index = torch.tensor([[0, 1], [1, 2]])\n    value = model.decode(z, edge_index)\n    assert torch.allclose(value, torch.tensor([-1.0, 4.0]).sigmoid())\n\n    if is_full_test():\n        jit = torch.jit.export(model)\n        assert torch.allclose(jit.encode(x), z)\n        assert torch.allclose(jit.decode(z, edge_index), value)\n\n    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n                               [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n    data = Data(edge_index=edge_index, num_nodes=11)\n    transform = RandomLinkSplit(split_labels=True,\n                                add_negative_train_samples=False)\n    train_data, val_data, test_data = transform(data)\n\n    z = torch.randn(11, 16)\n    loss = model.recon_loss(z, train_data.pos_edge_label_index)\n    assert float(loss) > 0\n\n    if has_package('sklearn'):\n        auc, ap = model.test(z, val_data.pos_edge_label_index,\n                             val_data.neg_edge_label_index)\n        assert auc >= 0 and auc <= 1 and ap >= 0 and ap <= 1\n\n\ndef test_vgae():\n    model = VGAE(encoder=lambda x: (x, x))\n\n    x = torch.tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]])\n    model.encode(x)\n    assert float(model.kl_loss()) > 0\n\n    model.eval()\n    model.encode(x)\n\n    if is_full_test():\n        jit = torch.jit.export(model)\n        jit.encode(x)\n        assert float(jit.kl_loss()) > 0\n\n\ndef test_arga():\n    model = ARGA(encoder=lambda x: x, discriminator=lambda x: T([0.5]))\n    model.reset_parameters()\n\n    x = torch.tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]])\n    z = model.encode(x)\n\n    assert float(model.reg_loss(z)) > 0\n    assert float(model.discriminator_loss(z)) > 0\n\n    if is_full_test():\n        jit = torch.jit.export(model)\n        assert torch.allclose(jit.encode(x), z)\n        assert float(jit.reg_loss(z)) > 0\n        assert float(jit.discriminator_loss(z)) > 0\n\n\ndef test_argva():\n    model = ARGVA(encoder=lambda x: (x, x), discriminator=lambda x: T([0.5]))\n\n    x = torch.tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]])\n    model.encode(x)\n    model.reparametrize(model.__mu__, model.__logstd__)\n    assert float(model.kl_loss()) > 0\n\n    if is_full_test():\n        jit = torch.jit.export(model)\n        jit.encode(x)\n        jit.reparametrize(jit.__mu__, jit.__logstd__)\n        assert float(jit.kl_loss()) > 0\n\n\ndef test_init():\n    encoder = torch.nn.Linear(16, 32)\n    decoder = torch.nn.Linear(32, 16)\n    discriminator = torch.nn.Linear(32, 1)\n\n    GAE(encoder, decoder)\n    VGAE(encoder, decoder)\n    ARGA(encoder, discriminator, decoder)\n    ARGVA(encoder, discriminator, decoder)\n"
  },
  {
    "path": "test/nn/models/test_basic_gnn.py",
    "content": "import os\nimport os.path as osp\nimport random\nimport warnings\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import SAGEConv\nfrom torch_geometric.nn.models import GAT, GCN, GIN, PNA, EdgeCNN, GraphSAGE\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import (\n    onlyFullTest,\n    onlyLinux,\n    onlyNeighborSampler,\n    onlyOnline,\n    withDevice,\n    withPackage,\n)\n\nout_dims = [None, 8]\ndropouts = [0.0, 0.5]\nacts = [None, 'leaky_relu', torch.relu_, F.elu, torch.nn.ReLU()]\nnorms = [None, 'batch_norm', 'layer_norm']\njks = [None, 'last', 'cat', 'max', 'lstm']\n\n\n@pytest.mark.parametrize('out_dim', out_dims)\n@pytest.mark.parametrize('dropout', dropouts)\n@pytest.mark.parametrize('act', acts)\n@pytest.mark.parametrize('norm', norms)\n@pytest.mark.parametrize('jk', jks)\ndef test_gcn(out_dim, dropout, act, norm, jk):\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    out_channels = 16 if out_dim is None else out_dim\n\n    model = GCN(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout,\n                act=act, norm=norm, jk=jk)\n    assert str(model) == f'GCN(8, {out_channels}, num_layers=3)'\n    assert model(x, edge_index).size() == (3, out_channels)\n\n\n@pytest.mark.parametrize('out_dim', out_dims)\n@pytest.mark.parametrize('dropout', dropouts)\n@pytest.mark.parametrize('act', acts)\n@pytest.mark.parametrize('norm', norms)\n@pytest.mark.parametrize('jk', jks)\ndef test_graph_sage(out_dim, dropout, act, norm, jk):\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    out_channels = 16 if out_dim is None else out_dim\n\n    model = GraphSAGE(8, 16, num_layers=3, out_channels=out_dim,\n                      dropout=dropout, act=act, norm=norm, jk=jk)\n    assert str(model) == f'GraphSAGE(8, {out_channels}, num_layers=3)'\n    assert model(x, edge_index).size() == (3, out_channels)\n\n\n@pytest.mark.parametrize('out_dim', out_dims)\n@pytest.mark.parametrize('dropout', dropouts)\n@pytest.mark.parametrize('act', acts)\n@pytest.mark.parametrize('norm', norms)\n@pytest.mark.parametrize('jk', jks)\ndef test_gin(out_dim, dropout, act, norm, jk):\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    out_channels = 16 if out_dim is None else out_dim\n\n    model = GIN(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout,\n                act=act, norm=norm, jk=jk)\n    assert str(model) == f'GIN(8, {out_channels}, num_layers=3)'\n    assert model(x, edge_index).size() == (3, out_channels)\n\n\n@pytest.mark.parametrize('out_dim', out_dims)\n@pytest.mark.parametrize('dropout', dropouts)\n@pytest.mark.parametrize('act', acts)\n@pytest.mark.parametrize('norm', norms)\n@pytest.mark.parametrize('jk', jks)\ndef test_gat(out_dim, dropout, act, norm, jk):\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    out_channels = 16 if out_dim is None else out_dim\n\n    for v2 in [False, True]:\n        model = GAT(8, 16, num_layers=3, out_channels=out_dim, v2=v2,\n                    dropout=dropout, act=act, norm=norm, jk=jk)\n        assert str(model) == f'GAT(8, {out_channels}, num_layers=3)'\n        assert model(x, edge_index).size() == (3, out_channels)\n\n        model = GAT(8, 16, num_layers=3, out_channels=out_dim, v2=v2,\n                    dropout=dropout, act=act, norm=norm, jk=jk, heads=4)\n        assert str(model) == f'GAT(8, {out_channels}, num_layers=3)'\n        assert model(x, edge_index).size() == (3, out_channels)\n\n\n@pytest.mark.parametrize('out_dim', out_dims)\n@pytest.mark.parametrize('dropout', dropouts)\n@pytest.mark.parametrize('act', acts)\n@pytest.mark.parametrize('norm', norms)\n@pytest.mark.parametrize('jk', jks)\ndef test_pna(out_dim, dropout, act, norm, jk):\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    deg = torch.tensor([0, 2, 1])\n    out_channels = 16 if out_dim is None else out_dim\n    aggregators = ['mean', 'min', 'max', 'std', 'var', 'sum']\n    scalers = [\n        'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear'\n    ]\n\n    model = PNA(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout,\n                act=act, norm=norm, jk=jk, aggregators=aggregators,\n                scalers=scalers, deg=deg)\n    assert str(model) == f'PNA(8, {out_channels}, num_layers=3)'\n    assert model(x, edge_index).size() == (3, out_channels)\n\n\n@pytest.mark.parametrize('out_dim', out_dims)\n@pytest.mark.parametrize('dropout', dropouts)\n@pytest.mark.parametrize('act', acts)\n@pytest.mark.parametrize('norm', norms)\n@pytest.mark.parametrize('jk', jks)\ndef test_edge_cnn(out_dim, dropout, act, norm, jk):\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    out_channels = 16 if out_dim is None else out_dim\n\n    model = EdgeCNN(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout,\n                    act=act, norm=norm, jk=jk)\n    assert str(model) == f'EdgeCNN(8, {out_channels}, num_layers=3)'\n    assert model(x, edge_index).size() == (3, out_channels)\n\n\ndef test_jit():\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    model = GCN(8, 16, num_layers=2)\n    model = torch.jit.script(model)\n\n    assert model(x, edge_index).size() == (3, 16)\n\n\n@pytest.mark.parametrize('out_dim', out_dims)\n@pytest.mark.parametrize('jk', jks)\ndef test_one_layer_gnn(out_dim, jk):\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    out_channels = 16 if out_dim is None else out_dim\n\n    model = GraphSAGE(8, 16, num_layers=1, out_channels=out_dim, jk=jk)\n    assert model(x, edge_index).size() == (3, out_channels)\n\n\n@pytest.mark.parametrize('norm', [\n    'BatchNorm',\n    'GraphNorm',\n    'InstanceNorm',\n    'LayerNorm',\n])\ndef test_batch(norm):\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    batch = torch.tensor([0, 0, 1])\n\n    model = GraphSAGE(8, 16, num_layers=2, norm=norm)\n    assert model.supports_norm_batch == (norm != 'BatchNorm')\n\n    out = model(x, edge_index, batch=batch)\n    assert out.size() == (3, 16)\n\n    if model.supports_norm_batch:\n        with pytest.raises(RuntimeError, match=\"out of bounds\"):\n            model(x, edge_index, batch=batch, batch_size=1)\n\n\n@onlyOnline\n@onlyNeighborSampler\n@pytest.mark.parametrize('jk', [None, 'last'])\ndef test_basic_gnn_inference(get_dataset, jk):\n    dataset = get_dataset(name='karate')\n    data = dataset[0]\n\n    model = GraphSAGE(dataset.num_features, hidden_channels=16, num_layers=2,\n                      out_channels=dataset.num_classes, jk=jk)\n    model.eval()\n\n    out1 = model(data.x, data.edge_index)\n    assert out1.size() == (data.num_nodes, dataset.num_classes)\n\n    loader = NeighborLoader(data, num_neighbors=[-1], batch_size=128)\n    out2 = model.inference(loader)\n    assert out1.size() == out2.size()\n    assert torch.allclose(out1, out2, atol=1e-4)\n\n    assert 'n_id' not in data\n\n\n@withDevice\n@onlyLinux\n@onlyFullTest\n@withPackage('torch>=2.0.0')\ndef test_compile_basic(device):\n    x = torch.randn(3, 8, device=device)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)\n\n    model = GCN(8, 16, num_layers=3).to(device)\n    compiled_model = torch.compile(model)\n\n    expected = model(x, edge_index)\n    out = compiled_model(x, edge_index)\n    assert torch.allclose(out, expected, atol=1e-6)\n\n\ndef test_packaging():\n    warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*')\n\n    os.makedirs(torch.hub._get_torch_home(), exist_ok=True)\n\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    model = GraphSAGE(8, 16, num_layers=3)\n    path = osp.join(torch.hub._get_torch_home(), 'pyg_test_model.pt')\n    torch.save(model, path)\n\n    model = torch.load(path, weights_only=False)\n    with torch.no_grad():\n        assert model(x, edge_index).size() == (3, 16)\n\n    model = GraphSAGE(8, 16, num_layers=3)\n    path = osp.join(torch.hub._get_torch_home(), 'pyg_test_package.pt')\n    with torch.package.PackageExporter(path) as pe:\n        pe.extern('torch_geometric.nn.**')\n        pe.extern('torch_geometric.inspector')\n        pe.extern('torch_geometric.utils._trim_to_layer')\n        pe.extern('_operator')\n        pe.save_pickle('models', 'model.pkl', model)\n\n    pi = torch.package.PackageImporter(path)\n    model = pi.load_pickle('models', 'model.pkl')\n    with torch.no_grad():\n        assert model(x, edge_index).size() == (3, 16)\n\n\n@onlyLinux\n@withPackage('torch>=2.6.0')\n@withPackage('onnx', 'onnxruntime', 'onnxscript')\ndef test_onnx(tmp_path: str) -> None:\n    import onnx\n    import onnxruntime as ort\n\n    from torch_geometric import safe_onnx_export\n\n    warnings.filterwarnings('ignore', '.*tensor to a Python boolean.*')\n    warnings.filterwarnings('ignore', '.*shape inference of prim::Constant.*')\n\n    class MyModel(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.conv1 = SAGEConv(8, 16)\n            self.conv2 = SAGEConv(16, 16)\n\n        def forward(self, x, edge_index):\n            x = self.conv1(x, edge_index).relu()\n            x = self.conv2(x, edge_index)\n            return x\n\n    model = MyModel()\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]])\n    expected = model(x, edge_index)\n    assert expected.size() == (3, 16)\n\n    path = osp.join(tmp_path, 'model.onnx')\n    success = safe_onnx_export(\n        model,\n        (x, edge_index),\n        path,\n        input_names=('x', 'edge_index'),\n        opset_version=18,\n        dynamo=True,  # False is deprecated by PyTorch\n        skip_on_error=True,  # Skip gracefully in CI if upstream issue occurs\n    )\n\n    if not success:\n        # ONNX export was skipped due to known upstream issue\n        # This allows CI to pass while the upstream bug exists\n        warnings.warn(\n            \"ONNX export test skipped due to known upstream onnx_ir issue. \"\n            \"This is expected and does not indicate a problem with PyTorch \"\n            \"Geometric.\", UserWarning, stacklevel=2)\n        return\n\n    onnx_model = onnx.load(path)\n    onnx.checker.check_model(onnx_model)\n\n    providers = ['CPUExecutionProvider']\n    ort_session = ort.InferenceSession(path, providers=providers)\n\n    out = ort_session.run(None, {\n        'x': x.numpy(),\n        'edge_index': edge_index.numpy()\n    })[0]\n    out = torch.from_numpy(out)\n    assert torch.allclose(out, expected, atol=1e-6)\n\n\n@withPackage('pyg_lib')\ndef test_trim_to_layer():\n    x = torch.randn(14, 16)\n    edge_index = torch.tensor([\n        [2, 3, 4, 5, 7, 7, 10, 11, 12, 13],\n        [0, 1, 2, 3, 2, 3, 7, 7, 7, 7],\n    ])\n    data = Data(x=x, edge_index=edge_index)\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[1, 2, 4],\n        batch_size=2,\n        shuffle=False,\n    )\n    batch = next(iter(loader))\n\n    model = GraphSAGE(in_channels=16, hidden_channels=16, num_layers=3)\n    out1 = model(batch.x, batch.edge_index)[:2]\n    assert out1.size() == (2, 16)\n\n    out2 = model(\n        batch.x,\n        batch.edge_index,\n        num_sampled_nodes_per_hop=batch.num_sampled_nodes,\n        num_sampled_edges_per_hop=batch.num_sampled_edges,\n    )[:2]\n    assert out2.size() == (2, 16)\n\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n\n@withDevice\n@onlyLinux\n@withPackage('torch>=2.1.0')\n@pytest.mark.parametrize('Model', [GCN, GraphSAGE, GIN, GAT, EdgeCNN, PNA])\ndef test_compile_graph_breaks(Model, device):\n    import torch._dynamo as dynamo\n\n    x = torch.randn(3, 8, device=device)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)\n\n    kwargs = {}\n    if Model in {GCN, GAT}:\n        # Adding self-loops inside the model leads to graph breaks :(\n        kwargs['add_self_loops'] = False\n\n    if Model in {PNA}:  # `PNA` requires additional arguments:\n        kwargs['aggregators'] = ['sum', 'mean', 'min', 'max', 'var', 'std']\n        kwargs['scalers'] = ['identity', 'amplification', 'attenuation']\n        kwargs['deg'] = torch.tensor([1, 2, 1])\n\n    model = Model(\n        in_channels=8,\n        hidden_channels=16,\n        num_layers=2,\n        **kwargs,\n    ).to(device)\n\n    explanation = dynamo.explain(model)(x, edge_index)\n    assert explanation.graph_break_count == 0\n\n\n@withPackage('pyg_lib')\ndef test_basic_gnn_cache():\n    x = torch.randn(14, 16)\n    edge_index = torch.tensor([\n        [2, 3, 4, 5, 7, 7, 10, 11, 12, 13],\n        [0, 1, 2, 3, 2, 3, 7, 7, 7, 7],\n    ])\n\n    loader = NeighborLoader(\n        Data(x=x, edge_index=edge_index),\n        num_neighbors=[-1],\n        batch_size=2,\n    )\n\n    model = GCN(in_channels=16, hidden_channels=16, num_layers=2)\n    model.eval()\n\n    out1 = model.inference(loader, cache=False)\n    out2 = model.inference(loader, cache=True)\n\n    assert torch.allclose(out1, out2)\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    parser.add_argument('--dynamic', action='store_true')\n    args = parser.parse_args()\n\n    if args.dynamic:\n        min_num_nodes, max_num_nodes = 10_000, 15_000\n        min_num_edges, max_num_edges = 200_000, 300_000\n    else:\n        min_num_nodes, max_num_nodes = 10_000, 10_000\n        min_num_edges, max_num_edges = 200_000, 200_000\n\n    def gen_args():\n        N = random.randint(min_num_nodes, max_num_nodes)\n        E = random.randint(min_num_edges, max_num_edges)\n\n        x = torch.randn(N, 64, device=args.device)\n        edge_index = torch.randint(N, (2, E), device=args.device)\n\n        return x, edge_index\n\n    for Model in [GCN, GraphSAGE, GIN, EdgeCNN]:\n        print(f'Model: {Model.__name__}')\n\n        model = Model(64, 64, num_layers=3).to(args.device)\n        compiled_model = torch.compile(model)\n\n        benchmark(\n            funcs=[model, compiled_model],\n            func_names=['Vanilla', 'Compiled'],\n            args=gen_args,\n            num_steps=50 if args.device == 'cpu' else 500,\n            num_warmups=10 if args.device == 'cpu' else 100,\n            backward=args.backward,\n        )\n"
  },
  {
    "path": "test/nn/models/test_correct_and_smooth.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn.models import CorrectAndSmooth\nfrom torch_geometric.testing import noWindows\nfrom torch_geometric.typing import SparseTensor\n\n\n@noWindows\ndef test_correct_and_smooth():\n    y_soft = torch.tensor([0.1, 0.5, 0.4]).repeat(6, 1)\n    y_true = torch.tensor([1, 0, 0, 2, 1, 1])\n    edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]])\n    mask = torch.randint(0, 2, (6, ), dtype=torch.bool)\n\n    model = CorrectAndSmooth(\n        num_correction_layers=2,\n        correction_alpha=0.5,\n        num_smoothing_layers=2,\n        smoothing_alpha=0.5,\n    )\n    assert str(model) == ('CorrectAndSmooth(\\n'\n                          '  correct: num_layers=2, alpha=0.5\\n'\n                          '  smooth:  num_layers=2, alpha=0.5\\n'\n                          '  autoscale=True, scale=1.0\\n'\n                          ')')\n\n    out = model.correct(y_soft, y_true[mask], mask, edge_index)\n    assert out.size() == (6, 3)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(6, 6))\n        assert torch.allclose(\n            out, model.correct(y_soft, y_true[mask], mask, adj.t()))\n\n    out = model.smooth(y_soft, y_true[mask], mask, edge_index)\n    assert out.size() == (6, 3)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(\n            out, model.smooth(y_soft, y_true[mask], mask, adj.t()))\n\n    # Test without autoscale:\n    model = CorrectAndSmooth(\n        num_correction_layers=2,\n        correction_alpha=0.5,\n        num_smoothing_layers=2,\n        smoothing_alpha=0.5,\n        autoscale=False,\n    )\n    out = model.correct(y_soft, y_true[mask], mask, edge_index)\n    assert out.size() == (6, 3)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(\n            out, model.correct(y_soft, y_true[mask], mask, adj.t()))\n"
  },
  {
    "path": "test/nn/models/test_deep_graph_infomax.py",
    "content": "import torch\n\nfrom torch_geometric.nn import GCN, DeepGraphInfomax\nfrom torch_geometric.testing import has_package, is_full_test, withDevice\n\n\n@withDevice\ndef test_infomax(device):\n    def corruption(z):\n        return z + 1\n\n    model = DeepGraphInfomax(\n        hidden_channels=16,\n        encoder=lambda x: x,\n        summary=lambda z, *args: z.mean(dim=0),\n        corruption=lambda x: x + 1,\n    ).to(device)\n    assert str(model) == 'DeepGraphInfomax(16)'\n\n    x = torch.ones(20, 16, device=device)\n\n    pos_z, neg_z, summary = model(x)\n    assert pos_z.size() == (20, 16)\n    assert neg_z.size() == (20, 16)\n    assert summary.size() == (16, )\n\n    loss = model.loss(pos_z, neg_z, summary)\n    assert float(loss) >= 0\n\n    if is_full_test():\n        jit = torch.jit.export(model)\n        pos_z, neg_z, summary = jit(x)\n        assert pos_z.size() == (20, 16) and neg_z.size() == (20, 16)\n        assert summary.size() == (16, )\n\n    if has_package('sklearn'):\n        acc = model.test(\n            train_z=torch.ones(20, 16),\n            train_y=torch.randint(10, (20, )),\n            test_z=torch.ones(20, 16),\n            test_y=torch.randint(10, (20, )),\n        )\n        assert 0 <= acc <= 1\n\n\n@withDevice\ndef test_infomax_predefined_model(device):\n    def corruption(x, edge_index, edge_weight):\n        return (\n            x[torch.randperm(x.size(0), device=x.device)],\n            edge_index,\n            edge_weight,\n        )\n\n    model = DeepGraphInfomax(\n        hidden_channels=16,\n        encoder=GCN(16, 16, num_layers=2),\n        summary=lambda z, *args, **kwargs: z.mean(dim=0).sigmoid(),\n        corruption=corruption,\n    ).to(device)\n\n    x = torch.randn(4, 16, device=device)\n    edge_index = torch.tensor(\n        [[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]],\n        device=device,\n    )\n    edge_weight = torch.rand(edge_index.size(1), device=device)\n\n    pos_z, neg_z, summary = model(x, edge_index, edge_weight=edge_weight)\n    assert pos_z.size() == (4, 16)\n    assert neg_z.size() == (4, 16)\n    assert summary.size() == (16, )\n\n    loss = model.loss(pos_z, neg_z, summary)\n    assert float(loss) >= 0\n"
  },
  {
    "path": "test/nn/models/test_deepgcn.py",
    "content": "import pytest\nimport torch\nfrom torch.nn import ReLU\n\nfrom torch_geometric.nn import DeepGCNLayer, GENConv, LayerNorm\n\n\n@pytest.mark.parametrize(\n    'block_tuple',\n    [('res+', 1), ('res', 1), ('dense', 2), ('plain', 1)],\n)\n@pytest.mark.parametrize('ckpt_grad', [True, False])\ndef test_deepgcn(block_tuple, ckpt_grad):\n    block, expansion = block_tuple\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    conv = GENConv(8, 8)\n    norm = LayerNorm(8)\n    act = ReLU()\n    layer = DeepGCNLayer(conv, norm, act, block=block, ckpt_grad=ckpt_grad)\n    assert str(layer) == f'DeepGCNLayer(block={block})'\n\n    out = layer(x, edge_index)\n    assert out.size() == (3, 8 * expansion)\n"
  },
  {
    "path": "test/nn/models/test_dimenet.py",
    "content": "import pytest\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.nn import DimeNet, DimeNetPlusPlus\nfrom torch_geometric.nn.models.dimenet import (\n    BesselBasisLayer,\n    Envelope,\n    ResidualLayer,\n)\nfrom torch_geometric.testing import is_full_test, withPackage\n\n\ndef test_dimenet_modules():\n    env = Envelope(exponent=5)\n    x = torch.randn(10, 3)\n    assert env(x).size() == (10, 3)  # Isotonic layer.\n\n    bbl = BesselBasisLayer(5)\n    x = torch.randn(10, 3)\n    assert bbl(x).size() == (10, 3, 5)  # Non-isotonic layer.\n\n    rl = ResidualLayer(128, torch.nn.functional.relu)\n    x = torch.randn(128, 128)\n    assert rl(x).size() == (128, 128)  # Isotonic layer.\n\n\n@withPackage('sympy')\n@withPackage('torch_sparse')  # TODO `triplet` requires `SparseTensor` for now.\n@withPackage('torch-cluster')\n@pytest.mark.parametrize('Model', [DimeNet, DimeNetPlusPlus])\ndef test_dimenet(Model):\n    z = torch.randint(1, 10, (20, ))\n    pos = torch.randn(20, 3)\n\n    if Model == DimeNet:\n        kwargs = dict(num_bilinear=3)\n    else:\n        kwargs = dict(out_emb_channels=3, int_emb_size=5, basis_emb_size=5)\n\n    model = Model(\n        hidden_channels=5,\n        out_channels=1,\n        num_blocks=5,\n        num_spherical=5,\n        num_radial=5,\n        **kwargs,\n    )\n    model.reset_parameters()\n\n    with torch.no_grad():\n        out = model(z, pos)\n        assert out.size() == (1, )\n\n        jit = torch.jit.export(model)\n        assert torch.allclose(jit(z, pos), out)\n\n    if is_full_test():\n        optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n\n        min_loss = float('inf')\n        for _ in range(100):\n            optimizer.zero_grad()\n            out = model(z, pos)\n            loss = F.l1_loss(out, torch.tensor([1.0]))\n            loss.backward()\n            optimizer.step()\n            min_loss = min(float(loss), min_loss)\n        assert min_loss < 2\n"
  },
  {
    "path": "test/nn/models/test_gnnff.py",
    "content": "import torch\n\nfrom torch_geometric.nn import GNNFF\nfrom torch_geometric.testing import is_full_test, withPackage\n\n\n@withPackage('torch_sparse')  # TODO `triplet` requires `SparseTensor` for now.\n@withPackage('torch-cluster')\ndef test_gnnff():\n    z = torch.randint(1, 10, (20, ))\n    pos = torch.randn(20, 3)\n\n    model = GNNFF(\n        hidden_node_channels=5,\n        hidden_edge_channels=5,\n        num_layers=5,\n    )\n    model.reset_parameters()\n\n    out = model(z, pos)\n    assert out.size() == (20, 3)\n\n    if is_full_test():\n        jit = torch.jit.export(model)\n        assert torch.allclose(jit(z, pos), out)\n"
  },
  {
    "path": "test/nn/models/test_gpse.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Batch, Data\nfrom torch_geometric.nn import GPSE, GPSENodeEncoder\nfrom torch_geometric.nn.models.gpse import (\n    IdentityHead,\n    gpse_loss,\n    gpse_process,\n    process_batch_idx,\n)\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.transforms import VirtualNode\n\n\ndef test_gpse_training():\n    x = torch.randn(6, 20)\n    y = torch.randn(6, 51)\n    edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],\n                               [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]])\n\n    data = Data(x=x, y=y, edge_index=edge_index)\n    data = VirtualNode()(data)\n    data.y_graph = torch.randn(11)\n\n    batch = Batch.from_data_list([data])\n    model = GPSE()\n\n    with torch.no_grad():\n        out = model(batch)\n        assert out[0].size() == out[1].size()\n        assert out[0].size() == (7, 62)\n\n    if is_full_test():\n        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n        min_loss = float('inf')\n        for _ in range(100):\n            optimizer.zero_grad()\n            pred, true = model(batch)\n            batch_idx = process_batch_idx(batch.batch, true)\n            loss, _ = gpse_loss(pred, true, batch_idx)\n            loss.backward()\n            optimizer.step()\n            min_loss = min(float(loss), min_loss)\n        assert min_loss < 2\n\n\ndef test_gpse_from_pretrained():\n    x = torch.randn(6, 4)\n    edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],\n                               [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]])\n    data = Data(x=x, edge_index=edge_index)\n    data = VirtualNode()(data)\n\n    model = GPSE()\n    model.post_mp = IdentityHead()\n\n    with torch.no_grad():\n        out = gpse_process(model, data, 'NormalSE')\n        assert out.size() == (7, 512)\n\n\n@pytest.mark.parametrize('expand_x', [False, True])\ndef test_gpse_node_encoder(expand_x):\n    x = torch.randn(6, 4)\n    pestat_GPSE = torch.randn(6, 512)\n\n    encoder = GPSENodeEncoder(\n        dim_emb=128,\n        dim_pe_in=512,\n        dim_pe_out=64,\n        dim_in=4,\n        expand_x=expand_x,\n    )\n    out = encoder(x, pestat_GPSE)\n    assert out.size() == (6, 128) if expand_x else (6, 64)\n"
  },
  {
    "path": "test/nn/models/test_graph_mixer.py",
    "content": "import torch\n\nfrom torch_geometric.nn.models.graph_mixer import (\n    LinkEncoder,\n    NodeEncoder,\n    get_latest_k_edge_attr,\n)\n\n\ndef test_node_encoder():\n    x = torch.arange(4, dtype=torch.float).view(-1, 1)\n    edge_index = torch.tensor([[1, 2, 0, 0, 1, 3], [0, 0, 1, 2, 2, 2]])\n    edge_time = torch.tensor([0, 1, 1, 1, 2, 3])\n    seed_time = torch.tensor([2, 2, 2, 2])\n\n    encoder = NodeEncoder(time_window=2)\n    encoder.reset_parameters()\n    assert str(encoder) == 'NodeEncoder(time_window=2)'\n\n    out = encoder(x, edge_index, edge_time, seed_time)\n    # Node 0 aggregates information from node 2 (excluding node 1).\n    # Node 1 aggregates information from node 0.\n    # Node 2 aggregates information from node 0 and node 1 (excluding node 3).\n    # Node 3 aggregates no information.\n    expected = torch.tensor([\n        [0 + 2],\n        [1 + 0],\n        [2 + 0.5 * (0 + 1)],\n        [3],\n    ])\n    assert torch.allclose(out, expected)\n\n\ndef test_link_encoder():\n    num_nodes = 3\n    num_edges = 6\n    edge_attr = torch.rand((num_edges, 10))\n    edge_index = torch.randint(low=0, high=num_nodes, size=(2, num_edges))\n    edge_time = torch.rand(num_edges)\n    seed_time = torch.ones(num_nodes)\n\n    encoder = LinkEncoder(\n        k=3,\n        in_channels=edge_attr.size(1),\n        hidden_channels=7,\n        out_channels=11,\n        time_channels=13,\n    )\n    encoder.reset_parameters()\n    assert str(encoder) == ('LinkEncoder(k=3, in_channels=10, '\n                            'hidden_channels=7, out_channels=11, '\n                            'time_channels=13, dropout=0.0)')\n\n    out = encoder(edge_index, edge_attr, edge_time, seed_time)\n    assert out.size() == (num_nodes, 11)\n\n\ndef test_latest_k_edge_attr():\n    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 0], [0, 1, 0, 1, 0, 1, 2]])\n    edge_time = torch.tensor([3, 1, 2, 3, 1, 2, 3])\n    edge_attr = torch.tensor([1, -1, 3, 4, -1, 6, 7]).view(-1, 1)\n\n    k = 2\n    out = get_latest_k_edge_attr(k, edge_index, edge_attr, edge_time,\n                                 num_nodes=3)\n    expected = torch.tensor([[[1], [3]], [[4], [6]], [[7], [0]]])\n    assert out.size() == (3, 2, 1)\n    assert torch.equal(out, expected)\n\n    k = 1\n    out = get_latest_k_edge_attr(k, edge_index, edge_attr, edge_time,\n                                 num_nodes=3)\n    expected = torch.tensor([[[1]], [[4]], [[7]]])\n    assert out.size() == (3, 1, 1)\n    assert torch.equal(out, expected)\n"
  },
  {
    "path": "test/nn/models/test_graph_unet.py",
    "content": "import torch\n\nfrom torch_geometric.nn import GraphUNet\nfrom torch_geometric.testing import is_full_test, onlyLinux\n\n\n@onlyLinux  # TODO  (matthias) Investigate CSR @ CSR support on Windows.\ndef test_graph_unet():\n    model = GraphUNet(16, 32, 8, depth=3)\n    out = 'GraphUNet(16, 32, 8, depth=3, pool_ratios=[0.5, 0.5, 0.5])'\n    assert str(model) == out\n\n    x = torch.randn(3, 16)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    out = model(x, edge_index)\n    assert out.size() == (3, 8)\n\n    if is_full_test():\n        jit = torch.jit.export(model)\n        out = jit(x, edge_index)\n        assert out.size() == (3, 8)\n"
  },
  {
    "path": "test/nn/models/test_jumping_knowledge.py",
    "content": "import torch\n\nfrom torch_geometric.nn import HeteroJumpingKnowledge, JumpingKnowledge\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_jumping_knowledge():\n    num_nodes, channels, num_layers = 100, 17, 5\n    xs = list([torch.randn(num_nodes, channels) for _ in range(num_layers)])\n\n    model = JumpingKnowledge('cat')\n    assert str(model) == 'JumpingKnowledge(cat)'\n\n    out = model(xs)\n    assert out.size() == (num_nodes, channels * num_layers)\n\n    if is_full_test():\n        jit = torch.jit.script(model)\n        assert torch.allclose(jit(xs), out)\n\n    model = JumpingKnowledge('max')\n    assert str(model) == 'JumpingKnowledge(max)'\n\n    out = model(xs)\n    assert out.size() == (num_nodes, channels)\n\n    if is_full_test():\n        jit = torch.jit.script(model)\n        assert torch.allclose(jit(xs), out)\n\n    model = JumpingKnowledge('lstm', channels, num_layers)\n    assert str(model) == (f'JumpingKnowledge(lstm, channels='\n                          f'{channels}, layers={num_layers})')\n\n    out = model(xs)\n    assert out.size() == (num_nodes, channels)\n\n    if is_full_test():\n        jit = torch.jit.script(model)\n        assert torch.allclose(jit(xs), out)\n\n\ndef test_hetero_jumping_knowledge():\n    num_nodes, channels, num_layers = 100, 17, 5\n\n    types = [\"author\", \"paper\"]\n    xs_dict = {\n        key: [torch.randn(num_nodes, channels) for _ in range(num_layers)]\n        for key in types\n    }\n\n    model = HeteroJumpingKnowledge(types, mode='cat')\n    model.reset_parameters()\n    assert str(model) == 'HeteroJumpingKnowledge(num_types=2, mode=cat)'\n\n    out_dict = model(xs_dict)\n    for out in out_dict.values():\n        assert out.size() == (num_nodes, channels * num_layers)\n\n    if is_full_test():\n        jit = torch.jit.script(model)\n        jit_out = jit(xs_dict)\n        for key in types:\n            assert torch.allclose(jit_out[key], out_dict[key])\n\n    model = HeteroJumpingKnowledge(types, mode='max')\n    assert str(model) == 'HeteroJumpingKnowledge(num_types=2, mode=max)'\n\n    out_dict = model(xs_dict)\n    for out in out_dict.values():\n        assert out.size() == (num_nodes, channels)\n\n    if is_full_test():\n        jit = torch.jit.script(model)\n        jit_out = jit(xs_dict)\n        for key in types:\n            assert torch.allclose(jit_out[key], out_dict[key])\n\n    model = HeteroJumpingKnowledge(types, mode='lstm', channels=channels,\n                                   num_layers=num_layers)\n    assert str(model) == (f'HeteroJumpingKnowledge(num_types=2, mode=lstm, '\n                          f'channels={channels}, layers={num_layers})')\n\n    out_dict = model(xs_dict)\n    for out in out_dict.values():\n        assert out.size() == (num_nodes, channels)\n\n    if is_full_test():\n        jit = torch.jit.script(model)\n        jit_out = jit(xs_dict)\n        for key in types:\n            assert torch.allclose(jit_out[key], out_dict[key])\n"
  },
  {
    "path": "test/nn/models/test_label_prop.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn.models import LabelPropagation\nfrom torch_geometric.typing import SparseTensor\n\n\ndef test_label_prop():\n    y = torch.tensor([1, 0, 0, 2, 1, 1])\n    edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]])\n    mask = torch.randint(0, 2, (6, ), dtype=torch.bool)\n\n    model = LabelPropagation(num_layers=2, alpha=0.5)\n    assert str(model) == 'LabelPropagation(num_layers=2, alpha=0.5)'\n\n    # Test without mask:\n    out = model(y, edge_index)\n    assert out.size() == (6, 3)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(6, 6))\n        assert torch.allclose(model(y, adj.t()), out)\n\n    # Test with mask:\n    out = model(y, edge_index, mask)\n    assert out.size() == (6, 3)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(model(y, adj.t(), mask), out)\n\n    # Test post step:\n    out = model(y, edge_index, mask, post_step=lambda y: torch.zeros_like(y))\n    assert torch.sum(out) == 0.\n"
  },
  {
    "path": "test/nn/models/test_lightgcn.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn.models import LightGCN\n\n\n@pytest.mark.parametrize('embedding_dim', [32, 64])\n@pytest.mark.parametrize('with_edge_weight', [False, True])\n@pytest.mark.parametrize('lambda_reg', [0, 1e-4])\n@pytest.mark.parametrize('alpha', [0, .25, torch.tensor([0.4, 0.3, 0.2])])\ndef test_lightgcn_ranking(embedding_dim, with_edge_weight, lambda_reg, alpha):\n    num_nodes = 500\n    num_edges = 400\n    edge_index = torch.randint(0, num_nodes, (2, num_edges))\n    edge_weight = torch.rand(num_edges) if with_edge_weight else None\n    edge_label_index = torch.randint(0, num_nodes, (2, 100))\n\n    model = LightGCN(num_nodes, embedding_dim, num_layers=2, alpha=alpha)\n    assert str(model) == f'LightGCN(500, {embedding_dim}, num_layers=2)'\n\n    pred = model(edge_index, edge_label_index, edge_weight)\n    assert pred.size() == (100, )\n\n    loss = model.recommendation_loss(\n        pos_edge_rank=pred[:50],\n        neg_edge_rank=pred[50:],\n        node_id=edge_index.unique(),\n        lambda_reg=lambda_reg,\n    )\n    assert loss.dim() == 0 and loss > 0\n\n    out = model.recommend(edge_index, edge_weight, k=2)\n    assert out.size() == (500, 2)\n    assert out.min() >= 0 and out.max() < 500\n\n    src_index = torch.arange(0, 250)\n    dst_index = torch.arange(250, 500)\n\n    out = model.recommend(edge_index, edge_weight, src_index, dst_index, k=2)\n    assert out.size() == (250, 2)\n    assert out.min() >= 250 and out.max() < 500\n\n\n@pytest.mark.parametrize('embedding_dim', [32, 64])\n@pytest.mark.parametrize('with_edge_weight', [False, True])\n@pytest.mark.parametrize('alpha', [0, .25, torch.tensor([0.4, 0.3, 0.2])])\ndef test_lightgcn_link_prediction(embedding_dim, with_edge_weight, alpha):\n    num_nodes = 500\n    num_edges = 400\n    edge_index = torch.randint(0, num_nodes, (2, num_edges))\n    edge_weight = torch.rand(num_edges) if with_edge_weight else None\n    edge_label_index = torch.randint(0, num_nodes, (2, 100))\n    edge_label = torch.randint(0, 2, (edge_label_index.size(1), ))\n\n    model = LightGCN(num_nodes, embedding_dim, num_layers=2, alpha=alpha)\n    assert str(model) == f'LightGCN(500, {embedding_dim}, num_layers=2)'\n\n    pred = model(edge_index, edge_label_index, edge_weight)\n    assert pred.size() == (100, )\n\n    loss = model.link_pred_loss(pred, edge_label)\n    assert loss.dim() == 0 and loss > 0\n\n    prob = model.predict_link(edge_index, edge_label_index, edge_weight,\n                              prob=True)\n    assert prob.size() == (100, )\n    assert prob.min() > 0 and prob.max() < 1\n\n    prob = model.predict_link(edge_index, edge_label_index, edge_weight,\n                              prob=False)\n    assert prob.size() == (100, )\n    assert ((prob == 0) | (prob == 1)).sum() == 100\n"
  },
  {
    "path": "test/nn/models/test_linkx.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import LINKX\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\n\n\n@pytest.mark.parametrize('num_edge_layers', [1, 2])\ndef test_linkx(num_edge_layers):\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 2], [1, 2, 3]])\n    edge_weight = torch.rand(edge_index.size(1))\n\n    model = LINKX(num_nodes=4, in_channels=16, hidden_channels=32,\n                  out_channels=8, num_layers=2,\n                  num_edge_layers=num_edge_layers)\n    assert str(model) == 'LINKX(num_nodes=4, in_channels=16, out_channels=8)'\n\n    out = model(x, edge_index)\n    assert out.size() == (4, 8)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4))\n        assert torch.allclose(out, model(x, adj.t()), atol=1e-6)\n\n    if is_full_test():\n        jit = torch.jit.script(model)\n        assert torch.allclose(jit(x, edge_index), out)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)\n\n    out = model(None, edge_index)\n    assert out.size() == (4, 8)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(out, model(None, adj.t()), atol=1e-6)\n\n    out = model(x, edge_index, edge_weight)\n    assert out.size() == (4, 8)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, edge_weight,\n                                           sparse_sizes=(4, 4))\n        assert torch.allclose(model(x, adj.t()), out, atol=1e-6)\n\n    out = model(None, edge_index, edge_weight)\n    assert out.size() == (4, 8)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(model(None, adj.t()), out, atol=1e-6)\n"
  },
  {
    "path": "test/nn/models/test_lpformer.py",
    "content": "import torch\n\nfrom torch_geometric.nn import LPFormer\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.utils import to_undirected\n\n\n@withPackage('numba')  # For ppr calculation\ndef test_lpformer():\n    model = LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)\n    assert str(\n        model\n    ) == 'LPFormer(16, 32, num_gnn_layers=2, num_transformer_layers=1)'\n\n    num_nodes = 20\n    x = torch.randn(num_nodes, 16)\n    edges = torch.randint(0, num_nodes - 1, (2, 110))\n    edge_index, test_edges = edges[:, :100], edges[:, 100:]\n    edge_index = to_undirected(edge_index)\n\n    ppr_matrix = model.calc_sparse_ppr(edge_index, num_nodes, eps=1e-4)\n\n    assert ppr_matrix.is_sparse\n    assert ppr_matrix.size() == (num_nodes, num_nodes)\n    assert ppr_matrix.sum().item() > 0\n\n    # Test with dense edge_index\n    out = model(test_edges, x, edge_index, ppr_matrix)\n    assert out.size() == (10, )\n\n    # Test with sparse edge_index\n    adj = torch.sparse_coo_tensor(edge_index, torch.ones(edge_index.size(1)),\n                                  [num_nodes, num_nodes])\n    out2 = model(test_edges, x, adj, ppr_matrix)\n    assert out2.size() == (10, )\n"
  },
  {
    "path": "test/nn/models/test_mask_label.py",
    "content": "import torch\n\nfrom torch_geometric.nn import MaskLabel\n\n\ndef test_mask_label():\n    model = MaskLabel(2, 10)\n    assert str(model) == 'MaskLabel()'\n\n    x = torch.rand(4, 10)\n    y = torch.tensor([1, 0, 1, 0])\n    mask = torch.tensor([False, False, True, True])\n\n    out = model(x, y, mask)\n    assert out.size() == (4, 10)\n    assert torch.allclose(out[~mask], x[~mask])\n\n    model = MaskLabel(2, 10, method='concat')\n    out = model(x, y, mask)\n    assert out.size() == (4, 20)\n    assert torch.allclose(out[:, :10], x)\n\n\ndef test_ratio_mask():\n    mask = torch.tensor([True, True, True, True, False, False, False, False])\n    out = MaskLabel.ratio_mask(mask, 0.5)\n    assert out[:4].sum() <= 4 and out[4:].sum() == 0\n"
  },
  {
    "path": "test/nn/models/test_meta.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nfrom torch_geometric.nn import MetaLayer\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.utils import scatter\n\ncount = 0\n\n\ndef test_meta_layer():\n    assert str(MetaLayer()) == ('MetaLayer(\\n'\n                                '  edge_model=None,\\n'\n                                '  node_model=None,\\n'\n                                '  global_model=None\\n'\n                                ')')\n\n    def dummy_model(*args):\n        global count\n        count += 1\n        return None\n\n    x = torch.randn(20, 10)\n    edge_index = torch.randint(0, high=10, size=(2, 20), dtype=torch.long)\n\n    for edge_model in (dummy_model, None):\n        for node_model in (dummy_model, None):\n            for global_model in (dummy_model, None):\n                model = MetaLayer(edge_model, node_model, global_model)\n                out = model(x, edge_index)\n                assert isinstance(out, tuple) and len(out) == 3\n\n    assert count == 12\n\n\ndef test_meta_layer_example():\n    class EdgeModel(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.edge_mlp = Seq(Lin(2 * 10 + 5 + 20, 5), ReLU(), Lin(5, 5))\n\n        def forward(\n            self,\n            src: Tensor,\n            dst: Tensor,\n            edge_attr: Optional[Tensor],\n            u: Optional[Tensor],\n            batch: Optional[Tensor],\n        ) -> Tensor:\n            assert edge_attr is not None\n            assert u is not None\n            assert batch is not None\n            out = torch.cat([src, dst, edge_attr, u[batch]], 1)\n            return self.edge_mlp(out)\n\n    class NodeModel(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.node_mlp_1 = Seq(Lin(15, 10), ReLU(), Lin(10, 10))\n            self.node_mlp_2 = Seq(Lin(2 * 10 + 20, 10), ReLU(), Lin(10, 10))\n\n        def forward(\n            self,\n            x: Tensor,\n            edge_index: Tensor,\n            edge_attr: Optional[Tensor],\n            u: Optional[Tensor],\n            batch: Optional[Tensor],\n        ) -> Tensor:\n            assert edge_attr is not None\n            assert u is not None\n            assert batch is not None\n            row = edge_index[0]\n            col = edge_index[1]\n            out = torch.cat([x[row], edge_attr], dim=1)\n            out = self.node_mlp_1(out)\n            out = scatter(out, col, dim=0, dim_size=x.size(0), reduce='mean')\n            out = torch.cat([x, out, u[batch]], dim=1)\n            return self.node_mlp_2(out)\n\n    class GlobalModel(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.global_mlp = Seq(Lin(20 + 10, 20), ReLU(), Lin(20, 20))\n\n        def forward(\n            self,\n            x: Tensor,\n            edge_index: Tensor,\n            edge_attr: Optional[Tensor],\n            u: Optional[Tensor],\n            batch: Optional[Tensor],\n        ) -> Tensor:\n            assert u is not None\n            assert batch is not None\n            out = torch.cat([\n                u,\n                scatter(x, batch, dim=0, reduce='mean'),\n            ], dim=1)\n            return self.global_mlp(out)\n\n    op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())\n\n    x = torch.randn(20, 10)\n    edge_attr = torch.randn(40, 5)\n    u = torch.randn(2, 20)\n    batch = torch.tensor([0] * 10 + [1] * 10)\n    edge_index = torch.randint(0, high=10, size=(2, 20), dtype=torch.long)\n    edge_index = torch.cat([edge_index, 10 + edge_index], dim=1)\n\n    x_out, edge_attr_out, u_out = op(x, edge_index, edge_attr, u, batch)\n    assert x_out.size() == (20, 10)\n    assert edge_attr_out.size() == (40, 5)\n    assert u_out.size() == (2, 20)\n\n    if is_full_test():\n        jit = torch.jit.script(op)\n\n        x_out, edge_attr_out, u_out = jit(x, edge_index, edge_attr, u, batch)\n        assert x_out.size() == (20, 10)\n        assert edge_attr_out.size() == (40, 5)\n        assert u_out.size() == (2, 20)\n"
  },
  {
    "path": "test/nn/models/test_metapath2vec.py",
    "content": "import torch\n\nfrom torch_geometric.nn import MetaPath2Vec\nfrom torch_geometric.testing import has_package, withDevice\n\n\n@withDevice\ndef test_metapath2vec(device):\n    edge_index_dict = {\n        ('author', 'writes', 'paper'):\n        torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]], device=device),\n        ('paper', 'written_by', 'author'):\n        torch.tensor([[0, 0, 1, 1], [0, 1, 1, 2]], device=device)\n    }\n\n    metapath = [\n        ('author', 'writes', 'paper'),\n        ('paper', 'written_by', 'author'),\n    ]\n\n    model = MetaPath2Vec(edge_index_dict, embedding_dim=16, metapath=metapath,\n                         walk_length=2, context_size=2).to(device)\n    assert str(model) == 'MetaPath2Vec(5, 16)'\n\n    z = model('author')\n    assert z.size() == (3, 16)\n\n    z = model('paper')\n    assert z.size() == (2, 16)\n\n    z = model('author', torch.arange(2, device=device))\n    assert z.size() == (2, 16)\n\n    pos_rw, neg_rw = model._sample(torch.arange(3))\n\n    loss = model.loss(pos_rw.to(device), neg_rw.to(device))\n    assert 0 <= loss.item()\n\n    if has_package('sklearn'):\n        acc = model.test(torch.ones(20, 16), torch.randint(10, (20, )),\n                         torch.ones(20, 16), torch.randint(10, (20, )))\n        assert 0 <= acc and acc <= 1\n\n\ndef test_metapath2vec_empty_edges():\n    num_nodes_dict = {'a': 3, 'b': 4}\n    edge_index_dict = {\n        ('a', 'to', 'b'): torch.empty((2, 0), dtype=torch.long),\n        ('b', 'to', 'a'): torch.empty((2, 0), dtype=torch.long),\n    }\n    metapath = [('a', 'to', 'b'), ('b', 'to', 'a')]\n\n    model = MetaPath2Vec(\n        edge_index_dict,\n        embedding_dim=16,\n        metapath=metapath,\n        walk_length=10,\n        context_size=7,\n        walks_per_node=5,\n        num_negative_samples=5,\n        num_nodes_dict=num_nodes_dict,\n    )\n    loader = model.loader(batch_size=16, shuffle=True)\n    next(iter(loader))\n"
  },
  {
    "path": "test/nn/models/test_mlp.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import MLP\n\n\n@pytest.mark.parametrize('norm', ['batch_norm', None])\n@pytest.mark.parametrize('act_first', [False, True])\n@pytest.mark.parametrize('plain_last', [False, True])\ndef test_mlp(norm, act_first, plain_last):\n    x = torch.randn(4, 16)\n\n    torch.manual_seed(12345)\n    mlp = MLP(\n        [16, 32, 32, 64],\n        norm=norm,\n        act_first=act_first,\n        plain_last=plain_last,\n    )\n    assert str(mlp) == 'MLP(16, 32, 32, 64)'\n    out = mlp(x)\n    assert out.size() == (4, 64)\n\n    jit = torch.jit.script(mlp)\n    assert torch.allclose(jit(x), out)\n\n    torch.manual_seed(12345)\n    mlp = MLP(\n        16,\n        hidden_channels=32,\n        out_channels=64,\n        num_layers=3,\n        norm=norm,\n        act_first=act_first,\n        plain_last=plain_last,\n    )\n    assert torch.allclose(mlp(x), out)\n\n\n@pytest.mark.parametrize('norm', [\n    'BatchNorm',\n    'GraphNorm',\n    'InstanceNorm',\n    'LayerNorm',\n])\ndef test_batch(norm):\n    x = torch.randn(3, 8)\n    batch = torch.tensor([0, 0, 1])\n\n    model = MLP(\n        8,\n        hidden_channels=16,\n        out_channels=32,\n        num_layers=2,\n        norm=norm,\n    )\n    assert model.supports_norm_batch == (norm != 'BatchNorm')\n\n    out = model(x, batch=batch)\n    assert out.size() == (3, 32)\n\n    if model.supports_norm_batch:\n        with pytest.raises(RuntimeError, match=\"out of bounds\"):\n            model(x, batch=batch, batch_size=1)\n\n\ndef test_mlp_return_emb():\n    x = torch.randn(4, 16)\n\n    mlp = MLP([16, 32, 1])\n\n    out, emb = mlp(x, return_emb=True)\n    assert out.size() == (4, 1)\n    assert emb.size() == (4, 32)\n\n    out, emb = mlp(x, return_emb=False)\n    assert out.size() == (4, 1)\n    assert emb is None\n\n\n@pytest.mark.parametrize('plain_last', [False, True])\ndef test_fine_grained_mlp(plain_last):\n    mlp = MLP(\n        [16, 32, 32, 64],\n        dropout=[0.1, 0.2, 0.3],\n        bias=[False, True, False],\n    )\n    assert mlp(torch.randn(4, 16)).size() == (4, 64)\n"
  },
  {
    "path": "test/nn/models/test_neural_fingerprint.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import NeuralFingerprint\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\n\n\n@pytest.mark.parametrize('batch', [None, torch.tensor([0, 1, 1])])\ndef test_neural_fingerprint(batch):\n    x = torch.randn(3, 7)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    model = NeuralFingerprint(7, 16, out_channels=5, num_layers=4)\n    assert str(model) == 'NeuralFingerprint(7, 5, num_layers=4)'\n    model.reset_parameters()\n\n    out = model(x, edge_index, batch)\n    assert out.size() == (1, 5) if batch is None else (2, 5)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(3, 3))\n        assert torch.allclose(model(x, adj.t(), batch), out)\n\n    if is_full_test():\n        jit = torch.jit.export(model)\n        assert torch.allclose(jit(x, edge_index, batch), out)\n"
  },
  {
    "path": "test/nn/models/test_node2vec.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import Node2Vec\nfrom torch_geometric.testing import (\n    has_package,\n    is_full_test,\n    withDevice,\n    withPackage,\n)\n\n\n@withDevice\n@withPackage('pyg_lib|torch_cluster')\n@pytest.mark.parametrize('p', [1.0])\n@pytest.mark.parametrize('q', [1.0, 0.5])\ndef test_node2vec(device, p, q):\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)\n    kwargs = dict(embedding_dim=16, walk_length=2, context_size=2, p=p, q=q)\n\n    if not torch_geometric.typing.WITH_TORCH_CLUSTER and q != 1.0:\n        with pytest.raises(ImportError, match=\"requires the 'torch-cluster'\"):\n            model = Node2Vec(edge_index, **kwargs)\n        return\n\n    model = Node2Vec(edge_index, **kwargs).to(device)\n    assert str(model) == 'Node2Vec(3, 16)'\n\n    assert model(torch.arange(3, device=device)).size() == (3, 16)\n\n    pos_rw, neg_rw = model.sample(torch.arange(3))\n    assert float(model.loss(pos_rw.to(device), neg_rw.to(device))) >= 0\n\n    if has_package('sklearn'):\n        acc = model.test(torch.ones(20, 16), torch.randint(10, (20, )),\n                         torch.ones(20, 16), torch.randint(10, (20, )))\n        assert 0 <= acc and acc <= 1\n\n    if is_full_test():\n        jit = torch.jit.script(model)\n\n        assert jit(torch.arange(3, device=device)).size() == (3, 16)\n\n        pos_rw, neg_rw = jit.sample(torch.arange(3))\n        assert float(jit.loss(pos_rw.to(device), neg_rw.to(device))) >= 0\n"
  },
  {
    "path": "test/nn/models/test_pmlp.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn.models import PMLP\n\n\ndef test_pmlp():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    pmlp = PMLP(in_channels=16, hidden_channels=32, out_channels=2,\n                num_layers=4)\n    assert str(pmlp) == 'PMLP(16, 2, num_layers=4)'\n\n    pmlp.training = True\n    assert pmlp(x).size() == (4, 2)\n\n    pmlp.training = False\n    assert pmlp(x, edge_index).size() == (4, 2)\n\n    with pytest.raises(ValueError, match=\"'edge_index' needs to be present\"):\n        pmlp.training = False\n        pmlp(x)\n"
  },
  {
    "path": "test/nn/models/test_polynormer.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn.models import Polynormer\n\n\n@pytest.mark.parametrize('local_attn', [True, False])\n@pytest.mark.parametrize('qk_shared', [True, False])\n@pytest.mark.parametrize('pre_ln', [True, False])\n@pytest.mark.parametrize('post_bn', [True, False])\ndef test_polynormer(local_attn, qk_shared, pre_ln, post_bn):\n    x = torch.randn(10, 16)\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n        [1, 2, 3, 4, 0, 6, 7, 8, 9, 5],\n    ])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])\n\n    model = Polynormer(\n        in_channels=16,\n        hidden_channels=128,\n        out_channels=40,\n        qk_shared=qk_shared,\n        pre_ln=pre_ln,\n        post_bn=post_bn,\n        local_attn=local_attn,\n    )\n    out = model(x, edge_index, batch)\n    assert out.size() == (10, 40)\n    model._global = True\n    out = model(x, edge_index, batch)\n    assert out.size() == (10, 40)\n"
  },
  {
    "path": "test/nn/models/test_re_net.py",
    "content": "import torch\n\nfrom torch_geometric.datasets.icews import EventDataset\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.nn import RENet\nfrom torch_geometric.testing import is_full_test\n\n\nclass MyTestEventDataset(EventDataset):\n    def __init__(self, root, seq_len):\n        super().__init__(root, pre_transform=RENet.pre_transform(seq_len))\n        self.load(self.processed_paths[0])\n\n    @property\n    def num_nodes(self):\n        return 16\n\n    @property\n    def num_rels(self):\n        return 8\n\n    @property\n    def processed_file_names(self):\n        return 'data.pt'\n\n    def _download(self):\n        pass\n\n    def process_events(self):\n        sub = torch.randint(self.num_nodes, (64, ), dtype=torch.long)\n        rel = torch.randint(self.num_rels, (64, ), dtype=torch.long)\n        obj = torch.randint(self.num_nodes, (64, ), dtype=torch.long)\n        t = torch.arange(8, dtype=torch.long).view(-1, 1).repeat(1, 8).view(-1)\n        return torch.stack([sub, rel, obj, t], dim=1)\n\n    def process(self):\n        data_list = self._process_data_list()\n        self.save(data_list, self.processed_paths[0])\n\n\ndef test_re_net(tmp_path):\n    dataset = MyTestEventDataset(tmp_path, seq_len=4)\n    loader = DataLoader(dataset, 2, follow_batch=['h_sub', 'h_obj'])\n\n    model = RENet(dataset.num_nodes, dataset.num_rels, hidden_channels=16,\n                  seq_len=4)\n\n    if is_full_test():\n        jit = torch.jit.export(model)\n\n    logits = torch.randn(6, 6)\n    y = torch.tensor([0, 1, 2, 3, 4, 5])\n\n    mrr, hits1, hits3, hits10 = model.test(logits, y)\n    assert 0.15 < mrr <= 1\n    assert hits1 <= hits3 and hits3 <= hits10 and hits10 == 1\n\n    for data in loader:\n        log_prob_obj, log_prob_sub = model(data)\n        if is_full_test():\n            log_prob_obj_jit, log_prob_sub_jit = jit(data)\n            assert torch.allclose(log_prob_obj_jit, log_prob_obj)\n            assert torch.allclose(log_prob_sub_jit, log_prob_sub)\n        model.test(log_prob_obj, data.obj)\n        model.test(log_prob_sub, data.sub)\n"
  },
  {
    "path": "test/nn/models/test_rect.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import RECT_L\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.typing import SparseTensor\n\n\ndef test_rect():\n    x = torch.randn(6, 8)\n    y = torch.tensor([1, 0, 0, 2, 1, 1])\n    edge_index = torch.tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]])\n    mask = torch.randint(0, 2, (6, ), dtype=torch.bool)\n\n    model = RECT_L(8, 16)\n    assert str(model) == 'RECT_L(8, 16)'\n\n    out = model(x, edge_index)\n    assert out.size() == (6, 8)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=(6, 6))\n        assert torch.allclose(out, model(x, adj.t()), atol=1e-6)\n\n    # Test `embed`:\n    embed_out = model.embed(x, edge_index)\n    assert embed_out.size() == (6, 16)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert torch.allclose(embed_out, model.embed(x, adj.t()), atol=1e-6)\n\n    # Test `get_semantic_labels`:\n    labels_out = model.get_semantic_labels(x, y, mask)\n    assert labels_out.size() == (int(mask.sum()), 8)\n\n    if is_full_test():\n        jit = torch.jit.script(model)\n        assert torch.allclose(jit(x, edge_index), out, atol=1e-6)\n        assert torch.allclose(embed_out, jit.embed(x, edge_index), atol=1e-6)\n        assert torch.allclose(labels_out, jit.get_semantic_labels(x, y, mask))\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            assert torch.allclose(jit(x, adj.t()), out, atol=1e-6)\n            assert torch.allclose(embed_out, jit.embed(x, adj.t()), atol=1e-6)\n            assert torch.allclose(labels_out,\n                                  jit.get_semantic_labels(x, y, mask))\n"
  },
  {
    "path": "test/nn/models/test_rev_gnn.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GraphConv, GroupAddRev, SAGEConv\nfrom torch_geometric.nn.dense.linear import Linear\n\n\n@pytest.mark.parametrize('num_groups', [2, 4, 8, 16])\ndef test_revgnn_forward_inverse(num_groups):\n    x = torch.randn(4, 32)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    lin = Linear(32, 32)\n    conv = SAGEConv(32 // num_groups, 32 // num_groups)\n    conv = GroupAddRev(conv, num_groups=num_groups)\n    assert str(conv) == (f'GroupAddRev(SAGEConv({32 // num_groups}, '\n                         f'{32 // num_groups}, aggr=mean), '\n                         f'num_groups={num_groups})')\n\n    h = lin(x)\n    h_o = h.clone().detach()\n\n    out = conv(h, edge_index)\n    if torch_geometric.typing.WITH_PT20:\n        assert h.untyped_storage().size() == 0\n    else:\n        assert h.storage().size() == 0\n\n    h_rev = conv.inverse(out, edge_index)\n    assert torch.allclose(h_o, h_rev, atol=0.001)\n\n\n@pytest.mark.parametrize('num_groups', [2, 4, 8, 16])\ndef test_revgnn_backward(num_groups):\n    x = torch.randn(4, 32)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    lin = Linear(32, 32)\n    conv = SAGEConv(32 // num_groups, 32 // num_groups)\n    conv = GroupAddRev(conv, num_groups=num_groups)\n\n    h = lin(x)\n    out = conv(h, edge_index)\n    target = out.mean()\n    target.backward()\n\n\n@pytest.mark.parametrize('num_groups', [2, 4, 8, 16])\ndef test_revgnn_multi_backward(num_groups):\n    x = torch.randn(4, 32)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    lin = Linear(32, 32)\n    conv = SAGEConv(32 // num_groups, 32 // num_groups)\n    conv = GroupAddRev(conv, num_groups=num_groups, num_bwd_passes=4)\n\n    h = lin(x)\n    out = conv(h, edge_index)\n    target = out.mean()\n    target.backward(retain_graph=True)\n    target.backward(retain_graph=True)\n    torch.autograd.grad(outputs=target, inputs=[h] + list(conv.parameters()),\n                        retain_graph=True)\n    torch.autograd.grad(outputs=target, inputs=[h] + list(conv.parameters()))\n\n\n@pytest.mark.parametrize('num_groups', [2, 4, 8, 16])\ndef test_revgnn_diable(num_groups):\n    x = torch.randn(4, 32)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n\n    lin = Linear(32, 32)\n    conv = SAGEConv(32 // num_groups, 32 // num_groups)\n    conv = GroupAddRev(conv, num_groups=num_groups, disable=True)\n\n    h = lin(x)\n    out = conv(h, edge_index)\n    target = out.mean()\n    target.backward()\n\n    # Memory will not be freed if disable:\n    if torch_geometric.typing.WITH_PT20:\n        assert h.untyped_storage().size() == 4 * 4 * 32\n    else:\n        assert h.storage().size() == 4 * 32\n\n\n@pytest.mark.parametrize('num_groups', [2, 4, 8, 16])\ndef test_revgnn_with_args(num_groups):\n    x = torch.randn(4, 32)\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    edge_weight = torch.rand(4)\n\n    lin = Linear(32, 32)\n    conv = GraphConv(32 // num_groups, 32 // num_groups)\n    conv = GroupAddRev(conv, num_groups=num_groups)\n\n    h = lin(x)\n    out = conv(h, edge_index, edge_weight)\n    target = out.mean()\n    target.backward()\n"
  },
  {
    "path": "test/nn/models/test_schnet.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Batch, Data\nfrom torch_geometric.nn import SchNet\nfrom torch_geometric.nn.models.schnet import RadiusInteractionGraph\nfrom torch_geometric.testing import is_full_test, withPackage\n\n\ndef generate_data():\n    return Data(\n        z=torch.randint(1, 10, (20, )),\n        pos=torch.randn(20, 3),\n    )\n\n\n@withPackage('torch_cluster')\n@withPackage('ase')\n@pytest.mark.parametrize('use_interaction_graph', [False, True])\n@pytest.mark.parametrize('use_atomref', [False, True])\ndef test_schnet(use_interaction_graph, use_atomref):\n    data = generate_data()\n\n    interaction_graph = None\n    if use_interaction_graph:\n        interaction_graph = RadiusInteractionGraph(cutoff=6.0)\n\n    model = SchNet(\n        hidden_channels=16,\n        num_filters=16,\n        num_interactions=2,\n        interaction_graph=interaction_graph,\n        num_gaussians=10,\n        cutoff=6.0,\n        dipole=True,\n        atomref=torch.randn(100, 1) if use_atomref else None,\n    )\n\n    assert str(model) == ('SchNet(hidden_channels=16, num_filters=16, '\n                          'num_interactions=2, num_gaussians=10, cutoff=6.0)')\n\n    with torch.no_grad():\n        out = model(data.z, data.pos)\n        assert out.size() == (1, 1)\n\n        if is_full_test():\n            jit = torch.jit.export(model)\n            out = jit(data.z, data.pos)\n            assert out.size() == (1, 1)\n\n\n@withPackage('torch_cluster')\ndef test_schnet_batch():\n    num_graphs = 3\n    batch = [generate_data() for _ in range(num_graphs)]\n    batch = Batch.from_data_list(batch)\n\n    model = SchNet(\n        hidden_channels=16,\n        num_filters=16,\n        num_interactions=2,\n        num_gaussians=10,\n        cutoff=6.0,\n    )\n\n    with torch.no_grad():\n        out = model(batch.z, batch.pos, batch.batch)\n        assert out.size() == (num_graphs, 1)\n"
  },
  {
    "path": "test/nn/models/test_sgformer.py",
    "content": "import torch\n\nfrom torch_geometric.nn.models import SGFormer\n\n\ndef test_sgformer():\n    x = torch.randn(10, 16)\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n        [1, 2, 3, 4, 0, 6, 7, 8, 9, 5],\n    ])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1])\n\n    model = SGFormer(\n        in_channels=16,\n        hidden_channels=128,\n        out_channels=40,\n    )\n    out = model(x, edge_index, batch)\n    assert out.size() == (10, 40)\n"
  },
  {
    "path": "test/nn/models/test_signed_gcn.py",
    "content": "import torch\n\nfrom torch_geometric.nn import SignedGCN\nfrom torch_geometric.testing import has_package, is_full_test\n\n\n# @withPackage('sklearn')\ndef test_signed_gcn():\n    model = SignedGCN(8, 16, num_layers=2, lamb=5)\n    assert str(model) == 'SignedGCN(8, 16, num_layers=2)'\n\n    pos_index = torch.randint(high=10, size=(2, 40), dtype=torch.long)\n    neg_index = torch.randint(high=10, size=(2, 40), dtype=torch.long)\n\n    train_pos_index, test_pos_index = model.split_edges(pos_index)\n    train_neg_index, test_neg_index = model.split_edges(neg_index)\n\n    assert train_pos_index.size() == (2, 32)\n    assert test_pos_index.size() == (2, 8)\n    assert train_neg_index.size() == (2, 32)\n    assert test_neg_index.size() == (2, 8)\n\n    if has_package('sklearn'):\n        x = model.create_spectral_features(\n            train_pos_index,\n            train_neg_index,\n            num_nodes=10,\n        )\n        assert x.size() == (10, 8)\n    else:\n        x = torch.randn(10, 8)\n\n    z = model(x, train_pos_index, train_neg_index)\n    assert z.size() == (10, 16)\n\n    loss = model.loss(z, train_pos_index, train_neg_index)\n    assert loss.item() >= 0\n\n    if has_package('sklearn'):\n        auc, f1 = model.test(z, test_pos_index, test_neg_index)\n        assert auc >= 0\n        assert f1 >= 0\n\n    if is_full_test():\n        jit = torch.jit.export(model)\n        assert torch.allclose(jit(x, train_pos_index, train_neg_index), z)\n"
  },
  {
    "path": "test/nn/models/test_tgn.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import TemporalData\nfrom torch_geometric.loader import TemporalDataLoader\nfrom torch_geometric.nn import TGNMemory\nfrom torch_geometric.nn.models.tgn import (\n    IdentityMessage,\n    LastAggregator,\n    LastNeighborLoader,\n)\n\n\n@pytest.mark.parametrize('neg_sampling_ratio', [0.0, 1.0])\ndef test_tgn(neg_sampling_ratio):\n    memory_dim = 16\n    time_dim = 16\n\n    src = torch.tensor([0, 1, 0, 2, 0, 3, 1, 4, 2, 3])\n    dst = torch.tensor([1, 2, 1, 1, 3, 2, 4, 3, 3, 4])\n    t = torch.arange(10)\n    msg = torch.randn(10, 16)\n    data = TemporalData(src=src, dst=dst, t=t, msg=msg)\n\n    loader = TemporalDataLoader(\n        data,\n        batch_size=5,\n        neg_sampling_ratio=neg_sampling_ratio,\n    )\n    neighbor_loader = LastNeighborLoader(data.num_nodes, size=3)\n    assert neighbor_loader.cur_e_id == 0\n    assert neighbor_loader.e_id.size() == (data.num_nodes, 3)\n\n    memory = TGNMemory(\n        num_nodes=data.num_nodes,\n        raw_msg_dim=msg.size(-1),\n        memory_dim=memory_dim,\n        time_dim=time_dim,\n        message_module=IdentityMessage(msg.size(-1), memory_dim, time_dim),\n        aggregator_module=LastAggregator(),\n    )\n    assert memory.memory.size() == (data.num_nodes, memory_dim)\n    assert memory.last_update.size() == (data.num_nodes, )\n\n    # Test TGNMemory training:\n    for i, batch in enumerate(loader):\n        n_id, edge_index, e_id = neighbor_loader(batch.n_id)\n        z, last_update = memory(n_id)\n        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)\n        neighbor_loader.insert(batch.src, batch.dst)\n        if i == 0:\n            assert n_id.size(0) >= 4\n            assert edge_index.numel() == 0\n            assert e_id.numel() == 0\n            assert z.size() == (n_id.size(0), memory_dim)\n            assert torch.sum(last_update) == 0\n        else:\n            assert n_id.size(0) == 5\n            assert edge_index.numel() == 12\n            assert e_id.numel() == 6\n            assert z.size() == (n_id.size(0), memory_dim)\n            assert torch.equal(last_update, torch.tensor([4, 3, 3, 4, 0]))\n\n    # Test TGNMemory inference:\n    memory.eval()\n    all_n_id = torch.arange(data.num_nodes)\n    z, last_update = memory(all_n_id)\n    assert z.size() == (data.num_nodes, memory_dim)\n    assert torch.equal(last_update, torch.tensor([4, 6, 8, 9, 9]))\n\n    post_src = torch.tensor([3, 4])\n    post_dst = torch.tensor([4, 3])\n    post_t = torch.tensor([10, 10])\n    post_msg = torch.randn(2, 16)\n    memory.update_state(post_src, post_dst, post_t, post_msg)\n    post_z, post_last_update = memory(all_n_id)\n    assert torch.allclose(z[0:3], post_z[0:3])\n    assert torch.equal(post_last_update, torch.tensor([4, 6, 8, 10, 10]))\n\n    memory.reset_state()\n    assert memory.memory.sum() == 0\n    assert memory.last_update.sum() == 0\n"
  },
  {
    "path": "test/nn/models/test_visnet.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import ViSNet\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('torch_cluster')\n@pytest.mark.parametrize('kwargs', [\n    dict(lmax=2, derivative=True, vecnorm_type=None, vertex=False),\n    dict(lmax=1, derivative=False, vecnorm_type='max_min', vertex=True),\n])\ndef test_visnet(kwargs):\n    z = torch.randint(1, 10, (20, ))\n    pos = torch.randn(20, 3)\n    batch = torch.zeros(20, dtype=torch.long)\n\n    model = ViSNet(**kwargs)\n\n    model.reset_parameters()\n\n    energy, forces = model(z, pos, batch)\n\n    assert energy.size() == (1, 1)\n\n    if kwargs['derivative']:\n        assert forces.size() == (20, 3)\n    else:\n        assert forces is None\n"
  },
  {
    "path": "test/nn/norm/test_batch_norm.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import BatchNorm, HeteroBatchNorm\nfrom torch_geometric.testing import is_full_test, withDevice\n\n\n@withDevice\n@pytest.mark.parametrize('conf', [True, False])\ndef test_batch_norm(device, conf):\n    x = torch.randn(100, 16, device=device)\n\n    norm = BatchNorm(16, affine=conf, track_running_stats=conf, device=device)\n    norm.reset_running_stats()\n    norm.reset_parameters()\n    assert str(norm) == (f'BatchNorm(16, eps=1e-05, momentum=0.1, '\n                         f'affine={conf}, track_running_stats={conf})')\n\n    if is_full_test():\n        torch.jit.script(norm)\n\n    out = norm(x)\n    assert out.size() == (100, 16)\n\n\ndef test_batch_norm_single_element():\n    x = torch.randn(1, 16)\n\n    norm = BatchNorm(16)\n    with pytest.raises(ValueError, match=\"Expected more than 1 value\"):\n        norm(x)\n\n    with pytest.raises(ValueError, match=\"requires 'track_running_stats'\"):\n        norm = BatchNorm(16, track_running_stats=False,\n                         allow_single_element=True)\n\n    norm = BatchNorm(16, track_running_stats=True, allow_single_element=True)\n    out = norm(x)\n    assert torch.allclose(out, x)\n\n\n@withDevice\n@pytest.mark.parametrize('conf', [True, False])\ndef test_hetero_batch_norm(device, conf):\n    x = torch.randn((100, 16), device=device)\n\n    # Test single type:\n    norm = BatchNorm(16, affine=conf, track_running_stats=conf, device=device)\n    expected = norm(x)\n\n    type_vec = torch.zeros(100, dtype=torch.long, device=device)\n    norm = HeteroBatchNorm(16, num_types=1, affine=conf,\n                           track_running_stats=conf, device=device)\n    norm.reset_running_stats()\n    norm.reset_parameters()\n    assert str(norm) == 'HeteroBatchNorm(16, num_types=1)'\n\n    out = norm(x, type_vec)\n    assert out.size() == (100, 16)\n    assert torch.allclose(out, expected, atol=1e-3)\n\n    # Test multiple types:\n    type_vec = torch.randint(5, (100, ), device=device)\n    norm = HeteroBatchNorm(16, num_types=5, affine=conf,\n                           track_running_stats=conf, device=device)\n    out = norm(x, type_vec)\n    assert out.size() == (100, 16)\n\n    for i in range(5):  # Check that mean=0 and std=1 across all types:\n        mean = out[type_vec == i].mean()\n        std = out[type_vec == i].std(unbiased=False)\n        assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-7)\n        assert torch.allclose(std, torch.ones_like(std), atol=1e-7)\n"
  },
  {
    "path": "test/nn/norm/test_diff_group_norm.py",
    "content": "import torch\n\nfrom torch_geometric.nn import DiffGroupNorm\nfrom torch_geometric.testing import is_full_test, withDevice\n\n\n@withDevice\ndef test_diff_group_norm(device):\n    x = torch.randn(6, 16, device=device)\n\n    norm = DiffGroupNorm(16, groups=4, lamda=0, device=device)\n    assert str(norm) == 'DiffGroupNorm(16, groups=4)'\n\n    assert torch.allclose(norm(x), x)\n\n    if is_full_test():\n        jit = torch.jit.script(norm)\n        assert torch.allclose(jit(x), x)\n\n    norm = DiffGroupNorm(16, groups=4, lamda=0.01, device=device)\n    assert str(norm) == 'DiffGroupNorm(16, groups=4)'\n\n    out = norm(x)\n    assert out.size() == x.size()\n\n    if is_full_test():\n        jit = torch.jit.script(norm)\n        assert torch.allclose(jit(x), out)\n\n\ndef test_group_distance_ratio():\n    x = torch.randn(6, 16)\n    y = torch.tensor([0, 1, 0, 1, 1, 1])\n\n    assert DiffGroupNorm.group_distance_ratio(x, y) > 0\n\n    if is_full_test():\n        jit = torch.jit.script(DiffGroupNorm.group_distance_ratio)\n        assert jit(x, y) > 0\n"
  },
  {
    "path": "test/nn/norm/test_graph_norm.py",
    "content": "import torch\n\nfrom torch_geometric.nn import GraphNorm\nfrom torch_geometric.testing import is_full_test, withDevice\n\n\n@withDevice\ndef test_graph_norm(device):\n    torch.manual_seed(42)\n    x = torch.randn(200, 16, device=device)\n    batch = torch.arange(4, device=device).view(-1, 1).repeat(1, 50).view(-1)\n\n    norm = GraphNorm(16, device=device)\n    assert str(norm) == 'GraphNorm(16)'\n\n    if is_full_test():\n        torch.jit.script(norm)\n\n    out = norm(x)\n    assert out.size() == (200, 16)\n    assert torch.allclose(out.mean(dim=0), torch.zeros(16, device=device),\n                          atol=1e-6)\n    assert torch.allclose(out.std(dim=0, unbiased=False),\n                          torch.ones(16, device=device), atol=1e-6)\n\n    out = norm(x, batch)\n    assert out.size() == (200, 16)\n    assert torch.allclose(out[:50].mean(dim=0), torch.zeros(16, device=device),\n                          atol=1e-6)\n    assert torch.allclose(out[:50].std(dim=0, unbiased=False),\n                          torch.ones(16, device=device), atol=1e-6)\n"
  },
  {
    "path": "test/nn/norm/test_graph_size_norm.py",
    "content": "import torch\n\nfrom torch_geometric.nn import GraphSizeNorm\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_graph_size_norm():\n    x = torch.randn(100, 16)\n    batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long))\n\n    norm = GraphSizeNorm()\n    assert str(norm) == 'GraphSizeNorm()'\n\n    out = norm(x, batch)\n    assert out.size() == (100, 16)\n\n    if is_full_test():\n        jit = torch.jit.script(norm)\n        assert torch.allclose(jit(x, batch), out)\n"
  },
  {
    "path": "test/nn/norm/test_instance_norm.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import InstanceNorm\nfrom torch_geometric.testing import is_full_test, withDevice\n\n\n@withDevice\n@pytest.mark.parametrize('conf', [True, False])\ndef test_instance_norm(conf, device):\n    batch = torch.zeros(100, dtype=torch.long, device=device)\n\n    x1 = torch.randn(100, 16, device=device)\n    x2 = torch.randn(100, 16, device=device)\n\n    norm1 = InstanceNorm(16, affine=conf, track_running_stats=conf,\n                         device=device)\n    norm2 = InstanceNorm(16, affine=conf, track_running_stats=conf,\n                         device=device)\n    assert str(norm1) == 'InstanceNorm(16)'\n\n    if is_full_test():\n        torch.jit.script(norm1)\n\n    out1 = norm1(x1)\n    out2 = norm2(x1, batch)\n    assert out1.size() == (100, 16)\n    assert torch.allclose(out1, out2, atol=1e-7)\n    if conf:\n        assert torch.allclose(norm1.running_mean, norm2.running_mean)\n        assert torch.allclose(norm1.running_var, norm2.running_var)\n\n    out1 = norm1(x2)\n    out2 = norm2(x2, batch)\n    assert torch.allclose(out1, out2, atol=1e-7)\n    if conf:\n        assert torch.allclose(norm1.running_mean, norm2.running_mean)\n        assert torch.allclose(norm1.running_var, norm2.running_var)\n\n    norm1.eval()\n    norm2.eval()\n\n    out1 = norm1(x1)\n    out2 = norm2(x1, batch)\n    assert torch.allclose(out1, out2, atol=1e-7)\n\n    out1 = norm1(x2)\n    out2 = norm2(x2, batch)\n    assert torch.allclose(out1, out2, atol=1e-7)\n\n    out1 = norm2(x1)\n    out2 = norm2(x2)\n    out3 = norm2(torch.cat([x1, x2], dim=0), torch.cat([batch, batch + 1]))\n    assert torch.allclose(out1, out3[:100], atol=1e-7)\n    assert torch.allclose(out2, out3[100:], atol=1e-7)\n"
  },
  {
    "path": "test/nn/norm/test_layer_norm.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import HeteroLayerNorm, LayerNorm\nfrom torch_geometric.testing import is_full_test, withDevice\n\n\n@withDevice\n@pytest.mark.parametrize('affine', [True, False])\n@pytest.mark.parametrize('mode', ['graph', 'node'])\ndef test_layer_norm(device, affine, mode):\n    x = torch.randn(100, 16, device=device)\n    batch = torch.zeros(100, dtype=torch.long, device=device)\n\n    norm = LayerNorm(16, affine=affine, mode=mode, device=device)\n    assert str(norm) == f'LayerNorm(16, affine={affine}, mode={mode})'\n\n    if is_full_test():\n        torch.jit.script(norm)\n\n    out1 = norm(x)\n    assert out1.size() == (100, 16)\n    assert torch.allclose(norm(x, batch), out1, atol=1e-6)\n\n    out2 = norm(torch.cat([x, x], dim=0), torch.cat([batch, batch + 1], dim=0))\n    assert torch.allclose(out1, out2[:100], atol=1e-6)\n    assert torch.allclose(out1, out2[100:], atol=1e-6)\n\n\n@withDevice\n@pytest.mark.parametrize('affine', [False, True])\ndef test_hetero_layer_norm(device, affine):\n    x = torch.randn((100, 16), device=device)\n    expected = LayerNorm(16, affine=affine, mode='node', device=device)(x)\n\n    # Test single type:\n    type_vec = torch.zeros(100, dtype=torch.long, device=device)\n    type_ptr = [0, 100]\n\n    norm = HeteroLayerNorm(16, num_types=1, affine=affine, device=device)\n    assert str(norm) == 'HeteroLayerNorm(16, num_types=1)'\n\n    out = norm(x, type_vec)\n    assert out.size() == (100, 16)\n    assert torch.allclose(out, expected, atol=1e-3)\n    assert torch.allclose(norm(out, type_ptr=type_ptr), expected, atol=1e-3)\n\n    mean = out.mean(dim=-1)\n    std = out.std(unbiased=False, dim=-1)\n    assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-2)\n    assert torch.allclose(std, torch.ones_like(std), atol=1e-2)\n\n    # Test multiple types:\n    type_vec = torch.arange(5, device=device)\n    type_vec = type_vec.view(-1, 1).repeat(1, 20).view(-1)\n    type_ptr = [0, 20, 40, 60, 80, 100]\n\n    norm = HeteroLayerNorm(16, num_types=5, affine=affine, device=device)\n    assert str(norm) == 'HeteroLayerNorm(16, num_types=5)'\n\n    out = norm(x, type_vec)\n    assert out.size() == (100, 16)\n    assert torch.allclose(out, expected, atol=1e-3)\n    assert torch.allclose(norm(out, type_ptr=type_ptr), expected, atol=1e-3)\n\n    mean = out.mean(dim=-1)\n    std = out.std(unbiased=False, dim=-1)\n    assert torch.allclose(mean, torch.zeros_like(mean), atol=1e-2)\n    assert torch.allclose(std, torch.ones_like(std), atol=1e-2)\n"
  },
  {
    "path": "test/nn/norm/test_mean_subtraction_norm.py",
    "content": "import torch\n\nfrom torch_geometric.nn import MeanSubtractionNorm\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_mean_subtraction_norm():\n    x = torch.randn(6, 16)\n    batch = torch.tensor([0, 0, 1, 1, 1, 2])\n\n    norm = MeanSubtractionNorm()\n    assert str(norm) == 'MeanSubtractionNorm()'\n\n    if is_full_test():\n        torch.jit.script(norm)\n\n    out = norm(x)\n    assert out.size() == (6, 16)\n    assert torch.allclose(out.mean(), torch.tensor(0.), atol=1e-6)\n\n    out = norm(x, batch)\n    assert out.size() == (6, 16)\n    assert torch.allclose(out[0:2].mean(), torch.tensor(0.), atol=1e-6)\n    assert torch.allclose(out[0:2].mean(), torch.tensor(0.), atol=1e-6)\n"
  },
  {
    "path": "test/nn/norm/test_msg_norm.py",
    "content": "import torch\n\nfrom torch_geometric.nn import MessageNorm\nfrom torch_geometric.testing import is_full_test, withDevice\n\n\n@withDevice\ndef test_message_norm(device):\n    norm = MessageNorm(learn_scale=True, device=device)\n    assert str(norm) == 'MessageNorm(learn_scale=True)'\n    x = torch.randn(100, 16, device=device)\n    msg = torch.randn(100, 16, device=device)\n    out = norm(x, msg)\n    assert out.size() == (100, 16)\n\n    if is_full_test():\n        jit = torch.jit.script(norm)\n        assert torch.allclose(jit(x, msg), out)\n\n    norm = MessageNorm(learn_scale=False, device=device)\n    assert str(norm) == 'MessageNorm(learn_scale=False)'\n    out = norm(x, msg)\n    assert out.size() == (100, 16)\n\n    if is_full_test():\n        jit = torch.jit.script(norm)\n        assert torch.allclose(jit(x, msg), out)\n"
  },
  {
    "path": "test/nn/norm/test_pair_norm.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import PairNorm\nfrom torch_geometric.testing import is_full_test\n\n\n@pytest.mark.parametrize('scale_individually', [False, True])\ndef test_pair_norm(scale_individually):\n    x = torch.randn(100, 16)\n    batch = torch.zeros(100, dtype=torch.long)\n\n    norm = PairNorm(scale_individually=scale_individually)\n    assert str(norm) == 'PairNorm()'\n\n    if is_full_test():\n        torch.jit.script(norm)\n\n    out1 = norm(x)\n    assert out1.size() == (100, 16)\n\n    out2 = norm(torch.cat([x, x], dim=0), torch.cat([batch, batch + 1], dim=0))\n    assert torch.allclose(out1, out2[:100], atol=1e-6)\n    assert torch.allclose(out1, out2[100:], atol=1e-6)\n"
  },
  {
    "path": "test/nn/pool/connect/test_filter_edges.py",
    "content": "import torch\n\nfrom torch_geometric.nn.pool.connect import FilterEdges\nfrom torch_geometric.nn.pool.select import SelectOutput\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_filter_edges():\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 1, 3, 2, 2]])\n    edge_attr = torch.tensor([1, 2, 3, 4, 5, 6])\n    batch = torch.tensor([0, 0, 1, 1])\n\n    select_output = SelectOutput(\n        node_index=torch.tensor([1, 2]),\n        num_nodes=4,\n        cluster_index=torch.tensor([0, 1]),\n        num_clusters=2,\n    )\n\n    connect = FilterEdges()\n    assert str(connect) == 'FilterEdges()'\n\n    out1 = connect(select_output, edge_index, edge_attr, batch)\n    assert out1.edge_index.tolist() == [[0, 1], [0, 1]]\n    assert out1.edge_attr.tolist() == [3, 5]\n    assert out1.batch.tolist() == [0, 1]\n\n    if is_full_test():\n        jit = torch.jit.script(connect)\n        out2 = jit(select_output, edge_index, edge_attr, batch)\n        torch.equal(out1.edge_index, out2.edge_index)\n        torch.equal(out1.edge_attr, out2.edge_attr)\n        torch.equal(out1.batch, out2.batch)\n"
  },
  {
    "path": "test/nn/pool/select/test_select_topk.py",
    "content": "from itertools import product\n\nimport pytest\nimport torch\n\nfrom torch_geometric.nn.pool.select import SelectOutput, SelectTopK\nfrom torch_geometric.nn.pool.select.topk import topk\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_topk_ratio():\n    x = torch.tensor([2.0, 4.0, 5.0, 6.0, 2.0, 9.0])\n    batch = torch.tensor([0, 0, 1, 1, 1, 1])\n\n    perm1 = topk(x, 0.5, batch)\n    assert perm1.tolist() == [1, 5, 3]\n    assert x[perm1].tolist() == [4.0, 9.0, 6.0]\n    assert batch[perm1].tolist() == [0, 1, 1]\n\n    perm2 = topk(x, 2, batch)\n    assert perm2.tolist() == [1, 0, 5, 3]\n    assert x[perm2].tolist() == [4.0, 2.0, 9.0, 6.0]\n    assert batch[perm2].tolist() == [0, 0, 1, 1]\n\n    perm3 = topk(x, 3, batch)\n    assert perm3.tolist() == [1, 0, 5, 3, 2]\n    assert x[perm3].tolist() == [4.0, 2.0, 9.0, 6.0, 5.0]\n    assert batch[perm3].tolist() == [0, 0, 1, 1, 1]\n\n    if is_full_test():\n        jit = torch.jit.script(topk)\n        assert torch.equal(jit(x, 0.5, batch), perm1)\n        assert torch.equal(jit(x, 2, batch), perm2)\n        assert torch.equal(jit(x, 3, batch), perm3)\n\n\n@pytest.mark.parametrize('min_score', [None, 2.0])\ndef test_select_topk(min_score):\n    x = torch.randn(6, 16)\n    batch = torch.tensor([0, 0, 1, 1, 1, 1])\n\n    pool = SelectTopK(16, min_score=min_score)\n\n    if min_score is None:\n        assert str(pool) == 'SelectTopK(16, ratio=0.5)'\n    else:\n        assert str(pool) == 'SelectTopK(16, min_score=2.0)'\n\n    out = pool(x, batch)\n    assert isinstance(out, SelectOutput)\n\n    assert out.num_nodes == 6\n    assert out.num_clusters <= out.num_nodes\n    assert out.node_index.min() >= 0\n    assert out.node_index.max() < out.num_nodes\n    assert out.cluster_index.min() == 0\n    assert out.cluster_index.max() == out.num_clusters - 1\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    args = parser.parse_args()\n\n    BS = [2**i for i in range(6, 8)]\n    NS = [2**i for i in range(8, 16)]\n\n    funcs = []\n    func_names = []\n    args_list = []\n    for B, N in product(BS, NS):\n        x = torch.randn(N, device=args.device)\n        batch = torch.randint(0, B, (N, ), device=args.device).sort()[0]\n\n        funcs.append(topk)\n        func_names.append(f'B={B}, N={N}')\n        args_list.append((x, 0.5, batch))\n\n    benchmark(\n        funcs=funcs,\n        func_names=func_names,\n        args=args_list,\n        num_steps=50 if args.device == 'cpu' else 500,\n        num_warmups=10 if args.device == 'cpu' else 100,\n        progress_bar=True,\n    )\n"
  },
  {
    "path": "test/nn/pool/test_approx_knn.py",
    "content": "import warnings\n\nimport torch\n\nfrom torch_geometric.nn import approx_knn, approx_knn_graph\nfrom torch_geometric.testing import onlyFullTest, withPackage\n\n\ndef to_set(edge_index):\n    return {(i, j) for i, j in edge_index.t().tolist()}\n\n\n@onlyFullTest  # JIT compile makes this test too slow :(\n@withPackage('pynndescent')\ndef test_approx_knn():\n    warnings.filterwarnings('ignore', '.*find n_neighbors.*')\n\n    x = torch.tensor([\n        [-1.0, -1.0],\n        [-1.0, +1.0],\n        [+1.0, +1.0],\n        [+1.0, -1.0],\n        [-1.0, -1.0],\n        [-1.0, +1.0],\n        [+1.0, +1.0],\n        [+1.0, -1.0],\n    ])\n    y = torch.tensor([\n        [+1.0, 0.0],\n        [-1.0, 0.0],\n    ])\n\n    batch_x = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])\n    batch_y = torch.tensor([0, 1])\n\n    edge_index = approx_knn(x, y, 2)\n    assert to_set(edge_index) == {(0, 2), (0, 3), (1, 0), (1, 1)}\n\n    edge_index = approx_knn(x, y, 2, batch_x, batch_y)\n    assert to_set(edge_index) == {(0, 2), (0, 3), (1, 4), (1, 5)}\n\n\n@onlyFullTest  # JIT compile makes this test too slow :(\n@withPackage('pynndescent')\ndef test_approx_knn_graph():\n    warnings.filterwarnings('ignore', '.*find n_neighbors.*')\n\n    x = torch.tensor([\n        [-1.0, -1.0],\n        [-1.0, +1.0],\n        [+1.0, +1.0],\n        [+1.0, -1.0],\n    ])\n\n    edge_index = approx_knn_graph(x, k=2, flow='target_to_source')\n    assert to_set(edge_index) == {(0, 1), (0, 3), (1, 0), (1, 2), (2, 1),\n                                  (2, 3), (3, 0), (3, 2)}\n\n    edge_index = approx_knn_graph(x, k=2, flow='source_to_target')\n    assert to_set(edge_index) == {(1, 0), (3, 0), (0, 1), (2, 1), (1, 2),\n                                  (3, 2), (0, 3), (2, 3)}\n"
  },
  {
    "path": "test/nn/pool/test_asap.py",
    "content": "import io\n\nimport torch\n\nfrom torch_geometric.nn import ASAPooling, GCNConv, GraphConv\nfrom torch_geometric.testing import is_full_test, onlyFullTest, onlyLinux\n\n\n@onlyLinux  # TODO  (matthias) Investigate CSR @ CSR support on Windows.\ndef test_asap():\n    in_channels = 16\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]])\n    num_nodes = edge_index.max().item() + 1\n    x = torch.randn((num_nodes, in_channels))\n\n    for GNN in [GraphConv, GCNConv]:\n        pool = ASAPooling(in_channels, ratio=0.5, GNN=GNN,\n                          add_self_loops=False)\n        assert str(pool) == ('ASAPooling(16, ratio=0.5)')\n        out = pool(x, edge_index)\n        assert out[0].size() == (num_nodes // 2, in_channels)\n        assert out[1].size() == (2, 2)\n\n        if is_full_test():\n            torch.jit.script(pool)\n\n        pool = ASAPooling(in_channels, ratio=0.5, GNN=GNN, add_self_loops=True)\n        assert str(pool) == ('ASAPooling(16, ratio=0.5)')\n        out = pool(x, edge_index)\n        assert out[0].size() == (num_nodes // 2, in_channels)\n        assert out[1].size() == (2, 4)\n\n        pool = ASAPooling(in_channels, ratio=2, GNN=GNN, add_self_loops=False)\n        assert str(pool) == ('ASAPooling(16, ratio=2)')\n        out = pool(x, edge_index)\n        assert out[0].size() == (2, in_channels)\n        assert out[1].size() == (2, 2)\n\n\n@onlyFullTest\ndef test_asap_jit_save():\n    pool = ASAPooling(in_channels=16)\n    torch.jit.save(torch.jit.script(pool), io.BytesIO())\n"
  },
  {
    "path": "test/nn/pool/test_avg_pool.py",
    "content": "import torch\n\nfrom torch_geometric.data import Batch\nfrom torch_geometric.nn import avg_pool, avg_pool_neighbor_x, avg_pool_x\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_avg_pool_x():\n    cluster = torch.tensor([0, 1, 0, 1, 2, 2])\n    x = torch.tensor([\n        [1.0, 2.0],\n        [3.0, 4.0],\n        [5.0, 6.0],\n        [7.0, 8.0],\n        [9.0, 10.0],\n        [11.0, 12.0],\n    ])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1])\n\n    out = avg_pool_x(cluster, x, batch)\n    assert out[0].tolist() == [[3, 4], [5, 6], [10, 11]]\n    assert out[1].tolist() == [0, 0, 1]\n\n    if is_full_test():\n        jit = torch.jit.script(avg_pool_x)\n        out = jit(cluster, x, batch)\n        assert out[0].tolist() == [[3, 4], [5, 6], [10, 11]]\n        assert out[1].tolist() == [0, 0, 1]\n\n    out, _ = avg_pool_x(cluster, x, batch, size=2)\n    assert out.tolist() == [[3, 4], [5, 6], [10, 11], [0, 0]]\n\n    batch_size = int(batch.max().item()) + 1\n    out2, _ = avg_pool_x(cluster, x, batch, batch_size=batch_size, size=2)\n    assert torch.equal(out, out2)\n\n    if is_full_test():\n        jit = torch.jit.script(avg_pool_x)\n        out, _ = jit(cluster, x, batch, size=2)\n        assert out.tolist() == [[3, 4], [5, 6], [10, 11], [0, 0]]\n\n        out2, _ = jit(cluster, x, batch, batch_size=batch_size, size=2)\n        assert torch.equal(out, out2)\n\n\ndef test_avg_pool():\n    cluster = torch.tensor([0, 1, 0, 1, 2, 2])\n    x = torch.tensor([\n        [1.0, 2.0],\n        [3.0, 4.0],\n        [5.0, 6.0],\n        [7.0, 8.0],\n        [9.0, 10.0],\n        [11.0, 12.0],\n    ])\n    pos = torch.tensor([\n        [0.0, 0.0],\n        [1.0, 1.0],\n        [2.0, 2.0],\n        [3.0, 3.0],\n        [4.0, 4.0],\n        [5.0, 5.0],\n    ])\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])\n    edge_attr = torch.ones(edge_index.size(1))\n    batch = torch.tensor([0, 0, 0, 0, 1, 1])\n\n    data = Batch(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr,\n                 batch=batch)\n\n    data = avg_pool(cluster, data, transform=lambda x: x)\n\n    assert data.x.tolist() == [[3, 4], [5, 6], [10, 11]]\n    assert data.pos.tolist() == [[1, 1], [2, 2], [4.5, 4.5]]\n    assert data.edge_index.tolist() == [[0, 1], [1, 0]]\n    assert data.edge_attr.tolist() == [4, 4]\n    assert data.batch.tolist() == [0, 0, 1]\n\n\ndef test_avg_pool_neighbor_x():\n    x = torch.tensor([\n        [1.0, 2.0],\n        [3.0, 4.0],\n        [5.0, 6.0],\n        [7.0, 8.0],\n        [9.0, 10.0],\n        [11.0, 12.0],\n    ])\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1])\n\n    data = Batch(x=x, edge_index=edge_index, batch=batch)\n    data = avg_pool_neighbor_x(data)\n\n    assert data.x.tolist() == [\n        [4, 5],\n        [4, 5],\n        [4, 5],\n        [4, 5],\n        [10, 11],\n        [10, 11],\n    ]\n    assert torch.equal(data.edge_index, edge_index)\n"
  },
  {
    "path": "test/nn/pool/test_cluster_pool.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import ClusterPooling\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('scipy')\n@pytest.mark.parametrize('edge_score_method', [\n    'tanh',\n    'sigmoid',\n    'log_softmax',\n])\ndef test_cluster_pooling(edge_score_method):\n    x = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [-1.0]])\n    edge_index = torch.tensor([\n        [0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6],\n        [1, 2, 3, 6, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4, 0],\n    ])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1, 0])\n\n    op = ClusterPooling(in_channels=1, edge_score_method=edge_score_method)\n    assert str(op) == 'ClusterPooling(1)'\n    op.reset_parameters()\n\n    x, edge_index, batch, unpool_info = op(x, edge_index, batch)\n    assert x.size(0) <= 7\n    assert edge_index.size(0) == 2\n    if edge_index.numel() > 0:\n        assert edge_index.min() >= 0\n        assert edge_index.max() < x.size(0)\n    assert batch.size() == (x.size(0), )\n"
  },
  {
    "path": "test/nn/pool/test_consecutive.py",
    "content": "import torch\n\nfrom torch_geometric.nn.pool.consecutive import consecutive_cluster\n\n\ndef test_consecutive_cluster():\n    src = torch.tensor([8, 2, 10, 15, 100, 1, 100])\n\n    out, perm = consecutive_cluster(src)\n    assert out.tolist() == [2, 1, 3, 4, 5, 0, 5]\n    assert perm.tolist() == [5, 1, 0, 2, 3, 6]\n"
  },
  {
    "path": "test/nn/pool/test_decimation.py",
    "content": "import torch\n\nfrom torch_geometric.nn.pool.decimation import decimation_indices\n\n\ndef test_decimation_basic():\n    N_1, N_2 = 4, 6\n    decimation_factor = 2\n    ptr = torch.tensor([0, N_1, N_1 + N_2])\n\n    idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor)\n\n    expected_size = (N_1 // decimation_factor) + (N_2 // decimation_factor)\n    assert idx_decim.size(0) == expected_size\n\n    expected = torch.tensor([0, N_1 // decimation_factor, expected_size])\n    assert torch.equal(ptr_decim, expected)\n\n\ndef test_decimation_single_cloud():\n    N_1 = 4\n    decimation_factor = 2\n    ptr = torch.tensor([0, N_1])\n\n    idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor)\n\n    expected_size = N_1 // decimation_factor\n    assert idx_decim.size(0) == expected_size\n    assert torch.equal(ptr_decim, torch.tensor([0, expected_size]))\n\n\ndef test_decimation_almost_empty():\n    N_1 = 4\n    decimation_factor = 666  # greater than N_1\n    ptr = torch.tensor([0, N_1])\n\n    idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor)\n\n    assert idx_decim.size(0) == 1\n    assert torch.equal(ptr_decim, torch.tensor([0, 1]))\n"
  },
  {
    "path": "test/nn/pool/test_edge_pool.py",
    "content": "import torch\n\nfrom torch_geometric.nn import EdgePooling\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.utils import scatter\n\n\ndef test_compute_edge_score_softmax():\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])\n    raw = torch.randn(edge_index.size(1))\n    e = EdgePooling.compute_edge_score_softmax(raw, edge_index, 6)\n    assert torch.all(e >= 0) and torch.all(e <= 1)\n\n    # Test whether all incoming edge scores sum up to one.\n    assert torch.allclose(\n        scatter(e, edge_index[1], reduce='sum'),\n        torch.ones(6),\n    )\n\n    if is_full_test():\n        jit = torch.jit.script(EdgePooling.compute_edge_score_softmax)\n        assert torch.allclose(jit(raw, edge_index, 6), e)\n\n\ndef test_compute_edge_score_tanh():\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])\n    raw = torch.randn(edge_index.size(1))\n    e = EdgePooling.compute_edge_score_tanh(raw, edge_index, 6)\n    assert torch.all(e >= -1) and torch.all(e <= 1)\n    assert torch.all(torch.argsort(raw) == torch.argsort(e))\n\n    if is_full_test():\n        jit = torch.jit.script(EdgePooling.compute_edge_score_tanh)\n        assert torch.allclose(jit(raw, edge_index, 6), e)\n\n\ndef test_compute_edge_score_sigmoid():\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])\n    raw = torch.randn(edge_index.size(1))\n    e = EdgePooling.compute_edge_score_sigmoid(raw, edge_index, 6)\n    assert torch.all(e >= 0) and torch.all(e <= 1)\n    assert torch.all(torch.argsort(raw) == torch.argsort(e))\n\n    if is_full_test():\n        jit = torch.jit.script(EdgePooling.compute_edge_score_sigmoid)\n        assert torch.allclose(jit(raw, edge_index, 6), e)\n\n\ndef test_edge_pooling():\n    x = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [-1.0]])\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4, 0]])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1, 0])\n\n    op = EdgePooling(in_channels=1)\n    assert str(op) == 'EdgePooling(1)'\n\n    # Setting parameters fixed so we can test the expected outcome:\n    op.lin.weight.data.fill_(1.)\n    op.lin.bias.data.fill_(0.)\n\n    # Test pooling:\n    new_x, new_edge_index, new_batch, unpool_info = op(x, edge_index, batch)\n\n    assert new_x.size(0) == new_batch.size(0) == 4\n    assert new_edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [0, 1, 2, 1, 2, 2]]\n    assert new_batch.tolist() == [1, 0, 0, 0]\n\n    if is_full_test():\n        jit = torch.jit.script(op)\n        out = jit(x, edge_index, batch)\n        assert torch.allclose(new_x, out[0])\n        assert torch.equal(new_edge_index, out[1])\n        assert torch.equal(new_batch, out[2])\n\n    # Test unpooling:\n    out = op.unpool(new_x, unpool_info)\n    assert out[0].size() == x.size()\n    assert out[0].tolist() == [[1], [1], [5], [5], [9], [9], [-1]]\n    assert torch.equal(out[1], edge_index)\n    assert torch.equal(out[2], batch)\n\n    if is_full_test():\n        jit = torch.jit.export(op)\n        out = jit.unpool(new_x, unpool_info)\n        assert out[0].size() == x.size()\n        assert out[0].tolist() == [[1], [1], [5], [5], [9], [9], [-1]]\n        assert torch.equal(out[1], edge_index)\n        assert torch.equal(out[2], batch)\n\n    # Test edge cases.\n    x = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]])\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1])\n    new_x, new_edge_index, new_batch, _ = op(x, edge_index, batch)\n\n    assert new_x.size(0) == new_batch.size(0) == 3\n    assert new_batch.tolist() == [1, 0, 0]\n    assert new_edge_index.tolist() == [[0, 1, 1, 2, 2], [0, 1, 2, 1, 2]]\n"
  },
  {
    "path": "test/nn/pool/test_glob.py",
    "content": "import torch\n\nfrom torch_geometric.nn import (\n    global_add_pool,\n    global_max_pool,\n    global_mean_pool,\n)\n\n\ndef test_global_pool():\n    N_1, N_2 = 4, 6\n    x = torch.randn(N_1 + N_2, 4)\n    batch = torch.tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)])\n\n    out = global_add_pool(x, batch)\n    assert out.size() == (2, 4)\n    assert torch.allclose(out[0], x[:4].sum(dim=0))\n    assert torch.allclose(out[1], x[4:].sum(dim=0))\n\n    out = global_add_pool(x, None)\n    assert out.size() == (1, 4)\n    assert torch.allclose(out, x.sum(dim=0, keepdim=True))\n\n    out = global_mean_pool(x, batch)\n    assert out.size() == (2, 4)\n    assert torch.allclose(out[0], x[:4].mean(dim=0))\n    assert torch.allclose(out[1], x[4:].mean(dim=0))\n\n    out = global_mean_pool(x, None)\n    assert out.size() == (1, 4)\n    assert torch.allclose(out, x.mean(dim=0, keepdim=True))\n\n    out = global_max_pool(x, batch)\n    assert out.size() == (2, 4)\n    assert torch.allclose(out[0], x[:4].max(dim=0)[0])\n    assert torch.allclose(out[1], x[4:].max(dim=0)[0])\n\n    out = global_max_pool(x, None)\n    assert out.size() == (1, 4)\n    assert torch.allclose(out, x.max(dim=0, keepdim=True)[0])\n\n\ndef test_permuted_global_pool():\n    N_1, N_2 = 4, 6\n    x = torch.randn(N_1 + N_2, 4)\n    batch = torch.cat([torch.zeros(N_1), torch.ones(N_2)]).to(torch.long)\n    perm = torch.randperm(N_1 + N_2)\n\n    px = x[perm]\n    pbatch = batch[perm]\n    px1 = px[pbatch == 0]\n    px2 = px[pbatch == 1]\n\n    out = global_add_pool(px, pbatch)\n    assert out.size() == (2, 4)\n    assert torch.allclose(out[0], px1.sum(dim=0))\n    assert torch.allclose(out[1], px2.sum(dim=0))\n\n    out = global_mean_pool(px, pbatch)\n    assert out.size() == (2, 4)\n    assert torch.allclose(out[0], px1.mean(dim=0))\n    assert torch.allclose(out[1], px2.mean(dim=0))\n\n    out = global_max_pool(px, pbatch)\n    assert out.size() == (2, 4)\n    assert torch.allclose(out[0], px1.max(dim=0)[0])\n    assert torch.allclose(out[1], px2.max(dim=0)[0])\n\n\ndef test_dense_global_pool():\n    x = torch.randn(3, 16, 32)\n    assert torch.allclose(global_add_pool(x, None), x.sum(dim=1))\n"
  },
  {
    "path": "test/nn/pool/test_graclus.py",
    "content": "import torch\n\nfrom torch_geometric.nn import graclus\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('torch_cluster')\ndef test_graclus():\n    edge_index = torch.tensor([[0, 1], [1, 0]])\n    assert graclus(edge_index).tolist() == [0, 0]\n"
  },
  {
    "path": "test/nn/pool/test_knn.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import (\n    ApproxL2KNNIndex,\n    ApproxMIPSKNNIndex,\n    L2KNNIndex,\n    MIPSKNNIndex,\n)\nfrom torch_geometric.testing import withCUDA, withPackage\n\n\n@withCUDA\n@withPackage('faiss')\n@pytest.mark.parametrize('k', [2])\ndef test_l2(device, k):\n    lhs = torch.randn(10, 16, device=device)\n    rhs = torch.randn(100, 16, device=device)\n\n    index = L2KNNIndex(rhs)\n    assert index.get_emb().device == device\n    assert torch.equal(index.get_emb(), rhs)\n\n    out = index.search(lhs, k)\n    assert out.score.device == device\n    assert out.index.device == device\n    assert out.score.size() == (10, k)\n    assert out.index.size() == (10, k)\n\n    mat = torch.linalg.norm(lhs.unsqueeze(1) - rhs.unsqueeze(0), dim=-1).pow(2)\n    score, index = mat.sort(dim=-1)\n\n    assert torch.allclose(out.score, score[:, :k])\n    assert torch.equal(out.index, index[:, :k])\n\n\n@withCUDA\n@withPackage('faiss')\n@pytest.mark.parametrize('k', [2])\ndef test_mips(device, k):\n    lhs = torch.randn(10, 16, device=device)\n    rhs = torch.randn(100, 16, device=device)\n\n    index = MIPSKNNIndex(rhs)\n    assert index.get_emb().device == device\n    assert torch.equal(index.get_emb(), rhs)\n\n    out = index.search(lhs, k)\n    assert out.score.device == device\n    assert out.index.device == device\n    assert out.score.size() == (10, k)\n    assert out.index.size() == (10, k)\n\n    mat = lhs @ rhs.t()\n    score, index = mat.sort(dim=-1, descending=True)\n\n    assert torch.allclose(out.score, score[:, :k])\n    assert torch.equal(out.index, index[:, :k])\n\n\n@withCUDA\n@withPackage('faiss')\n@pytest.mark.parametrize('k', [2])\n@pytest.mark.parametrize('reserve', [None, 100])\ndef test_approx_l2(device, k, reserve):\n    lhs = torch.randn(10, 16, device=device)\n    rhs = torch.randn(10_000, 16, device=device)\n\n    index = ApproxL2KNNIndex(\n        num_cells=10,\n        num_cells_to_visit=10,\n        bits_per_vector=8,\n        emb=rhs,\n        reserve=reserve,\n    )\n\n    out = index.search(lhs, k)\n    assert out.score.device == device\n    assert out.index.device == device\n    assert out.score.size() == (10, k)\n    assert out.index.size() == (10, k)\n    assert out.index.min() >= 0 and out.index.max() < 10_000\n\n\n@withCUDA\n@withPackage('faiss')\n@pytest.mark.parametrize('k', [2])\n@pytest.mark.parametrize('reserve', [None, 100])\ndef test_approx_mips(device, k, reserve):\n    lhs = torch.randn(10, 16, device=device)\n    rhs = torch.randn(10_000, 16, device=device)\n\n    index = ApproxMIPSKNNIndex(\n        num_cells=10,\n        num_cells_to_visit=10,\n        bits_per_vector=8,\n        emb=rhs,\n        reserve=reserve,\n    )\n\n    out = index.search(lhs, k)\n    assert out.score.device == device\n    assert out.index.device == device\n    assert out.score.size() == (10, k)\n    assert out.index.size() == (10, k)\n    assert out.index.min() >= 0 and out.index.max() < 10_000\n\n\n@withCUDA\n@withPackage('faiss')\n@pytest.mark.parametrize('k', [50])\ndef test_mips_exclude(device, k):\n    lhs = torch.randn(10, 16, device=device)\n    rhs = torch.randn(100, 16, device=device)\n\n    exclude_lhs = torch.randint(0, 10, (500, ), device=device)\n    exclude_rhs = torch.randint(0, 100, (500, ), device=device)\n    exclude_links = torch.stack([exclude_lhs, exclude_rhs], dim=0)\n    exclude_links = exclude_links.unique(dim=1)\n\n    index = MIPSKNNIndex(rhs)\n\n    out = index.search(lhs, k, exclude_links)\n    assert out.score.device == device\n    assert out.index.device == device\n    assert out.score.size() == (10, k)\n    assert out.index.size() == (10, k)\n\n    # Ensure that excluded links are not present in `out.index`:\n    batch = torch.arange(lhs.size(0), device=device).repeat_interleave(k)\n    knn_links = torch.stack([batch, out.index.view(-1)], dim=0)\n    knn_links = knn_links[:, knn_links[1] >= 0]\n\n    unique_links = torch.cat([knn_links, exclude_links], dim=1).unique(dim=1)\n    assert unique_links.size(1) == knn_links.size(1) + exclude_links.size(1)\n"
  },
  {
    "path": "test/nn/pool/test_max_pool.py",
    "content": "import torch\n\nfrom torch_geometric.data import Batch\nfrom torch_geometric.nn import max_pool, max_pool_neighbor_x, max_pool_x\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_max_pool_x():\n    cluster = torch.tensor([0, 1, 0, 1, 2, 2])\n    x = torch.tensor([\n        [1.0, 2.0],\n        [3.0, 4.0],\n        [5.0, 6.0],\n        [7.0, 8.0],\n        [9.0, 10.0],\n        [11.0, 12.0],\n    ])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1])\n\n    out = max_pool_x(cluster, x, batch)\n    assert out[0].tolist() == [[5, 6], [7, 8], [11, 12]]\n    assert out[1].tolist() == [0, 0, 1]\n\n    if is_full_test():\n        jit = torch.jit.script(max_pool_x)\n        out = jit(cluster, x, batch)\n        assert out[0].tolist() == [[5, 6], [7, 8], [11, 12]]\n        assert out[1].tolist() == [0, 0, 1]\n\n    out, _ = max_pool_x(cluster, x, batch, size=2)\n    assert out.tolist() == [[5, 6], [7, 8], [11, 12], [0, 0]]\n\n    batch_size = int(batch.max().item()) + 1\n    out2, _ = max_pool_x(cluster, x, batch, batch_size=batch_size, size=2)\n    assert torch.equal(out, out2)\n\n    if is_full_test():\n        jit = torch.jit.script(max_pool_x)\n        out, _ = jit(cluster, x, batch, size=2)\n        assert out.tolist() == [[5, 6], [7, 8], [11, 12], [0, 0]]\n\n        out2, _ = jit(cluster, x, batch, batch_size=batch_size, size=2)\n        assert torch.equal(out, out2)\n\n\ndef test_max_pool():\n    cluster = torch.tensor([0, 1, 0, 1, 2, 2])\n    x = torch.tensor([\n        [1.0, 2.0],\n        [3.0, 4.0],\n        [5.0, 6.0],\n        [7.0, 8.0],\n        [9.0, 10.0],\n        [11.0, 12.0],\n    ])\n    pos = torch.tensor([\n        [0.0, 0.0],\n        [1.0, 1.0],\n        [2.0, 2.0],\n        [3.0, 3.0],\n        [4.0, 4.0],\n        [5.0, 5.0],\n    ])\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])\n    edge_attr = torch.ones(edge_index.size(1))\n    batch = torch.tensor([0, 0, 0, 0, 1, 1])\n\n    data = Batch(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr,\n                 batch=batch)\n\n    data = max_pool(cluster, data, transform=lambda x: x)\n\n    assert data.x.tolist() == [[5, 6], [7, 8], [11, 12]]\n    assert data.pos.tolist() == [[1, 1], [2, 2], [4.5, 4.5]]\n    assert data.edge_index.tolist() == [[0, 1], [1, 0]]\n    assert data.edge_attr.tolist() == [4, 4]\n    assert data.batch.tolist() == [0, 0, 1]\n\n\ndef test_max_pool_neighbor_x():\n    x = torch.tensor([\n        [1.0, 2.0],\n        [3.0, 4.0],\n        [5.0, 6.0],\n        [7.0, 8.0],\n        [9.0, 10.0],\n        [11.0, 12.0],\n    ])\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1])\n\n    data = Batch(x=x, edge_index=edge_index, batch=batch)\n    data = max_pool_neighbor_x(data)\n\n    assert data.x.tolist() == [\n        [7, 8],\n        [7, 8],\n        [7, 8],\n        [7, 8],\n        [11, 12],\n        [11, 12],\n    ]\n    assert torch.equal(data.edge_index, edge_index)\n"
  },
  {
    "path": "test/nn/pool/test_mem_pool.py",
    "content": "import torch\n\nfrom torch_geometric.nn import MemPooling\nfrom torch_geometric.utils import to_dense_batch\n\n\ndef test_mem_pool():\n    mpool1 = MemPooling(4, 8, heads=3, num_clusters=2)\n    assert str(mpool1) == 'MemPooling(4, 8, heads=3, num_clusters=2)'\n    mpool2 = MemPooling(8, 4, heads=2, num_clusters=1)\n\n    x = torch.randn(17, 4)\n    batch = torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4])\n    _, mask = to_dense_batch(x, batch)\n\n    out1, S = mpool1(x, batch)\n    loss = MemPooling.kl_loss(S)\n    with torch.autograd.set_detect_anomaly(True):\n        loss.backward()\n    out2, _ = mpool2(out1)\n\n    assert out1.size() == (5, 2, 8)\n    assert out2.size() == (5, 1, 4)\n    assert S[~mask].sum() == 0\n    assert round(S[mask].sum().item()) == x.size(0)\n    assert float(loss) > 0\n"
  },
  {
    "path": "test/nn/pool/test_pan_pool.py",
    "content": "import torch\n\nfrom torch_geometric.nn import PANConv, PANPooling\nfrom torch_geometric.testing import is_full_test, withPackage\n\n\n@withPackage('torch_sparse')\ndef test_pan_pooling():\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]])\n    num_nodes = edge_index.max().item() + 1\n    x = torch.randn((num_nodes, 16))\n\n    conv = PANConv(16, 32, filter_size=2)\n    pool = PANPooling(32, ratio=0.5)\n    assert str(pool) == 'PANPooling(32, ratio=0.5, multiplier=1.0)'\n\n    x, M = conv(x, edge_index)\n    h, edge_index, edge_weight, batch, perm, score = pool(x, M)\n\n    assert h.size() == (2, 32)\n    assert edge_index.size() == (2, 4)\n    assert edge_weight.size() == (4, )\n    assert perm.size() == (2, )\n    assert score.size() == (2, )\n\n    if is_full_test():\n        jit = torch.jit.script(pool)\n        out = jit(x, M)\n        assert torch.allclose(h, out[0])\n        assert torch.equal(edge_index, out[1])\n        assert torch.allclose(edge_weight, out[2])\n        assert torch.equal(batch, out[3])\n        assert torch.equal(perm, out[4])\n        assert torch.allclose(score, out[5])\n"
  },
  {
    "path": "test/nn/pool/test_pool.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn import radius_graph\nfrom torch_geometric.testing import onlyFullTest, withPackage\n\n\n@onlyFullTest\n@withPackage('torch_cluster')\ndef test_radius_graph_jit():\n    class Net(torch.nn.Module):\n        def forward(self, x: Tensor, batch: Optional[Tensor] = None) -> Tensor:\n            return radius_graph(x, r=2.5, batch=batch, loop=False)\n\n    x = torch.tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.float)\n    batch = torch.tensor([0, 0, 0, 0])\n\n    model = Net()\n    jit = torch.jit.script(model)\n    assert model(x, batch).size() == jit(x, batch).size()\n"
  },
  {
    "path": "test/nn/pool/test_sag_pool.py",
    "content": "import torch\n\nfrom torch_geometric.nn import (\n    GATConv,\n    GCNConv,\n    GraphConv,\n    SAGEConv,\n    SAGPooling,\n)\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_sag_pooling():\n    in_channels = 16\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]])\n    num_nodes = edge_index.max().item() + 1\n    x = torch.randn((num_nodes, in_channels))\n\n    for GNN in [GraphConv, GCNConv, GATConv, SAGEConv]:\n        pool1 = SAGPooling(in_channels, ratio=0.5, GNN=GNN)\n        assert str(pool1) == (f'SAGPooling({GNN.__name__}, 16, '\n                              f'ratio=0.5, multiplier=1.0)')\n        out1 = pool1(x, edge_index)\n        assert out1[0].size() == (num_nodes // 2, in_channels)\n        assert out1[1].size() == (2, 2)\n\n        pool2 = SAGPooling(in_channels, ratio=None, GNN=GNN, min_score=0.1)\n        assert str(pool2) == (f'SAGPooling({GNN.__name__}, 16, '\n                              f'min_score=0.1, multiplier=1.0)')\n        out2 = pool2(x, edge_index)\n        assert out2[0].size(0) <= x.size(0) and out2[0].size(1) == (16)\n        assert out2[1].size(0) == 2 and out2[1].size(1) <= edge_index.size(1)\n\n        pool3 = SAGPooling(in_channels, ratio=2, GNN=GNN)\n        assert str(pool3) == (f'SAGPooling({GNN.__name__}, 16, '\n                              f'ratio=2, multiplier=1.0)')\n        out3 = pool3(x, edge_index)\n        assert out3[0].size() == (2, in_channels)\n        assert out3[1].size() == (2, 2)\n\n        if is_full_test():\n            jit1 = torch.jit.script(pool1)\n            assert torch.allclose(jit1(x, edge_index)[0], out1[0])\n\n            jit2 = torch.jit.script(pool2)\n            assert torch.allclose(jit2(x, edge_index)[0], out2[0])\n\n            jit3 = torch.jit.script(pool3)\n            assert torch.allclose(jit3(x, edge_index)[0], out3[0])\n"
  },
  {
    "path": "test/nn/pool/test_topk_pool.py",
    "content": "import torch\n\nfrom torch_geometric.nn.pool import TopKPooling\nfrom torch_geometric.nn.pool.connect.filter_edges import filter_adj\nfrom torch_geometric.testing import is_full_test\n\n\ndef test_filter_adj():\n    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3],\n                               [1, 3, 0, 2, 1, 3, 0, 2]])\n    edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])\n    perm = torch.tensor([2, 3])\n\n    out = filter_adj(edge_index, edge_attr, perm)\n    assert out[0].tolist() == [[0, 1], [1, 0]]\n    assert out[1].tolist() == [6.0, 8.0]\n\n    if is_full_test():\n        jit = torch.jit.script(filter_adj)\n\n        out = jit(edge_index, edge_attr, perm)\n        assert out[0].tolist() == [[0, 1], [1, 0]]\n        assert out[1].tolist() == [6.0, 8.0]\n\n\ndef test_topk_pooling():\n    in_channels = 16\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]])\n    num_nodes = edge_index.max().item() + 1\n    x = torch.randn((num_nodes, in_channels))\n\n    pool1 = TopKPooling(in_channels, ratio=0.5)\n    assert str(pool1) == 'TopKPooling(16, ratio=0.5, multiplier=1.0)'\n    out1 = pool1(x, edge_index)\n    assert out1[0].size() == (num_nodes // 2, in_channels)\n    assert out1[1].size() == (2, 2)\n\n    pool2 = TopKPooling(in_channels, ratio=None, min_score=0.1)\n    assert str(pool2) == 'TopKPooling(16, min_score=0.1, multiplier=1.0)'\n    out2 = pool2(x, edge_index)\n    assert out2[0].size(0) <= x.size(0) and out2[0].size(1) == (16)\n    assert out2[1].size(0) == 2 and out2[1].size(1) <= edge_index.size(1)\n\n    pool3 = TopKPooling(in_channels, ratio=2)\n    assert str(pool3) == 'TopKPooling(16, ratio=2, multiplier=1.0)'\n    out3 = pool3(x, edge_index)\n    assert out3[0].size() == (2, in_channels)\n    assert out3[1].size() == (2, 2)\n\n    if is_full_test():\n        jit1 = torch.jit.script(pool1)\n        assert torch.allclose(jit1(x, edge_index)[0], out1[0])\n\n        jit2 = torch.jit.script(pool2)\n        assert torch.allclose(jit2(x, edge_index)[0], out2[0])\n\n        jit3 = torch.jit.script(pool3)\n        assert torch.allclose(jit3(x, edge_index)[0], out3[0])\n"
  },
  {
    "path": "test/nn/pool/test_voxel_grid.py",
    "content": "import torch\n\nfrom torch_geometric.data import Batch\nfrom torch_geometric.nn import avg_pool, voxel_grid\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('torch_cluster')\ndef test_voxel_grid():\n    pos = torch.tensor([\n        [0.0, 0.0],\n        [11.0, 9.0],\n        [2.0, 8.0],\n        [2.0, 2.0],\n        [8.0, 3.0],\n    ])\n    batch = torch.tensor([0, 0, 0, 1, 1])\n\n    assert voxel_grid(pos, size=5, batch=batch).tolist() == [0, 5, 3, 6, 7]\n    assert voxel_grid(pos, size=5).tolist() == [0, 5, 3, 0, 1]\n\n    cluster = voxel_grid(pos, size=5, batch=batch, start=-1, end=[18, 14])\n    assert cluster.tolist() == [0, 10, 4, 16, 17]\n\n    cluster_no_batch = voxel_grid(pos, size=5, start=-1, end=[18, 14])\n    assert cluster_no_batch.tolist() == [0, 10, 4, 0, 1]\n\n\n@withPackage('torch_cluster')\ndef test_single_voxel_grid():\n    pos = torch.tensor([\n        [0.0, 0.0],\n        [1.0, 1.0],\n        [2.0, 2.0],\n        [3.0, 3.0],\n        [4.0, 4.0],\n    ])\n    edge_index = torch.tensor([[0, 0, 3], [1, 2, 4]])\n    batch = torch.tensor([0, 0, 0, 1, 1])\n    x = torch.randn(5, 16)\n\n    cluster = voxel_grid(pos, size=5, batch=batch)\n    assert cluster.tolist() == [0, 0, 0, 1, 1]\n\n    data = Batch(x=x, edge_index=edge_index, pos=pos, batch=batch)\n    data = avg_pool(cluster, data)\n\n    cluster_no_batch = voxel_grid(pos, size=5)\n    assert cluster_no_batch.tolist() == [0, 0, 0, 0, 0]\n\n    data_no_batch = Batch(x=x, edge_index=edge_index, pos=pos)\n    data_no_batch = avg_pool(cluster_no_batch, data_no_batch)\n"
  },
  {
    "path": "test/nn/test_compile_basic.py",
    "content": "import torch\n\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import (\n    onlyFullTest,\n    onlyLinux,\n    withDevice,\n    withPackage,\n)\nfrom torch_geometric.utils import scatter\n\n\n# Basic \"Gather-Apply-Scatter\" patterns commonly used in PyG:\ndef gather_scatter(x, edge_index, reduce='sum'):\n    row, col = edge_index\n    x_j = x[row]\n    return scatter(x_j, col, dim_size=x.size(0), reduce=reduce)\n\n\ndef gather_cat_scatter(x, edge_index, reduce='sum'):\n    row, col = edge_index\n    x_ij = torch.cat([x[col], x[row]], dim=-1)\n    return scatter(x_ij, col, dim_size=x.size(0), reduce=reduce)\n\n\ndef gather_weight_scatter(x, edge_index, edge_weight, reduce='sum'):\n    row, col = edge_index\n    x_j = x[row] * edge_weight.view(-1, 1)\n    return scatter(x_j, col, dim_size=x.size(0), reduce=reduce)\n\n\ndef gather_transform_scatter(x, edge_index, matrix, reduce='sum'):\n    row, col = edge_index\n    x_j = x[row] @ matrix\n    return scatter(x_j, col, dim_size=x.size(0), reduce=reduce)\n\n\ndef fused_gather_scatter(x, edge_index, reduce=('sum', 'mean', 'max')):\n    row, col = edge_index\n    x_j = x[row]\n    outs = [scatter(x_j, col, dim_size=x.size(0), reduce=r) for r in reduce]\n    return torch.cat(outs, dim=-1)\n\n\n@withDevice\n@onlyLinux\n@onlyFullTest\n@withPackage('torch>=2.0.0')\ndef test_torch_compile(device):\n    x = torch.randn(10, 16, device=device)\n    edge_index = torch.randint(0, x.size(0), (2, 40), device=device)\n    edge_weight = torch.rand(edge_index.size(1), device=device)\n    matrix = torch.randn(x.size(-1), x.size(-1), device=device)\n\n    expected = gather_scatter(x, edge_index)\n    compiled_op = torch.compile(gather_scatter)\n    out = compiled_op(x, edge_index)\n    assert torch.allclose(out, expected, atol=1e-6)\n\n    expected = gather_cat_scatter(x, edge_index)\n    compiled_op = torch.compile(gather_cat_scatter)\n    out = compiled_op(x, edge_index)\n    assert torch.allclose(out, expected, atol=1e-6)\n\n    expected = gather_weight_scatter(x, edge_index, edge_weight)\n    compiled_op = torch.compile(gather_weight_scatter)\n    out = compiled_op(x, edge_index, edge_weight)\n    assert torch.allclose(out, expected, atol=1e-6)\n\n    expected = gather_transform_scatter(x, edge_index, matrix)\n    compiled_op = torch.compile(gather_transform_scatter)\n    out = compiled_op(x, edge_index, matrix)\n    assert torch.allclose(out, expected, atol=1e-6)\n\n    expected = fused_gather_scatter(x, edge_index)\n    compiled_op = torch.compile(fused_gather_scatter)\n    out = compiled_op(x, edge_index)\n    assert torch.allclose(out, expected, atol=1e-6)\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    args = parser.parse_args()\n\n    num_nodes, num_edges = 10_000, 200_000\n    x = torch.randn(num_nodes, 64, device=args.device)\n    edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)\n    edge_weight = torch.rand(num_edges, device=args.device)\n    matrix = torch.randn(64, 64, device=args.device)\n\n    for reduce in ['sum', 'mean', 'max']:\n        print(f'Aggregator: {reduce}')\n\n        benchmark(\n            funcs=[\n                gather_scatter,\n                torch.compile(gather_scatter),\n            ],\n            func_names=['Vanilla', 'Compiled'],\n            args=(x, edge_index, reduce),\n            num_steps=50 if args.device == 'cpu' else 500,\n            num_warmups=10 if args.device == 'cpu' else 100,\n            backward=args.backward,\n        )\n\n        benchmark(\n            funcs=[\n                gather_cat_scatter,\n                torch.compile(gather_cat_scatter),\n            ],\n            func_names=['Vanilla Cat', 'Compiled Cat'],\n            args=(x, edge_index, reduce),\n            num_steps=50 if args.device == 'cpu' else 500,\n            num_warmups=10 if args.device == 'cpu' else 100,\n            backward=args.backward,\n        )\n\n        benchmark(\n            funcs=[\n                gather_weight_scatter,\n                torch.compile(gather_weight_scatter),\n            ],\n            func_names=['Vanilla Weight', 'Compiled Weight'],\n            args=(x, edge_index, edge_weight, reduce),\n            num_steps=50 if args.device == 'cpu' else 500,\n            num_warmups=10 if args.device == 'cpu' else 100,\n            backward=args.backward,\n        )\n\n        benchmark(\n            funcs=[\n                gather_transform_scatter,\n                torch.compile(gather_transform_scatter),\n            ],\n            func_names=['Vanilla Transform', 'Compiled Transform'],\n            args=(x, edge_index, matrix, reduce),\n            num_steps=50 if args.device == 'cpu' else 500,\n            num_warmups=10 if args.device == 'cpu' else 100,\n            backward=args.backward,\n        )\n\n    benchmark(\n        funcs=[\n            fused_gather_scatter,\n            torch.compile(fused_gather_scatter),\n        ],\n        func_names=['Vanilla Fused', 'Compiled Fused'],\n        args=(x, edge_index),\n        num_steps=50 if args.device == 'cpu' else 500,\n        num_warmups=10 if args.device == 'cpu' else 100,\n        backward=args.backward,\n    )\n"
  },
  {
    "path": "test/nn/test_compile_conv.py",
    "content": "import pytest\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.nn import GCNConv, SAGEConv\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import (\n    onlyFullTest,\n    onlyLinux,\n    withDevice,\n    withPackage,\n)\nfrom torch_geometric.utils import scatter\n\n\nclass MySAGEConv(torch.nn.Module):\n    def __init__(self, in_channels: int, out_channels: int):\n        super().__init__()\n        self.lin_src = torch.nn.Linear(in_channels, out_channels)\n        self.lin_dst = torch.nn.Linear(in_channels, out_channels)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x_j = x[edge_index[0]]\n        out = scatter(x_j, edge_index[1], dim_size=x.size(0), reduce='mean')\n        return self.lin_src(out) + self.lin_dst(x)\n\n\n@withDevice\n@onlyLinux\n@onlyFullTest\n@withPackage('torch>=2.1.0')\n@pytest.mark.parametrize('Conv', [GCNConv, SAGEConv])\ndef test_compile_conv(device, Conv):\n    import torch._dynamo as dynamo\n\n    x = torch.randn(10, 16, device=device)\n    edge_index = torch.randint(0, x.size(0), (2, 40), device=device)\n\n    if Conv == GCNConv:\n        conv = Conv(16, 32, add_self_loops=False).to(device)\n    else:\n        conv = Conv(16, 32).to(device)\n\n    explanation = dynamo.explain(conv)(x, edge_index)\n    assert explanation.graph_break_count == 0\n\n    out = torch.compile(conv)(x, edge_index)\n    assert torch.allclose(conv(x, edge_index), out, atol=1e-6)\n\n\n@withDevice\n@onlyLinux\n@onlyFullTest\n@withPackage('torch==2.3')\n@pytest.mark.parametrize('Conv', [GCNConv, SAGEConv])\ndef test_compile_conv_edge_index(device, Conv):\n    import torch._dynamo as dynamo\n\n    x = torch.randn(10, 16, device=device)\n    edge_index = torch.randint(0, x.size(0), (2, 40), device=device)\n    edge_index = EdgeIndex(edge_index, sparse_size=(10, 10))\n    edge_index = edge_index.sort_by('col')[0]\n    edge_index.fill_cache_()\n\n    if Conv == GCNConv:\n        conv = Conv(16, 32, normalize=False).to(device)\n    else:\n        conv = Conv(16, 32).to(device)\n\n    explanation = dynamo.explain(conv)(x, edge_index)\n    assert explanation.graph_break_count == 0\n\n    out = torch.compile(conv, fullgraph=True)(x, edge_index)\n    assert torch.allclose(conv(x, edge_index), out, atol=1e-6)\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    args = parser.parse_args()\n\n    num_nodes, num_edges = 10_000, 200_000\n    x = torch.randn(num_nodes, 64, device=args.device)\n    edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)\n\n    conv = MySAGEConv(64, 64).to(args.device)\n    benchmark(\n        funcs=[conv, torch.compile(conv)],\n        func_names=['Vanilla', 'Compiled'],\n        args=(x, edge_index),\n        num_steps=50 if args.device == 'cpu' else 500,\n        num_warmups=10 if args.device == 'cpu' else 100,\n        backward=args.backward,\n    )\n\n    for Conv in [GCNConv, SAGEConv]:\n        print(f'Conv: {Conv.__name__}')\n\n        conv = Conv(64, 64).to(args.device)\n        compiled_conv = torch.compile(conv)\n\n        benchmark(\n            funcs=[conv, compiled_conv],\n            func_names=['Vanilla', 'Compiled'],\n            args=(x, edge_index),\n            num_steps=50 if args.device == 'cpu' else 500,\n            num_warmups=10 if args.device == 'cpu' else 100,\n            backward=args.backward,\n        )\n"
  },
  {
    "path": "test/nn/test_compile_dynamic.py",
    "content": "import random\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.testing import (\n    get_random_edge_index,\n    onlyFullTest,\n    onlyLinux,\n    withDevice,\n    withPackage,\n)\nfrom torch_geometric.utils import scatter\n\n\nclass MySAGEConv(torch.nn.Module):\n    def __init__(self, in_channels: int, out_channels: int):\n        super().__init__()\n        self.lin_src = torch.nn.Linear(in_channels, out_channels)\n        self.lin_dst = torch.nn.Linear(in_channels, out_channels)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x_j = x[edge_index[0]]\n        out = scatter(x_j, edge_index[1], dim_size=x.size(0), reduce='mean')\n        return self.lin_src(out) + self.lin_dst(x)\n\n\n@withDevice\n@onlyLinux\n@onlyFullTest\n@withPackage('torch>2.0.0')\ndef test_dynamic_torch_compile(device):\n    conv = MySAGEConv(64, 64).to(device)\n    conv = torch.compile(conv, dynamic=True)\n\n    optimizer = torch.optim.Adam(conv.parameters(), lr=0.01)\n\n    for _ in range(10):\n        N = random.randrange(100, 500)\n        E = random.randrange(200, 1000)\n\n        x = torch.randn(N, 64, device=device)\n        edge_index = get_random_edge_index(N, N, E, device=device)\n\n        optimizer.zero_grad()\n        expected = conv(x, edge_index)\n        expected.mean().backward()\n        optimizer.step()\n"
  },
  {
    "path": "test/nn/test_data_parallel.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.nn import DataParallel\nfrom torch_geometric.testing import onlyCUDA\n\n\n@onlyCUDA\ndef test_data_parallel_single_gpu():\n    with pytest.warns(UserWarning, match=\"much slower\"):\n        module = DataParallel(torch.nn.Identity())\n    data_list = [Data(x=torch.randn(x, 1)) for x in [2, 3, 10, 4]]\n    batches = module.scatter(data_list, device_ids=[0])\n    assert len(batches) == 1\n\n\n@onlyCUDA\n@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUs')\ndef test_data_parallel_multi_gpu():\n    with pytest.warns(UserWarning, match=\"much slower\"):\n        module = DataParallel(torch.nn.Identity())\n    data_list = [Data(x=torch.randn(x, 1)) for x in [2, 3, 10, 4]]\n    batches = module.scatter(data_list, device_ids=[0, 1, 0, 1])\n    assert len(batches) == 3\n"
  },
  {
    "path": "test/nn/test_encoding.py",
    "content": "import torch\n\nfrom torch_geometric.nn import PositionalEncoding, TemporalEncoding\nfrom torch_geometric.testing import withDevice\n\n\n@withDevice\ndef test_positional_encoding(device):\n    encoder = PositionalEncoding(64, device=device)\n    assert str(encoder) == 'PositionalEncoding(64)'\n\n    x = torch.tensor([1.0, 2.0, 3.0], device=device)\n    assert encoder(x).size() == (3, 64)\n\n\n@withDevice\ndef test_temporal_encoding(device):\n    encoder = TemporalEncoding(64, device=device)\n    assert str(encoder) == 'TemporalEncoding(64)'\n\n    x = torch.tensor([1.0, 2.0, 3.0], device=device)\n    assert encoder(x).size() == (3, 64)\n"
  },
  {
    "path": "test/nn/test_fvcore.py",
    "content": "import torch\n\nfrom torch_geometric.nn import GraphSAGE\nfrom torch_geometric.testing import get_random_edge_index, withPackage\n\n\n@withPackage('fvcore')\ndef test_fvcore():\n    from fvcore.nn import FlopCountAnalysis\n\n    x = torch.randn(10, 16)\n    edge_index = get_random_edge_index(10, 10, num_edges=100)\n\n    model = GraphSAGE(16, 32, num_layers=2)\n\n    flops = FlopCountAnalysis(model, (x, edge_index))\n\n    # TODO (matthias) Currently, aggregations are not properly registered.\n    assert flops.by_module()['convs.0'] == 2 * 10 * 16 * 32\n    assert flops.by_module()['convs.1'] == 2 * 10 * 32 * 32\n    assert flops.total() == (flops.by_module()['convs.0'] +\n                             flops.by_module()['convs.1'])\n    assert flops.by_operator()['linear'] == flops.total()\n"
  },
  {
    "path": "test/nn/test_fx.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\n\ndef test_dropout():\n    class MyModule(torch.nn.Module):\n        def forward(self, x: Tensor) -> Tensor:\n            return F.dropout(x, p=1.0, training=self.training)\n\n    module = MyModule()\n    graph_module = torch.fx.symbolic_trace(module)\n    graph_module.recompile()\n\n    x = torch.randn(4)\n\n    graph_module.train()\n    assert torch.allclose(graph_module(x), torch.zeros_like(x))\n\n    # This is certainly undesired behavior due to tracing :(\n    graph_module.eval()\n    assert torch.allclose(graph_module(x), torch.zeros_like(x))\n"
  },
  {
    "path": "test/nn/test_inits.py",
    "content": "import torch\nfrom torch.nn import Linear as Lin\nfrom torch.nn import ReLU\nfrom torch.nn import Sequential as Seq\n\nfrom torch_geometric.nn.inits import (\n    glorot,\n    glorot_orthogonal,\n    ones,\n    reset,\n    uniform,\n    zeros,\n)\n\n\ndef test_inits():\n    x = torch.empty(1, 4)\n\n    uniform(size=4, value=x)\n    assert x.min() >= -0.5\n    assert x.max() <= 0.5\n\n    glorot(x)\n    assert x.min() >= -1.1\n    assert x.max() <= 1.1\n\n    glorot_orthogonal(x, scale=1.0)\n    assert x.min() >= -2.5\n    assert x.max() <= 2.5\n\n    zeros(x)\n    assert x.tolist() == [[0, 0, 0, 0]]\n\n    ones(x)\n    assert x.tolist() == [[1, 1, 1, 1]]\n\n    nn = Lin(16, 16)\n    uniform(size=4, value=nn.weight)\n    assert nn.weight[0].min() >= -0.5\n    assert nn.weight[0].max() <= 0.5\n\n    glorot(nn.weight)\n    assert nn.weight[0].min() >= -0.45\n    assert nn.weight[0].max() <= 0.45\n\n    glorot_orthogonal(nn.weight, scale=1.0)\n    assert nn.weight[0].min() >= -2.5\n    assert nn.weight[0].max() <= 2.5\n\n\ndef test_reset():\n    nn = Lin(16, 16)\n    w = nn.weight.clone()\n    reset(nn)\n    assert not torch.allclose(nn.weight, w)\n\n    nn = Seq(Lin(16, 16), ReLU(), Lin(16, 16))\n    w_1, w_2 = nn[0].weight.clone(), nn[2].weight.clone()\n    reset(nn)\n    assert not torch.allclose(nn[0].weight, w_1)\n    assert not torch.allclose(nn[2].weight, w_2)\n"
  },
  {
    "path": "test/nn/test_model_hub.py",
    "content": "import os\nfrom pathlib import Path\nfrom unittest.mock import Mock\n\nimport pytest\nimport torch\n\nfrom torch_geometric.nn import GCN\nfrom torch_geometric.nn.model_hub import PyGModelHubMixin\nfrom torch_geometric.testing import withPackage\n\nREPO_NAME = 'pyg_hugging_test'\nMODEL_NAME = 'pyg_test_model'\nDATASET_NAME = 'pyg_dataset'\nCONFIG = {'hello': 'world'}\n\n\nclass DummyModel(GCN, PyGModelHubMixin):\n    def __init__(self, model_name, dataset_name, model_kwargs):\n        GCN.__init__(self, in_channels=3, hidden_channels=5, num_layers=2)\n        PyGModelHubMixin.__init__(self, model_name, dataset_name, model_kwargs)\n\n\n@pytest.fixture\ndef model():\n    return DummyModel(MODEL_NAME, DATASET_NAME, CONFIG)\n\n\n@withPackage('huggingface_hub')\ndef test_model_init():\n    model = DummyModel(\n        MODEL_NAME, DATASET_NAME, model_kwargs={\n            **CONFIG, 'tensor': torch.randn([1, 2, 3])\n        })\n    assert model.model_config == CONFIG\n\n\n@withPackage('huggingface_hub')\ndef test_save_pretrained(model, tmp_path):\n    save_directory = f'{str(tmp_path / REPO_NAME)}'\n    model.save_pretrained(save_directory)\n    files = os.listdir(save_directory)\n    assert 'model.pth' in files\n    assert len(files) >= 1\n\n\n@withPackage('huggingface_hub')\ndef test_save_pretrained_internal(model, tmp_path):\n    save_directory = f'{str(tmp_path / REPO_NAME)}'\n    model._save_pretrained = Mock()\n    model.save_pretrained(save_directory)\n    model._save_pretrained.assert_called_with(Path(save_directory))\n\n\n@withPackage('huggingface_hub')\ndef test_save_pretrained_with_push_to_hub(model, tmp_path):\n    save_directory = f'{str(tmp_path / REPO_NAME)}'\n\n    model.push_to_hub = Mock()\n    model.construct_model_card = Mock()\n    model._save_pretrained = Mock()  # disable _save_pretrained to speed-up\n\n    # Not pushed to hub\n    model.save_pretrained(save_directory)\n    model.push_to_hub.assert_not_called()\n    model.construct_model_card.assert_called_with(MODEL_NAME, DATASET_NAME)\n\n    # Push to hub with repo_id\n    model.save_pretrained(save_directory, push_to_hub=True, repo_id='CustomID',\n                          config=CONFIG)\n    model.push_to_hub.assert_called_with(\n        repo_id='CustomID',\n        model_card_kwargs={},\n        config=CONFIG,\n    )\n\n    # Push to hub with default repo_id (based on dir name)\n    model.save_pretrained(save_directory, push_to_hub=True, config=CONFIG)\n    model.push_to_hub.assert_called_with(\n        repo_id=REPO_NAME,\n        model_card_kwargs={},\n        config=CONFIG,\n    )\n\n\n@withPackage('huggingface_hub')\ndef test_from_pretrained(model, tmp_path):\n    save_directory = f'{str(tmp_path / REPO_NAME)}'\n    model.save_pretrained(save_directory)\n\n    model = model.from_pretrained(save_directory)\n    assert isinstance(model, DummyModel)\n\n\n@withPackage('huggingface_hub')\ndef test_from_pretrained_internal(model, monkeypatch):\n    hf_hub_download = Mock(side_effect='model')\n    monkeypatch.setattr('torch_geometric.nn.model_hub.hf_hub_download',\n                        hf_hub_download)\n    monkeypatch.setattr('torch_geometric.nn.model_hub.fs.torch_load',\n                        lambda x, **kwargs: {'state_dict': 1})\n\n    model = model._from_pretrained(\n        model_id=MODEL_NAME,\n        revision=None,\n        cache_dir=None,\n        force_download=False,\n        local_files_only=False,\n        token=False,\n        dataset_name=DATASET_NAME,\n        model_name=MODEL_NAME,\n        map_location='cpu',\n        strict=False,\n        **CONFIG,\n    )\n\n    assert hf_hub_download.call_count == 1\n    assert model.model_config == CONFIG\n"
  },
  {
    "path": "test/nn/test_model_summary.py",
    "content": "from typing import Optional\n\nimport pytest\nimport torch\nfrom torch import Tensor, nn\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import Linear, SAGEConv, summary, to_hetero\nfrom torch_geometric.nn.models import GCN\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.typing import SparseTensor\n\n\nclass GraphSAGE(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lin1 = Linear(16, 16)\n        self.conv1 = SAGEConv(16, 32)\n        self.lin2 = Linear(32, 32)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x = self.lin1(x).relu()\n        x = self.conv1(x, edge_index).relu()\n        x = self.lin2(x)\n        return x\n\n\nclass ModuleDictModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.acts = nn.ModuleDict({\n            \"lrelu\": nn.LeakyReLU(),\n            \"prelu\": nn.PReLU()\n        })\n\n    def forward(self, x: torch.Tensor, act_type: str) -> torch.Tensor:\n        return self.acts[act_type](x)\n\n\n@pytest.fixture\ndef gcn():\n    torch.manual_seed(1)\n    model = GCN(32, 16, num_layers=2, out_channels=32)\n    x = torch.randn(100, 32)\n    edge_index = torch.randint(100, size=(2, 20))\n    adj_t: Optional[SparseTensor] = None\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj_t = SparseTensor.from_edge_index(\n            edge_index,\n            sparse_sizes=(100, 100),\n        ).t()\n    return dict(model=model, x=x, edge_index=edge_index, adj_t=adj_t)\n\n\n@withPackage('tabulate')\ndef test_summary_basic(gcn):\n    expected = \"\"\"\n+---------------------+--------------------+----------------+----------+\n| Layer               | Input Shape        | Output Shape   | #Param   |\n|---------------------+--------------------+----------------+----------|\n| GCN                 | [100, 32], [2, 20] | [100, 32]      | 1,072    |\n| ├─(dropout)Dropout  | [100, 16]          | [100, 16]      | --       |\n| ├─(act)ReLU         | [100, 16]          | [100, 16]      | --       |\n| ├─(convs)ModuleList | --                 | --             | 1,072    |\n| │    └─(0)GCNConv   | [100, 32], [2, 20] | [100, 16]      | 528      |\n| │    └─(1)GCNConv   | [100, 16], [2, 20] | [100, 32]      | 544      |\n| ├─(norms)ModuleList | --                 | --             | --       |\n| │    └─(0)Identity  | [100, 16]          | [100, 16]      | --       |\n| │    └─(1)Identity  | --                 | --             | --       |\n+---------------------+--------------------+----------------+----------+\n\"\"\"\n    assert summary(gcn['model'], gcn['x'], gcn['edge_index']) == expected[1:-1]\n\n\n@withPackage('tabulate', 'torch_sparse')\ndef test_summary_with_sparse_tensor(gcn):\n    expected = \"\"\"\n+---------------------+-----------------------+----------------+----------+\n| Layer               | Input Shape           | Output Shape   | #Param   |\n|---------------------+-----------------------+----------------+----------|\n| GCN                 | [100, 32], [100, 100] | [100, 32]      | 1,072    |\n| ├─(dropout)Dropout  | [100, 16]             | [100, 16]      | --       |\n| ├─(act)ReLU         | [100, 16]             | [100, 16]      | --       |\n| ├─(convs)ModuleList | --                    | --             | 1,072    |\n| │    └─(0)GCNConv   | [100, 32], [100, 100] | [100, 16]      | 528      |\n| │    └─(1)GCNConv   | [100, 16], [100, 100] | [100, 32]      | 544      |\n| ├─(norms)ModuleList | --                    | --             | --       |\n| │    └─(0)Identity  | [100, 16]             | [100, 16]      | --       |\n| │    └─(1)Identity  | --                    | --             | --       |\n+---------------------+-----------------------+----------------+----------+\n\"\"\"\n    assert summary(gcn['model'], gcn['x'], gcn['adj_t']) == expected[1:-1]\n\n\n@withPackage('tabulate')\ndef test_lazy_gcn():\n    expected = \"\"\"\n+---------------------+--------------------+----------------+----------+\n| Layer               | Input Shape        | Output Shape   | #Param   |\n|---------------------+--------------------+----------------+----------|\n| GCN                 | [100, 32], [2, 20] | [100, 32]      | -1       |\n| ├─(dropout)Dropout  | [100, 16]          | [100, 16]      | --       |\n| ├─(act)ReLU         | [100, 16]          | [100, 16]      | --       |\n| ├─(convs)ModuleList | --                 | --             | -1       |\n| │    └─(0)GCNConv   | [100, 32], [2, 20] | [100, 16]      | -1       |\n| │    └─(1)GCNConv   | [100, 16], [2, 20] | [100, 32]      | 544      |\n| ├─(norms)ModuleList | --                 | --             | --       |\n| │    └─(0)Identity  | [100, 16]          | [100, 16]      | --       |\n| │    └─(1)Identity  | --                 | --             | --       |\n+---------------------+--------------------+----------------+----------+\n\"\"\"\n    model = GCN(-1, 16, num_layers=2, out_channels=32)\n    x = torch.randn(100, 32)\n    edge_index = torch.randint(100, size=(2, 20))\n\n    assert summary(model, x, edge_index) == expected[1:-1]\n\n\n@withPackage('tabulate')\ndef test_summary_with_max_depth(gcn):\n    expected = \"\"\"\n+---------------------+--------------------+----------------+----------+\n| Layer               | Input Shape        | Output Shape   | #Param   |\n|---------------------+--------------------+----------------+----------|\n| GCN                 | [100, 32], [2, 20] | [100, 32]      | 1,072    |\n| ├─(dropout)Dropout  | [100, 16]          | [100, 16]      | --       |\n| ├─(act)ReLU         | [100, 16]          | [100, 16]      | --       |\n| ├─(convs)ModuleList | --                 | --             | 1,072    |\n| ├─(norms)ModuleList | --                 | --             | --       |\n+---------------------+--------------------+----------------+----------+\n\"\"\"\n    assert summary(\n        gcn['model'],\n        gcn['x'],\n        gcn['edge_index'],\n        max_depth=1,\n    ) == expected[1:-1]\n\n\n@withPackage('tabulate')\ndef test_summary_with_leaf_module(gcn):\n    expected = \"\"\"# noqa: E501\n+-----------------------------------------+--------------------+----------------+----------+\n| Layer                                   | Input Shape        | Output Shape   | #Param   |\n|-----------------------------------------+--------------------+----------------+----------|\n| GCN                                     | [100, 32], [2, 20] | [100, 32]      | 1,072    |\n| ├─(dropout)Dropout                      | [100, 16]          | [100, 16]      | --       |\n| ├─(act)ReLU                             | [100, 16]          | [100, 16]      | --       |\n| ├─(convs)ModuleList                     | --                 | --             | 1,072    |\n| │    └─(0)GCNConv                       | [100, 32], [2, 20] | [100, 16]      | 528      |\n| │    │    └─(aggr_module)SumAggregation | [120, 16]          | [100, 16]      | --       |\n| │    │    └─(lin)Linear                 | [100, 32]          | [100, 16]      | 512      |\n| │    └─(1)GCNConv                       | [100, 16], [2, 20] | [100, 32]      | 544      |\n| │    │    └─(aggr_module)SumAggregation | [120, 32]          | [100, 32]      | --       |\n| │    │    └─(lin)Linear                 | [100, 16]          | [100, 32]      | 512      |\n| ├─(norms)ModuleList                     | --                 | --             | --       |\n| │    └─(0)Identity                      | [100, 16]          | [100, 16]      | --       |\n| │    └─(1)Identity                      | --                 | --             | --       |\n+-----------------------------------------+--------------------+----------------+----------+\n\"\"\"\n    assert summary(\n        gcn['model'],\n        gcn['x'],\n        gcn['edge_index'],\n        leaf_module=None,\n    ) == expected[13:-1]\n\n\n@withPackage('tabulate')\ndef test_summary_with_reusing_layers():\n    act = nn.ReLU(inplace=True)\n    model1 = nn.Sequential(act, nn.Identity(), act, nn.Identity(), act)\n    model2 = nn.Sequential(\n        nn.ReLU(inplace=True),\n        nn.Identity(),\n        nn.ReLU(inplace=True),\n        nn.Identity(),\n        nn.ReLU(inplace=True),\n    )\n    x = torch.randn(10)\n\n    assert summary(model1, x) == summary(model2, x)\n\n\n@withPackage('tabulate')\ndef test_summary_with_to_hetero_model():\n    x_dict = {\n        'p': torch.randn(100, 16),\n        'a': torch.randn(100, 16),\n    }\n    edge_index_dict = {\n        ('p', 'to', 'p'): torch.randint(100, (2, 200)),\n        ('p', 'to', 'a'): torch.randint(100, (2, 200)),\n        ('a', 'to', 'p'): torch.randint(100, (2, 200)),\n    }\n    metadata = list(x_dict.keys()), list(edge_index_dict.keys())\n    model = to_hetero(GraphSAGE(), metadata)\n\n    expected = \"\"\"\n+---------------------------+---------------------+----------------+----------+\n| Layer                     | Input Shape         | Output Shape   |   #Param |\n|---------------------------+---------------------+----------------+----------|\n| GraphModule               |                     |                |    5,824 |\n| ├─(lin1)ModuleDict        | --                  | --             |      544 |\n| │    └─(p)Linear          | [100, 16]           | [100, 16]      |      272 |\n| │    └─(a)Linear          | [100, 16]           | [100, 16]      |      272 |\n| ├─(conv1)ModuleDict       | --                  | --             |    3,168 |\n| │    └─(p__to__p)SAGEConv | [100, 16], [2, 200] | [100, 32]      |    1,056 |\n| │    └─(p__to__a)SAGEConv | [2, 200]            | [100, 32]      |    1,056 |\n| │    └─(a__to__p)SAGEConv | [2, 200]            | [100, 32]      |    1,056 |\n| ├─(lin2)ModuleDict        | --                  | --             |    2,112 |\n| │    └─(p)Linear          | [100, 32]           | [100, 32]      |    1,056 |\n| │    └─(a)Linear          | [100, 32]           | [100, 32]      |    1,056 |\n+---------------------------+---------------------+----------------+----------+\n\"\"\"\n    assert summary(model, x_dict, edge_index_dict) == expected[1:-1]\n\n\n@withPackage('tabulate')\ndef test_summary_with_module_dict_model():\n    model = ModuleDictModel()\n    x = torch.randn(100, 32)\n\n    expected = \"\"\"\n+-------------------------+---------------+----------------+----------+\n| Layer                   | Input Shape   | Output Shape   | #Param   |\n|-------------------------+---------------+----------------+----------|\n| ModuleDictModel         | [100, 32]     | [100, 32]      | 1        |\n| ├─(acts)ModuleDict      | --            | --             | 1        |\n| │    └─(lrelu)LeakyReLU | --            | --             | --       |\n| │    └─(prelu)PReLU     | [100, 32]     | [100, 32]      | 1        |\n+-------------------------+---------------+----------------+----------+\n\"\"\"\n    assert summary(model, x, 'prelu') == expected[1:-1]\n\n\n@withPackage('tabulate')\ndef test_summary_with_jit_model():\n    model = nn.Sequential(nn.Linear(32, 16), nn.ReLU(), nn.Linear(16, 8))\n    model = torch.jit.script(model)\n    x = torch.randn(100, 32)\n\n    expected = \"\"\"\n+----------------------------+---------------+----------------+----------+\n| Layer                      | Input Shape   | Output Shape   | #Param   |\n|----------------------------+---------------+----------------+----------|\n| RecursiveScriptModule      | --            | --             | 664      |\n| ├─(0)RecursiveScriptModule | --            | --             | 528      |\n| ├─(1)RecursiveScriptModule | --            | --             | --       |\n| ├─(2)RecursiveScriptModule | --            | --             | 136      |\n+----------------------------+---------------+----------------+----------+\n\"\"\"\n    assert summary(model, x) == expected[1:-1]\n"
  },
  {
    "path": "test/nn/test_module_dict.py",
    "content": "import torch\n\nfrom torch_geometric.nn.module_dict import ModuleDict\n\n\ndef test_internal_external_key_conversion():\n    assert ModuleDict.to_internal_key('a.b') == 'a#b'\n    assert ModuleDict.to_internal_key('ab') == 'ab'\n    assert ModuleDict.to_internal_key('a.b.c') == 'a#b#c'\n    assert ModuleDict.to_internal_key(('a', 'b')) == '<a___b>'\n    assert ModuleDict.to_internal_key(('a.b', 'c')) == '<a#b___c>'\n    assert ModuleDict.to_internal_key('type') == '<type>'\n\n    assert ModuleDict.to_external_key('a#b') == 'a.b'\n    assert ModuleDict.to_external_key('a#b#c') == 'a.b.c'\n    assert ModuleDict.to_external_key('<a___b>') == ('a', 'b')\n    assert ModuleDict.to_external_key('<a#b___c>') == ('a.b', 'c')\n    assert ModuleDict.to_external_key('<type>') == 'type'\n\n\ndef test_dot_syntax_keys():\n    module_dict = ModuleDict({\n        'lin1': torch.nn.Linear(16, 16),\n        'model.lin2': torch.nn.Linear(8, 8),\n        'model.sub_model.lin3': torch.nn.Linear(4, 4),\n    })\n\n    expected_keys = {'lin1', 'model.lin2', 'model.sub_model.lin3'}\n    assert set(module_dict.keys()) == expected_keys\n    assert {key for key, _ in module_dict.items()} == expected_keys\n\n    for key in expected_keys:\n        assert key in module_dict\n\n    del module_dict['model.lin2']\n    assert 'model.lin2' not in module_dict\n\n\ndef test_tuple_keys():\n    module_dict = ModuleDict({\n        ('a', 'b'): torch.nn.Linear(16, 16),\n        ('a.b', 'c'): torch.nn.Linear(8, 8),\n    })\n\n    expected_keys = {('a', 'b'), ('a.b', 'c')}\n    assert set(module_dict.keys()) == expected_keys\n    assert {key for key, _ in module_dict.items()} == expected_keys\n\n    for key in expected_keys:\n        assert key in module_dict\n\n    del module_dict['a', 'b']\n    assert ('a', 'b') not in module_dict\n\n\ndef test_reserved_keys():\n    module_dict = ModuleDict({\n        'type': torch.nn.Linear(16, 16),\n        '__annotations__': torch.nn.Linear(8, 8),\n    })\n\n    expected_keys = {'type', '__annotations__'}\n    assert set(module_dict.keys()) == expected_keys\n    assert {key for key, _ in module_dict.items()} == expected_keys\n\n    for key in expected_keys:\n        assert key in module_dict\n\n    del module_dict['type']\n    assert 'type' not in module_dict\n"
  },
  {
    "path": "test/nn/test_parameter_dict.py",
    "content": "import torch\n\nfrom torch_geometric.nn.parameter_dict import ParameterDict\n\n\ndef test_internal_external_key_conversion():\n    assert ParameterDict.to_internal_key('a.b') == 'a#b'\n    assert ParameterDict.to_internal_key('ab') == 'ab'\n    assert ParameterDict.to_internal_key('a.b.c') == 'a#b#c'\n    assert ParameterDict.to_internal_key(('a', 'b')) == '<a___b>'\n    assert ParameterDict.to_internal_key(('a.b', 'c')) == '<a#b___c>'\n    assert ParameterDict.to_internal_key('type') == '<type>'\n\n    assert ParameterDict.to_external_key('a#b') == 'a.b'\n    assert ParameterDict.to_external_key('a#b#c') == 'a.b.c'\n    assert ParameterDict.to_external_key('<a___b>') == ('a', 'b')\n    assert ParameterDict.to_external_key('<a#b___c>') == ('a.b', 'c')\n    assert ParameterDict.to_external_key('<type>') == 'type'\n\n\ndef test_dot_syntax_keys():\n    parameter_dict = {\n        'param1': torch.nn.Parameter(torch.randn(16, 16)),\n        'model.param2': torch.nn.Parameter(torch.randn(8, 8)),\n        'model.sub_model.param3': torch.nn.Parameter(torch.randn(4, 4)),\n    }\n    parameter_dict = ParameterDict(parameter_dict)\n\n    expected_keys = {'param1', 'model.param2', 'model.sub_model.param3'}\n    assert set(parameter_dict.keys()) == expected_keys\n    assert {key for key, _ in parameter_dict.items()} == expected_keys\n\n    for key in expected_keys:\n        assert key in parameter_dict\n\n    del parameter_dict['model.param2']\n    assert 'model.param2' not in parameter_dict\n\n\ndef test_tuple_keys():\n    parameter_dict = {\n        ('a', 'b'): torch.nn.Parameter(torch.randn(16, 16)),\n        ('a.b', 'c'): torch.nn.Parameter(torch.randn(8, 8)),\n    }\n    parameter_dict = ParameterDict(parameter_dict)\n\n    expected_keys = {('a', 'b'), ('a.b', 'c')}\n    assert set(parameter_dict.keys()) == expected_keys\n    assert {key for key, _ in parameter_dict.items()} == expected_keys\n\n    for key in expected_keys:\n        assert key in parameter_dict\n\n    del parameter_dict['a', 'b']\n    assert ('a', 'b') not in parameter_dict\n\n\ndef test_reserved_keys():\n    parameter_dict = {\n        'type': torch.nn.Parameter(torch.randn(16, 16)),\n        '__annotations__': torch.nn.Parameter(torch.randn(8, 8)),\n    }\n    parameter_dict = ParameterDict(parameter_dict)\n\n    expected_keys = {'type', '__annotations__'}\n    assert set(parameter_dict.keys()) == expected_keys\n    assert {key for key, _ in parameter_dict.items()} == expected_keys\n\n    for key in expected_keys:\n        assert key in parameter_dict\n\n    del parameter_dict['type']\n    assert 'type' not in parameter_dict\n"
  },
  {
    "path": "test/nn/test_reshape.py",
    "content": "import torch\n\nfrom torch_geometric.nn.reshape import Reshape\n\n\ndef test_reshape():\n    x = torch.randn(10, 4)\n\n    op = Reshape(5, 2, 4)\n    assert str(op) == 'Reshape(5, 2, 4)'\n\n    assert op(x).size() == (5, 2, 4)\n    assert torch.equal(op(x).view(10, 4), x)\n"
  },
  {
    "path": "test/nn/test_resolver.py",
    "content": "import pytest\nimport torch\nfrom torch.optim.lr_scheduler import ConstantLR, LambdaLR, ReduceLROnPlateau\n\nimport torch_geometric\nfrom torch_geometric.nn.resolver import (\n    activation_resolver,\n    aggregation_resolver,\n    lr_scheduler_resolver,\n    normalization_resolver,\n    optimizer_resolver,\n)\n\n\ndef test_activation_resolver():\n    assert isinstance(activation_resolver(torch.nn.ELU()), torch.nn.ELU)\n    assert isinstance(activation_resolver(torch.nn.ReLU()), torch.nn.ReLU)\n    assert isinstance(activation_resolver(torch.nn.PReLU()), torch.nn.PReLU)\n\n    assert isinstance(activation_resolver('elu'), torch.nn.ELU)\n    assert isinstance(activation_resolver('relu'), torch.nn.ReLU)\n    assert isinstance(activation_resolver('prelu'), torch.nn.PReLU)\n\n\n@pytest.mark.parametrize('aggr_tuple', [\n    (torch_geometric.nn.MeanAggregation, 'mean'),\n    (torch_geometric.nn.SumAggregation, 'sum'),\n    (torch_geometric.nn.SumAggregation, 'add'),\n    (torch_geometric.nn.MaxAggregation, 'max'),\n    (torch_geometric.nn.MinAggregation, 'min'),\n    (torch_geometric.nn.MulAggregation, 'mul'),\n    (torch_geometric.nn.VarAggregation, 'var'),\n    (torch_geometric.nn.StdAggregation, 'std'),\n    (torch_geometric.nn.SoftmaxAggregation, 'softmax'),\n    (torch_geometric.nn.PowerMeanAggregation, 'powermean'),\n])\ndef test_aggregation_resolver(aggr_tuple):\n    aggr_module, aggr_repr = aggr_tuple\n    assert isinstance(aggregation_resolver(aggr_module()), aggr_module)\n    assert isinstance(aggregation_resolver(aggr_repr), aggr_module)\n\n\ndef test_multi_aggregation_resolver():\n    aggr = aggregation_resolver(None)\n    assert aggr is None\n\n    aggr = aggregation_resolver(['sum', 'mean', None])\n    assert isinstance(aggr, torch_geometric.nn.MultiAggregation)\n    assert len(aggr.aggrs) == 3\n    assert isinstance(aggr.aggrs[0], torch_geometric.nn.SumAggregation)\n    assert isinstance(aggr.aggrs[1], torch_geometric.nn.MeanAggregation)\n    assert aggr.aggrs[2] is None\n\n\n@pytest.mark.parametrize('norm_tuple', [\n    (torch_geometric.nn.BatchNorm, 'batch', (16, )),\n    (torch_geometric.nn.BatchNorm, 'batch_norm', (16, )),\n    (torch_geometric.nn.InstanceNorm, 'instance_norm', (16, )),\n    (torch_geometric.nn.LayerNorm, 'layer_norm', (16, )),\n    (torch_geometric.nn.GraphNorm, 'graph_norm', (16, )),\n    (torch_geometric.nn.GraphSizeNorm, 'graphsize_norm', ()),\n    (torch_geometric.nn.PairNorm, 'pair_norm', ()),\n    (torch_geometric.nn.MessageNorm, 'message_norm', ()),\n    (torch_geometric.nn.DiffGroupNorm, 'diffgroup_norm', (16, 4)),\n])\ndef test_normalization_resolver(norm_tuple):\n    norm_module, norm_repr, norm_args = norm_tuple\n    assert isinstance(normalization_resolver(norm_module(*norm_args)),\n                      norm_module)\n    assert isinstance(normalization_resolver(norm_repr, *norm_args),\n                      norm_module)\n\n\ndef test_optimizer_resolver():\n    params = [torch.nn.Parameter(torch.randn(1))]\n\n    assert isinstance(optimizer_resolver(torch.optim.SGD(params, lr=0.01)),\n                      torch.optim.SGD)\n    assert isinstance(optimizer_resolver(torch.optim.Adam(params)),\n                      torch.optim.Adam)\n    assert isinstance(optimizer_resolver(torch.optim.Rprop(params)),\n                      torch.optim.Rprop)\n\n    assert isinstance(optimizer_resolver('sgd', params, lr=0.01),\n                      torch.optim.SGD)\n    assert isinstance(optimizer_resolver('adam', params), torch.optim.Adam)\n    assert isinstance(optimizer_resolver('rprop', params), torch.optim.Rprop)\n\n\n@pytest.mark.parametrize('scheduler_args', [\n    ('constant_with_warmup', LambdaLR),\n    ('linear_with_warmup', LambdaLR),\n    ('cosine_with_warmup', LambdaLR),\n    ('cosine_with_warmup_restarts', LambdaLR),\n    ('polynomial_with_warmup', LambdaLR),\n    ('constant', ConstantLR),\n    ('ReduceLROnPlateau', ReduceLROnPlateau),\n])\ndef test_lr_scheduler_resolver(scheduler_args):\n    scheduler_name, scheduler_cls = scheduler_args\n\n    model = torch.nn.Linear(10, 5)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n    lr_scheduler = lr_scheduler_resolver(\n        scheduler_name,\n        optimizer,\n        num_training_steps=100,\n    )\n    assert isinstance(lr_scheduler, scheduler_cls)\n"
  },
  {
    "path": "test/nn/test_sequential.py",
    "content": "from collections import OrderedDict\n\nimport torch\nimport torch.fx\nfrom torch.nn import Dropout, Linear, ReLU\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import (\n    GCNConv,\n    JumpingKnowledge,\n    MessagePassing,\n    SAGEConv,\n    Sequential,\n    global_mean_pool,\n    to_hetero,\n)\nfrom torch_geometric.typing import SparseTensor\n\n\ndef test_sequential_basic():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    batch = torch.zeros(4, dtype=torch.long)\n\n    model = Sequential('x, edge_index', [\n        (GCNConv(16, 64), 'x, edge_index -> x'),\n        ReLU(inplace=True),\n        (GCNConv(64, 64), 'x, edge_index -> x'),\n        ReLU(inplace=True),\n        Linear(64, 7),\n    ]).cpu()\n    model.reset_parameters()\n\n    assert len(model) == 5\n    assert str(model) == (\n        'Sequential(\\n'\n        '  (0) - GCNConv(16, 64): x, edge_index -> x\\n'\n        '  (1) - ReLU(inplace=True): x -> x\\n'\n        '  (2) - GCNConv(64, 64): x, edge_index -> x\\n'\n        '  (3) - ReLU(inplace=True): x -> x\\n'\n        '  (4) - Linear(in_features=64, out_features=7, bias=True): x -> x\\n'\n        ')')\n\n    assert isinstance(model[0], GCNConv)\n    assert isinstance(model[1], ReLU)\n    assert isinstance(model[2], GCNConv)\n    assert isinstance(model[3], ReLU)\n    assert isinstance(model[4], Linear)\n\n    out = model(x, edge_index)\n    assert out.size() == (4, 7)\n\n    model = Sequential('x, edge_index, batch', [\n        (Dropout(p=0.5), 'x -> x'),\n        (GCNConv(16, 64), 'x, edge_index -> x1'),\n        ReLU(inplace=True),\n        (GCNConv(64, 64), 'x1, edge_index -> x2'),\n        ReLU(inplace=True),\n        (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'),\n        (JumpingKnowledge('cat', 64, num_layers=2), 'xs -> x'),\n        (global_mean_pool, 'x, batch -> x'),\n        Linear(2 * 64, 7),\n    ])\n    model.reset_parameters()\n\n    out = model(x, edge_index, batch)\n    assert out.size() == (1, 7)\n\n\ndef test_sequential_jit():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    model = Sequential('x: Tensor, edge_index: Tensor', [\n        (GCNConv(16, 64), 'x, edge_index -> x'),\n        ReLU(inplace=True),\n        (GCNConv(64, 64), 'x, edge_index -> x'),\n        ReLU(inplace=True),\n        Linear(64, 7),\n    ])\n    torch.jit.script(model)(x, edge_index)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj_t = SparseTensor.from_edge_index(edge_index).t()\n\n        model = Sequential('x: Tensor, edge_index: SparseTensor', [\n            (GCNConv(16, 64), 'x, edge_index -> x'),\n            ReLU(inplace=True),\n            (GCNConv(64, 64), 'x, edge_index -> x'),\n            ReLU(inplace=True),\n            Linear(64, 7),\n        ])\n        torch.jit.script(model)(x, adj_t)\n\n\ndef symbolic_trace(module):\n    class Tracer(torch.fx.Tracer):\n        def is_leaf_module(self, module, *args, **kwargs) -> bool:\n            return (isinstance(module, MessagePassing)\n                    or super().is_leaf_module(module, *args, **kwargs))\n\n    return torch.fx.GraphModule(module, Tracer().trace(module))\n\n\ndef test_sequential_tracable():\n    model = Sequential('x, edge_index', [\n        (GCNConv(16, 64), 'x, edge_index -> x1'),\n        ReLU(inplace=True),\n        (GCNConv(64, 64), 'x1, edge_index -> x2'),\n        ReLU(inplace=True),\n        (lambda x1, x2: x1 + x2, 'x1, x2 -> x'),\n        Linear(64, 7),\n    ])\n    symbolic_trace(model)\n\n\ndef test_sequential_with_multiple_return_values():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    model = Sequential('x, edge_index', [\n        (GCNConv(16, 32), 'x, edge_index -> x1'),\n        (GCNConv(32, 64), 'x1, edge_index -> x2'),\n        (lambda x1, x2: (x1, x2), 'x1, x2 -> x1, x2'),\n    ])\n\n    x1, x2 = model(x, edge_index)\n    assert x1.size() == (4, 32)\n    assert x2.size() == (4, 64)\n\n\ndef test_sequential_with_ordered_dict():\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n\n    model = Sequential(\n        'x, edge_index', modules=OrderedDict([\n            ('conv1', (GCNConv(16, 32), 'x, edge_index -> x')),\n            ('conv2', (GCNConv(32, 64), 'x, edge_index -> x')),\n        ]))\n\n    assert isinstance(model.conv1, GCNConv)\n    assert isinstance(model.conv2, GCNConv)\n\n    x = model(x, edge_index)\n    assert x.size() == (4, 64)\n\n\ndef test_sequential_to_hetero():\n    model = Sequential('x, edge_index', [\n        (SAGEConv((-1, -1), 32), 'x, edge_index -> x1'),\n        ReLU(),\n        (SAGEConv((-1, -1), 64), 'x1, edge_index -> x2'),\n        ReLU(),\n    ])\n\n    x_dict = {\n        'paper': torch.randn(100, 16),\n        'author': torch.randn(100, 16),\n    }\n    edge_index_dict = {\n        ('paper', 'cites', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('paper', 'written_by', 'author'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('author', 'writes', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n    }\n    metadata = list(x_dict.keys()), list(edge_index_dict.keys())\n\n    model = to_hetero(model, metadata, debug=False)\n\n    out_dict = model(x_dict, edge_index_dict)\n    assert isinstance(out_dict, dict) and len(out_dict) == 2\n    assert out_dict['paper'].size() == (100, 64)\n    assert out_dict['author'].size() == (100, 64)\n"
  },
  {
    "path": "test/nn/test_to_fixed_size_transformer.py",
    "content": "import torch\n\nfrom torch_geometric.nn import SumAggregation\nfrom torch_geometric.nn.to_fixed_size_transformer import to_fixed_size\n\n\nclass Model(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.aggr = SumAggregation()\n\n    def forward(self, x, batch):\n        return self.aggr(x, batch, dim=0)\n\n\ndef test_to_fixed_size():\n    x = torch.randn(10, 16)\n    batch = torch.zeros(10, dtype=torch.long)\n\n    model = Model()\n    assert model(x, batch).size() == (1, 16)\n\n    model = to_fixed_size(model, batch_size=10)\n    assert model(x, batch).size() == (10, 16)\n"
  },
  {
    "path": "test/nn/test_to_hetero_module.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn.conv import SAGEConv\nfrom torch_geometric.nn.dense import Linear\nfrom torch_geometric.nn.to_hetero_module import (\n    ToHeteroLinear,\n    ToHeteroMessagePassing,\n)\n\n\n@pytest.mark.parametrize('LinearCls', [torch.nn.Linear, Linear])\ndef test_to_hetero_linear(LinearCls):\n    x_dict = {'1': torch.randn(5, 16), '2': torch.randn(4, 16)}\n    x = torch.cat([x_dict['1'], x_dict['2']], dim=0)\n    type_vec = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1])\n\n    module = ToHeteroLinear(LinearCls(16, 32), list(x_dict.keys()))\n\n    out_dict = module(x_dict)\n    assert len(out_dict) == 2\n    assert out_dict['1'].size() == (5, 32)\n    assert out_dict['2'].size() == (4, 32)\n\n    out = module(x, type_vec)\n    assert out.size() == (9, 32)\n\n    assert torch.allclose(out_dict['1'], out[0:5])\n    assert torch.allclose(out_dict['2'], out[5:9])\n\n\ndef test_to_hetero_message_passing():\n    x_dict = {'1': torch.randn(5, 16), '2': torch.randn(4, 16)}\n    x = torch.cat([x_dict['1'], x_dict['2']], dim=0)\n    node_type = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1])\n\n    edge_index_dict = {\n        ('1', 'to', '2'): torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 2, 3]]),\n        ('2', 'to', '1'): torch.tensor([[0, 0, 1, 2, 3], [0, 1, 2, 3, 4]]),\n    }\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 4, 5, 5, 6, 7, 8],\n        [5, 5, 6, 7, 8, 0, 1, 2, 3, 4],\n    ])\n    edge_type = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])\n\n    module = ToHeteroMessagePassing(SAGEConv(16, 32), list(x_dict.keys()),\n                                    list(edge_index_dict.keys()))\n\n    out_dict = module(x_dict, edge_index_dict)\n    assert len(out_dict) == 2\n    assert out_dict['1'].size() == (5, 32)\n    assert out_dict['2'].size() == (4, 32)\n\n    out = module(x, edge_index, node_type, edge_type)\n    assert out.size() == (9, 32)\n\n    assert torch.allclose(out_dict['1'], out[0:5])\n    assert torch.allclose(out_dict['2'], out[5:9])\n"
  },
  {
    "path": "test/nn/test_to_hetero_transformer.py",
    "content": "from typing import Tuple\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Linear, ReLU, Sequential\n\nimport torch_geometric.typing\nfrom torch_geometric.datasets import FakeHeteroDataset\nfrom torch_geometric.nn import (\n    GAT,\n    BatchNorm,\n    GATv2Conv,\n    GCNConv,\n    GINEConv,\n    GraphSAGE,\n)\nfrom torch_geometric.nn import Linear as LazyLinear\nfrom torch_geometric.nn import (\n    MeanAggregation,\n    MessagePassing,\n    RGCNConv,\n    SAGEConv,\n    to_hetero,\n)\nfrom torch_geometric.testing import onlyCUDA\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import dropout_edge\n\ntorch.fx.wrap('dropout_edge')\n\n\nclass Net1(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lin1 = Linear(16, 32)\n        self.lin2 = Linear(8, 16)\n\n    def forward(self, x: Tensor, edge_attr: Tensor) -> Tuple[Tensor, Tensor]:\n        x = self.lin1(x)\n        edge_attr = self.lin2(edge_attr)\n        return x, edge_attr\n\n\nclass Net2(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lin1 = Linear(16, 16)\n        self.conv1 = SAGEConv(16, 32)\n        self.lin2 = Linear(32, 32)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x = self.lin1(x).relu_()\n        x = self.conv1(x, edge_index).relu_()\n        x = self.lin2(x)\n        return x\n\n\nclass Net3(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lin1 = Linear(8, 16)\n        self.conv1 = GINEConv(nn=Linear(16, 32))\n\n    def forward(self, x: Tensor, edge_index: Tensor,\n                edge_attr: Tensor) -> Tensor:\n        x = self.conv1(x, edge_index, self.lin1(edge_attr))\n        return x\n\n\nclass Net4(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = SAGEConv(16, 16)\n        self.conv2 = SAGEConv(16, 16)\n        self.lin1 = Linear(3 * 16, 32)\n\n    def forward(self, x0: Tensor, edge_index: Tensor) -> Tensor:\n        x1 = self.conv1(x0, edge_index).relu_()\n        x2 = self.conv2(x1, edge_index).relu_()\n        return self.lin1(torch.cat([x0, x1, x2], dim=-1))\n\n\nclass Net5(torch.nn.Module):\n    def __init__(self, num_layers):\n        super().__init__()\n        self.lins = torch.nn.ModuleList()\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            self.lins.append(Linear(16, 16))\n            self.convs.append(SAGEConv(16, 16))\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        for lin, conv in zip(self.lins, self.convs):\n            x = (conv(x, edge_index) + lin(x))\n        return x\n\n\nclass Net6(torch.nn.Module):\n    def __init__(self, num_layers):\n        super().__init__()\n        self.lins = torch.nn.ModuleDict()\n        self.convs = torch.nn.ModuleDict()\n        for i in range(num_layers):\n            self.lins[str(i)] = Linear(16, 16)\n            self.convs[str(i)] = SAGEConv(16, 16)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        for i in range(len(self.lins)):\n            x = (self.convs[str(i)](x, edge_index) + self.lins[str(i)](x))\n        return x\n\n\nclass Net7(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.mlp1 = Sequential(Linear(16, 16), ReLU(), Linear(16, 16))\n        self.conv1 = SAGEConv(16, 32)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x = self.mlp1(x)\n        x = self.conv1(x, edge_index)\n        return x\n\n\nclass Net8(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lin1 = LazyLinear(-1, 32)\n\n    def forward(self, x: Tensor) -> Tensor:\n        x = self.lin1(x)\n        return x\n\n\nclass Net9(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.batch_norm = BatchNorm(16)\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self.batch_norm(x)\n\n\nclass Net10(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv = SAGEConv(16, 32)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x = F.dropout(x, p=0.5, training=self.training)\n        edge_index, _ = dropout_edge(edge_index, p=0.5, training=self.training)\n        return self.conv(x, edge_index)\n\n\nclass Net11(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv = SAGEConv(16, 16)\n        self.num_layers = 3\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        xs = [x]\n        for _ in range(self.num_layers):\n            xs.append(self.conv(xs[-1], edge_index))\n        return torch.cat(xs, dim=-1)\n\n\nclass Net12(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv = Net8()\n\n    def forward(self, x: Tensor) -> Tensor:\n        return self.conv(x)\n\n\ndef test_to_hetero_basic():\n    x_dict = {\n        'paper': torch.randn(100, 16),\n        'author': torch.randn(100, 16),\n    }\n    edge_index_dict = {\n        ('paper', 'cites', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('paper', 'written_by', 'author'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('author', 'writes', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n    }\n    edge_attr_dict = {\n        ('paper', 'cites', 'paper'): torch.randn(200, 8),\n        ('paper', 'written_by', 'author'): torch.randn(200, 8),\n        ('author', 'writes', 'paper'): torch.randn(200, 8),\n    }\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj_t_dict = {}\n        for edge_type, (row, col) in edge_index_dict.items():\n            adj_t_dict[edge_type] = SparseTensor(\n                row=col,\n                col=row,\n                sparse_sizes=(100, 100),\n            )\n\n    metadata = list(x_dict.keys()), list(edge_index_dict.keys())\n\n    model = Net1()\n    model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_attr_dict)\n    assert isinstance(out, tuple) and len(out) == 2\n    assert isinstance(out[0], dict) and len(out[0]) == 2\n    assert out[0]['paper'].size() == (100, 32)\n    assert out[0]['author'].size() == (100, 32)\n    assert isinstance(out[1], dict) and len(out[1]) == 3\n    assert out[1][('paper', 'cites', 'paper')].size() == (200, 16)\n    assert out[1][('paper', 'written_by', 'author')].size() == (200, 16)\n    assert out[1][('author', 'writes', 'paper')].size() == (200, 16)\n    assert sum(p.numel() for p in model.parameters()) == 1520\n\n    for aggr in ['sum', 'mean', 'min', 'max', 'mul']:\n        model = Net2()\n        model = to_hetero(model, metadata, aggr=aggr, debug=False)\n        assert sum(p.numel() for p in model.parameters()) == 5824\n\n        out1 = model(x_dict, edge_index_dict)\n        assert isinstance(out1, dict) and len(out1) == 2\n        assert out1['paper'].size() == (100, 32)\n        assert out1['author'].size() == (100, 32)\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            out2 = model(x_dict, adj_t_dict)\n            assert isinstance(out2, dict) and len(out2) == 2\n            for key in x_dict.keys():\n                assert torch.allclose(out1[key], out2[key], atol=1e-6)\n\n    model = Net3()\n    model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_index_dict, edge_attr_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 32)\n    assert out['author'].size() == (100, 32)\n\n    model = Net4()\n    model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 32)\n    assert out['author'].size() == (100, 32)\n\n    model = Net5(num_layers=2)\n    model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 16)\n    assert out['author'].size() == (100, 16)\n\n    model = Net6(num_layers=2)\n    model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 16)\n    assert out['author'].size() == (100, 16)\n\n    model = Net7()\n    model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 32)\n    assert out['author'].size() == (100, 32)\n\n    model = Net8()\n    model = to_hetero(model, metadata, debug=False)\n    out = model({'paper': torch.randn(4, 8), 'author': torch.randn(8, 16)})\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (4, 32)\n    assert out['author'].size() == (8, 32)\n\n    model = Net9()\n    model = to_hetero(model, metadata, debug=False)\n    out = model({'paper': torch.randn(4, 16), 'author': torch.randn(8, 16)})\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (4, 16)\n    assert out['author'].size() == (8, 16)\n\n    model = Net10()\n    with pytest.warns(UserWarning, match=\"with keyword argument 'training'\"):\n        model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 32)\n    assert out['author'].size() == (100, 32)\n\n    model = Net11()\n    model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 64)\n    assert out['author'].size() == (100, 64)\n\n    model = Net12()\n    with pytest.warns(UserWarning, match=\"parameters cannot be reset\"):\n        model = to_hetero(model, metadata, debug=False)\n    out = model({'paper': torch.randn(4, 8), 'author': torch.randn(8, 16)})\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (4, 32)\n    assert out['author'].size() == (8, 32)\n\n\nclass GCN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GCNConv(16, 32)\n        self.conv2 = GCNConv(32, 64)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x = self.conv1(x, edge_index).relu()\n        x = self.conv2(x, edge_index).relu()\n        return x\n\n\ndef test_to_hetero_with_gcn():\n    x_dict = {\n        'paper': torch.randn(100, 16),\n    }\n    edge_index_dict = {\n        ('paper', 'cites', 'paper'): torch.randint(100, (2, 200)),\n        ('paper', 'rev_cites', 'paper'): torch.randint(100, (2, 200)),\n    }\n    metadata = list(x_dict.keys()), list(edge_index_dict.keys())\n\n    model = GCN()\n    model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 1\n    assert out['paper'].size() == (100, 64)\n\n\ndef test_to_hetero_with_basic_model():\n    x_dict = {\n        'paper': torch.randn(100, 16),\n        'author': torch.randn(100, 16),\n    }\n    edge_index_dict = {\n        ('paper', 'cites', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('paper', 'written_by', 'author'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('author', 'writes', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n    }\n\n    metadata = list(x_dict.keys()), list(edge_index_dict.keys())\n\n    model = GraphSAGE((-1, -1), 32, num_layers=3)\n    model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n\n    model = GAT((-1, -1), 32, num_layers=3, add_self_loops=False)\n    model = to_hetero(model, metadata, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n\n\nclass GraphConv(MessagePassing):\n    def __init__(self, in_channels, out_channels):\n        super().__init__(aggr='sum')\n        self.lin = Linear(in_channels, out_channels, bias=False)\n\n    def reset_parameters(self):\n        self.lin.reset_parameters()\n\n    def forward(self, x, edge_index):\n        if isinstance(x, Tensor):\n            x = (x, x)\n        return self.propagate(edge_index, x=(self.lin(x[0]), x[1]))\n\n\nclass RGCN(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv = GraphConv(in_channels, out_channels)\n        self.lin = Linear(in_channels, out_channels, bias=True)\n\n    def forward(self, x, edge_index):\n        return self.lin(x) + self.conv(x, edge_index)\n\n\ndef test_to_hetero_and_rgcn_equal_output():\n    torch.manual_seed(1234)\n\n    # Run `RGCN`:\n    x = torch.randn(10, 16)  # 6 paper nodes, 4 author nodes\n    adj = (torch.rand(10, 10) > 0.5)\n    adj[6:, 6:] = False\n    edge_index = adj.nonzero(as_tuple=False).t().contiguous()\n    row, col = edge_index\n\n    # # 0 = paper<->paper, 1 = paper->author, 2 = author->paper\n    edge_type = torch.full((edge_index.size(1), ), -1, dtype=torch.long)\n    edge_type[(row < 6) & (col < 6)] = 0\n    edge_type[(row < 6) & (col >= 6)] = 1\n    edge_type[(row >= 6) & (col < 6)] = 2\n    assert edge_type.min() == 0\n\n    conv = RGCNConv(16, 32, num_relations=3, aggr='sum')\n    out1 = conv(x, edge_index, edge_type)\n\n    # Run `to_hetero`:\n    x_dict = {\n        'paper': x[:6],\n        'author': x[6:],\n    }\n    edge_index_dict = {\n        ('paper', '_', 'paper'):\n        edge_index[:, edge_type == 0],\n        ('paper', '_', 'author'):\n        edge_index[:, edge_type == 1] - torch.tensor([[0], [6]]),\n        ('author', '_', 'paper'):\n        edge_index[:, edge_type == 2] - torch.tensor([[6], [0]]),\n    }\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj_t_dict = {\n            key: SparseTensor.from_edge_index(edge_index).t()\n            for key, edge_index in edge_index_dict.items()\n        }\n\n    node_types, edge_types = list(x_dict.keys()), list(edge_index_dict.keys())\n\n    model = to_hetero(RGCN(16, 32), (node_types, edge_types))\n\n    # Set model weights:\n    for i, edge_type in enumerate(edge_types):\n        weight = model.conv['__'.join(edge_type)].lin.weight\n        weight.data = conv.weight[i].data.t()\n    for node_type in node_types:\n        model.lin[node_type].weight.data = conv.root.data.t()\n        model.lin[node_type].bias.data = conv.bias.data\n\n    out2 = model(x_dict, edge_index_dict)\n    out2 = torch.cat([out2['paper'], out2['author']], dim=0)\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        out3 = model(x_dict, adj_t_dict)\n        out3 = torch.cat([out3['paper'], out3['author']], dim=0)\n        assert torch.allclose(out1, out3, atol=1e-6)\n\n\nclass GraphLevelGNN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv = SAGEConv(16, 32)\n        self.pool = MeanAggregation()\n        self.lin = Linear(32, 64)\n\n    def forward(self, x: Tensor, edge_index: Tensor, batch: Tensor) -> Tensor:\n        x = self.conv(x, edge_index)\n        x = self.pool(x, batch)\n        x = self.lin(x)\n        return x\n\n\ndef test_graph_level_to_hetero():\n    x_dict = {\n        'paper': torch.randn(100, 16),\n        'author': torch.randn(100, 16),\n    }\n    edge_index_dict = {\n        ('paper', 'written_by', 'author'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('author', 'writes', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n    }\n    batch_dict = {\n        'paper': torch.zeros(100, dtype=torch.long),\n        'author': torch.zeros(100, dtype=torch.long),\n    }\n\n    metadata = list(x_dict.keys()), list(edge_index_dict.keys())\n\n    model = GraphLevelGNN()\n    model = to_hetero(model, metadata, aggr='mean', debug=False)\n    out = model(x_dict, edge_index_dict, batch_dict)\n    assert out.size() == (1, 64)\n\n\nclass MessagePassingLoops(MessagePassing):\n    def __init__(self):\n        super().__init__()\n        self.add_self_loops = True\n\n    def forward(self, x):\n        return x\n\n\nclass ModelLoops(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv = MessagePassingLoops()\n\n    def forward(self, x):\n        return self.conv(x)\n\n\ndef test_hetero_transformer_self_loop_error():\n    to_hetero(ModelLoops(), metadata=(['a'], [('a', 'to', 'a')]))\n    with pytest.raises(ValueError, match=\"incorrect message passing\"):\n        to_hetero(ModelLoops(), metadata=(['a', 'b'], [('a', 'to', 'b'),\n                                                       ('b', 'to', 'a')]))\n\n\ndef test_to_hetero_validate():\n    model = Net1()\n    metadata = (['my test'], [('my test', 'rel', 'my test')])\n\n    with pytest.warns(UserWarning, match=\"letters, numbers and underscores\"):\n        model = to_hetero(model, metadata, debug=False)\n\n\ndef test_to_hetero_on_static_graphs():\n    x_dict = {\n        'paper': torch.randn(4, 100, 16),\n        'author': torch.randn(4, 100, 16),\n    }\n    edge_index_dict = {\n        ('paper', 'written_by', 'author'): torch.randint(100, (2, 200)),\n        ('author', 'writes', 'paper'): torch.randint(100, (2, 200)),\n    }\n\n    metadata = list(x_dict.keys()), list(edge_index_dict.keys())\n    model = to_hetero(Net4(), metadata, debug=False)\n    out_dict = model(x_dict, edge_index_dict)\n\n    assert len(out_dict) == 2\n    assert out_dict['paper'].size() == (4, 100, 32)\n    assert out_dict['author'].size() == (4, 100, 32)\n\n\n@onlyCUDA\ndef test_to_hetero_lazy_cuda():\n    class Model(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.conv = GATv2Conv(\n                (-1, -1),\n                out_channels=2,\n                add_self_loops=False,\n                edge_dim=-1,\n                heads=1,\n            ).to('cuda')\n\n        def forward(self, x, edge_index, edge_attr):\n            return self.conv(x, edge_index, edge_attr)\n\n    data = FakeHeteroDataset(edge_dim=10)[0].to('cuda')\n    model = to_hetero(Model(), data.metadata())\n    out_dict = model(data.x_dict, data.edge_index_dict, data.edge_attr_dict)\n    assert len(out_dict) == len(data.node_types)\n    for out in out_dict.values():\n        assert out.is_cuda\n        assert out.size(-1) == 2\n"
  },
  {
    "path": "test/nn/test_to_hetero_with_bases_transformer.py",
    "content": "import os.path as osp\nfrom typing import Tuple\n\nimport pytest\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Linear, ReLU, Sequential\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import (\n    GINEConv,\n    MessagePassing,\n    RGCNConv,\n    SAGEConv,\n    to_hetero_with_bases,\n)\nfrom torch_geometric.typing import SparseTensor\n\n\nclass Net1(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lin1 = Linear(16, 32)\n        self.lin2 = Linear(8, 16)\n\n    def forward(self, x: Tensor, edge_attr: Tensor) -> Tuple[Tensor, Tensor]:\n        x = self.lin1(x)\n        edge_attr = self.lin2(edge_attr)\n        return x, edge_attr\n\n\nclass Net2(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lin1 = Linear(16, 16)\n        self.conv1 = SAGEConv(16, 32)\n        self.lin2 = Linear(32, 32)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x = self.lin1(x).relu_()\n        x = self.conv1(x, edge_index).relu_()\n        x = self.lin2(x)\n        return x\n\n\nclass Net3(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.lin1 = Linear(8, 16)\n        self.conv1 = GINEConv(nn=Linear(16, 32))\n\n    def forward(self, x: Tensor, edge_index: Tensor,\n                edge_attr: Tensor) -> Tensor:\n        x = self.conv1(x, edge_index, self.lin1(edge_attr))\n        return x\n\n\nclass Net4(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = SAGEConv(16, 16)\n        self.conv2 = SAGEConv(16, 16)\n        self.lin1 = Linear(3 * 16, 32)\n\n    def forward(self, x0: Tensor, edge_index: Tensor) -> Tensor:\n        x1 = self.conv1(x0, edge_index).relu_()\n        x2 = self.conv2(x1, edge_index).relu_()\n        return self.lin1(torch.cat([x0, x1, x2], dim=-1))\n\n\nclass Net5(torch.nn.Module):\n    def __init__(self, num_layers):\n        super().__init__()\n        self.lins = torch.nn.ModuleList()\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            self.lins.append(Linear(16, 16))\n            self.convs.append(SAGEConv(16, 16))\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        for lin, conv in zip(self.lins, self.convs):\n            x = (conv(x, edge_index) + lin(x))\n        return x\n\n\nclass Net6(torch.nn.Module):\n    def __init__(self, num_layers):\n        super().__init__()\n        self.lins = torch.nn.ModuleDict()\n        self.convs = torch.nn.ModuleDict()\n        for i in range(num_layers):\n            self.lins[str(i)] = Linear(16, 16)\n            self.convs[str(i)] = SAGEConv(16, 16)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        for i in range(len(self.lins)):\n            x = (self.convs[str(i)](x, edge_index) + self.lins[str(i)](x))\n        return x\n\n\nclass Net7(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.mlp1 = Sequential(Linear(16, 16), ReLU(), Linear(16, 16))\n        self.conv1 = SAGEConv(16, 32)\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        x = self.mlp1(x)\n        x = self.conv1(x, edge_index)\n        return x\n\n\ndef test_to_hetero_with_bases():\n    metadata = (['paper', 'author'], [('paper', 'cites', 'paper'),\n                                      ('paper', 'written_by', 'author'),\n                                      ('author', 'writes', 'paper')])\n\n    x_dict = {'paper': torch.randn(100, 16), 'author': torch.randn(100, 8)}\n    edge_index_dict = {\n        ('paper', 'cites', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('paper', 'written_by', 'author'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('author', 'writes', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n    }\n    edge_attr_dict = {\n        ('paper', 'cites', 'paper'): torch.randn(200, 8),\n        ('paper', 'written_by', 'author'): torch.randn(200, 8),\n        ('author', 'writes', 'paper'): torch.randn(200, 8),\n    }\n\n    model = Net1()\n    in_channels = {'x': 16, 'edge_attr': 8}\n    model = to_hetero_with_bases(model, metadata, num_bases=4,\n                                 in_channels=in_channels, debug=False)\n    out = model(x_dict, edge_attr_dict)\n    assert isinstance(out, tuple) and len(out) == 2\n    assert isinstance(out[0], dict) and len(out[0]) == 2\n    assert out[0]['paper'].size() == (100, 32)\n    assert out[0]['author'].size() == (100, 32)\n    assert isinstance(out[1], dict) and len(out[1]) == 3\n    assert out[1][('paper', 'cites', 'paper')].size() == (200, 16)\n    assert out[1][('paper', 'written_by', 'author')].size() == (200, 16)\n    assert out[1][('author', 'writes', 'paper')].size() == (200, 16)\n    assert sum(p.numel() for p in model.parameters()) == 1264\n\n    model = Net2()\n    in_channels = {'x': 16}\n    model = to_hetero_with_bases(model, metadata, num_bases=4,\n                                 in_channels=in_channels, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 32)\n    assert out['author'].size() == (100, 32)\n    assert sum(p.numel() for p in model.parameters()) == 5948\n\n    model = Net3()\n    in_channels = {'x': 16, 'edge_attr': 8}\n    model = to_hetero_with_bases(model, metadata, num_bases=4,\n                                 in_channels=in_channels, debug=False)\n    out = model(x_dict, edge_index_dict, edge_attr_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 32)\n    assert out['author'].size() == (100, 32)\n\n    model = Net4()\n    in_channels = {'x0': 16}\n    model = to_hetero_with_bases(model, metadata, num_bases=4,\n                                 in_channels=in_channels, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 32)\n    assert out['author'].size() == (100, 32)\n\n    model = Net5(num_layers=2)\n    in_channels = {'x': 16}\n    model = to_hetero_with_bases(model, metadata, num_bases=4,\n                                 in_channels=in_channels, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 16)\n    assert out['author'].size() == (100, 16)\n\n    model = Net6(num_layers=2)\n    in_channels = {'x': 16}\n    model = to_hetero_with_bases(model, metadata, num_bases=4,\n                                 in_channels=in_channels, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 16)\n    assert out['author'].size() == (100, 16)\n\n    model = Net7()\n    in_channels = {'x': 16}\n    model = to_hetero_with_bases(model, metadata, num_bases=4,\n                                 in_channels=in_channels, debug=False)\n    out = model(x_dict, edge_index_dict)\n    assert isinstance(out, dict) and len(out) == 2\n    assert out['paper'].size() == (100, 32)\n    assert out['author'].size() == (100, 32)\n\n\nclass GraphConv(MessagePassing):\n    def __init__(self, in_channels, out_channels):\n        super().__init__(aggr='add')\n        self.lin = Linear(in_channels, out_channels, bias=False)\n\n    def reset_parameters(self):\n        self.lin.reset_parameters()\n\n    def forward(self, x, edge_index):\n        if isinstance(x, Tensor):\n            x = (x, x)\n        return self.propagate(edge_index, x=(self.lin(x[0]), x[1]))\n\n\nclass RGCN(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv = GraphConv(in_channels, out_channels)\n        self.lin = Linear(in_channels, out_channels, bias=True)\n\n    def forward(self, x, edge_index):\n        return self.lin(x) + self.conv(x, edge_index)\n\n\ndef test_to_hetero_with_bases_and_rgcn_equal_output():\n    torch.manual_seed(1234)\n\n    # Run `RGCN` with basis decomposition:\n    x = torch.randn(10, 16)  # 6 paper nodes, 4 author nodes\n    adj = (torch.rand(10, 10) > 0.5)\n    adj[6:, 6:] = False\n    edge_index = adj.nonzero(as_tuple=False).t().contiguous()\n    row, col = edge_index\n\n    # # 0 = paper<->paper, 1 = author->paper, 2 = paper->author\n    edge_type = torch.full((edge_index.size(1), ), -1, dtype=torch.long)\n    edge_type[(row < 6) & (col < 6)] = 0\n    edge_type[(row < 6) & (col >= 6)] = 1\n    edge_type[(row >= 6) & (col < 6)] = 2\n    assert edge_type.min() == 0\n\n    num_bases = 4\n    conv = RGCNConv(16, 32, num_relations=3, num_bases=num_bases, aggr='add')\n    out1 = conv(x, edge_index, edge_type)\n\n    # Run `to_hetero_with_bases`:\n    x_dict = {\n        'paper': x[:6],\n        'author': x[6:],\n    }\n    edge_index_dict = {\n        ('paper', '_', 'paper'):\n        edge_index[:, edge_type == 0],\n        ('paper', '_', 'author'):\n        edge_index[:, edge_type == 1] - torch.tensor([[0], [6]]),\n        ('author', '_', 'paper'):\n        edge_index[:, edge_type == 2] - torch.tensor([[6], [0]]),\n    }\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj_t_dict = {\n            key: SparseTensor.from_edge_index(edge_index).t()\n            for key, edge_index in edge_index_dict.items()\n        }\n\n    metadata = (list(x_dict.keys()), list(edge_index_dict.keys()))\n    model = to_hetero_with_bases(RGCN(16, 32), metadata, num_bases=num_bases,\n                                 debug=False)\n\n    # Set model weights:\n    for i in range(num_bases):\n        model.conv.convs[i].lin.weight.data = conv.weight[i].data.t()\n        model.conv.convs[i].edge_type_weight.data = conv.comp[:, i].data.t()\n\n    model.lin.weight.data = conv.root.data.t()\n    model.lin.bias.data = conv.bias.data\n\n    out2 = model(x_dict, edge_index_dict)\n    out2 = torch.cat([out2['paper'], out2['author']], dim=0)\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        out3 = model(x_dict, adj_t_dict)\n        out3 = torch.cat([out3['paper'], out3['author']], dim=0)\n        assert torch.allclose(out1, out3, atol=1e-6)\n\n\ndef test_to_hetero_with_bases_validate():\n    model = Net1()\n    metadata = (['my test'], [('my test', 'rel', 'my test')])\n\n    with pytest.warns(UserWarning, match=\"letters, numbers and underscores\"):\n        model = to_hetero_with_bases(model, metadata, num_bases=4, debug=False)\n\n\ndef test_to_hetero_with_bases_on_static_graphs():\n    x_dict = {\n        'paper': torch.randn(4, 100, 16),\n        'author': torch.randn(4, 100, 16),\n    }\n    edge_index_dict = {\n        ('paper', 'written_by', 'author'): torch.randint(100, (2, 200)),\n        ('author', 'writes', 'paper'): torch.randint(100, (2, 200)),\n    }\n\n    metadata = list(x_dict.keys()), list(edge_index_dict.keys())\n    model = to_hetero_with_bases(Net4(), metadata, num_bases=4,\n                                 in_channels={'x0': 16}, debug=False)\n\n    out_dict = model(x_dict, edge_index_dict)\n\n    assert len(out_dict) == 2\n    assert out_dict['paper'].size() == (4, 100, 32)\n\n\ndef test_to_hetero_with_bases_save(tmp_path):\n    x_dict = {'paper': torch.randn(100, 16), 'author': torch.randn(100, 8)}\n    edge_index_dict = {\n        ('paper', 'cites', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('paper', 'written_by', 'author'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n        ('author', 'writes', 'paper'):\n        torch.randint(100, (2, 200), dtype=torch.long),\n    }\n\n    model = to_hetero_with_bases(\n        Net2(),\n        (list(x_dict.keys()), list(edge_index_dict.keys())),\n        num_bases=4,\n        in_channels={'x': 16},\n        debug=False,\n    )\n    model(x_dict, edge_index_dict)\n\n    path = osp.join(tmp_path, 'model.pt')\n    torch.save(model, path)\n"
  },
  {
    "path": "test/nn/unpool/test_knn_interpolate.py",
    "content": "import torch\n\nfrom torch_geometric.nn import knn_interpolate\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('torch_cluster')\ndef test_knn_interpolate():\n    x = torch.tensor([[1.0], [10.0], [100.0], [-1.0], [-10.0], [-100.0]])\n    pos_x = torch.tensor([\n        [-1.0, 0.0],\n        [0.0, 0.0],\n        [1.0, 0.0],\n        [-2.0, 0.0],\n        [0.0, 0.0],\n        [2.0, 0.0],\n    ])\n    pos_y = torch.tensor([\n        [-1.0, -1.0],\n        [1.0, 1.0],\n        [-2.0, -2.0],\n        [2.0, 2.0],\n    ])\n    batch_x = torch.tensor([0, 0, 0, 1, 1, 1])\n    batch_y = torch.tensor([0, 0, 1, 1])\n\n    y = knn_interpolate(x, pos_x, pos_y, batch_x, batch_y, k=2)\n    assert y.tolist() == [[4.0], [70.0], [-4.0], [-70.0]]\n"
  },
  {
    "path": "test/profile/test_benchmark.py",
    "content": "import torch\n\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import withPackage\n\n\n@withPackage('tabulate')\ndef test_benchmark(capfd):\n    def add(x, y):\n        return x + y\n\n    benchmark(\n        funcs=[add],\n        args=(torch.randn(10), torch.randn(10)),\n        num_steps=1,\n        num_warmups=1,\n        backward=True,\n    )\n\n    out, _ = capfd.readouterr()\n    assert '| Name   | Forward   | Backward   | Total   |' in out\n    assert '| add    |' in out\n"
  },
  {
    "path": "test/profile/test_nvtx.py",
    "content": "from unittest.mock import call, patch\n\nfrom torch_geometric.profile import nvtxit\n\n\ndef _setup_mock(torch_cuda_mock):\n    torch_cuda_mock.is_available.return_value = True\n    torch_cuda_mock.cudart.return_value.cudaProfilerStart.return_value = None\n    torch_cuda_mock.cudart.return_value.cudaProfilerStop.return_value = None\n    return torch_cuda_mock\n\n\n@patch('torch_geometric.profile.nvtx.torch.cuda')\ndef test_nvtxit_base(torch_cuda_mock):\n    torch_cuda_mock = _setup_mock(torch_cuda_mock)\n\n    # dummy func calls a calls b\n\n    @nvtxit()\n    def call_b():\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1  # noqa: E501\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n        return 42\n\n    @nvtxit()\n    def call_a():\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1  # noqa: E501\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n        return call_b()\n\n    def dummy_func():\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0  # noqa: E501\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n        return call_a()\n\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0  # noqa: E501\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n    dummy_func()\n\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1  # noqa: E501\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1  # noqa: E501\n    assert torch_cuda_mock.nvtx.range_push.call_args_list == [\n        call('call_a_0'), call('call_b_0')\n    ]\n\n\n@patch('torch_geometric.profile.nvtx.torch.cuda')\ndef test_nvtxit_rename(torch_cuda_mock):\n    torch_cuda_mock = _setup_mock(torch_cuda_mock)\n\n    # dummy func calls a calls b\n\n    @nvtxit()\n    def call_b():\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1  # noqa: E501\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n        return 42\n\n    @nvtxit('a_nvtx')\n    def call_a():\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1  # noqa: E501\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n        return call_b()\n\n    def dummy_func():\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0  # noqa: E501\n        assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n        return call_a()\n\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0  # noqa: E501\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n    dummy_func()\n\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1  # noqa: E501\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1  # noqa: E501\n    assert torch_cuda_mock.nvtx.range_push.call_args_list == [\n        call('a_nvtx_0'), call('call_b_0')\n    ]\n\n\n@patch('torch_geometric.profile.nvtx.torch.cuda')\ndef test_nvtxit_iters(torch_cuda_mock):\n    torch_cuda_mock = _setup_mock(torch_cuda_mock)\n\n    # dummy func calls a calls b\n\n    @nvtxit(n_iters=1)\n    def call_b():\n        return 42\n\n    @nvtxit()\n    def call_a():\n        return call_b()\n\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0  # noqa: E501\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n\n    call_b()\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1  # noqa: E501\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1  # noqa: E501\n    call_a()\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 2  # noqa: E501\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 2  # noqa: E501\n\n    assert torch_cuda_mock.nvtx.range_push.call_args_list == [\n        call('call_b_0'), call('call_a_0')\n    ]\n\n\n@patch('torch_geometric.profile.nvtx.torch.cuda')\ndef test_nvtxit_warmups(torch_cuda_mock):\n    torch_cuda_mock = _setup_mock(torch_cuda_mock)\n\n    # dummy func calls a calls b\n\n    @nvtxit(n_warmups=1)\n    def call_b():\n        return 42\n\n    @nvtxit()\n    def call_a():\n        return call_b()\n\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0  # noqa: E501\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n\n    call_b()\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 0  # noqa: E501\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 0  # noqa: E501\n    call_a()\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStart.call_count == 1  # noqa: E501\n    assert torch_cuda_mock.cudart.return_value.cudaProfilerStop.call_count == 1  # noqa: E501\n\n    assert torch_cuda_mock.nvtx.range_push.call_args_list == [\n        call('call_a_0'), call('call_b_1')\n    ]\n"
  },
  {
    "path": "test/profile/test_profile.py",
    "content": "import os\nimport os.path as osp\nimport warnings\n\nimport pytest\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.nn import GraphSAGE\nfrom torch_geometric.profile import (\n    get_stats_summary,\n    profileit,\n    rename_profile_file,\n    timeit,\n)\nfrom torch_geometric.profile.profile import torch_profile, xpu_profile\nfrom torch_geometric.testing import (\n    onlyCUDA,\n    onlyLinux,\n    onlyOnline,\n    onlyXPU,\n    withDevice,\n    withPackage,\n)\n\n\n@withDevice\n@onlyLinux\ndef test_timeit(device):\n    x = torch.randn(100, 16, device=device)\n    lin = torch.nn.Linear(16, 32).to(device)\n\n    with timeit(log=False) as t:\n        assert not hasattr(t, 'duration')\n\n        with torch.no_grad():\n            lin(x)\n        t.reset()\n        assert t.duration > 0\n\n        del t.duration\n        assert not hasattr(t, 'duration')\n    assert t.duration > 0\n\n\n@onlyCUDA\n@onlyOnline\n@withPackage('pytorch_memlab')\ndef test_profileit_cuda(get_dataset):\n    warnings.filterwarnings('ignore', '.*arguments of DataFrame.drop.*')\n\n    dataset = get_dataset(name='karate')\n    data = dataset[0].cuda()\n    model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3,\n                      out_channels=dataset.num_classes).cuda()\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n    @profileit('cuda')\n    def train(model, x, edge_index, y):\n        model.train()\n        optimizer.zero_grad()\n        out = model(x, edge_index)\n        loss = F.cross_entropy(out, y)\n        loss.backward()\n        return float(loss.detach())\n\n    stats_list = []\n    for epoch in range(5):\n        _, stats = train(model, data.x, data.edge_index, data.y)\n        assert stats.time > 0\n        assert stats.max_allocated_gpu > 0\n        assert stats.max_reserved_gpu > 0\n        assert stats.max_active_gpu > 0\n        assert stats.nvidia_smi_free_cuda >= 0\n        assert stats.nvidia_smi_used_cuda >= 0\n\n        if epoch >= 2:  # Warm-up\n            stats_list.append(stats)\n\n    stats_summary = get_stats_summary(stats_list)\n    assert stats_summary.time_mean > 0\n    assert stats_summary.time_std > 0\n    assert stats_summary.max_allocated_gpu > 0\n    assert stats_summary.max_reserved_gpu > 0\n    assert stats_summary.max_active_gpu > 0\n    assert stats_summary.min_nvidia_smi_free_cuda >= 0\n    assert stats_summary.max_nvidia_smi_used_cuda >= 0\n\n\n@onlyXPU\ndef test_profileit_xpu(get_dataset):\n    warnings.filterwarnings('ignore', '.*arguments of DataFrame.drop.*')\n\n    dataset = get_dataset(name='karate')\n    device = torch.device('xpu')\n    data = dataset[0].to(device)\n    model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3,\n                      out_channels=dataset.num_classes).to(device)\n    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n\n    @profileit('xpu')\n    def train(model, x, edge_index, y):\n        model.train()\n        optimizer.zero_grad()\n        out = model(x, edge_index)\n        loss = F.cross_entropy(out, y)\n        loss.backward()\n        return float(loss.detach())\n\n    stats_list = []\n    for epoch in range(5):\n        _, stats = train(model, data.x, data.edge_index, data.y)\n        assert stats.time > 0\n        assert stats.max_allocated_gpu > 0\n        assert stats.max_reserved_gpu > 0\n        assert stats.max_active_gpu > 0\n        assert not hasattr(stats, 'nvidia_smi_free_cuda')\n        assert not hasattr(stats, 'nvidia_smi_used_cuda')\n\n        if epoch >= 2:  # Warm-up\n            stats_list.append(stats)\n\n    stats_summary = get_stats_summary(stats_list)\n    assert stats_summary.time_mean > 0\n    assert stats_summary.time_std > 0\n    assert stats_summary.max_allocated_gpu > 0\n    assert stats_summary.max_reserved_gpu > 0\n    assert stats_summary.max_active_gpu > 0\n    assert not hasattr(stats_summary, 'min_nvidia_smi_free_cuda')\n    assert not hasattr(stats_summary, 'max_nvidia_smi_used_cuda')\n\n\n@withDevice\n@onlyOnline\ndef test_torch_profile(capfd, get_dataset, device):\n    dataset = get_dataset(name='karate')\n    data = dataset[0].to(device)\n    model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3,\n                      out_channels=dataset.num_classes).to(device)\n\n    with torch_profile():\n        model(data.x, data.edge_index)\n\n    out, _ = capfd.readouterr()\n    assert 'Self CPU time total' in out\n    if data.x.is_cuda:\n        assert 'Self CUDA time total' in out\n\n    rename_profile_file('test_profile')\n    assert osp.exists('profile-test_profile.json')\n    os.remove('profile-test_profile.json')\n\n\n@onlyXPU\n@onlyOnline\n@pytest.mark.parametrize('export_chrome_trace', [False, True])\ndef test_xpu_profile(capfd, get_dataset, export_chrome_trace):\n    dataset = get_dataset(name='karate')\n    device = torch.device('xpu')\n    data = dataset[0].to(device)\n    model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3,\n                      out_channels=dataset.num_classes).to(device)\n\n    with xpu_profile(export_chrome_trace):\n        model(data.x, data.edge_index)\n\n    out, _ = capfd.readouterr()\n    assert 'Self CPU' in out\n    if data.x.is_xpu:\n        assert 'Self XPU' in out\n\n    f_name = 'timeline.json'\n    f_exists = osp.exists(f_name)\n    if not export_chrome_trace:\n        assert not f_exists\n    else:\n        assert f_exists\n        os.remove(f_name)\n"
  },
  {
    "path": "test/profile/test_profile_utils.py",
    "content": "import torch\nfrom torch.nn import Linear\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.profile import (\n    count_parameters,\n    get_cpu_memory_from_gc,\n    get_data_size,\n    get_gpu_memory_from_gc,\n    get_gpu_memory_from_ipex,\n    get_gpu_memory_from_nvidia_smi,\n    get_model_size,\n)\nfrom torch_geometric.profile.utils import (\n    byte_to_megabyte,\n    medibyte_to_megabyte,\n)\nfrom torch_geometric.testing import onlyCUDA, onlyXPU, withPackage\nfrom torch_geometric.typing import SparseTensor\n\n\ndef test_count_parameters():\n    assert count_parameters(Linear(32, 128)) == 32 * 128 + 128\n\n\ndef test_get_model_size():\n    model_size = get_model_size(Linear(32, 128, bias=False))\n    assert model_size >= 32 * 128 * 4 and model_size < 32 * 128 * 4 + 2000\n\n\ndef test_get_data_size():\n    x = torch.randn(10, 128)\n    data = Data(x=x, y=x)\n\n    data_size = get_data_size(data)\n    assert data_size == 10 * 128 * 4\n\n\n@withPackage('torch_sparse')\ndef test_get_data_size_with_sparse_tensor():\n    x = torch.randn(10, 128)\n    row, col = torch.randint(0, 10, (2, 100), dtype=torch.long)\n    adj_t = SparseTensor(row=row, col=col, value=None, sparse_sizes=(10, 10))\n    data = Data(x=x, y=x, adj_t=adj_t)\n\n    data_size = get_data_size(data)\n    assert data_size == 10 * 128 * 4 + 11 * 8 + 100 * 8\n\n\ndef test_get_cpu_memory_from_gc():\n    old_mem = get_cpu_memory_from_gc()\n    _ = torch.randn(10, 128)\n    new_mem = get_cpu_memory_from_gc()\n    assert new_mem - old_mem == 10 * 128 * 4\n\n\n@onlyCUDA\ndef test_get_gpu_memory_from_gc():\n    old_mem = get_gpu_memory_from_gc()\n    _ = torch.randn(10, 128, device='cuda')\n    new_mem = get_gpu_memory_from_gc()\n    assert new_mem - old_mem == 10 * 128 * 4\n\n\n@onlyCUDA\ndef test_get_gpu_memory_from_nvidia_smi():\n    free_mem, used_mem = get_gpu_memory_from_nvidia_smi(device=0, digits=2)\n    assert free_mem >= 0\n    assert used_mem >= 0\n\n\n@onlyXPU\ndef test_get_gpu_memory_from_ipex():\n    max_allocated, max_reserved, max_active = get_gpu_memory_from_ipex()\n    assert max_allocated >= 0\n    assert max_reserved >= 0\n    assert max_active >= 0\n\n\ndef test_bytes_function():\n    assert byte_to_megabyte(1024 * 1024) == 1.00\n    assert medibyte_to_megabyte(1 / 1.0485) == 1.00\n"
  },
  {
    "path": "test/profile/test_profiler.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import GraphSAGE\nfrom torch_geometric.profile.profiler import Profiler\nfrom torch_geometric.testing import withDevice\n\n\n@withDevice\ndef test_profiler(capfd, get_dataset, device):\n    x = torch.randn(10, 16, device=device)\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9],\n        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8],\n    ], device=device)\n\n    model = GraphSAGE(16, hidden_channels=32, num_layers=2).to(device)\n\n    with Profiler(model, profile_memory=True, use_cuda=x.is_cuda) as prof:\n        model(x, edge_index)\n\n    _, err = capfd.readouterr()\n    if not torch_geometric.typing.WITH_PT24:\n        assert 'Completed Stage' in err\n\n    _, heading_list, raw_results, layer_names, layer_stats = prof.get_trace()\n    assert 'Self CPU total' in heading_list\n    assert 'aten::relu' in raw_results\n    assert '-act--aten::relu' in layer_names\n"
  },
  {
    "path": "test/sampler/test_sampler_base.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.sampler.base import (\n    HeteroSamplerOutput,\n    NumNeighbors,\n    SamplerOutput,\n)\nfrom torch_geometric.sampler.utils import global_to_local_node_idx\nfrom torch_geometric.testing import get_random_edge_index\nfrom torch_geometric.utils import is_undirected\n\n\ndef test_homogeneous_num_neighbors():\n    with pytest.raises(ValueError, match=\"'default' must be set to 'None'\"):\n        num_neighbors = NumNeighbors([25, 10], default=[-1, -1])\n\n    num_neighbors = NumNeighbors([25, 10])\n    assert str(num_neighbors) == 'NumNeighbors(values=[25, 10], default=None)'\n\n    assert num_neighbors.get_values() == [25, 10]\n    assert num_neighbors.__dict__['_values'] == [25, 10]\n    assert num_neighbors.get_values() == [25, 10]  # Test caching.\n\n    assert num_neighbors.get_mapped_values() == [25, 10]\n    assert num_neighbors.__dict__['_mapped_values'] == [25, 10]\n    assert num_neighbors.get_mapped_values() == [25, 10]  # Test caching.\n\n    assert num_neighbors.num_hops == 2\n    assert num_neighbors.__dict__['_num_hops'] == 2\n    assert num_neighbors.num_hops == 2  # Test caching.\n\n\n'''\nMerge and collate tests use the following graph:\n\n    #############                    ###########\n    # Alice (0) # -> \"works with\" -> # Bob (1) #\n    #############                    ###########\n         |\n         v\n      \"leads\"\n         |\n         v\n    #############                    ############\n    # Carol (2) # -> \"works with\" -> # Dave (3) #\n    #############                    ############\n\n'''\n\n\ndef _init_merge_sampler_outputs(hetero=False, disjoint=False):\n    if not hetero:\n        output1 = SamplerOutput(\n            node=torch.tensor([0, 1, 2]),\n            row=torch.tensor([0, 0]),\n            col=torch.tensor([1, 2]),\n            edge=torch.tensor([0, 1]),\n            batch=torch.tensor([0, 0, 0]) if disjoint else None,\n            num_sampled_nodes=list([1, 2]),\n            num_sampled_edges=list([2]),\n            orig_row=None,\n            orig_col=None,\n            metadata=(None, None),\n        )\n        output2 = SamplerOutput(\n            node=torch.tensor([0, 2, 3]),\n            row=torch.tensor([0, 1]),\n            col=torch.tensor([1, 2]),\n            edge=torch.tensor([1, 2]),\n            batch=torch.tensor([0, 0, 0]) if disjoint else None,\n            num_sampled_nodes=list([1, 1, 1]),\n            num_sampled_edges=list([1, 1]),\n            orig_row=None,\n            orig_col=None,\n            metadata=(None, None),\n        )\n\n        return output1, output2\n    else:\n        # TODO(zaristei)\n        raise NotImplementedError(\"Heterogeneous merge not implemented\")\n\n\n@pytest.mark.parametrize(\"disjoint\", [True, False])\n@pytest.mark.parametrize(\"bidirectional\", [True, False])\ndef test_homogeneous_merge(disjoint, bidirectional):\n    \"\"\"Merge an output representing 1<-0->2 with one representing 0->2->3.\"\"\"\n    output1, output2 = _init_merge_sampler_outputs(disjoint=disjoint)\n    if bidirectional:\n        output1 = output1.to_bidirectional(keep_orig_edges=True)\n        output2 = output2.to_bidirectional(keep_orig_edges=True)\n\n    expected_output = SamplerOutput(\n        node=torch.tensor([0, 1, 2, 3]),\n        row=torch.tensor([0, 0, 2]),\n        col=torch.tensor([1, 2, 3]),\n        edge=torch.tensor([0, 1, 2]),\n        batch=torch.tensor([0, 0, 0, 0]) if disjoint else None,\n        num_sampled_nodes=[1, 2, 0, 0, 1],\n        num_sampled_edges=[2, 0, 1],\n        orig_row=None,\n        orig_col=None,\n        metadata=[(None, None), (None, None)],\n    )\n    if bidirectional:\n        expected_output = expected_output.to_bidirectional(\n            keep_orig_edges=True)\n    merged_output = output1.merge_with(output2)\n\n    assert str(merged_output) == str(expected_output)\n\n\n@pytest.mark.parametrize(\"disjoint\", [True, False])\n@pytest.mark.parametrize(\"bidirectional\", [True, False])\ndef test_homogeneous_merge_no_replace(disjoint, bidirectional):\n    \"\"\"Merge an output representing 1<-0->2 with one representing 0->2->3.\n    replace=True makes it so that merged output is a simple concatenation\n    instead of removing already sampled nodes/edges.\n    \"\"\"\n    output1, output2 = _init_merge_sampler_outputs(disjoint=disjoint)\n    if bidirectional:\n        output1 = output1.to_bidirectional(keep_orig_edges=True)\n        output2 = output2.to_bidirectional(keep_orig_edges=True)\n\n    expected_output = SamplerOutput(\n        node=torch.tensor([0, 1, 2, 0, 2, 3]),\n        row=torch.tensor([0, 0, 3, 4]),\n        col=torch.tensor([1, 2, 4, 5]),\n        edge=torch.tensor([0, 1, 1, 2]),\n        batch=torch.tensor([0, 0, 0, 3, 3, 3]) if disjoint else None,\n        num_sampled_nodes=[1, 2, 1, 1, 1],\n        num_sampled_edges=[2, 1, 1],\n        orig_row=None,\n        orig_col=None,\n        metadata=[(None, None), (None, None)],\n    )\n    if bidirectional:\n        expected_output = expected_output.to_bidirectional(\n            keep_orig_edges=True)\n    merged_output = output1.merge_with(output2, replace=False)\n\n    assert str(merged_output) == str(expected_output)\n\n\ndef _init_collate_sampler_outputs(disjoint=False):\n    output1, output2 = _init_merge_sampler_outputs(disjoint=disjoint)\n    # new edge not present in graph above\n    output3 = SamplerOutput(\n        node=torch.tensor([3, 4]),\n        row=torch.tensor([0]),\n        col=torch.tensor([1]),\n        edge=torch.tensor([3]),\n        batch=torch.tensor([0, 0]) if disjoint else None,\n        num_sampled_nodes=list([1, 1]),\n        num_sampled_edges=list([1]),\n        orig_row=None,\n        orig_col=None,\n        metadata=(None, None),\n    )\n    return [output1, output2, output3]\n\n\n@pytest.mark.parametrize(\"replace\", [True, False])\n@pytest.mark.parametrize(\"disjoint\", [True, False])\ndef test_homogeneous_collate(disjoint, replace):\n    output1, output2, output3 = _init_collate_sampler_outputs(disjoint)\n    collated = SamplerOutput.collate([output1, output2, output3],\n                                     replace=replace)\n    assert str(collated) == str(\n        (output1.merge_with(output2, replace=replace)).merge_with(\n            output3, replace=replace))\n\n\ndef test_homogeneous_collate_empty():\n    with pytest.raises(ValueError,\n                       match=\"Cannot collate an empty list of SamplerOutputs\"):\n        SamplerOutput.collate([])\n\n\ndef test_homogeneous_collate_single():\n    output, _ = _init_merge_sampler_outputs()\n    collated = SamplerOutput.collate([output])\n    assert str(collated) == str(output)\n\n\ndef test_homogeneous_collate_missing_fields():\n    output1, output2, output3 = _init_collate_sampler_outputs()\n    output3.edge = None\n    with pytest.raises(\n            ValueError,\n            match=\"Output 3 has a different field than the first output\"):\n        SamplerOutput.collate([output1, output2, output3])\n\n\ndef test_heterogeneous_num_neighbors_list():\n    num_neighbors = NumNeighbors([25, 10])\n\n    values = num_neighbors.get_values([('A', 'B'), ('B', 'A')])\n    assert values == {('A', 'B'): [25, 10], ('B', 'A'): [25, 10]}\n\n    values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')])\n    assert values == {'A__to__B': [25, 10], 'B__to__A': [25, 10]}\n\n    assert num_neighbors.num_hops == 2\n\n\ndef test_heterogeneous_num_neighbors_dict_and_default():\n    num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1])\n    with pytest.raises(ValueError, match=\"hops must be the same across all\"):\n        values = num_neighbors.get_values([('A', 'B'), ('B', 'A')])\n\n    num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1, -1])\n\n    with pytest.raises(ValueError, match=\"Not all edge types\"):\n        num_neighbors.get_values([('A', 'C'), ('B', 'A')])\n\n    values = num_neighbors.get_values([('A', 'B'), ('B', 'A')])\n    assert values == {('A', 'B'): [25, 10], ('B', 'A'): [-1, -1]}\n\n    values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')])\n    assert values == {'A__to__B': [25, 10], 'B__to__A': [-1, -1]}\n\n    assert num_neighbors.num_hops == 2\n\n\ndef test_heterogeneous_num_neighbors_empty_dict():\n    num_neighbors = NumNeighbors({}, default=[25, 10])\n\n    values = num_neighbors.get_values([('A', 'B'), ('B', 'A')])\n    assert values == {('A', 'B'): [25, 10], ('B', 'A'): [25, 10]}\n\n    values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')])\n    assert values == {'A__to__B': [25, 10], 'B__to__A': [25, 10]}\n\n    assert num_neighbors.num_hops == 2\n\n\ndef test_homogeneous_to_bidirectional():\n    edge_index = get_random_edge_index(10, 10, num_edges=20)\n\n    obj = SamplerOutput(\n        node=torch.arange(10),\n        row=edge_index[0],\n        col=edge_index[0],\n        edge=torch.arange(edge_index.size(1)),\n    ).to_bidirectional()\n\n    assert is_undirected(torch.stack([obj.row, obj.col], dim=0))\n\n\ndef test_heterogeneous_to_bidirectional():\n    edge_index1 = get_random_edge_index(10, 5, num_edges=20)\n    edge_index2 = get_random_edge_index(5, 10, num_edges=20)\n    edge_index3 = get_random_edge_index(10, 10, num_edges=20)\n\n    obj = HeteroSamplerOutput(\n        node={\n            'v1': torch.arange(10),\n            'v2': torch.arange(5)\n        },\n        row={\n            ('v1', 'to', 'v2'): edge_index1[0],\n            ('v2', 'rev_to', 'v1'): edge_index2[0],\n            ('v1', 'to', 'v1'): edge_index3[0],\n        },\n        col={\n            ('v1', 'to', 'v2'): edge_index1[1],\n            ('v2', 'rev_to', 'v1'): edge_index2[1],\n            ('v1', 'to', 'v1'): edge_index3[1],\n        },\n        edge={},\n    ).to_bidirectional()\n\n    assert torch.equal(\n        obj.row['v1', 'to', 'v2'].sort().values,\n        obj.col['v2', 'rev_to', 'v1'].sort().values,\n    )\n    assert torch.equal(\n        obj.col['v1', 'to', 'v2'].sort().values,\n        obj.row['v2', 'rev_to', 'v1'].sort().values,\n    )\n    assert is_undirected(\n        torch.stack([obj.row['v1', 'to', 'v1'], obj.col['v1', 'to', 'v1']], 0))\n\n\ndef test_homogeneous_sampler_output_global_fields():\n    output = SamplerOutput(\n        node=torch.tensor([0, 2, 3]),\n        row=torch.tensor([0, 1]),\n        col=torch.tensor([1, 2]),\n        edge=torch.tensor([1, 2]),\n        batch=torch.tensor([0, 0, 0]),\n        num_sampled_nodes=[1, 1, 1],\n        num_sampled_edges=[1, 1],\n        orig_row=None,\n        orig_col=None,\n        metadata=(None, None),\n    )\n\n    local_values = []\n    global_values = []\n\n    global_row, global_col = output.global_row, output.global_col\n    assert torch.equal(global_row, torch.tensor([0, 2]))\n    assert torch.equal(global_col, torch.tensor([2, 3]))\n    local_values.append(output.row)\n    local_values.append(output.col)\n    global_values.append(global_row)\n    global_values.append(global_col)\n\n    seed_node = output.seed_node\n    assert torch.equal(seed_node, torch.tensor([0, 0, 0]))\n    local_values.append(output.batch)\n    global_values.append(seed_node)\n\n    output_bidirectional = output.to_bidirectional(keep_orig_edges=True)\n    global_bidir_row, global_bidir_col = \\\n        output_bidirectional.global_row, output_bidirectional.global_col\n    assert torch.equal(global_bidir_row, torch.tensor([2, 0, 3, 2]))\n    assert torch.equal(global_bidir_col, torch.tensor([0, 2, 2, 3]))\n    local_values.append(output_bidirectional.row)\n    local_values.append(output_bidirectional.col)\n    global_values.append(global_bidir_row)\n    global_values.append(global_bidir_col)\n\n    assert torch.equal(output.global_row, output_bidirectional.global_orig_row)\n    assert torch.equal(output.global_col, output_bidirectional.global_orig_col)\n\n    # Make sure reverse mapping is correct\n    for local_value, global_value in zip(local_values, global_values):\n        assert torch.equal(global_to_local_node_idx(output.node, global_value),\n                           local_value)\n\n\ndef test_heterogeneous_sampler_output_global_fields():\n    def _tensor_dict_equal(dict1, dict2):\n        is_equal = True\n        is_equal &= dict1.keys() == dict2.keys()\n        for key in dict1.keys():\n            is_equal &= torch.equal(dict1[key], dict2[key])\n        return is_equal\n\n    output = HeteroSamplerOutput(\n        node={\"person\": torch.tensor([0, 2, 3])},\n        row={\n            (\"person\", \"works_with\", \"person\"): torch.tensor([1]),\n            (\"person\", \"leads\", \"person\"): torch.tensor([0])\n        },\n        col={\n            (\"person\", \"works_with\", \"person\"): torch.tensor([2]),\n            (\"person\", \"leads\", \"person\"): torch.tensor([1])\n        },\n        edge={\n            (\"person\", \"works_with\", \"person\"): torch.tensor([1]),\n            (\"person\", \"leads\", \"person\"): torch.tensor([0])\n        },\n        batch={\"person\": torch.tensor([0, 0, 0])},\n        num_sampled_nodes={\"person\": torch.tensor([1, 1, 1])},\n        num_sampled_edges={\n            (\"person\", \"works_with\", \"person\"): torch.tensor([1]),\n            (\"person\", \"leads\", \"person\"): torch.tensor([1])\n        },\n        orig_row=None,\n        orig_col=None,\n        metadata=(None, None),\n    )\n\n    global_row, global_col = output.global_row, output.global_col\n    assert _tensor_dict_equal(\n        global_row, {\n            (\"person\", \"works_with\", \"person\"): torch.tensor([2]),\n            (\"person\", \"leads\", \"person\"): torch.tensor([0])\n        })\n    assert _tensor_dict_equal(\n        global_col, {\n            (\"person\", \"works_with\", \"person\"): torch.tensor([3]),\n            (\"person\", \"leads\", \"person\"): torch.tensor([2])\n        })\n\n    local_row_dict = {\n        k: global_to_local_node_idx(output.node[k[0]], v)\n        for k, v in global_row.items()\n    }\n    assert _tensor_dict_equal(local_row_dict, output.row)\n\n    local_col_dict = {\n        k: global_to_local_node_idx(output.node[k[2]], v)\n        for k, v in global_col.items()\n    }\n    assert _tensor_dict_equal(local_col_dict, output.col)\n\n    seed_node = output.seed_node\n    assert _tensor_dict_equal(seed_node, {\"person\": torch.tensor([0, 0, 0])})\n\n    local_batch_dict = {\n        k: global_to_local_node_idx(output.node[k], v)\n        for k, v in seed_node.items()\n    }\n    assert _tensor_dict_equal(local_batch_dict, output.batch)\n\n    output_bidirectional = output.to_bidirectional(keep_orig_edges=True)\n    global_bidir_row, global_bidir_col = \\\n        output_bidirectional.global_row, output_bidirectional.global_col\n    assert _tensor_dict_equal(\n        global_bidir_row, {\n            (\"person\", \"works_with\", \"person\"): torch.tensor([3, 2]),\n            (\"person\", \"leads\", \"person\"): torch.tensor([2, 0])\n        })\n    assert _tensor_dict_equal(\n        global_bidir_col, {\n            (\"person\", \"works_with\", \"person\"): torch.tensor([2, 3]),\n            (\"person\", \"leads\", \"person\"): torch.tensor([0, 2])\n        })\n\n    local_bidir_row_dict = {\n        k: global_to_local_node_idx(output_bidirectional.node[k[0]], v)\n        for k, v in global_bidir_row.items()\n    }\n    assert _tensor_dict_equal(local_bidir_row_dict, output_bidirectional.row)\n\n    local_bidir_col_dict = {\n        k: global_to_local_node_idx(output_bidirectional.node[k[2]], v)\n        for k, v in global_bidir_col.items()\n    }\n    assert _tensor_dict_equal(local_bidir_col_dict, output_bidirectional.col)\n\n    assert _tensor_dict_equal(output.global_row,\n                              output_bidirectional.global_orig_row)\n    assert _tensor_dict_equal(output.global_col,\n                              output_bidirectional.global_orig_col)\n"
  },
  {
    "path": "test/sampler/test_sampler_neighbor_sampler.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.sampler.base import NodeSamplerInput, SamplerOutput\nfrom torch_geometric.sampler.neighbor_sampler import (\n    BidirectionalNeighborSampler,\n    NeighborSampler,\n)\nfrom torch_geometric.testing import (\n    MyFeatureStore,\n    MyGraphStore,\n    onlyNeighborSampler,\n)\n\n\ndef _init_sample_graph(hetero=False):\n    \"\"\"Initializes the following graph.\n\n    #############                    ###########\n    # Alice (0) # -> \"works with\" -> # Bob (1) #\n    #############                    ###########\n         |\n         v\n      \"leads\"\n         |\n         v\n    #############                    ############\n    # Carol (2) # -> \"works with\" -> # Dave (3) #\n    #############                    ############\n    \"\"\"\n    sample_attr = None\n    sample_edge_attr = None\n    sample_edge_indices = None\n    if not hetero:\n        sample_attr = torch.tensor([[0], [1], [2], [3]])\n        sample_edge_attr = torch.tensor([[1], [2], [3]])\n        sample_edge_indices = torch.tensor([[0, 0, 2], [1, 2, 3]])\n    else:\n        sample_attr = dict({\n            \"person\": dict({\"x\": torch.tensor([[1], [2], [3]])}),\n            \"manager\": dict({\"x\": torch.tensor([[0]])})\n        })\n        sample_edge_attr = dict({\n            ('person', 'works_with', 'person'):\n            dict({\"edge_attr\": torch.tensor([[3]])}),\n            ('manager', 'leads', 'person'):\n            dict({\"edge_attr\": torch.tensor([[1]])}),\n            ('manager', 'works_with', 'person'):\n            dict({\"edge_attr\": torch.tensor([[2]])})\n        })\n        sample_edge_indices = dict({\n            ('person', 'works_with', 'person'):\n            dict({\"edge_index\": torch.tensor([[1], [2]])}),\n            ('manager', 'leads', 'person'):\n            dict({\"edge_index\": torch.tensor([[0], [1]])}),\n            ('manager', 'works_with', 'person'):\n            dict({\"edge_index\": torch.tensor([[0], [0]])})\n        })\n    return sample_attr, sample_edge_attr, sample_edge_indices\n\n\ndef _init_graph_to_sample(graph_dtype, hetero=False, reverse=False):\n    sample_attr, sample_edge_attr, sample_edge_indices = _init_sample_graph(\n        hetero)\n    if reverse:\n        if not hetero:\n            sample_edge_indices = sample_edge_indices.flip(0)\n        else:\n            reversed_edge_indices = dict()\n            reversed_edge_attr = dict()\n            for edge_type, edge_index in sample_edge_indices.items():\n                edge_index = edge_index[\"edge_index\"]\n                edge_attr = sample_edge_attr[edge_type][\"edge_attr\"]\n                flipped_edge_index = edge_index.flip(0)\n                flipped_edge_type = (edge_type[2], edge_type[1], edge_type[0])\n                reversed_edge_indices[flipped_edge_type] = dict(\n                    {\"edge_index\": flipped_edge_index})\n                reversed_edge_attr[flipped_edge_type] = dict(\n                    {\"edge_attr\": edge_attr})\n            sample_edge_indices = reversed_edge_indices\n            sample_edge_attr = reversed_edge_attr\n    graph_to_sample = None\n    if graph_dtype == 'data' and not hetero:\n        graph_to_sample = Data(edge_index=sample_edge_indices, x=sample_attr,\n                               time=sample_attr.squeeze(-1),\n                               edge_attr=sample_edge_attr.squeeze(-1))\n    elif graph_dtype == 'remote' and not hetero:\n        graph_store = MyGraphStore()\n        graph_store.put_edge_index(sample_edge_indices, edge_type=None,\n                                   layout='coo', is_sorted=True, size=(4, 4))\n        feature_store = MyFeatureStore()\n        feature_store.put_tensor(sample_attr, group_name='default',\n                                 attr_name='x', index=None)\n        # temporal node sampling on (fs, gs) needs 'time' attr\n        feature_store.put_tensor(sample_attr.squeeze(-1), group_name='default',\n                                 attr_name='time', index=None)\n        feature_store.put_tensor(sample_edge_attr.squeeze(-1),\n                                 group_name='default', attr_name='edge_attr',\n                                 index=None)\n        graph_to_sample = (feature_store, graph_store)\n    elif graph_dtype == 'data' and hetero:\n        graph_to_sample = HeteroData()\n        for node_type, node_attr in sample_attr.items():\n            graph_to_sample[node_type].x = node_attr['x']\n            graph_to_sample[node_type].time = node_attr['x'].squeeze(-1)\n        for edge_type in sample_edge_indices.keys():\n            graph_to_sample[edge_type].edge_index = sample_edge_indices[\n                edge_type][\"edge_index\"]\n            graph_to_sample[edge_type].edge_attr = sample_edge_attr[edge_type][\n                \"edge_attr\"].squeeze(-1)\n    elif graph_dtype == 'remote' and hetero:\n        graph_store = MyGraphStore()\n        for edge_type, edge_index in sample_edge_indices.items():\n            edge_index = edge_index[\"edge_index\"]\n            graph_store.put_edge_index(\n                edge_index, edge_type=edge_type, layout='coo', is_sorted=True,\n                size=(len(sample_attr[edge_type[0]][\"x\"]),\n                      len(sample_attr[edge_type[2]][\"x\"])))\n        feature_store = MyFeatureStore()\n        for node_type, node_attr in sample_attr.items():\n            feature_store.put_tensor(node_attr[\"x\"], group_name=node_type,\n                                     attr_name='x', index=None)\n            # temporal node sampling on (fs, gs) needs 'time' attr\n            feature_store.put_tensor(node_attr[\"x\"].squeeze(-1),\n                                     group_name=node_type, attr_name='time',\n                                     index=None)\n        for edge_type, edge_attr in sample_edge_attr.items():\n            feature_store.put_tensor(edge_attr[\"edge_attr\"].squeeze(-1),\n                                     group_name=edge_type,\n                                     attr_name='edge_attr', index=None)\n        graph_to_sample = (feature_store, graph_store)\n    return graph_to_sample\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('input_type', ['data', 'remote'])\ndef test_homogeneous_neighbor_sampler_basic(input_type):\n    graph_to_sample = _init_graph_to_sample(input_type, hetero=False)\n\n    # NeighborSampler default parameters 1 node\n    # disjoint = False\n    # replace = False\n    sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [1],\n    }\n\n    # Sampling from Bob should yield only Alice\n    node_sampler_input = NodeSamplerInput(input_id=None,\n                                          node=torch.tensor([1]))\n    expected_output = SamplerOutput(\n        node=torch.tensor([1, 0]), row=torch.tensor([1]),\n        col=torch.tensor([0]), edge=torch.tensor([0]), batch=None,\n        num_sampled_nodes=[1, 1], num_sampled_edges=[1], orig_row=None,\n        orig_col=None, metadata=(None, None))\n    sampler = NeighborSampler(**sampler_kwargs)\n    sampler_output = sampler.sample_from_nodes(node_sampler_input)\n    assert str(sampler_output) == str(expected_output)\n\n    # Sampling Alice should yield no edges\n    node_sampler_input = NodeSamplerInput(input_id=None,\n                                          node=torch.tensor([0]))\n    expected_output = SamplerOutput(node=torch.tensor([0]), row=torch.empty(\n        0, dtype=torch.int64), col=torch.empty(0, dtype=torch.int64),\n                                    edge=torch.empty(0, dtype=torch.int64),\n                                    batch=None, num_sampled_nodes=[1, 0],\n                                    num_sampled_edges=[0], orig_row=None,\n                                    orig_col=None, metadata=(None, None))\n    sampler = NeighborSampler(**sampler_kwargs)\n    sampler_output = sampler.sample_from_nodes(node_sampler_input)\n    assert str(sampler_output) == str(expected_output)\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('input_type', ['data', 'remote'])\ndef test_heterogeneous_neighbor_sampler_basic(input_type):\n    graph_to_sample = _init_graph_to_sample(input_type, hetero=True)\n\n    # NeighborSampler default parameters 1 node\n    # disjoint = False\n    # replace = False\n    sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [1],\n    }\n\n    # Sampling from Bob should yield only Alice\n    node_sampler_input = NodeSamplerInput(input_id=None,\n                                          node=torch.tensor([0]),\n                                          input_type=\"person\")\n    sampler = NeighborSampler(**sampler_kwargs)\n    sampler_output = sampler.sample_from_nodes(node_sampler_input)\n\n    assert sampler_output.node['person'].tolist() == [0]\n    assert sampler_output.node['manager'].tolist() == [0]\n\n    assert sampler_output.row[('manager', 'works_with',\n                               'person')] == torch.tensor([0])\n    assert sampler_output.row[('manager', 'leads', 'person')].numel() == 0\n    assert sampler_output.row[('person', 'works_with', 'person')].numel() == 0\n    assert sampler_output.col[('manager', 'works_with',\n                               'person')] == torch.tensor([0])\n    assert sampler_output.col[('manager', 'leads', 'person')].numel() == 0\n    assert sampler_output.col[('person', 'works_with', 'person')].numel() == 0\n    assert sampler_output.edge[('manager', 'works_with',\n                                'person')] == torch.tensor([0])\n    assert sampler_output.edge[('manager', 'leads', 'person')].numel() == 0\n    assert sampler_output.edge[('person', 'works_with', 'person')].numel() == 0\n\n    # Sampling Alice should yield no edges\n    node_sampler_input = NodeSamplerInput(input_id=None,\n                                          node=torch.tensor([0]),\n                                          input_type=\"manager\")\n    sampler_output = sampler.sample_from_nodes(node_sampler_input)\n\n    assert sampler_output.node['manager'].tolist() == [0]\n    assert sampler_output.node['person'].numel() == 0\n\n    assert sampler_output.row[('manager', 'works_with', 'person')].numel() == 0\n    assert sampler_output.row[('manager', 'leads', 'person')].numel() == 0\n    assert sampler_output.row[('person', 'works_with', 'person')].numel() == 0\n    assert sampler_output.col[('manager', 'works_with', 'person')].numel() == 0\n    assert sampler_output.col[('manager', 'leads', 'person')].numel() == 0\n    assert sampler_output.col[('person', 'works_with', 'person')].numel() == 0\n    assert sampler_output.edge[('manager', 'works_with',\n                                'person')].numel() == 0\n    assert sampler_output.edge[('manager', 'leads', 'person')].numel() == 0\n    assert sampler_output.edge[('person', 'works_with', 'person')].numel() == 0\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('input_type', ['data', 'remote'])\ndef test_homogeneous_neighbor_sampler_backwards(input_type):\n\n    graph_to_sample = _init_graph_to_sample(input_type, hetero=False)\n\n    sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [1],\n    }\n\n    node_sampler_input = NodeSamplerInput(input_id=None,\n                                          node=torch.tensor([2]))\n\n    sampler = NeighborSampler(**sampler_kwargs)\n    # This output should have Carol and Alice\n    sampler_output = sampler.sample_from_nodes(node_sampler_input)\n\n    backward_sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [1],\n        'sample_direction': 'backward',\n    }\n    backward_sampler = NeighborSampler(**backward_sampler_kwargs)\n    # This output should have Carol and Dave\n    backward_sampler_output = backward_sampler.sample_from_nodes(\n        node_sampler_input)\n\n    reverse_graph_to_sample = _init_graph_to_sample(input_type, hetero=False,\n                                                    reverse=True)\n\n    reverse_sampler_kwargs = {\n        'data': reverse_graph_to_sample,\n        'num_neighbors': [1],\n    }\n\n    reverse_sampler = NeighborSampler(**reverse_sampler_kwargs)\n    # This output should have Carol and Dave\n    reverse_sampler_output = reverse_sampler.sample_from_nodes(\n        node_sampler_input)\n\n    reverse_backward_sampler_kwargs = {\n        'data': reverse_graph_to_sample,\n        'num_neighbors': [1],\n        'sample_direction': 'backward',\n    }\n\n    reverse_backward_sampler = NeighborSampler(\n        **reverse_backward_sampler_kwargs)\n    # This output should have Carol and Alice\n    reverse_backward_sampler_output = \\\n        reverse_backward_sampler.sample_from_nodes(node_sampler_input)\n\n    assert torch.equal(sampler_output.node,\n                       reverse_backward_sampler_output.node)\n    assert torch.equal(sampler_output.row, reverse_backward_sampler_output.col)\n    assert torch.equal(sampler_output.col, reverse_backward_sampler_output.row)\n    assert torch.equal(sampler_output.edge,\n                       reverse_backward_sampler_output.edge)\n\n    assert torch.equal(backward_sampler_output.node,\n                       reverse_sampler_output.node)\n    assert torch.equal(backward_sampler_output.row, reverse_sampler_output.col)\n    assert torch.equal(backward_sampler_output.col, reverse_sampler_output.row)\n    assert torch.equal(backward_sampler_output.edge,\n                       reverse_sampler_output.edge)\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('input_type', ['data', 'remote'])\ndef test_homogeneous_neighbor_sampler_weighted_backwards(input_type):\n    graph_to_sample = _init_graph_to_sample(input_type, hetero=False)\n    reverse_graph_to_sample = _init_graph_to_sample(input_type, hetero=False,\n                                                    reverse=True)\n\n    sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [1, 1],\n        'weight_attr': 'weight',\n        'sample_direction': 'backward'\n    }\n    reverse_sampler_kwargs = {\n        'data': reverse_graph_to_sample,\n        'num_neighbors': [1, 1],\n        'weight_attr': 'weight',\n        'sample_direction': 'forward',\n    }\n\n    if input_type == 'remote':\n        with pytest.raises(NotImplementedError):\n            NeighborSampler(**sampler_kwargs)\n        return\n\n    graph_to_sample['weight'] = torch.tensor([1.0, 0.0, 1.0])\n    reverse_graph_to_sample['weight'] = torch.tensor([1.0, 0.0, 1.0])\n\n    node_sampler_input = NodeSamplerInput(input_id=None,\n                                          node=torch.tensor([0]))\n\n    # Sampling from Alice should yield Bob\n    backward_sampler = NeighborSampler(**sampler_kwargs)\n    backward_sampler_output = backward_sampler.sample_from_nodes(\n        node_sampler_input)\n\n    reverse_sampler = NeighborSampler(**reverse_sampler_kwargs)\n    reverse_sampler_output = reverse_sampler.sample_from_nodes(\n        node_sampler_input)\n\n    assert torch.equal(backward_sampler_output.node,\n                       reverse_sampler_output.node)\n    assert torch.equal(backward_sampler_output.row, reverse_sampler_output.col)\n    assert torch.equal(backward_sampler_output.col, reverse_sampler_output.row)\n    assert torch.equal(backward_sampler_output.edge,\n                       reverse_sampler_output.edge)\n\n    graph_to_sample['weight'] = torch.tensor([0.0, 1.0, 1.0])\n    reverse_graph_to_sample['weight'] = torch.tensor([0.0, 1.0, 1.0])\n\n    # Sampling from Alice should yield Carol and Dave\n    backward_sampler = NeighborSampler(**sampler_kwargs)\n    backward_sampler_output = backward_sampler.sample_from_nodes(\n        node_sampler_input)\n\n    reverse_sampler = NeighborSampler(**reverse_sampler_kwargs)\n    reverse_sampler_output = reverse_sampler.sample_from_nodes(\n        node_sampler_input)\n\n    assert torch.equal(backward_sampler_output.node,\n                       reverse_sampler_output.node)\n    assert torch.equal(backward_sampler_output.row, reverse_sampler_output.col)\n    assert torch.equal(backward_sampler_output.col, reverse_sampler_output.row)\n    assert torch.equal(backward_sampler_output.edge,\n                       reverse_sampler_output.edge)\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('input_type', ['data', 'remote'])\n@pytest.mark.parametrize('time_attr', ['time', 'edge_attr'])\ndef test_homogeneous_neighbor_sampler_temporal_backwards(\n        input_type, time_attr):\n    graph_to_sample = _init_graph_to_sample(input_type, hetero=False)\n    reverse_graph_to_sample = _init_graph_to_sample(input_type, hetero=False,\n                                                    reverse=True)\n\n    sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [2, 2],\n        'time_attr': time_attr,\n    }\n    reverse_sampler_kwargs = {\n        'data': reverse_graph_to_sample,\n        'num_neighbors': [2, 2],\n        'time_attr': time_attr,\n    }\n\n    node_sampler_input = NodeSamplerInput(input_id=None,\n                                          node=torch.tensor([1]),\n                                          time=torch.tensor([1]))\n    reverse_node_sampler_input = NodeSamplerInput(input_id=None,\n                                                  node=torch.tensor([0]),\n                                                  time=torch.tensor([1]))\n\n    # sampling from Dave should yield Carol, Alice\n    sampler = NeighborSampler(**sampler_kwargs)\n    sampler_output = sampler.sample_from_nodes(node_sampler_input)\n\n    reverse_sampler = NeighborSampler(**reverse_sampler_kwargs)\n    reverse_sampler_output = reverse_sampler.sample_from_nodes(\n        reverse_node_sampler_input)\n\n    assert torch.equal(sampler_output.node, torch.tensor([1, 0]))\n    assert torch.equal(reverse_sampler_output.node, torch.tensor([0, 1]))\n    \"\"\"\n    TODO (zaristei) Negative cases for temporal sampling,\n    then verify that the output is correct for backwards sampling.\n    \"\"\"\n    pytest.skip(\"still TODO\")\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('input_type', ['data', 'remote'])\ndef test_heterogeneous_neighbor_sampler_backwards(input_type):\n    graph_to_sample = _init_graph_to_sample(input_type, hetero=True)\n\n    sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [1],\n    }\n\n    node_sampler_input = NodeSamplerInput(input_id=None,\n                                          node=torch.tensor([1]),\n                                          input_type=\"person\")\n\n    sampler = NeighborSampler(**sampler_kwargs)\n    # This output should have Carol and Alice\n    sampler_output = sampler.sample_from_nodes(node_sampler_input)\n\n    backward_sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [1],\n        'sample_direction': 'backward',\n    }\n    backward_sampler = NeighborSampler(**backward_sampler_kwargs)\n    # This output should have Carol and Dave\n    backward_sampler_output = backward_sampler.sample_from_nodes(\n        node_sampler_input)\n\n    reverse_graph_to_sample = _init_graph_to_sample(input_type, hetero=True,\n                                                    reverse=True)\n\n    reverse_sampler_kwargs = {\n        'data': reverse_graph_to_sample,\n        'num_neighbors': [1],\n    }\n\n    reverse_sampler = NeighborSampler(**reverse_sampler_kwargs)\n    # This output should have Carol and Dave\n    reverse_sampler_output = reverse_sampler.sample_from_nodes(\n        node_sampler_input)\n\n    reverse_backward_sampler_kwargs = {\n        'data': reverse_graph_to_sample,\n        'num_neighbors': [1],\n        'sample_direction': 'backward',\n    }\n\n    reverse_backward_sampler = NeighborSampler(\n        **reverse_backward_sampler_kwargs)\n    # This output should have Carol and Alice\n    reverse_backward_sampler_output = \\\n        reverse_backward_sampler.sample_from_nodes(node_sampler_input)\n\n    def reverse_key(key):\n        return (key[2], key[1], key[0])\n\n    assert sampler_output.node.keys(\n    ) == reverse_backward_sampler_output.node.keys()\n    assert reverse_sampler_output.node.keys(\n    ) == backward_sampler_output.node.keys()\n    for key in sampler_output.node.keys():\n        assert torch.equal(sampler_output.node[key],\n                           reverse_backward_sampler_output.node[key])\n    for key in reverse_sampler_output.node.keys():\n        assert torch.equal(reverse_sampler_output.node[key],\n                           backward_sampler_output.node[key])\n\n    assert len(sampler_output.row.keys()) == len(\n        reverse_backward_sampler_output.row.keys())\n    for key in sampler_output.row.keys():\n        assert reverse_key(key) in reverse_backward_sampler_output.row.keys()\n\n        assert torch.equal(\n            sampler_output.row[key],\n            reverse_backward_sampler_output.col[reverse_key(key)])\n        assert torch.equal(\n            sampler_output.col[key],\n            reverse_backward_sampler_output.row[reverse_key(key)])\n        assert torch.equal(\n            sampler_output.edge[key],\n            reverse_backward_sampler_output.edge[reverse_key(key)])\n\n    assert len(reverse_sampler_output.row.keys()) == len(\n        backward_sampler_output.row.keys())\n    for key in reverse_sampler_output.row.keys():\n        assert reverse_key(key) in backward_sampler_output.row.keys()\n\n        assert torch.equal(reverse_sampler_output.row[key],\n                           backward_sampler_output.col[reverse_key(key)])\n        assert torch.equal(reverse_sampler_output.col[key],\n                           backward_sampler_output.row[reverse_key(key)])\n        assert torch.equal(reverse_sampler_output.edge[key],\n                           backward_sampler_output.edge[reverse_key(key)])\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('input_type', ['data', 'remote'])\ndef test_bidirectional_neighbor_sampler(input_type):\n    graph_to_sample = _init_graph_to_sample(input_type, hetero=False)\n\n    sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [1],\n    }\n\n    node_sampler_input = NodeSamplerInput(input_id=None,\n                                          node=torch.tensor([2]))\n    sampler = BidirectionalNeighborSampler(**sampler_kwargs)\n    sampler_output = sampler.sample_from_nodes(node_sampler_input)\n\n    expected_output = SamplerOutput(\n        # Union between forward and backward nodes\n        node=torch.tensor([2, 0, 3]),\n        # Reindexed to be relative to new nodes field\n        row=torch.tensor([1, 0]),\n        # Reindexed to be relative to new nodes field\n        col=torch.tensor([0, 2]),\n        # Union between forward and backward edges\n        edge=torch.tensor([1, 2]),\n        # Will be part of node uid if disjoint=True\n        batch=None,\n        # nodes are only counted on their first sample\n        num_sampled_nodes=[1, 1, 0, 1],\n        # edges are only counted on their first sample\n        num_sampled_edges=[1, 1],\n        # Will be used as edge uid if bidirectional=True with\n        # keep_orig_edges=True\n        orig_row=None,\n        # Will be used as edge uid if bidirectional=True with\n        # keep_orig_edges=True\n        orig_col=None,\n        # simple concat of forward and backward metadata\n        metadata=(None, None))\n    assert str(sampler_output) == str(expected_output)\n\n    adv_sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [2, 2, 2, 2],\n    }\n\n    adv_sampler_input = NodeSamplerInput(input_id=None,\n                                         node=torch.tensor([1, 3]))\n\n    adv_sampler = BidirectionalNeighborSampler(**adv_sampler_kwargs)\n    adv_sampler_output = adv_sampler.sample_from_nodes(adv_sampler_input)\n\n    adv_expected_output = SamplerOutput(\n        node=torch.tensor([1, 3, 0, 2]),\n        row=torch.tensor([2, 3, 2]),\n        col=torch.tensor([0, 1, 3]),\n        edge=torch.tensor([0, 2, 1]),\n        batch=None,\n        # 8 _sample calls total, each have 2 num_sampled_nodes slots\n        num_sampled_nodes=[2, 2] + [0] * 14,\n        num_sampled_edges=[2, 0, 1, 0, 0, 0, 0, 0],\n        orig_row=None,\n        orig_col=None,\n        metadata=(None, None))\n    assert str(adv_sampler_output) == str(adv_expected_output)\n\n    adv_sampler_kwargs['disjoint'] = True\n\n    adv_sampler_disjoint = BidirectionalNeighborSampler(**adv_sampler_kwargs)\n    adv_sampler_disjoint_output = adv_sampler_disjoint.sample_from_nodes(\n        adv_sampler_input)\n\n    adv_expected_disjoint_output = SamplerOutput(\n        node=torch.tensor([1, 3, 0, 2, 0, 2, 1, 3]),\n        row=torch.tensor([2, 3, 4, 2, 4, 5]),\n        col=torch.tensor([0, 1, 3, 5, 6, 7]),\n        edge=torch.tensor([0, 2, 1, 1, 0, 2]),\n        batch=torch.tensor([0, 1, 0, 1, 1, 0, 1, 0]),\n        num_sampled_nodes=[\n            # First forward iteration:\n            # (Bob, Dave) -> (Alice, Carol)\n            2,\n            2,\n            # First backward iteration:\n            # (Bob (seen), Dave (seen)) -> (None, None)\n            0,\n            0,\n            # Second forward iteration:\n            # (Alice (seen), Carol (seen)) -> (None, Alice)\n            0,\n            1,\n            # Second backward iteration:\n            # (Alice (seen), Carol (seen)) ->\n            #   (Bob (seen) and Carol, Alice and Dave (seen))\n            0,\n            1,\n            # Third forward iteration:\n            # (Carol (seen), Alice (seen)) -> (Alice (seen), None)\n            0,\n            0,\n            # Third backward iteration:\n            # (Alice (seen), Carol (seen)) -> (Bob and Carol (seen), Dave)\n            0,\n            2,\n            # Fourth forward iteration:\n            # (Bob (seen), Dave (seen)) -> (Alice (seen), Carol (seen))\n            0,\n            0,\n            # Fourth backward iteration:\n            # (Bob (seen), Dave (seen)) -> (None, None)\n            0,\n            0\n        ],\n        num_sampled_edges=[2, 0, 1, 1, 0, 2, 0, 0],\n        orig_row=None,\n        orig_col=None,\n        metadata=(None, None))\n    assert str(adv_sampler_disjoint_output) == str(\n        adv_expected_disjoint_output)\n\n\n@pytest.mark.skip(\n    reason=\"BidirectionalSampler not implemented yet for heterogeneous graphs.\"\n)\n@onlyNeighborSampler\n@pytest.mark.parametrize('input_type', ['data', 'remote'])\ndef test_bidirectional_neighbor_sampler_hetero(input_type):\n    raise NotImplementedError\n\n\n@onlyNeighborSampler\n@pytest.mark.parametrize('input_type', ['data', 'remote'])\n@pytest.mark.parametrize('hetero', [False, True])\ndef test_neighbor_sampler_backwards_not_supported(input_type, hetero):\n    graph_to_sample = _init_graph_to_sample(input_type, hetero=hetero)\n\n    sampler_kwargs = {\n        'data': graph_to_sample,\n        'num_neighbors': [1],\n        'sample_direction': 'backward',\n        'time_attr': 'time'\n    }\n\n    with pytest.raises(NotImplementedError):\n        NeighborSampler(**sampler_kwargs)\n"
  },
  {
    "path": "test/test_config_mixin.py",
    "content": "from dataclasses import asdict, dataclass, is_dataclass\nfrom typing import Sequence\n\nimport pytest\nimport torch\n\nfrom torch_geometric.config_mixin import ConfigMixin\nfrom torch_geometric.config_store import clear_config_store, register\n\n\n@pytest.fixture(scope=\"session\", autouse=True)\ndef teardown_once():\n    yield  # This allows tests to run before teardown is executed\n    clear_config_store()\n\n\n@dataclass\nclass Dataclass:\n    x: int\n    y: int\n\n\nclass Base(torch.nn.Module, ConfigMixin):\n    pass\n\n\n@register(with_target=True)\nclass Module(Base):\n    def __init__(self, x: int, data: Dataclass):\n        super().__init__()\n        self.x = x\n        self.data = data\n\n\n@register(with_target=True)\nclass SubModule(Base):\n    def __init__(self, p: float):\n        super().__init__()\n        self.p = p\n\n\n@register(with_target=True)\nclass CompoundModule(torch.nn.Module, ConfigMixin):\n    def __init__(\n        self,\n        z: str,\n        module: Module,\n        submodules: list[SubModule],\n        key_modules: dict[str, torch.nn.Module],\n    ):\n        super().__init__()\n        self.z = z\n        self.module = module\n        self.submodules = torch.nn.ModuleList(submodules)\n        self.key_modules = torch.nn.ModuleDict(key_modules)\n\n\ndef test_config_mixin() -> None:\n    x = 0\n    data = Dataclass(x=1, y=2)\n\n    model = Module(x, data)\n    cfg = model.config()\n    assert is_dataclass(cfg)\n    assert cfg.x == 0\n    assert isinstance(cfg.data, Dataclass)\n    assert cfg.data.x == 1\n    assert cfg.data.y == 2\n    assert cfg._target_ == 'test_config_mixin.Module'\n\n    model = Module.from_config(cfg)\n    assert isinstance(model, Module)\n    assert model.x == 0\n    assert isinstance(model.data, Dataclass)\n    assert model.data.x == 1\n    assert model.data.y == 2\n\n    model = Base.from_config(cfg)\n    assert isinstance(model, Module)\n    assert model.x == 0\n    assert isinstance(model.data, Dataclass)\n    assert model.data.x == 1\n    assert model.data.y == 2\n\n    model = Base.from_config(cfg, 3)\n    assert isinstance(model, Module)\n    assert model.x == 3\n    assert isinstance(model.data, Dataclass)\n    assert model.data.x == 1\n    assert model.data.y == 2\n\n    model = Base.from_config(cfg, data=Dataclass(x=2, y=3))\n    assert isinstance(model, Module)\n    assert model.x == 0\n    assert isinstance(model.data, Dataclass)\n    assert model.data.x == 2\n    assert model.data.y == 3\n\n    cfg = asdict(cfg)\n\n    model = Module.from_config(cfg)\n    assert isinstance(model, Module)\n    assert model.x == 0\n    assert isinstance(model.data, dict)\n    assert model.data['x'] == 1\n    assert model.data['y'] == 2\n\n    model = Base.from_config(cfg)\n    assert isinstance(model, Module)\n    assert model.x == 0\n    assert isinstance(model.data, dict)\n    assert model.data['x'] == 1\n    assert model.data['y'] == 2\n\n\ndef test_config_mixin_compound() -> None:\n    module = Module(x=0, data=Dataclass(x=1, y=2))\n    submodules = [SubModule(1.41), SubModule(3.14)]\n    key_modules = {\n        \"key1\": Module(x=10, data=Dataclass(x=11, y=12)),\n        \"key2\": SubModule(2.71),\n    }\n    model = CompoundModule(z=\"foo\", module=module, submodules=submodules,\n                           key_modules=key_modules)\n\n    cfg = model.config()\n    assert is_dataclass(cfg)\n    assert cfg._target_ == 'test_config_mixin.CompoundModule'\n    assert cfg.z == \"foo\"\n    assert cfg.module._target_ == 'test_config_mixin.Module'\n    assert cfg.module.x == 0\n    assert isinstance(cfg.module.data, Dataclass)\n    assert cfg.module.data.x == 1\n    assert cfg.module.data.y == 2\n\n    assert len(cfg.submodules) == 2\n    assert isinstance(cfg.submodules, Sequence)\n    assert cfg.submodules[0]._target_ == 'test_config_mixin.SubModule'\n    assert cfg.submodules[0].p == 1.41\n    assert cfg.submodules[1]._target_ == 'test_config_mixin.SubModule'\n    assert cfg.submodules[1].p == 3.14\n\n    assert len(cfg.key_modules) == 2\n    assert cfg.key_modules[\"key1\"]._target_ == 'test_config_mixin.Module'\n    assert cfg.key_modules[\"key1\"].x == 10\n    assert isinstance(cfg.key_modules[\"key1\"].data, Dataclass)\n    assert cfg.key_modules[\"key1\"].data.x == 11\n    assert cfg.key_modules[\"key1\"].data.y == 12\n\n    assert cfg.key_modules[\"key2\"]._target_ == 'test_config_mixin.SubModule'\n    assert cfg.key_modules[\"key2\"].p == 2.71\n\n    model = CompoundModule.from_config(cfg)\n    assert isinstance(model, CompoundModule)\n    assert model.z == \"foo\"\n    assert isinstance(model.module, Module)\n    assert model.module.x == 0\n    assert isinstance(model.module.data, Dataclass)\n    assert model.module.data.x == 1\n    assert model.module.data.y == 2\n    assert isinstance(model.submodules, torch.nn.ModuleList)\n    assert len(model.submodules) == 2\n    assert isinstance(model.submodules[0], SubModule)\n    assert model.submodules[0].p == 1.41\n    assert isinstance(model.submodules[1], SubModule)\n    assert model.submodules[1].p == 3.14\n    assert isinstance(model.key_modules, torch.nn.ModuleDict)\n    assert len(model.key_modules) == 2\n    assert isinstance(model.key_modules[\"key1\"], Module)\n    assert model.key_modules[\"key1\"].x == 10\n    assert isinstance(model.key_modules[\"key1\"].data, Dataclass)\n    assert model.key_modules[\"key1\"].data.x == 11\n    assert model.key_modules[\"key1\"].data.y == 12\n    assert isinstance(model.key_modules[\"key2\"], SubModule)\n    assert model.key_modules[\"key2\"].p == 2.71\n"
  },
  {
    "path": "test/test_config_store.py",
    "content": "from typing import Any, Dict, List, Tuple\n\nfrom torch_geometric.config_store import (\n    class_from_dataclass,\n    clear_config_store,\n    dataclass_from_class,\n    fill_config_store,\n    get_config_store,\n    map_annotation,\n    register,\n    to_dataclass,\n)\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import AddSelfLoops\n\n\ndef teardown_function():\n    clear_config_store()\n\n\ndef test_to_dataclass():\n    from torch_geometric.transforms import AddSelfLoops\n\n    AddSelfLoopsConfig = to_dataclass(AddSelfLoops, with_target=True)\n    assert AddSelfLoopsConfig.__name__ == 'AddSelfLoops'\n\n    fields = AddSelfLoopsConfig.__dataclass_fields__\n\n    assert fields['attr'].name == 'attr'\n    assert fields['attr'].type == str\n    assert fields['attr'].default == 'edge_weight'\n\n    assert fields['fill_value'].name == 'fill_value'\n    assert fields['fill_value'].type == Any\n    assert fields['fill_value'].default == 1.0\n\n    assert fields['_target_'].name == '_target_'\n    assert fields['_target_'].type == str\n    assert fields['_target_'].default == (\n        'torch_geometric.transforms.add_self_loops.AddSelfLoops')\n\n    cfg = AddSelfLoopsConfig()\n    assert str(cfg) == (\"AddSelfLoops(attr='edge_weight', fill_value=1.0, \"\n                        \"_target_='torch_geometric.transforms.add_self_loops.\"\n                        \"AddSelfLoops')\")\n\n\ndef test_map_annotation():\n    mapping = {int: Any}\n    assert map_annotation(dict[str, int], mapping) == dict[str, Any]\n    assert map_annotation(Dict[str, float], mapping) == Dict[str, float]\n    assert map_annotation(List[str], mapping) == List[str]\n    assert map_annotation(List[int], mapping) == List[Any]\n    assert map_annotation(Tuple[int], mapping) == Tuple[Any]\n    assert map_annotation(dict[str, int], mapping) == dict[str, Any]\n    assert map_annotation(dict[str, float], mapping) == dict[str, float]\n    assert map_annotation(list[str], mapping) == list[str]\n    assert map_annotation(list[int], mapping) == list[Any]\n    assert map_annotation(tuple[int], mapping) == tuple[Any]\n\n\ndef test_register():\n    register(AddSelfLoops, group='transform')\n    assert 'transform' in get_config_store().repo\n\n    AddSelfLoopsConfig = dataclass_from_class('AddSelfLoops')\n\n    Cls = class_from_dataclass('AddSelfLoops')\n    assert Cls == AddSelfLoops\n    Cls = class_from_dataclass(AddSelfLoopsConfig)\n    assert Cls == AddSelfLoops\n\n    ConfigCls = dataclass_from_class('AddSelfLoops')\n    assert ConfigCls == AddSelfLoopsConfig\n    ConfigCls = dataclass_from_class(ConfigCls)\n    assert ConfigCls == AddSelfLoopsConfig\n\n\ndef test_fill_config_store():\n    fill_config_store()\n\n    assert {\n        'transform',\n        'dataset',\n        'model',\n        'optimizer',\n        'lr_scheduler',\n    }.issubset(get_config_store().repo.keys())\n\n\n@withPackage('hydra')\ndef test_hydra_config_store():\n    import hydra\n    from omegaconf import DictConfig\n\n    fill_config_store()\n\n    with hydra.initialize(config_path='.', version_base='1.1'):\n        cfg = hydra.compose(config_name='my_config')\n\n    assert len(cfg) == 4\n    assert 'dataset' in cfg\n    assert 'model' in cfg\n    assert 'optimizer' in cfg\n    assert 'lr_scheduler' in cfg\n\n    # Check `cfg.dataset`:\n    assert len(cfg.dataset) == 2\n    assert cfg.dataset._target_.split('.')[-1] == 'KarateClub'\n\n    # Check `cfg.dataset.transform`:\n    assert isinstance(cfg.dataset.transform, DictConfig)\n    assert len(cfg.dataset.transform) == 2\n    assert 'NormalizeFeatures' in cfg.dataset.transform\n    assert 'AddSelfLoops' in cfg.dataset.transform\n\n    assert isinstance(cfg.dataset.transform.NormalizeFeatures, DictConfig)\n    assert (cfg.dataset.transform.NormalizeFeatures._target_.split('.')[-1] ==\n            'NormalizeFeatures')\n    assert cfg.dataset.transform.NormalizeFeatures.attrs == ['x']\n\n    assert isinstance(cfg.dataset.transform.AddSelfLoops, DictConfig)\n    assert (cfg.dataset.transform.AddSelfLoops._target_.split('.')[-1] ==\n            'AddSelfLoops')\n    assert cfg.dataset.transform.AddSelfLoops.attr == 'edge_weight'\n    assert cfg.dataset.transform.AddSelfLoops.fill_value == 1.0\n\n    # Check `cfg.model`:\n    assert len(cfg.model) == 12\n    assert cfg.model._target_.split('.')[-1] == 'GCN'\n    assert cfg.model.in_channels == 34\n    assert cfg.model.out_channels == 4\n    assert cfg.model.hidden_channels == 16\n    assert cfg.model.num_layers == 2\n    assert cfg.model.dropout == 0.0\n    assert cfg.model.act == 'relu'\n    assert cfg.model.norm is None\n    assert cfg.model.norm_kwargs is None\n    assert cfg.model.jk is None\n    assert not cfg.model.act_first\n    assert cfg.model.act_kwargs is None\n\n    # Check `cfg.optimizer`:\n    assert cfg.optimizer._target_.split('.')[-1] == 'Adam'\n    assert cfg.optimizer.lr == 0.001\n    assert cfg.optimizer.betas == [0.9, 0.999]\n    assert cfg.optimizer.eps == 1e-08\n    assert cfg.optimizer.weight_decay == 0\n    assert not cfg.optimizer.amsgrad\n    if hasattr(cfg.optimizer, 'maximize'):\n        assert not cfg.optimizer.maximize\n\n    # Check `cfg.lr_scheduler`:\n    assert cfg.lr_scheduler._target_.split('.')[-1] == 'ReduceLROnPlateau'\n    assert cfg.lr_scheduler.mode == 'min'\n    assert cfg.lr_scheduler.factor == 0.1\n    assert cfg.lr_scheduler.patience == 10\n    assert cfg.lr_scheduler.threshold == 0.0001\n    assert cfg.lr_scheduler.threshold_mode == 'rel'\n    assert cfg.lr_scheduler.cooldown == 0\n    assert cfg.lr_scheduler.min_lr == 0\n    assert cfg.lr_scheduler.eps == 1e-08\n"
  },
  {
    "path": "test/test_debug.py",
    "content": "from torch_geometric import debug, is_debug_enabled, set_debug\n\n\ndef test_debug():\n    assert is_debug_enabled() is False\n    set_debug(True)\n    assert is_debug_enabled() is True\n    set_debug(False)\n    assert is_debug_enabled() is False\n\n    assert is_debug_enabled() is False\n    with set_debug(True):\n        assert is_debug_enabled() is True\n    assert is_debug_enabled() is False\n\n    assert is_debug_enabled() is False\n    set_debug(True)\n    assert is_debug_enabled() is True\n    with set_debug(False):\n        assert is_debug_enabled() is False\n    assert is_debug_enabled() is True\n    set_debug(False)\n    assert is_debug_enabled() is False\n\n    assert is_debug_enabled() is False\n    with debug():\n        assert is_debug_enabled() is True\n    assert is_debug_enabled() is False\n"
  },
  {
    "path": "test/test_edge_index.py",
    "content": "import os.path as osp\nimport warnings\nfrom typing import List, Optional\n\nimport numpy as np\nimport pytest\nimport torch\nfrom torch import Tensor, tensor\n\nimport torch_geometric\nfrom torch_geometric import EdgeIndex, Index\nfrom torch_geometric.edge_index import (\n    ReduceType,\n    SortReturnType,\n    _scatter_spmm,\n    _torch_sparse_spmm,\n    _TorchSPMM,\n    set_tuple_item,\n)\nfrom torch_geometric.io import fs\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import (\n    onlyCUDA,\n    onlyLinux,\n    withCUDA,\n    withoutExtensions,\n    withPackage,\n)\nfrom torch_geometric.typing import INDEX_DTYPES, SparseTensor\nfrom torch_geometric.utils import scatter\n\nDTYPES = [pytest.param(dtype, id=str(dtype)[6:]) for dtype in INDEX_DTYPES]\nIS_UNDIRECTED = [\n    pytest.param(False, id='directed'),\n    pytest.param(True, id='undirected'),\n]\nTRANSPOSE = [\n    pytest.param(False, id=''),\n    pytest.param(True, id='transpose'),\n]\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_basic(dtype, device):\n    kwargs = dict(dtype=dtype, device=device, sparse_size=(3, 3))\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs)\n    adj.validate()\n    assert isinstance(adj, EdgeIndex)\n\n    assert str(adj).startswith('EdgeIndex([[0, 1, 1, 2],\\n'\n                               '           [1, 0, 2, 1]], ')\n    assert 'sparse_size=(3, 3), nnz=4' in str(adj)\n    assert (f\"device='{device}'\" in str(adj)) == adj.is_cuda\n    assert (f'dtype={dtype}' in str(adj)) == (dtype != torch.long)\n\n    assert adj.dtype == dtype\n    assert adj.device == device\n    assert adj.sparse_size() == (3, 3)\n    assert adj.sparse_size(0) == 3\n    assert adj.sparse_size(-1) == 3\n\n    assert adj.sort_order is None\n    assert not adj.is_sorted\n    assert not adj.is_sorted_by_row\n    assert not adj.is_sorted_by_col\n\n    assert not adj.is_undirected\n\n    out = adj.as_tensor()\n    assert not isinstance(out, EdgeIndex)\n    assert out.dtype == dtype\n    assert out.device == device\n\n    out = adj * 1\n    assert not isinstance(out, EdgeIndex)\n    assert out.dtype == dtype\n    assert out.device == device\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_identity(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **kwargs)\n\n    out = EdgeIndex(adj)\n    assert not isinstance(out.as_tensor(), EdgeIndex)\n    assert out.data_ptr() == adj.data_ptr()\n    assert out.dtype == adj.dtype\n    assert out.device == adj.device\n    assert out.sparse_size() == adj.sparse_size()\n    assert out.sort_order == adj.sort_order\n    assert out.is_undirected == adj.is_undirected\n\n    out = EdgeIndex(adj, sparse_size=(4, 4), sort_order='row')\n    assert out.sparse_size() == (4, 4)\n    assert out.sort_order == 'row'\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_sparse_tensor(dtype, device):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=True)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n\n    out = EdgeIndex(adj.to_sparse_coo())\n    assert out.equal(adj)\n    assert out.sort_order == 'row'\n    assert out.sparse_size() == (3, 3)\n    assert out._indptr is None\n\n    out = EdgeIndex(adj.to_sparse_csr())\n    assert out.equal(adj)\n    assert out.sort_order == 'row'\n    assert out.sparse_size() == (3, 3)\n    assert out._indptr.equal(tensor([0, 1, 3, 4], device=device))\n\n    out = EdgeIndex(adj.to_sparse_csc())\n    assert out.equal(adj.sort_by('col')[0])\n    assert out.sort_order == 'col'\n    assert out.sparse_size() == (3, 3)\n    assert out._indptr.equal(tensor([0, 1, 3, 4], device=device))\n\n\ndef test_set_tuple_item():\n    tmp = (0, 1, 2)\n    assert set_tuple_item(tmp, 0, 3) == (3, 1, 2)\n    assert set_tuple_item(tmp, 1, 3) == (0, 3, 2)\n    assert set_tuple_item(tmp, 2, 3) == (0, 1, 3)\n    with pytest.raises(IndexError, match=\"tuple index out of range\"):\n        set_tuple_item(tmp, 3, 3)\n    assert set_tuple_item(tmp, -1, 3) == (0, 1, 3)\n    assert set_tuple_item(tmp, -2, 3) == (0, 3, 2)\n    assert set_tuple_item(tmp, -3, 3) == (3, 1, 2)\n    with pytest.raises(IndexError, match=\"tuple index out of range\"):\n        set_tuple_item(tmp, -4, 3)\n\n\ndef test_validate():\n    with pytest.raises(TypeError, match=\"tensors of a single element\"):\n        EdgeIndex([torch.tensor([0, 1]), torch.tensor([1, 0])])\n    with pytest.raises(ValueError, match=\"unsupported data type\"):\n        EdgeIndex([[0.0, 1.0], [1.0, 0.0]])\n    with pytest.raises(ValueError, match=\"needs to be two-dimensional\"):\n        EdgeIndex([[[0], [1]], [[1], [0]]])\n    with pytest.raises(ValueError, match=\"needs to have a shape of\"):\n        EdgeIndex([[0, 1], [1, 0], [1, 1]])\n    with pytest.raises(ValueError, match=\"received a non-symmetric size\"):\n        EdgeIndex([[0, 1], [1, 0]], is_undirected=True, sparse_size=(2, 3))\n    with pytest.raises(TypeError, match=\"invalid combination of arguments\"):\n        EdgeIndex(tensor([[0, 1], [1, 0]]), torch.long)\n    with pytest.raises(TypeError, match=\"invalid keyword arguments\"):\n        EdgeIndex(tensor([[0, 1], [1, 0]]), dtype=torch.long)\n    with pytest.raises(ValueError, match=\"contains negative indices\"):\n        EdgeIndex([[-1, 0], [0, 1]]).validate()\n    with pytest.raises(ValueError, match=\"than its number of rows\"):\n        EdgeIndex([[0, 10], [1, 0]], sparse_size=(2, 2)).validate()\n    with pytest.raises(ValueError, match=\"than its number of columns\"):\n        EdgeIndex([[0, 1], [10, 0]], sparse_size=(2, 2)).validate()\n    with pytest.raises(ValueError, match=\"not sorted by row indices\"):\n        EdgeIndex([[1, 0], [0, 1]], sort_order='row').validate()\n    with pytest.raises(ValueError, match=\"not sorted by column indices\"):\n        EdgeIndex([[0, 1], [1, 0]], sort_order='col').validate()\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_undirected(dtype, device):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=True)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs)\n    assert isinstance(adj, EdgeIndex)\n    assert adj.is_undirected\n\n    assert adj.sparse_size() == (None, None)\n    adj.get_num_rows()\n    assert adj.sparse_size() == (3, 3)\n    adj.validate()\n\n    adj = EdgeIndex([[0, 1], [1, 0]], sparse_size=(3, None), **kwargs)\n    assert adj.sparse_size() == (3, 3)\n    adj.validate()\n\n    adj = EdgeIndex([[0, 1], [1, 0]], sparse_size=(None, 3), **kwargs)\n    assert adj.sparse_size() == (3, 3)\n    adj.validate()\n\n    with pytest.raises(ValueError, match=\"'EdgeIndex' is not undirected\"):\n        EdgeIndex([[0, 1, 1, 2], [0, 0, 1, 1]], **kwargs).validate()\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_fill_cache_(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n    adj.validate().fill_cache_()\n    assert adj.sparse_size() == (3, 3)\n    assert adj._indptr.dtype == dtype\n    assert adj._indptr.equal(tensor([0, 1, 3, 4], device=device))\n    assert adj._T_perm.dtype == dtype\n    assert (adj._T_perm.equal(tensor([1, 0, 3, 2], device=device))\n            or adj._T_perm.equal(tensor([1, 3, 0, 2], device=device)))\n    assert adj._T_index[0].dtype == dtype\n    assert (adj._T_index[0].equal(tensor([1, 0, 2, 1], device=device))\n            or adj._T_index[0].equal(tensor([1, 2, 0, 1], device=device)))\n    assert adj._T_index[1].dtype == dtype\n    assert adj._T_index[1].equal(tensor([0, 1, 1, 2], device=device))\n    if is_undirected:\n        assert adj._T_indptr is None\n    else:\n        assert adj._T_indptr.dtype == dtype\n        assert adj._T_indptr.equal(tensor([0, 1, 3, 4], device=device))\n\n    adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs)\n    adj.validate().fill_cache_()\n    assert adj.sparse_size() == (3, 3)\n    assert adj._indptr.dtype == dtype\n    assert adj._indptr.equal(tensor([0, 1, 3, 4], device=device))\n    assert (adj._T_perm.equal(tensor([1, 0, 3, 2], device=device))\n            or adj._T_perm.equal(tensor([1, 3, 0, 2], device=device)))\n    assert adj._T_index[0].dtype == dtype\n    assert adj._T_index[0].equal(tensor([0, 1, 1, 2], device=device))\n    assert adj._T_index[1].dtype == dtype\n    assert (adj._T_index[1].equal(tensor([1, 0, 2, 1], device=device))\n            or adj._T_index[1].equal(tensor([1, 2, 0, 1], device=device)))\n    if is_undirected:\n        assert adj._T_indptr is None\n    else:\n        assert adj._T_indptr.dtype == dtype\n        assert adj._T_indptr.equal(tensor([0, 1, 3, 4], device=device))\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_clone(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n\n    out = adj.clone()\n    assert isinstance(out, EdgeIndex)\n    assert out.dtype == dtype\n    assert out.device == device\n    assert out.is_sorted_by_row\n    assert out.is_undirected == is_undirected\n\n    out = torch.clone(adj)\n    assert isinstance(out, EdgeIndex)\n    assert out.dtype == dtype\n    assert out.device == device\n    assert out.is_sorted_by_row\n    assert out.is_undirected == is_undirected\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_to_function(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, is_undirected=is_undirected)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n    adj.fill_cache_()\n\n    adj = adj.to(device)\n    assert isinstance(adj, EdgeIndex)\n    assert adj.device == device\n    assert adj._indptr.dtype == dtype\n    assert adj._indptr.device == device\n    assert adj._T_perm.dtype == dtype\n    assert adj._T_perm.device == device\n\n    out = adj.cpu()\n    assert isinstance(out, EdgeIndex)\n    assert out.device == torch.device('cpu')\n\n    out = adj.to(torch.int)\n    assert out.dtype == torch.int\n    if torch_geometric.typing.WITH_PT20:\n        assert isinstance(out, EdgeIndex)\n        assert out._indptr.dtype == torch.int\n        assert out._T_perm.dtype == torch.int\n    else:\n        assert not isinstance(out, EdgeIndex)\n\n    out = adj.to(torch.float)\n    assert not isinstance(out, EdgeIndex)\n    assert out.dtype == torch.float\n\n    out = adj.long()\n    assert isinstance(out, EdgeIndex)\n    assert out.dtype == torch.int64\n\n    out = adj.int()\n    assert out.dtype == torch.int\n    if torch_geometric.typing.WITH_PT20:\n        assert isinstance(out, EdgeIndex)\n    else:\n        assert not isinstance(out, EdgeIndex)\n\n\n@onlyCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_cpu_cuda(dtype):\n    kwargs = dict(dtype=dtype)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs)\n    assert adj.is_cpu\n\n    out = adj.cuda()\n    assert isinstance(out, EdgeIndex)\n    assert out.is_cuda\n\n    out = out.cpu()\n    assert isinstance(out, EdgeIndex)\n    assert out.is_cpu\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_share_memory(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n    adj.fill_cache_()\n\n    out = adj.share_memory_()\n    assert isinstance(out, EdgeIndex)\n    assert out.is_shared()\n    assert out._data.is_shared()\n    assert out._indptr.is_shared()\n    assert out.data_ptr() == adj.data_ptr()\n\n\n@onlyCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_pin_memory(dtype):\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=dtype)\n    assert not adj.is_pinned()\n    out = adj.pin_memory()\n    assert out.is_pinned()\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_contiguous(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    data = tensor([[0, 1], [1, 0], [1, 2], [2, 1]], **kwargs).t()\n\n    with pytest.raises(ValueError, match=\"needs to be contiguous\"):\n        EdgeIndex(data)\n\n    adj = EdgeIndex(data.contiguous()).contiguous()\n    assert isinstance(adj, EdgeIndex)\n    assert adj.is_contiguous()\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_sort_by(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n    out = adj.sort_by('row')\n    assert isinstance(out, SortReturnType)\n    assert isinstance(out.values, EdgeIndex)\n    assert not isinstance(out.indices, EdgeIndex)\n    assert out.values.equal(adj)\n    assert out.indices is None\n\n    adj = EdgeIndex([[0, 1, 2, 1], [1, 0, 1, 2]], **kwargs)\n    out = adj.sort_by('row')\n    assert isinstance(out, SortReturnType)\n    assert isinstance(out.values, EdgeIndex)\n    assert not isinstance(out.indices, EdgeIndex)\n    assert out.values[0].equal(tensor([0, 1, 1, 2], device=device))\n    assert (out.values[1].equal(tensor([1, 0, 2, 1], device=device))\n            or out.values[1].equal(tensor([1, 2, 0, 1], device=device)))\n    assert (out.indices.equal(tensor([0, 1, 3, 2], device=device))\n            or out.indices.equal(tensor([0, 3, 1, 2], device=device)))\n\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n\n    out, perm = adj.sort_by('col')\n    assert adj._T_perm is not None  # Check caches.\n    assert adj._T_index[0] is not None and adj._T_index[1] is not None\n    assert (out[0].equal(tensor([1, 0, 2, 1], device=device))\n            or out[0].equal(tensor([1, 2, 0, 1], device=device)))\n    assert out[1].equal(tensor([0, 1, 1, 2], device=device))\n    assert (perm.equal(tensor([1, 0, 3, 2], device=device))\n            or perm.equal(tensor([1, 3, 0, 2], device=device)))\n    assert out._T_perm is None\n    assert out._T_index[0] is None and out._T_index[1] is None\n\n    out, perm = out.sort_by('row')\n    assert out[0].equal(tensor([0, 1, 1, 2], device=device))\n    assert (out[1].equal(tensor([1, 0, 2, 1], device=device))\n            or out[1].equal(tensor([1, 2, 0, 1], device=device)))\n    assert (perm.equal(tensor([1, 0, 3, 2], device=device))\n            or perm.equal(tensor([2, 3, 0, 1], device=device)))\n    assert out._T_perm is None\n    assert out._T_index[0] is None and out._T_index[1] is None\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_cat(dtype, device, is_undirected):\n    args = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **args)\n    adj2 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], sparse_size=(4, 4), **args)\n    adj3 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], dtype=dtype, device=device)\n\n    out = torch.cat([adj1, adj2], dim=1)\n    assert out.size() == (2, 8)\n    assert isinstance(out, EdgeIndex)\n    assert out.sparse_size() == (4, 4)\n    assert not out.is_sorted\n    assert out.is_undirected == is_undirected\n\n    assert out._cat_metadata.nnz == [4, 4]\n    assert out._cat_metadata.sparse_size == [(3, 3), (4, 4)]\n    assert out._cat_metadata.sort_order == [None, None]\n    assert out._cat_metadata.is_undirected == [is_undirected, is_undirected]\n\n    out = torch.cat([adj1, adj2, adj3], dim=1)\n    assert out.size() == (2, 12)\n    assert isinstance(out, EdgeIndex)\n    assert out.sparse_size() == (None, None)\n    assert not out.is_sorted\n    assert not out.is_undirected\n\n    out = torch.cat([adj1, adj2], dim=0)\n    assert out.size() == (4, 4)\n    assert not isinstance(out, EdgeIndex)\n\n    inplace = torch.empty(2, 8, dtype=dtype, device=device)\n    out = torch.cat([adj1, adj2], dim=1, out=inplace)\n    assert out.data_ptr() == inplace.data_ptr()\n    assert not isinstance(out, EdgeIndex)\n    assert not isinstance(inplace, EdgeIndex)\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_flip(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n    adj.fill_cache_()\n\n    out = adj.flip(0)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[1, 0, 2, 1], [0, 1, 1, 2]], device=device))\n    assert out.is_sorted_by_col\n    assert out.is_undirected == is_undirected\n    assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device))\n\n    out = adj.flip([0, 1])\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[1, 2, 0, 1], [2, 1, 1, 0]], device=device))\n    assert not out.is_sorted\n    assert out.is_undirected == is_undirected\n    assert out._T_indptr is None\n\n    adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs)\n    out = adj.flip(0)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device))\n    assert out.is_sorted_by_row\n    assert out.is_undirected == is_undirected\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_index_select(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n\n    index = tensor([1, 3], device=device)\n    out = adj.index_select(1, index)\n    assert out.equal(tensor([[1, 2], [0, 1]], device=device))\n    assert isinstance(out, EdgeIndex)\n    assert not out.is_sorted\n    assert not out.is_undirected\n\n    index = tensor([0], device=device)\n    out = adj.index_select(0, index)\n    assert out.equal(tensor([[0, 1, 1, 2]], device=device))\n    assert not isinstance(out, EdgeIndex)\n\n    index = tensor([1, 3], device=device)\n    inplace = torch.empty(2, 2, dtype=dtype, device=device)\n    out = torch.index_select(adj, 1, index, out=inplace)\n    assert out.data_ptr() == inplace.data_ptr()\n    assert not isinstance(out, EdgeIndex)\n    assert not isinstance(inplace, EdgeIndex)\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_narrow(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n\n    out = adj.narrow(dim=1, start=1, length=2)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[1, 1], [0, 2]], device=device))\n    assert out.is_sorted_by_row\n    assert not out.is_undirected\n\n    out = adj.narrow(dim=0, start=0, length=1)\n    assert not isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[0, 1, 1, 2]], device=device))\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_getitem(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n\n    out = adj[:, tensor([False, True, False, True], device=device)]\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[1, 2], [0, 1]], device=device))\n    assert out.is_sorted_by_row\n    assert not out.is_undirected\n\n    out = adj[..., tensor([1, 3], device=device)]\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[1, 2], [0, 1]], device=device))\n    assert not out.is_sorted\n    assert not out.is_undirected\n\n    out = adj[..., 1::2]\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[1, 2], [0, 1]], device=device))\n    assert out.is_sorted_by_row\n    assert not out.is_undirected\n\n    out = adj[...]\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device))\n    assert out.is_sorted_by_row\n    assert out.is_undirected == is_undirected\n\n    out = adj[None]\n    assert not isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[[0, 1, 1, 2], [1, 0, 2, 1]]], device=device))\n\n    out = adj[0, 0]\n    assert not isinstance(out, EdgeIndex)\n    assert out.equal(tensor(0, device=device))\n\n    out = adj[:, 0]\n    assert not isinstance(out, EdgeIndex)\n\n    out = adj[tensor([0], device=device)]\n    assert not isinstance(out, EdgeIndex)\n\n    out = adj[tensor([0], device=device), tensor([0], device=device)]\n    assert not isinstance(out, EdgeIndex)\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_select(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n\n    adj = EdgeIndex(\n        [[0, 1, 1, 2], [1, 0, 2, 1]],\n        sort_order='row',\n        sparse_size=(4, 5),\n        **kwargs,\n    ).fill_cache_()\n\n    out = adj[0]\n    assert isinstance(out, Index)\n    assert out.equal(tensor([0, 1, 1, 2], device=device))\n    assert out.dim_size == 4\n    assert out.is_sorted\n    assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device))\n\n    out = adj[-1]\n    assert isinstance(out, Index)\n    assert out.equal(tensor([1, 0, 2, 1], device=device))\n    assert out.dim_size == 5\n    assert not out.is_sorted\n    assert out._indptr is None\n\n    out = adj[-2, 2:4]\n    assert isinstance(out, Index)\n    assert out.equal(tensor([1, 2], device=device))\n    assert out.dim_size == 4\n    assert out.is_sorted\n    assert out._indptr is None\n\n    adj = EdgeIndex(\n        [[1, 0, 2, 1], [0, 1, 1, 2]],\n        sort_order='col',\n        sparse_size=(5, 4),\n        **kwargs,\n    ).fill_cache_()\n\n    out = adj[1]\n    assert isinstance(out, Index)\n    assert out.equal(tensor([0, 1, 1, 2], device=device))\n    assert out.dim_size == 4\n    assert out.is_sorted\n    assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device))\n\n    out = adj[-2]\n    assert isinstance(out, Index)\n    assert out.equal(tensor([1, 0, 2, 1], device=device))\n    assert out.dim_size == 5\n    assert not out.is_sorted\n    assert out._indptr is None\n\n    out = adj[-1, 2:4]\n    assert isinstance(out, Index)\n    assert out.equal(tensor([1, 2], device=device))\n    assert out.dim_size == 4\n    assert out.is_sorted\n    assert out._indptr is None\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_unbind(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n\n    adj = EdgeIndex(\n        [[0, 1, 1, 2], [1, 0, 2, 1]],\n        sort_order='row',\n        sparse_size=(4, 5),\n        **kwargs,\n    ).fill_cache_()\n\n    row, col = adj\n\n    assert isinstance(row, Index)\n    assert row.equal(tensor([0, 1, 1, 2], device=device))\n    assert row.dim_size == 4\n    assert row.is_sorted\n    assert row._indptr.equal(tensor([0, 1, 3, 4, 4], device=device))\n\n    assert isinstance(col, Index)\n    assert col.equal(tensor([1, 0, 2, 1], device=device))\n    assert col.dim_size == 5\n    assert not col.is_sorted\n    assert col._indptr is None\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('value_dtype', [None, torch.double])\ndef test_to_dense(dtype, device, value_dtype):\n    kwargs = dict(dtype=dtype, device=device)\n    adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs)\n\n    out = adj.to_dense(dtype=value_dtype)\n    assert isinstance(out, Tensor)\n    assert out.size() == (3, 3)\n    expected = [[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]\n    assert out.equal(tensor(expected, dtype=value_dtype, device=device))\n\n    value = torch.arange(1, 5, dtype=value_dtype or torch.float, device=device)\n    out = adj.to_dense(value)\n    assert isinstance(out, Tensor)\n    assert out.size() == (3, 3)\n    expected = [[0.0, 2.0, 0.0], [1.0, 0.0, 4.0], [0.0, 3.0, 0.0]]\n    assert out.equal(tensor(expected, dtype=value_dtype, device=device))\n\n    value = torch.arange(1, 5, dtype=value_dtype or torch.float, device=device)\n    out = adj.to_dense(value.view(-1, 1))\n    assert isinstance(out, Tensor)\n    assert out.size() == (3, 3, 1)\n    expected = [\n        [[0.0], [2.0], [0.0]],\n        [[1.0], [0.0], [4.0]],\n        [[0.0], [3.0], [0.0]],\n    ]\n    assert out.equal(tensor(expected, dtype=value_dtype, device=device))\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_to_sparse_coo(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs)\n\n    if torch_geometric.typing.WITH_PT20:\n        with pytest.raises(ValueError, match=\"Unexpected tensor layout\"):\n            adj.to_sparse(layout='int64')\n\n    if torch_geometric.typing.WITH_PT20:\n        out = adj.to_sparse(layout=torch.sparse_coo)\n    else:\n        out = adj.to_sparse()\n    assert isinstance(out, Tensor)\n    assert out.dtype == torch.float\n    assert out.device == device\n    assert out.layout == torch.sparse_coo\n    assert out.size() == (3, 3)\n    assert adj.equal(out._indices())\n    assert not out.is_coalesced()\n\n    adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs)\n    out = adj.to_sparse_coo()\n    assert isinstance(out, Tensor)\n    assert out.dtype == torch.float\n    assert out.device == device\n    assert out.layout == torch.sparse_coo\n    assert out.size() == (3, 3)\n    assert adj.equal(out._indices())\n    assert not out.is_coalesced()\n\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n    out = adj.to_sparse_coo()\n    assert isinstance(out, Tensor)\n    assert out.dtype == torch.float\n    assert out.device == device\n    assert out.layout == torch.sparse_coo\n    assert out.size() == (3, 3)\n    assert adj.equal(out._indices())\n    assert out.is_coalesced()\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_to_sparse_csr(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    with pytest.raises(ValueError, match=\"not sorted\"):\n        EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs).to_sparse_csr()\n\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n    if torch_geometric.typing.WITH_PT20:\n        out = adj.to_sparse(layout=torch.sparse_csr)\n    else:\n        out = adj.to_sparse_csr()\n    assert isinstance(out, Tensor)\n    assert out.dtype == torch.float\n    assert out.device == device\n    assert out.layout == torch.sparse_csr\n    assert out.size() == (3, 3)\n    assert adj._indptr.equal(out.crow_indices())\n    assert adj[1].equal(out.col_indices())\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_to_sparse_csc(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    with pytest.raises(ValueError, match=\"not sorted\"):\n        EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs).to_sparse_csc()\n\n    adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs)\n    if torch_geometric.typing.WITH_PT20:\n        out = adj.to_sparse(layout=torch.sparse_csc)\n    else:\n        out = adj.to_sparse_csc()\n    assert isinstance(out, Tensor)\n    assert out.dtype == torch.float\n    assert out.device == device\n    assert out.layout == torch.sparse_csc\n    assert out.size() == (3, 3)\n    assert adj._indptr.equal(out.ccol_indices())\n    assert adj[0].equal(out.row_indices())\n\n\n@withCUDA\n@withPackage('torch_sparse')\ndef test_to_sparse_tensor(device):\n    kwargs = dict(device=device)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs)\n    out = adj.to_sparse_tensor()\n    assert isinstance(out, SparseTensor)\n    assert out.sizes() == [3, 3]\n    row, col, _ = out.coo()\n    assert row.equal(adj[0])\n    assert col.equal(adj[1])\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_add(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_size=(3, 3), **kwargs)\n\n    out = torch.add(adj, 2, alpha=2)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[4, 5, 5, 6], [5, 4, 6, 5]], device=device))\n    assert out.is_undirected == is_undirected\n    assert out.sparse_size() == (7, 7)\n\n    out = adj + tensor([2], dtype=dtype, device=device)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device))\n    assert out.is_undirected == is_undirected\n    assert out.sparse_size() == (5, 5)\n\n    out = adj + tensor([[2], [1]], dtype=dtype, device=device)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[2, 3, 3, 4], [2, 1, 3, 2]], device=device))\n    assert not out.is_undirected\n    assert out.sparse_size() == (5, 4)\n\n    out = adj + tensor([[2], [2]], dtype=dtype, device=device)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device))\n    assert out.is_undirected == is_undirected\n    assert out.sparse_size() == (5, 5)\n\n    out = adj.add(adj)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[0, 2, 2, 4], [2, 0, 4, 2]], device=device))\n    assert not out.is_undirected\n    assert out.sparse_size() == (6, 6)\n\n    adj += 2\n    assert isinstance(adj, EdgeIndex)\n    assert adj.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device))\n    assert adj.is_undirected == is_undirected\n    assert adj.sparse_size() == (5, 5)\n\n    with pytest.raises(RuntimeError, match=\"can't be cast\"):\n        adj += 2.5\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_sub(dtype, device, is_undirected):\n    kwargs = dict(dtype=dtype, device=device, is_undirected=is_undirected)\n    adj = EdgeIndex([[4, 5, 5, 6], [5, 4, 6, 5]], sparse_size=(7, 7), **kwargs)\n\n    out = torch.sub(adj, 2, alpha=2)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device))\n    assert out.is_undirected == is_undirected\n    assert out.sparse_size() == (3, 3)\n\n    out = adj - tensor([2], dtype=dtype, device=device)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device))\n    assert out.is_undirected == is_undirected\n    assert out.sparse_size() == (5, 5)\n\n    out = adj - tensor([[2], [1]], dtype=dtype, device=device)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[2, 3, 3, 4], [4, 3, 5, 4]], device=device))\n    assert not out.is_undirected\n    assert out.sparse_size() == (5, 6)\n\n    out = adj - tensor([[2], [2]], dtype=dtype, device=device)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device))\n    assert out.is_undirected == is_undirected\n    assert out.sparse_size() == (5, 5)\n\n    out = adj.sub(adj)\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(tensor([[0, 0, 0, 0], [0, 0, 0, 0]], device=device))\n    assert not out.is_undirected\n    assert out.sparse_size() == (None, None)\n\n    adj -= 2\n    assert isinstance(adj, EdgeIndex)\n    assert adj.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device))\n    assert adj.is_undirected == is_undirected\n    assert adj.sparse_size() == (5, 5)\n\n    with pytest.raises(RuntimeError, match=\"can't be cast\"):\n        adj -= 2.5\n\n\n@withCUDA\n@withPackage('torch_sparse')\n@pytest.mark.parametrize('reduce', ReduceType.__args__)\n@pytest.mark.parametrize('transpose', TRANSPOSE)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_torch_sparse_spmm(device, reduce, transpose, is_undirected):\n    if is_undirected:\n        kwargs = dict(is_undirected=True)\n        adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, **kwargs)\n    else:\n        adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device)\n    adj = adj.sort_by('col' if transpose else 'row').values\n\n    # Basic:\n    x = torch.randn(3, 1, device=device)\n\n    out = _torch_sparse_spmm(adj, x, None, reduce, transpose)\n    exp = _scatter_spmm(adj, x, None, reduce, transpose)\n    assert out.allclose(exp, atol=1e-6)\n\n    # With non-zero values:\n    x = torch.randn(3, 1, device=device)\n    value = torch.rand(adj.size(1), device=device)\n\n    out = _torch_sparse_spmm(adj, x, value, reduce, transpose)\n    exp = _scatter_spmm(adj, x, value, reduce, transpose)\n    assert out.allclose(exp, atol=1e-6)\n\n    # Gradients w.r.t. other:\n    x1 = torch.randn(3, 1, device=device, requires_grad=True)\n    x2 = x1.detach().requires_grad_()\n    grad = torch.randn_like(x1)\n\n    out = _torch_sparse_spmm(adj, x1, None, reduce, transpose)\n    out.backward(grad)\n    exp = _scatter_spmm(adj, x2, None, reduce, transpose)\n    exp.backward(grad)\n    assert x1.grad.allclose(x2.grad, atol=1e-6)\n\n    # Gradients w.r.t. value:\n    x = torch.randn(3, 1, device=device)\n    value1 = torch.rand(adj.size(1), device=device, requires_grad=True)\n    value2 = value1.detach().requires_grad_()\n    grad = torch.randn_like(x)\n\n    out = _torch_sparse_spmm(adj, x, value1, reduce, transpose)\n    out.backward(grad)\n    exp = _scatter_spmm(adj, x, value2, reduce, transpose)\n    exp.backward(grad)\n    assert value1.grad.allclose(value2.grad, atol=1e-6)\n\n\n@withCUDA\n@pytest.mark.parametrize('reduce', ReduceType.__args__)\n@pytest.mark.parametrize('transpose', TRANSPOSE)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_torch_spmm(device, reduce, transpose, is_undirected):\n    if is_undirected:\n        kwargs = dict(is_undirected=True)\n        adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, **kwargs)\n    else:\n        adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device)\n    adj, perm = adj.sort_by('col' if transpose else 'row')\n\n    # Basic:\n    x = torch.randn(3, 2, device=device)\n\n    if ((not x.is_cuda and torch_geometric.typing.WITH_PT20)\n            or reduce in ['sum', 'add']):\n        out = _TorchSPMM.apply(adj, x, None, reduce, transpose)\n        exp = _scatter_spmm(adj, x, None, reduce, transpose)\n        assert out.allclose(exp)\n    else:\n        with pytest.raises(AssertionError):\n            _TorchSPMM.apply(adj, x, None, reduce, transpose)\n\n    # With non-zero values:\n    x = torch.randn(3, 1, device=device)\n    value = torch.rand(adj.size(1), device=device)\n\n    if ((not x.is_cuda and torch_geometric.typing.WITH_PT20)\n            or reduce in ['sum', 'add']):\n        out = _TorchSPMM.apply(adj, x, value, reduce, transpose)\n        exp = _scatter_spmm(adj, x, value, reduce, transpose)\n        assert out.allclose(exp)\n    else:\n        with pytest.raises(AssertionError):\n            _TorchSPMM.apply(adj, x, value, reduce, transpose)\n\n    # Gradients w.r.t. other:\n    x1 = torch.randn(3, 1, device=device, requires_grad=True)\n    x2 = x1.detach().requires_grad_()\n    grad = torch.randn_like(x1)\n\n    if reduce in ['sum', 'add']:\n        out = _TorchSPMM.apply(adj, x1, None, reduce, transpose)\n        out.backward(grad)\n        exp = _scatter_spmm(adj, x2, None, reduce, transpose)\n        exp.backward(grad)\n        assert x1.grad.allclose(x2.grad)\n    else:\n        with pytest.raises(AssertionError):\n            out = _TorchSPMM.apply(adj, x1, None, reduce, transpose)\n            out.backward(grad)\n\n    # Gradients w.r.t. value:\n    x = torch.randn(3, 1, device=device)\n    value1 = torch.rand(adj.size(1), device=device, requires_grad=True)\n    grad = torch.randn_like(x)\n\n    with pytest.raises((AssertionError, NotImplementedError)):\n        out = _TorchSPMM.apply(adj, x, value1, reduce, transpose)\n        out.backward(grad)\n\n\n@withCUDA\n@withoutExtensions\n@pytest.mark.parametrize('reduce', ReduceType.__args__)\n@pytest.mark.parametrize('transpose', TRANSPOSE)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_spmm(without_extensions, device, reduce, transpose, is_undirected):\n    warnings.filterwarnings('ignore', '.*can be accelerated via.*')\n\n    if is_undirected:\n        kwargs = dict(is_undirected=True)\n        adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, **kwargs)\n    else:\n        adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], device=device)\n    adj = adj.sort_by('col' if transpose else 'row').values\n\n    # Basic:\n    x = torch.randn(3, 1, device=device)\n\n    with pytest.raises(ValueError, match=\"to be sorted by\"):\n        adj.matmul(x, reduce=reduce, transpose=not transpose)\n\n    out = adj.matmul(x, reduce=reduce, transpose=transpose)\n    exp = _scatter_spmm(adj, x, None, reduce, transpose)\n    assert out.allclose(exp)\n\n    # With non-zero values:\n    x = torch.randn(3, 1, device=device)\n    value = torch.rand(adj.size(1), device=device)\n\n    with pytest.raises(ValueError, match=\"'other_value' not supported\"):\n        adj.matmul(x, reduce=reduce, other_value=value, transpose=transpose)\n\n    out = adj.matmul(x, value, reduce=reduce, transpose=transpose)\n    exp = _scatter_spmm(adj, x, value, reduce, transpose)\n    assert out.allclose(exp)\n\n    # Gradients w.r.t. other:\n    x1 = torch.randn(3, 1, device=device, requires_grad=True)\n    x2 = x1.detach().requires_grad_()\n    grad = torch.randn_like(x1)\n\n    out = adj.matmul(x1, reduce=reduce, transpose=transpose)\n    out.backward(grad)\n    exp = _scatter_spmm(adj, x2, None, reduce, transpose)\n    exp.backward(grad)\n    assert x1.grad.allclose(x2.grad)\n\n    # Gradients w.r.t. value:\n    x = torch.randn(3, 1, device=device)\n    value1 = torch.rand(adj.size(1), device=device, requires_grad=True)\n    value2 = value1.detach().requires_grad_()\n    grad = torch.randn_like(x)\n\n    out = adj.matmul(x, value1, reduce=reduce, transpose=transpose)\n    out.backward(grad)\n    exp = _scatter_spmm(adj, x, value2, reduce, transpose)\n    exp.backward(grad)\n    assert value1.grad.allclose(value2.grad)\n\n\n@withCUDA\n@pytest.mark.parametrize('reduce', ReduceType.__args__)\n@pytest.mark.parametrize('transpose', TRANSPOSE)\n@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED)\ndef test_spspmm(device, reduce, transpose, is_undirected):\n    if is_undirected:\n        kwargs = dict(device=device, sort_order='row', is_undirected=True)\n        adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs)\n    else:\n        kwargs = dict(device=device, sort_order='row')\n        adj1 = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], **kwargs)\n\n    adj1_dense = adj1.to_dense().t() if transpose else adj1.to_dense()\n    adj2 = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col',\n                     device=device)\n    adj2_dense = adj2.to_dense()\n\n    if reduce in ['sum', 'add']:\n        out, value = adj1.matmul(adj2, reduce=reduce, transpose=transpose)\n        assert isinstance(out, EdgeIndex)\n        assert out.is_sorted_by_row\n        assert out._sparse_size == (3, 3)\n        if not torch_geometric.typing.NO_MKL:\n            assert out._indptr is not None\n        assert torch.allclose(out.to_dense(value), adj1_dense @ adj2_dense)\n    else:\n        with pytest.raises(NotImplementedError, match=\"not yet supported\"):\n            adj1.matmul(adj2, reduce=reduce, transpose=transpose)\n\n\n@withCUDA\n@withoutExtensions\ndef test_matmul(without_extensions, device):\n    kwargs = dict(sort_order='row', device=device)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs)\n    x = torch.randn(3, 1, device=device)\n    expected = adj.to_dense() @ x\n\n    out = adj @ x\n    assert torch.allclose(out, expected)\n\n    out = adj.matmul(x)\n    assert torch.allclose(out, expected)\n\n    out = torch.mm(adj, x)\n    assert torch.allclose(out, expected)\n\n    out = torch.matmul(adj, x)\n    assert torch.allclose(out, expected)\n\n    if torch_geometric.typing.WITH_PT20:\n        out = torch.sparse.mm(adj, x, reduce='sum')\n    else:\n        with pytest.raises(TypeError, match=\"got an unexpected keyword\"):\n            torch.sparse.mm(adj, x, reduce='sum')\n        out = torch.sparse.mm(adj, x)\n    assert torch.allclose(out, expected)\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_sparse_row_narrow(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n\n    out = adj.sparse_narrow(dim=0, start=1, length=1)\n    assert out.equal(tensor([[0, 0], [0, 2]], device=device))\n    assert out.sparse_size() == (1, None)\n    assert out.sort_order == 'row'\n    assert out._indptr.equal(tensor([0, 2], device=device))\n\n    out = adj.sparse_narrow(dim=0, start=2, length=0)\n    assert out.equal(tensor([[], []], device=device))\n    assert out.sparse_size() == (0, None)\n    assert out.sort_order == 'row'\n    assert out._indptr is None\n\n    out = adj.sparse_narrow(dim=1, start=1, length=1)\n    assert out.equal(tensor([[0, 2], [0, 0]], device=device))\n    assert out.sparse_size() == (3, 1)\n    assert out.sort_order == 'row'\n    assert out._indptr is None\n\n    out = adj.sparse_narrow(dim=1, start=2, length=0)\n    assert out.equal(tensor([[], []], device=device))\n    assert out.sparse_size() == (3, 0)\n    assert out.sort_order == 'row'\n    assert out._indptr is None\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_sparse_col_narrow(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs)\n\n    out = adj.sparse_narrow(dim=1, start=1, length=1)\n    assert out.equal(tensor([[0, 2], [0, 0]], device=device))\n    assert out.sparse_size() == (None, 1)\n    assert out.sort_order == 'col'\n    assert out._indptr.equal(tensor([0, 2], device=device))\n\n    out = adj.sparse_narrow(dim=1, start=2, length=0)\n    assert out.equal(tensor([[], []], device=device))\n    assert out.sparse_size() == (None, 0)\n    assert out.sort_order == 'col'\n    assert out._indptr is None\n\n    out = adj.sparse_narrow(dim=0, start=1, length=1)\n    assert out.equal(tensor([[0, 0], [0, 2]], device=device))\n    assert out.sparse_size() == (1, 3)\n    assert out.sort_order == 'col'\n    assert out._indptr is None\n\n    out = adj.sparse_narrow(dim=0, start=2, length=0)\n    assert out.equal(tensor([[], []], device=device))\n    assert out.sparse_size() == (0, 3)\n    assert out.sort_order == 'col'\n    assert out._indptr is None\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_sparse_resize(dtype, device):\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=dtype, device=device)\n\n    out = adj.sort_by('row')[0].fill_cache_()\n    assert out.sparse_size() == (3, 3)\n    assert out._indptr.equal(tensor([0, 1, 3, 4], device=device))\n    assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device))\n    out = out.sparse_resize_(4, 5)\n    assert out.sparse_size() == (4, 5)\n    assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device))\n    assert out._T_indptr.equal(tensor([0, 1, 3, 4, 4, 4], device=device))\n    out = out.sparse_resize_(3, 3)\n    assert out.sparse_size() == (3, 3)\n    assert out._indptr.equal(tensor([0, 1, 3, 4], device=device))\n    assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device))\n    out = out.sparse_resize_(None, None)\n    assert out.sparse_size() == (None, None)\n    assert out._indptr is None\n    assert out._T_indptr is None\n\n    out = adj.sort_by('col')[0].fill_cache_()\n    assert out.sparse_size() == (3, 3)\n    assert out._indptr.equal(tensor([0, 1, 3, 4], device=device))\n    assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device))\n    out = out.sparse_resize_(4, 5)\n    assert out.sparse_size() == (4, 5)\n    assert out._indptr.equal(tensor([0, 1, 3, 4, 4, 4], device=device))\n    assert out._T_indptr.equal(tensor([0, 1, 3, 4, 4], device=device))\n    out = out.sparse_resize_(3, 3)\n    assert out.sparse_size() == (3, 3)\n    assert out._indptr.equal(tensor([0, 1, 3, 4], device=device))\n    assert out._T_indptr.equal(tensor([0, 1, 3, 4], device=device))\n    out = out.sparse_resize_(None, None)\n    assert out.sparse_size() == (None, None)\n    assert out._indptr is None\n    assert out._T_indptr is None\n\n\ndef test_tolist():\n    data = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    adj = EdgeIndex(data)\n    assert adj.tolist() == data.tolist()\n\n\ndef test_numpy():\n    data = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    adj = EdgeIndex(data)\n    assert np.array_equal(adj.numpy(), data.numpy())\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_global_mapping(device, dtype):\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, dtype=dtype)\n    n_id = tensor([10, 20, 30], device=device, dtype=dtype)\n\n    expected = tensor([[10, 20, 20, 30], [20, 10, 30, 20]], device=device)\n    out = n_id[adj]\n    assert not isinstance(out, EdgeIndex)\n    assert out.equal(expected)\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_to_vector(device, dtype):\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], device=device, dtype=dtype)\n    out = adj.to_vector()\n    assert not isinstance(out, EdgeIndex)\n    assert out.equal(tensor([1, 3, 5, 7], device=device))\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_save_and_load(dtype, device, tmp_path):\n    kwargs = dict(dtype=dtype, device=device)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n    adj.fill_cache_()\n\n    assert adj.sort_order == 'row'\n    assert adj._indptr is not None\n\n    path = osp.join(tmp_path, 'edge_index.pt')\n    torch.save(adj, path)\n    out = fs.torch_load(path)\n\n    assert isinstance(out, EdgeIndex)\n    assert out.equal(adj)\n    assert out.sort_order == 'row'\n    assert out._indptr.equal(adj._indptr)\n\n\ndef _collate_fn(edge_indices: List[EdgeIndex]) -> List[EdgeIndex]:\n    return edge_indices\n\n\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('num_workers', [0, 2])\n@pytest.mark.parametrize('pin_memory', [False, True])\ndef test_data_loader(dtype, num_workers, pin_memory):\n    kwargs = dict(dtype=dtype)\n    adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs)\n    adj.fill_cache_()\n\n    loader = torch.utils.data.DataLoader(\n        [adj] * 4,\n        batch_size=2,\n        num_workers=num_workers,\n        collate_fn=_collate_fn,\n        pin_memory=pin_memory,\n        drop_last=True,\n    )\n\n    assert len(loader) == 2\n    for batch in loader:\n        assert isinstance(batch, list)\n        assert len(batch) == 2\n        for adj in batch:\n            assert isinstance(adj, EdgeIndex)\n            assert adj.dtype == dtype\n            assert adj.is_shared() != (num_workers == 0) or pin_memory\n            assert adj._data.is_shared() != (num_workers == 0) or pin_memory\n\n\ndef test_torch_script():\n    class Model(torch.nn.Module):\n        def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor:\n            row, col = edge_index[0], edge_index[1]\n            x_j = x[row]\n            out = scatter(x_j, col, dim_size=edge_index.num_cols)\n            return out\n\n    x = torch.randn(3, 8)\n    # Test that `num_cols` gets picked up by making last node isolated.\n    edge_index = EdgeIndex([[0, 1, 1, 2], [1, 0, 0, 1]], sparse_size=(3, 3))\n\n    model = Model()\n    expected = model(x, edge_index)\n    assert expected.size() == (3, 8)\n\n    # `torch.jit.script` does not support inheritance at the `Tensor` level :(\n    with pytest.raises(RuntimeError, match=\"attribute or method 'num_cols'\"):\n        torch.jit.script(model)\n\n    # A valid workaround is to treat `EdgeIndex` as a regular PyTorch tensor\n    # whenever we are in script mode:\n    class ScriptableModel(torch.nn.Module):\n        def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor:\n            row, col = edge_index[0], edge_index[1]\n            x_j = x[row]\n            dim_size: Optional[int] = None\n            if (not torch.jit.is_scripting()\n                    and isinstance(edge_index, EdgeIndex)):\n                dim_size = edge_index.num_cols\n            out = scatter(x_j, col, dim_size=dim_size)\n            return out\n\n    script_model = torch.jit.script(ScriptableModel())\n    out = script_model(x, edge_index)\n    assert out.size() == (2, 8)\n    assert torch.allclose(out, expected[:2])\n\n\n@onlyLinux\n@withPackage('torch>=2.3')\n@pytest.mark.skip(reason=\"Does not work currently\")\ndef test_compile_basic():\n    import torch._dynamo as dynamo\n\n    class Model(torch.nn.Module):\n        def forward(self, x: Tensor, edge_index: EdgeIndex) -> Tensor:\n            x_j = x[edge_index[0]]\n            out = scatter(x_j, edge_index[1], dim_size=edge_index.num_cols)\n            return out\n\n    x = torch.randn(3, 8)\n    # Test that `num_cols` gets picked up by making last node isolated.\n    edge_index = EdgeIndex(\n        [[0, 1, 1, 2], [1, 0, 0, 1]],\n        sparse_size=(3, 3),\n        sort_order='row',\n    ).fill_cache_()\n\n    model = Model()\n    expected = model(x, edge_index)\n    assert expected.size() == (3, 8)\n\n    explanation = dynamo.explain(model)(x, edge_index)\n    assert explanation.graph_break_count == 0\n\n    compiled_model = torch.compile(model, fullgraph=True)\n    out = compiled_model(x, edge_index)\n    assert torch.allclose(out, expected)\n\n\n@onlyLinux\n@withPackage('torch>=2.3')\n@pytest.mark.skip(reason=\"Does not work currently\")\ndef test_compile_create_edge_index():\n    import torch._dynamo as dynamo\n\n    class Model(torch.nn.Module):\n        def forward(self) -> EdgeIndex:\n            # Wait for: https://github.com/pytorch/pytorch/issues/117806\n            edge_index = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]])\n            return edge_index\n\n    model = Model()\n\n    explanation = dynamo.explain(model)()\n    assert explanation.graph_break_count == 0\n\n    compiled_model = torch.compile(model, fullgraph=True)\n    assert compiled_model() is None\n\n\nif __name__ == '__main__':\n    import argparse\n\n    warnings.filterwarnings('ignore', \".*Sparse CSR tensor support.*\")\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    args = parser.parse_args()\n\n    channels = 128\n    num_nodes = 20_000\n    num_edges = 200_000\n\n    x = torch.randn(num_nodes, channels, device=args.device)\n    edge_index = EdgeIndex(\n        torch.randint(0, num_nodes, size=(2, num_edges), device=args.device),\n        sparse_size=(num_nodes, num_nodes),\n    ).sort_by('row')[0]\n    edge_index.fill_cache_()\n    adj1 = edge_index.to_sparse_csr()\n    adj2 = SparseTensor(\n        row=edge_index[0],\n        col=edge_index[1],\n        sparse_sizes=(num_nodes, num_nodes),\n    )\n\n    def edge_index_mm(edge_index, x, reduce):\n        return edge_index.matmul(x, reduce=reduce)\n\n    def torch_sparse_mm(adj, x):\n        return adj @ x\n\n    def sparse_tensor_mm(adj, x, reduce):\n        return adj.matmul(x, reduce=reduce)\n\n    def scatter_mm(edge_index, x, reduce):\n        return _scatter_spmm(edge_index, x, reduce=reduce)\n\n    funcs = [edge_index_mm, torch_sparse_mm, sparse_tensor_mm, scatter_mm]\n    func_names = ['edge_index', 'torch.sparse', 'SparseTensor', 'scatter']\n\n    for reduce in ['sum', 'mean', 'amin', 'amax']:\n        func_args = [(edge_index, x, reduce), (adj1, x), (adj2, x, reduce),\n                     (edge_index, x, reduce)]\n        print(f\"reduce='{reduce}':\")\n\n        benchmark(\n            funcs=funcs,\n            func_names=func_names,\n            args=func_args,\n            num_steps=100 if args.device == 'cpu' else 1000,\n            num_warmups=50 if args.device == 'cpu' else 500,\n            backward=args.backward,\n        )\n"
  },
  {
    "path": "test/test_experimental.py",
    "content": "import pytest\n\nfrom torch_geometric import (\n    experimental_mode,\n    is_experimental_mode_enabled,\n    set_experimental_mode,\n)\n\n\n@pytest.mark.parametrize('options', ['disable_dynamic_shapes'])\ndef test_experimental_mode(options):\n    assert is_experimental_mode_enabled(options) is False\n    with experimental_mode(options):\n        assert is_experimental_mode_enabled(options) is True\n    assert is_experimental_mode_enabled(options) is False\n\n    with set_experimental_mode(True, options):\n        assert is_experimental_mode_enabled(options) is True\n    assert is_experimental_mode_enabled(options) is False\n\n    with set_experimental_mode(False, options):\n        assert is_experimental_mode_enabled(options) is False\n    assert is_experimental_mode_enabled(options) is False\n\n    set_experimental_mode(True, options)\n    assert is_experimental_mode_enabled(options) is True\n    set_experimental_mode(False, options)\n    assert is_experimental_mode_enabled(options) is False\n"
  },
  {
    "path": "test/test_hash_tensor.py",
    "content": "import os.path as osp\nfrom typing import List\n\nimport numpy as np\nimport pytest\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import HashTensor\nfrom torch_geometric.io import fs\nfrom torch_geometric.testing import (\n    onlyCUDA,\n    onlyLinux,\n    withCUDA,\n    withHashTensor,\n    withPackage,\n)\n\nKEY_DTYPES = [\n    pytest.param(torch.bool, id='bool'),\n    pytest.param(torch.uint8, id='uint8'),\n    pytest.param(torch.int8, id='int8'),\n    pytest.param(torch.int16, id='int16'),\n    pytest.param(torch.int32, id='int32'),\n    pytest.param(torch.int64, id='int64'),\n    pytest.param(torch.float16, id='float16'),\n    pytest.param(torch.bfloat16, id='bfloat16'),\n    pytest.param(torch.float32, id='float32'),\n    pytest.param(torch.float64, id='float64'),\n]\n\n\n@withCUDA\n@withHashTensor\n@pytest.mark.parametrize('dtype', KEY_DTYPES)\ndef test_basic(dtype, device):\n    if dtype != torch.bool:\n        key = torch.tensor([2, 1, 0], dtype=dtype, device=device)\n    else:\n        key = torch.tensor([True, False], device=device)\n\n    tensor = HashTensor(key)\n    if tensor.is_cuda:\n        assert str(tensor) == (f\"HashTensor({tensor.as_tensor().tolist()}, \"\n                               f\"device='{tensor.device}')\")\n    else:\n        assert str(tensor) == f\"HashTensor({tensor.as_tensor().tolist()})\"\n\n    assert tensor.dtype == torch.int64\n    assert tensor.device == device\n    assert tensor.size() == (key.size(0), )\n\n    value = torch.randn(key.size(0), 2, device=device)\n    tensor = HashTensor(key, value)\n    assert str(tensor).startswith(\"HashTensor([\")\n    assert tensor.dtype == torch.float\n    assert tensor.device == device\n    assert tensor.size() == (key.size(0), 2)\n\n\n@withCUDA\n@withHashTensor\n@pytest.mark.parametrize('dtype', KEY_DTYPES)\ndef test_empty(dtype, device):\n    key = torch.empty(0, dtype=dtype, device=device)\n    tensor = HashTensor(key)\n    assert tensor.dtype == torch.int64\n    assert tensor.device == device\n    assert tensor.size() == (0, )\n\n    out = tensor.index_select(0, torch.empty(0, dtype=dtype, device=device))\n    assert not isinstance(out, HashTensor)\n    assert out.dtype == torch.int64\n    assert out.device == device\n    assert out.size() == (0, )\n\n    value = torch.empty(0, device=device)\n    tensor = HashTensor(key, value)\n    assert tensor.dtype == value.dtype\n    assert tensor.device == device\n    assert tensor.size() == (0, )\n\n    out = tensor.index_select(0, torch.empty(0, dtype=dtype, device=device))\n    assert not isinstance(out, HashTensor)\n    assert out.dtype == value.dtype\n    assert out.device == device\n    assert out.size() == (0, )\n\n\n@withCUDA\n@withHashTensor\ndef test_string_key(device):\n    tensor = HashTensor(['1', '2', '3'], device=device)\n    out = tensor[['3', '2', '4']]\n    assert out.equal(torch.tensor([2, 1, -1], device=device))\n\n\n@withCUDA\n@withHashTensor\ndef test_clone(device):\n    key = torch.tensor([2, 1, 0], device=device)\n    value = torch.randn(key.size(0), 2, device=device)\n    tensor = HashTensor(key, value)\n\n    out = tensor.clone()\n    assert isinstance(out, HashTensor)\n    assert out.dtype == tensor.dtype\n    assert out.device == tensor.device\n    assert out._value.data_ptr() != tensor._value.data_ptr()\n\n    out = torch.clone(tensor)\n    assert isinstance(out, HashTensor)\n    assert out.dtype == tensor.dtype\n    assert out.device == tensor.device\n    assert out._value.data_ptr() != tensor._value.data_ptr()\n\n\n@withCUDA\n@withHashTensor\ndef test_share_memory(device):\n    key = torch.tensor([2, 1, 0], device=device)\n    value = torch.randn(key.size(0), 2, device=device)\n    tensor = HashTensor(key, value)\n\n    out = tensor.share_memory_()\n    assert isinstance(out, HashTensor)\n    assert out.is_shared()\n    assert out._value.is_shared()\n    assert out.data_ptr() == tensor.data_ptr()\n\n\n@onlyCUDA\n@withHashTensor\ndef test_pin_memory():\n    key = torch.tensor([2, 1, 0])\n    value = torch.randn(key.size(0), 2)\n    tensor = HashTensor(key, value)\n\n    assert not tensor.is_pinned()\n    out = tensor.pin_memory()\n    assert isinstance(out, HashTensor)\n    assert out.is_pinned()\n\n\n@withCUDA\n@withHashTensor\ndef test_detach(device):\n    key = torch.tensor([2, 1, 0], device=device)\n    value = torch.randn(key.size(0), 2, device=device, requires_grad=True)\n    tensor = HashTensor(key, value)\n\n    assert tensor.requires_grad\n    out = tensor.detach()\n    assert isinstance(out, HashTensor)\n    assert not out.requires_grad\n    assert not out._value.requires_grad\n\n    tensor.detach_()\n    assert not tensor.requires_grad\n    assert not tensor._value.requires_grad\n\n\n@withCUDA\n@withHashTensor\ndef test_contiguous(device):\n    key = torch.tensor([2, 1, 0], device=device)\n    value = torch.randn(2, key.size(0), device=device).t()\n    assert not value.is_contiguous()\n\n    tensor = HashTensor(key, value)\n    assert not tensor.is_contiguous()\n    out = tensor.contiguous()\n    assert out.is_contiguous()\n    assert out._value.is_contiguous()\n\n\n@withCUDA\n@withHashTensor\ndef test_save_and_load(device, tmp_path):\n    key = torch.tensor([2, 1, 0], device=device)\n    value = torch.randn(key.size(0), 2, device=device)\n    tensor = HashTensor(key, value)\n\n    path = osp.join(tmp_path, 'hash_tensor.pt')\n    torch.save(tensor, path)\n    out = fs.torch_load(path)\n\n    assert isinstance(out, HashTensor)\n    assert out._value.equal(value)\n    assert out._min_key.equal(key.min())\n    assert out._max_key.equal(key.max())\n\n\n@withCUDA\n@withHashTensor\ndef test_to_function(device):\n    key = torch.tensor([2, 1, 0], device=device)\n    value = torch.randn(key.size(0), 2, device=device)\n    tensor = HashTensor(key, value)\n\n    out = tensor.to(device)\n    assert isinstance(out, HashTensor)\n    assert id(out) == id(tensor)\n    assert out.device == device\n    assert out._value.device == device\n    assert out._min_key.device == device\n    assert out._max_key.device == device\n\n    out = tensor.to('cpu')\n    assert isinstance(out, HashTensor)\n    if key.is_cuda:\n        assert id(out) != id(tensor)\n    else:\n        assert id(out) == id(tensor)\n    assert out.device == torch.device('cpu')\n    assert out._value.device == torch.device('cpu')\n    assert out._min_key.device == torch.device('cpu')\n    assert out._max_key.device == torch.device('cpu')\n\n    out = tensor.double()\n    assert isinstance(out, HashTensor)\n    assert out._value.dtype == torch.double\n\n\n@withCUDA\n@withHashTensor\ndef test_unsqueeze(device):\n    key = torch.tensor([2, 1, 0], device=device)\n    tensor = HashTensor(key)\n\n    with pytest.raises(IndexError, match=\"in the first dimension\"):\n        tensor.unsqueeze(0)\n\n    with pytest.raises(IndexError, match=\"in the first dimension\"):\n        tensor.unsqueeze(-2)\n\n    with pytest.raises(IndexError, match=\"out of range\"):\n        tensor.unsqueeze(2)\n\n    with pytest.raises(IndexError, match=\"out of range\"):\n        tensor.unsqueeze(-3)\n\n    out = tensor.unsqueeze(-1)\n    assert out.size() == (3, 1)\n    assert out._value is not None\n\n    out = tensor[..., None]\n    assert out.size() == (3, 1)\n    assert out._value is not None\n\n    out = tensor[..., None, None]\n    assert out.size() == (3, 1, 1)\n    assert out._value is not None\n\n    value = torch.randn(key.size(0), 2, device=device)\n    tensor = HashTensor(key, value)\n\n    out = tensor.unsqueeze(-1)\n    assert out.size() == (3, 2, 1)\n    assert out._value is not None\n\n    out = tensor[..., None]\n    assert out.size() == (3, 2, 1)\n    assert out._value is not None\n\n    out = tensor[..., None, None]\n    assert out.size() == (3, 2, 1, 1)\n    assert out._value is not None\n\n    out = tensor.unsqueeze(1)\n    assert out.size() == (3, 1, 2)\n    assert out._value is not None\n\n\n@withCUDA\n@withHashTensor\n@pytest.mark.parametrize('num_keys', [3, 1])\ndef test_squeeze(num_keys, device):\n    key = torch.tensor([2, 1, 0][:num_keys], device=device)\n    tensor = HashTensor(key)\n\n    out = tensor.squeeze()\n    assert isinstance(out, HashTensor)\n    assert out.size() == (num_keys, )\n\n    out = tensor.squeeze(0)\n    assert isinstance(out, HashTensor)\n    assert out.size() == (num_keys, )\n\n    out = tensor.squeeze(-1)\n    assert isinstance(out, HashTensor)\n    assert out.size() == (num_keys, )\n\n    if torch_geometric.typing.WITH_PT20:\n        out = tensor.squeeze([0])\n        assert isinstance(out, HashTensor)\n        assert out.size() == (num_keys, )\n\n    with pytest.raises(IndexError, match=\"out of range\"):\n        tensor.squeeze(1)\n\n    with pytest.raises(IndexError, match=\"out of range\"):\n        tensor.squeeze(-2)\n\n    value = torch.randn(key.size(0), 1, 1, device=device)\n    tensor = HashTensor(key, value)\n\n    out = tensor.squeeze()\n    assert isinstance(out, HashTensor)\n    assert out.size() == (num_keys, )\n\n    out = tensor.squeeze(0)\n    assert isinstance(out, HashTensor)\n    assert out.size() == (num_keys, 1, 1)\n\n    out = tensor.squeeze(-1)\n    assert isinstance(out, HashTensor)\n    assert out.size() == (num_keys, 1)\n\n    if torch_geometric.typing.WITH_PT20:\n        out = tensor.squeeze([0, 1, 2])\n        assert isinstance(out, HashTensor)\n        assert out.size() == (num_keys, )\n\n\n@withCUDA\n@withHashTensor\ndef test_slice(device):\n    key = torch.tensor([2, 1, 0], device=device)\n    tensor = HashTensor(key)\n\n    with pytest.raises(IndexError, match=\"out of range\"):\n        torch.narrow(tensor, dim=-2, start=0, length=2)\n\n    out = tensor[:]\n    assert isinstance(out, HashTensor)\n    assert out._value is None\n\n    out = tensor[-2:4]\n    assert isinstance(out, HashTensor)\n    assert out.as_tensor().equal(torch.tensor([1, 2], device=device))\n\n    out = tensor[..., 0:2]\n    assert isinstance(out, HashTensor)\n    assert out.as_tensor().equal(torch.tensor([0, 1], device=device))\n\n    out = torch.narrow(tensor, dim=0, start=2, length=1)\n    assert isinstance(out, HashTensor)\n    assert out.as_tensor().equal(torch.tensor([2], device=device))\n\n    out = tensor.narrow(dim=0, start=1, length=2)\n    assert isinstance(out, HashTensor)\n    assert out.as_tensor().equal(torch.tensor([1, 2], device=device))\n\n    value = torch.randn(key.size(0), 4, device=device)\n    tensor = HashTensor(key, value)\n\n    out = tensor[0:2]\n    assert isinstance(out, HashTensor)\n    assert out.as_tensor().equal(value[0:2])\n\n    out = tensor[..., 0:2]\n    assert isinstance(out, HashTensor)\n    assert out.as_tensor().equal(value[..., 0:2])\n\n    out = torch.narrow(tensor, dim=1, start=2, length=1)\n    assert isinstance(out, HashTensor)\n    assert out.as_tensor().equal(value[..., 2:3])\n\n\n@withCUDA\n@withHashTensor\n@pytest.mark.parametrize('dtype', KEY_DTYPES)\ndef test_index_select(dtype, device):\n    if dtype != torch.bool:\n        key = torch.tensor([2, 1, 0], dtype=dtype, device=device)\n        query = torch.tensor([0, 3, 2], dtype=dtype, device=device)\n    else:\n        key = torch.tensor([True, False], device=device)\n        query = torch.tensor([False, True], device=device)\n\n    tensor = HashTensor(key)\n\n    out = torch.index_select(tensor, 0, query)\n    assert not isinstance(out, HashTensor)\n    if dtype != torch.bool:\n        assert out.equal(torch.tensor([2, -1, 0], device=device))\n    else:\n        assert out.equal(torch.tensor([1, 0], device=device))\n\n    out = tensor.index_select(-1, query)\n    assert not isinstance(out, HashTensor)\n    if dtype != torch.bool:\n        assert out.equal(torch.tensor([2, -1, 0], device=device))\n    else:\n        assert out.equal(torch.tensor([1, 0], device=device))\n\n    with pytest.raises(IndexError, match=\"out of range\"):\n        torch.index_select(tensor, 1, query)\n\n    with pytest.raises(IndexError, match=\"out of range\"):\n        tensor.index_select(-2, query)\n\n    value = torch.randn(key.size(0), 2, device=device)\n    tensor = HashTensor(key, value)\n\n    out = torch.index_select(tensor, 0, query)\n    assert not isinstance(out, HashTensor)\n    if dtype != torch.bool:\n        expected = torch.full_like(value, float('NaN'))\n        expected[0] = value[2]\n        expected[2] = value[0]\n        assert out.allclose(expected, equal_nan=True)\n    else:\n        assert out.allclose(value.flip(dims=[0]))\n\n    index = torch.tensor([1], device=device)\n    out = tensor.index_select(1, index)\n    assert isinstance(out, HashTensor)\n    assert out.size() == (3 if dtype != torch.bool else 2, 1)\n    assert out.as_tensor().allclose(value[:, 1:])\n\n\n@withCUDA\n@withHashTensor\ndef test_select(device):\n    key = torch.tensor([2, 1, 0], device=device)\n    tensor = HashTensor(key)\n\n    out = tensor[0]\n    assert not isinstance(out, HashTensor)\n    assert out.dim() == 0\n    assert int(out) == 2\n\n    out = tensor.select(0, 4)\n    assert not isinstance(out, HashTensor)\n    assert out.dim() == 0\n    assert int(out) == -1\n\n    with pytest.raises(IndexError, match=\"out of range\"):\n        torch.select(tensor, 1, 0)\n\n    with pytest.raises(IndexError, match=\"out of range\"):\n        tensor.select(-2, 0)\n\n    value = torch.randn(key.size(0), 2, device=device)\n    tensor = HashTensor(key, value)\n\n    out = tensor[0]\n    assert not isinstance(out, HashTensor)\n    assert out.equal(value[2])\n\n    out = tensor.select(-1, 0)\n    assert isinstance(out, HashTensor)\n    assert out.as_tensor().equal(value[:, 0])\n\n\n@withHashTensor\ndef test_tolist():\n    key = torch.tensor([2, 1, 0])\n    value = torch.randn(key.size(0), 2)\n    assert HashTensor(key, value).tolist() == value.tolist()\n\n\n@withHashTensor\ndef test_numpy():\n    key = torch.tensor([2, 1, 0])\n    value = torch.randn(key.size(0), 2)\n    assert np.allclose(HashTensor(key, value).numpy(), value.numpy())\n\n\ndef _collate_fn(hash_tensors: List[HashTensor]) -> List[HashTensor]:\n    return hash_tensors\n\n\n@pytest.mark.parametrize('num_workers', [0, 2])\n@pytest.mark.parametrize('pin_memory', [False, True])\ndef test_data_loader(num_workers, pin_memory):\n    key = torch.tensor([2, 1, 0])\n    value = torch.randn(key.size(0), 2)\n    tensor = HashTensor(key, value)\n\n    loader = torch.utils.data.DataLoader(\n        [tensor] * 4,\n        batch_size=2,\n        num_workers=num_workers,\n        collate_fn=_collate_fn,\n        pin_memory=pin_memory,\n        drop_last=True,\n    )\n\n    assert len(loader) == 2\n    for batch in loader:\n        assert isinstance(batch, list)\n        assert len(batch) == 2\n        for tensor in batch:\n            assert isinstance(tensor, HashTensor)\n            assert tensor.dtype == value.dtype\n            assert tensor.is_shared() != (num_workers == 0) or pin_memory\n\n\n@withCUDA\n@withHashTensor\n@pytest.mark.parametrize('dtype', KEY_DTYPES[:1])\ndef test_getitem(dtype, device):\n    if dtype != torch.bool:\n        key = torch.tensor([20, 10, 0], dtype=dtype, device=device)\n    else:\n        key = torch.tensor([True, False], device=device)\n\n    value = torch.randn(key.size(0), 2, 4, device=device)\n    tensor = HashTensor(key, value)\n\n    if dtype != torch.bool:\n        out = tensor[10, :, None, torch.tensor([1, 2])]\n    else:\n        out = tensor[False, :, None, torch.tensor([1, 2])]\n    assert not isinstance(out, HashTensor)\n    assert out.allclose(value[1, :, None, torch.tensor([1, 2])])\n\n    if dtype != torch.bool:\n        out = tensor[..., 10, :, None, torch.tensor([1, 2])]\n    else:\n        out = tensor[..., False, :, None, torch.tensor([1, 2])]\n    assert not isinstance(out, HashTensor)\n    assert out.allclose(value[1, :, None, torch.tensor([1, 2])])\n\n    if dtype != torch.bool:\n        out = tensor[..., [10, 20], 1, None, 0:2]\n    else:\n        out = tensor[..., [False, True], 1, None, 0:2]\n    assert not isinstance(out, HashTensor)\n    assert out.allclose(value[torch.tensor([1, 0]), 1, None, 0:2])\n\n    if dtype != torch.bool:\n        out = tensor[[10, 20], 1, None, 0:2]\n    else:\n        out = tensor[[False, True], 1, None, 0:2]\n    assert not isinstance(out, HashTensor)\n    assert out.allclose(value[torch.tensor([1, 0]), 1, None, 0:2])\n\n    out = tensor[..., None, torch.tensor([1, 2])]\n    assert isinstance(out, HashTensor)\n    assert out.as_tensor().allclose(value[..., None, torch.tensor([1, 2])])\n\n    out = tensor[...]\n    assert isinstance(out, HashTensor)\n    assert out.size() == value.size()\n\n    out = tensor[:2]\n    assert isinstance(out, HashTensor)\n    assert out.size() == (2, ) + value.size()[1:]\n\n\n@onlyLinux\n@withHashTensor\n@withPackage('torch>=2.3')\n@pytest.mark.skip(reason=\"Does not work currently\")\ndef test_compile_basic():\n    import torch._dynamo as dynamo\n\n    class Model(torch.nn.Module):\n        def forward(self, key: Tensor, query: Tensor) -> Tensor:\n            _map = HashTensor(key)\n            return _map[query]\n\n    key = torch.randperm(10)\n    query = key[:5]\n\n    model = Model()\n    expected = model(key, query)\n    assert expected.equal(torch.arange(query.numel()))\n\n    explanation = dynamo.explain(model)(key, query)\n    assert explanation.graph_break_count == 0\n"
  },
  {
    "path": "test/test_home.py",
    "content": "import os\nimport os.path as osp\n\nfrom torch_geometric import get_home_dir, set_home_dir\nfrom torch_geometric.home import DEFAULT_CACHE_DIR\n\n\ndef test_home():\n    os.environ.pop('PYG_HOME', None)\n    home_dir = osp.expanduser(DEFAULT_CACHE_DIR)\n    assert get_home_dir() == home_dir\n\n    home_dir = '/tmp/test_pyg1'\n    os.environ['PYG_HOME'] = home_dir\n    assert get_home_dir() == home_dir\n\n    home_dir = '/tmp/test_pyg2'\n    set_home_dir(home_dir)\n    assert get_home_dir() == home_dir\n"
  },
  {
    "path": "test/test_index.py",
    "content": "import os.path as osp\nfrom typing import List\n\nimport numpy as np\nimport pytest\nimport torch\nfrom torch import tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import Index\nfrom torch_geometric.io import fs\nfrom torch_geometric.testing import onlyCUDA, withCUDA\nfrom torch_geometric.typing import INDEX_DTYPES\n\nDTYPES = [pytest.param(dtype, id=str(dtype)[6:]) for dtype in INDEX_DTYPES]\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_basic(dtype, device):\n    kwargs = dict(dtype=dtype, device=device, dim_size=3)\n    index = Index([0, 1, 1, 2], **kwargs)\n    index.validate()\n    assert isinstance(index, Index)\n\n    assert str(index).startswith('Index([0, 1, 1, 2], ')\n    assert 'dim_size=3' in str(index)\n    assert (f\"device='{device}'\" in str(index)) == index.is_cuda\n    assert (f'dtype={dtype}' in str(index)) == (dtype != torch.long)\n\n    assert index.dtype == dtype\n    assert index.device == device\n    assert index.dim_size == 3\n    assert not index.is_sorted\n\n    out = index.as_tensor()\n    assert not isinstance(out, Index)\n    assert out.dtype == dtype\n    assert out.device == device\n\n    out = index * 1\n    assert not isinstance(out, Index)\n    assert out.dtype == dtype\n    assert out.device == device\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_identity(dtype, device):\n    kwargs = dict(dtype=dtype, device=device, dim_size=3, is_sorted=True)\n    index = Index([0, 1, 1, 2], **kwargs)\n\n    out = Index(index)\n    assert not isinstance(out.as_tensor(), Index)\n    assert out.data_ptr() == index.data_ptr()\n    assert out.dtype == index.dtype\n    assert out.device == index.device\n    assert out.dim_size == index.dim_size\n    assert out.is_sorted == index.is_sorted\n\n    out = Index(index, dim_size=4, is_sorted=False)\n    assert out.dim_size == 4\n    assert out.is_sorted == index.is_sorted\n\n\ndef test_validate():\n    with pytest.raises(ValueError, match=\"unsupported data type\"):\n        Index([0.0, 1.0])\n    with pytest.raises(ValueError, match=\"needs to be one-dimensional\"):\n        Index([[0], [1]])\n    with pytest.raises(TypeError, match=\"invalid combination of arguments\"):\n        Index(tensor([0, 1]), torch.long)\n    with pytest.raises(TypeError, match=\"invalid keyword arguments\"):\n        Index(tensor([0, 1]), dtype=torch.long)\n    with pytest.raises(ValueError, match=\"contains negative indices\"):\n        Index([-1, 0]).validate()\n    with pytest.raises(ValueError, match=\"than its registered size\"):\n        Index([0, 10], dim_size=2).validate()\n    with pytest.raises(ValueError, match=\"not sorted\"):\n        Index([1, 0], is_sorted=True).validate()\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_fill_cache_(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], is_sorted=True, **kwargs)\n    index.validate().fill_cache_()\n    assert index.dim_size == 3\n    assert index._indptr.dtype == dtype\n    assert index._indptr.equal(tensor([0, 1, 3, 4], device=device))\n\n    index = Index([1, 0, 2, 1], **kwargs)\n    index.validate().fill_cache_()\n    assert index.dim_size == 3\n    assert index._indptr is None\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_dim_resize(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], is_sorted=True, **kwargs).fill_cache_()\n\n    assert index.dim_size == 3\n    assert index._indptr.equal(tensor([0, 1, 3, 4], device=device))\n\n    out = index.dim_resize_(4)\n    assert out.dim_size == 4\n    assert out._indptr.equal(tensor([0, 1, 3, 4, 4], device=device))\n\n    out = index.dim_resize_(3)\n    assert out.dim_size == 3\n    assert out._indptr.equal(tensor([0, 1, 3, 4], device=device))\n\n    out = index.dim_resize_(None)\n    assert out.dim_size is None\n    assert out._indptr is None\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_clone(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], is_sorted=True, dim_size=3, **kwargs)\n\n    out = index.clone()\n    assert isinstance(out, Index)\n    assert out.dtype == dtype\n    assert out.device == device\n    assert out.dim_size == 3\n    assert out.is_sorted\n\n    out = torch.clone(index)\n    assert isinstance(out, Index)\n    assert out.dtype == dtype\n    assert out.device == device\n    assert out.dim_size == 3\n    assert out.is_sorted\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_to_function(dtype, device):\n    kwargs = dict(dtype=dtype)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n    index.fill_cache_()\n\n    index = index.to(device)\n    assert isinstance(index, Index)\n    assert index.device == device\n    assert index._indptr.dtype == dtype\n    assert index._indptr.device == device\n\n    out = index.cpu()\n    assert isinstance(out, Index)\n    assert out.device == torch.device('cpu')\n\n    out = index.to(torch.int)\n    assert out.dtype == torch.int\n    if torch_geometric.typing.WITH_PT20:\n        assert isinstance(out, Index)\n        assert out._indptr.dtype == torch.int\n    else:\n        assert not isinstance(out, Index)\n\n    out = index.to(torch.float)\n    assert not isinstance(out, Index)\n    assert out.dtype == torch.float\n\n    out = index.long()\n    assert isinstance(out, Index)\n    assert out.dtype == torch.int64\n\n    out = index.int()\n    assert out.dtype == torch.int\n    if torch_geometric.typing.WITH_PT20:\n        assert isinstance(out, Index)\n    else:\n        assert not isinstance(out, Index)\n\n\n@onlyCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_cpu_cuda(dtype):\n    kwargs = dict(dtype=dtype)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n    assert index.is_cpu\n\n    out = index.cuda()\n    assert isinstance(out, Index)\n    assert out.is_cuda\n\n    out = out.cpu()\n    assert isinstance(out, Index)\n    assert out.is_cpu\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_share_memory(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n    index.fill_cache_()\n\n    out = index.share_memory_()\n    assert isinstance(out, Index)\n    assert out.is_shared()\n    assert out._data.is_shared()\n    assert out._indptr.is_shared()\n    assert out.data_ptr() == index.data_ptr()\n\n\n@onlyCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_pin_memory(dtype):\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, dtype=dtype)\n    assert not index.is_pinned()\n    out = index.pin_memory()\n    assert out.is_pinned()\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_contiguous(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n\n    assert index.is_contiguous\n    out = index.contiguous()\n    assert isinstance(out, Index)\n    assert out.is_contiguous\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_sort(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([1, 0, 2, 1], dim_size=3, **kwargs)\n\n    index, _ = index.sort()\n    assert isinstance(index, Index)\n    assert index.equal(tensor([0, 1, 1, 2], device=device))\n    assert index.dim_size == 3\n    assert index.is_sorted\n\n    out, perm = index.sort()\n    assert isinstance(out, Index)\n    assert out._data.data_ptr() == index._data.data_ptr()\n    assert perm.equal(tensor([0, 1, 2, 3], device=device))\n    assert out.dim_size == 3\n\n    index, _ = index.sort(descending=True)\n    assert isinstance(index, Index)\n    assert index.equal(tensor([2, 1, 1, 0], device=device))\n    assert index.dim_size == 3\n    assert not index.is_sorted\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_sort_stable(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([1, 0, 2, 1], dim_size=3, **kwargs)\n\n    index, perm = index.sort(stable=True)\n    assert isinstance(index, Index)\n    assert index.equal(tensor([0, 1, 1, 2], device=device))\n    assert perm.equal(tensor([1, 0, 3, 2], device=device))\n    assert index.dim_size == 3\n    assert index.is_sorted\n\n    out, perm = index.sort(stable=True)\n    assert isinstance(out, Index)\n    assert out._data.data_ptr() == index._data.data_ptr()\n    assert perm.equal(tensor([0, 1, 2, 3], device=device))\n    assert out.dim_size == 3\n\n    index, perm = index.sort(descending=True, stable=True)\n    assert isinstance(index, Index)\n    assert index.equal(tensor([2, 1, 1, 0], device=device))\n    assert perm.equal(tensor([3, 1, 2, 0], device=device))\n    assert index.dim_size == 3\n    assert not index.is_sorted\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_cat(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index1 = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n    index2 = Index([1, 2, 2, 3], dim_size=4, is_sorted=True, **kwargs)\n    index3 = Index([1, 2, 2, 3], **kwargs)\n\n    out = torch.cat([index1, index2])\n    assert out.equal(tensor([0, 1, 1, 2, 1, 2, 2, 3], device=device))\n    assert out.size() == (8, )\n    assert isinstance(out, Index)\n    assert out.dim_size == 4\n    assert not out.is_sorted\n\n    assert out._cat_metadata.nnz == [4, 4]\n    assert out._cat_metadata.dim_size == [3, 4]\n    assert out._cat_metadata.is_sorted == [True, True]\n\n    out = torch.cat([index1, index2, index3])\n    assert out.size() == (12, )\n    assert isinstance(out, Index)\n    assert out.dim_size is None\n    assert not out.is_sorted\n\n    out = torch.cat([index1, index2.as_tensor()])\n    assert out.size() == (8, )\n    assert not isinstance(out, Index)\n\n    inplace = torch.empty(8, dtype=dtype, device=device)\n    out = torch.cat([index1, index2], out=inplace)\n    assert out.equal(tensor([0, 1, 1, 2, 1, 2, 2, 3], device=device))\n    assert out.data_ptr() == inplace.data_ptr()\n    assert not isinstance(out, Index)\n    assert not isinstance(inplace, Index)\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_flip(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n\n    out = index.flip(0)\n    assert isinstance(out, Index)\n    assert out.equal(tensor([2, 1, 1, 0], device=device))\n    assert out.dim_size == 3\n    assert not out.is_sorted\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_index_select(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n\n    i = tensor([1, 3], device=device)\n    out = index.index_select(0, i)\n    assert out.equal(tensor([1, 2], device=device))\n    assert isinstance(out, Index)\n    assert out.dim_size == 3\n    assert not out.is_sorted\n\n    inplace = torch.empty(2, dtype=dtype, device=device)\n    out = torch.index_select(index, 0, i, out=inplace)\n    assert out.equal(tensor([1, 2], device=device))\n    assert out.data_ptr() == inplace.data_ptr()\n    assert not isinstance(out, Index)\n    assert not isinstance(inplace, Index)\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_narrow(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n\n    out = index.narrow(0, start=1, length=2)\n    assert isinstance(out, Index)\n    assert out.equal(tensor([1, 1], device=device))\n    assert out.dim_size == 3\n    assert out.is_sorted\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_getitem(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n\n    out = index[:]\n    assert isinstance(out, Index)\n    assert out._data.data_ptr() == index._data.data_ptr()\n    assert out.equal(tensor([0, 1, 1, 2], device=device))\n    assert out.dim_size == 3\n    assert out.is_sorted\n\n    out = index[tensor([False, True, False, True], device=device)]\n    assert isinstance(out, Index)\n    assert out.equal(tensor([1, 2], device=device))\n    assert out.dim_size == 3\n    assert out.is_sorted\n\n    out = index[tensor([1, 3], device=device)]\n    assert isinstance(out, Index)\n    assert out.equal(tensor([1, 2], device=device))\n    assert out.dim_size == 3\n    assert not out.is_sorted\n\n    out = index[1:3]\n    assert isinstance(out, Index)\n    assert out.equal(tensor([1, 1], device=device))\n    assert out.dim_size == 3\n    assert out.is_sorted\n\n    out = index[...]\n    assert isinstance(out, Index)\n    assert out._data.data_ptr() == index._data.data_ptr()\n    assert out.equal(tensor([0, 1, 1, 2], device=device))\n    assert out.dim_size == 3\n    assert out.is_sorted\n\n    out = index[..., 1:3]\n    assert isinstance(out, Index)\n    assert out.equal(tensor([1, 1], device=device))\n    assert out.dim_size == 3\n    assert out.is_sorted\n\n    out = index[None, 1:3]\n    assert not isinstance(out, Index)\n    assert out.equal(tensor([[1, 1]], device=device))\n\n    out = index[1:3, None]\n    assert not isinstance(out, Index)\n    assert out.equal(tensor([[1], [1]], device=device))\n\n    out = index[0]\n    assert not isinstance(out, Index)\n    assert out.equal(tensor(0, device=device))\n\n    tmp = torch.randn(3, device=device)\n    out = tmp[index]\n    assert not isinstance(out, Index)\n    assert out.equal(tmp[index.as_tensor()])\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_add(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n\n    out = torch.add(index, 2, alpha=2)\n    assert isinstance(out, Index)\n    assert out.equal(tensor([4, 5, 5, 6], device=device))\n    assert out.dim_size == 7\n    assert out.is_sorted\n\n    out = index + tensor([2], dtype=dtype, device=device)\n    assert isinstance(out, Index)\n    assert out.equal(tensor([2, 3, 3, 4], device=device))\n    assert out.dim_size == 5\n    assert out.is_sorted\n\n    out = tensor([2], dtype=dtype, device=device) + index\n    assert isinstance(out, Index)\n    assert out.equal(tensor([2, 3, 3, 4], device=device))\n    assert out.dim_size == 5\n    assert out.is_sorted\n\n    out = index.add(index)\n    assert isinstance(out, Index)\n    assert out.equal(tensor([0, 2, 2, 4], device=device))\n    assert out.dim_size == 6\n    assert not out.is_sorted\n\n    index += 2\n    assert isinstance(index, Index)\n    assert index.equal(tensor([2, 3, 3, 4], device=device))\n    assert index.dim_size == 5\n    assert index.is_sorted\n\n    with pytest.raises(RuntimeError, match=\"can't be cast\"):\n        index += 2.5\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_sub(dtype, device):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([4, 5, 5, 6], dim_size=7, is_sorted=True, **kwargs)\n\n    out = torch.sub(index, 2, alpha=2)\n    assert isinstance(out, Index)\n    assert out.equal(tensor([0, 1, 1, 2], device=device))\n    assert out.dim_size == 3\n    assert out.is_sorted\n\n    out = index - tensor([2], dtype=dtype, device=device)\n    assert isinstance(out, Index)\n    assert out.equal(tensor([2, 3, 3, 4], device=device))\n    assert out.dim_size == 5\n    assert out.is_sorted\n\n    out = tensor([6], dtype=dtype, device=device) - index\n    assert isinstance(out, Index)\n    assert out.equal(tensor([2, 1, 1, 0], device=device))\n    assert out.dim_size is None\n    assert not out.is_sorted\n\n    out = index.sub(index)\n    assert isinstance(out, Index)\n    assert out.equal(tensor([0, 0, 0, 0], device=device))\n    assert out.dim_size is None\n    assert not out.is_sorted\n\n    index -= 2\n    assert isinstance(index, Index)\n    assert index.equal(tensor([2, 3, 3, 4], device=device))\n    assert index.dim_size == 5\n    assert not out.is_sorted\n\n    with pytest.raises(RuntimeError, match=\"can't be cast\"):\n        index -= 2.5\n\n\ndef test_to_list():\n    data = torch.tensor([0, 1, 1, 2])\n    index = Index(data)\n    assert index.tolist() == data.tolist()\n\n\ndef test_numpy():\n    data = torch.tensor([0, 1, 1, 2])\n    index = Index(data)\n    assert np.array_equal(index.numpy(), data.numpy())\n\n\n@withCUDA\n@pytest.mark.parametrize('dtype', DTYPES)\ndef test_save_and_load(dtype, device, tmp_path):\n    kwargs = dict(dtype=dtype, device=device)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n    index.fill_cache_()\n\n    path = osp.join(tmp_path, 'edge_index.pt')\n    torch.save(index, path)\n    out = fs.torch_load(path)\n\n    assert isinstance(out, Index)\n    assert out.equal(index)\n    assert out.dim_size == 3\n    assert out.is_sorted\n    assert out._indptr.equal(index._indptr)\n\n\ndef _collate_fn(indices: List[Index]) -> List[Index]:\n    return indices\n\n\n@pytest.mark.parametrize('dtype', DTYPES)\n@pytest.mark.parametrize('num_workers', [0, 2])\n@pytest.mark.parametrize('pin_memory', [False, True])\ndef test_data_loader(dtype, num_workers, pin_memory):\n    kwargs = dict(dtype=dtype)\n    index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True, **kwargs)\n    index.fill_cache_()\n\n    loader = torch.utils.data.DataLoader(\n        [index] * 4,\n        batch_size=2,\n        num_workers=num_workers,\n        collate_fn=_collate_fn,\n        pin_memory=pin_memory,\n        drop_last=True,\n    )\n\n    assert len(loader) == 2\n    for batch in loader:\n        assert isinstance(batch, list)\n        assert len(batch) == 2\n        for index in batch:\n            assert isinstance(index, Index)\n            assert index.dtype == dtype\n            assert index.is_shared() != (num_workers == 0) or pin_memory\n            assert index._data.is_shared() != (num_workers == 0) or pin_memory\n"
  },
  {
    "path": "test/test_inspector.py",
    "content": "import inspect\nfrom typing import Any, Dict, Final, List, Optional, Set, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.inspector import Inspector, Parameter, Signature\nfrom torch_geometric.nn import GATConv, SAGEConv\nfrom torch_geometric.typing import OptPairTensor\n\n\ndef test_eval_type() -> None:\n    inspector = Inspector(SAGEConv)\n\n    assert inspector.eval_type('Tensor') == Tensor\n    assert inspector.eval_type('List[Tensor]') == List[Tensor]\n    assert inspector.eval_type('Tuple[Tensor, int]') == Tuple[Tensor, int]\n    assert inspector.eval_type('Tuple[int, ...]') == Tuple[int, ...]\n\n\ndef test_type_repr() -> None:\n    inspector = Inspector(SAGEConv)\n\n    assert inspector.type_repr(Any) == 'typing.Any'\n    assert inspector.type_repr(Final) == 'typing.Final'\n    assert inspector.type_repr(OptPairTensor) == (\n        'Tuple[Tensor, Optional[Tensor]]')\n    assert inspector.type_repr(\n        Final[Optional[Tensor]]) == ('typing.Final[Optional[Tensor]]')\n    assert inspector.type_repr(Union[None, Tensor]) == 'Optional[Tensor]'\n    assert inspector.type_repr(Optional[Tensor]) == 'Optional[Tensor]'\n    assert inspector.type_repr(Set[Tensor]) == 'typing.Set[Tensor]'\n    assert inspector.type_repr(List) == 'List'\n    assert inspector.type_repr(Tuple) == 'Tuple'\n    assert inspector.type_repr(Set) == 'typing.Set'\n    assert inspector.type_repr(Dict) == 'typing.Dict'\n    assert inspector.type_repr(Dict[str, Tuple[Tensor, Tensor]]) == (  #\n        'typing.Dict[str, Tuple[Tensor, Tensor]]')\n    assert inspector.type_repr(Tuple[int, ...]) == 'Tuple[int, ...]'\n    assert inspector.type_repr(Union[int, str, None]) == (  #\n        'Union[int, str, None]')\n\n\ndef test_inspector_sage_conv() -> None:\n    inspector = Inspector(SAGEConv)\n    assert str(inspector) == 'Inspector(SAGEConv)'\n    assert inspector.implements('message')\n    assert inspector.implements('message_and_aggregate')\n\n    out = inspector.inspect_signature(SAGEConv.message)\n    assert isinstance(out, Signature)\n    assert out.param_dict == {\n        'x_j': Parameter('x_j', Tensor, 'Tensor', inspect._empty)\n    }\n    assert out.return_type == Tensor\n    assert inspector.get_flat_params(['message', 'message']) == [\n        Parameter('x_j', Tensor, 'Tensor', inspect._empty),\n    ]\n    assert inspector.get_flat_param_names(['message']) == ['x_j']\n\n    kwargs = {'x_j': torch.randn(5), 'x_i': torch.randn(5)}\n    data = inspector.collect_param_data('message', kwargs)\n    assert len(data) == 1\n    assert torch.allclose(data['x_j'], kwargs['x_j'])\n\n    assert inspector.get_params_from_method_call(SAGEConv.propagate) == {\n        'x': Parameter('x', OptPairTensor, 'OptPairTensor', inspect._empty),\n    }\n\n\ndef test_inspector_gat_conv() -> None:\n    inspector = Inspector(GATConv)\n    assert str(inspector) == 'Inspector(GATConv)'\n    assert inspector.implements('message')\n    assert not inspector.implements('message_and_aggregate')\n\n    out = inspector.inspect_signature(GATConv.message)\n    assert isinstance(out, Signature)\n    assert out.param_dict == {\n        'x_j': Parameter('x_j', Tensor, 'Tensor', inspect._empty),\n        'alpha': Parameter('alpha', Tensor, 'Tensor', inspect._empty),\n    }\n    assert out.return_type == Tensor\n    assert inspector.get_flat_params(['message', 'message']) == [\n        Parameter('x_j', Tensor, 'Tensor', inspect._empty),\n        Parameter('alpha', Tensor, 'Tensor', inspect._empty),\n    ]\n    assert inspector.get_flat_param_names(['message']) == ['x_j', 'alpha']\n\n    kwargs = {'x_j': torch.randn(5), 'alpha': torch.randn(5)}\n    data = inspector.collect_param_data('message', kwargs)\n    assert len(data) == 2\n    assert torch.allclose(data['x_j'], kwargs['x_j'])\n    assert torch.allclose(data['alpha'], kwargs['alpha'])\n\n    assert inspector.get_params_from_method_call(SAGEConv.propagate) == {\n        'x': Parameter('x', OptPairTensor, 'OptPairTensor', inspect._empty),\n        'alpha': Parameter('alpha', Tensor, 'Tensor', inspect._empty),\n    }\n\n\ndef test_get_params_from_method_call() -> None:\n    class FromMethodCall1:\n        propagate_type = {'x': Tensor}\n\n    inspector = Inspector(FromMethodCall1)\n    assert inspector.get_params_from_method_call('propagate') == {\n        'x': Parameter('x', Tensor, 'Tensor', inspect._empty),\n    }\n\n    class FromMethodCall2:\n        # propagate_type: (x: Tensor)\n        pass\n\n    inspector = Inspector(FromMethodCall2)\n    assert inspector.get_params_from_method_call('propagate') == {\n        'x': Parameter('x', Tensor, 'Tensor', inspect._empty),\n    }\n\n    class FromMethodCall3:\n        def forward(self) -> None:\n            self.propagate(  # type: ignore\n                torch.randn(5, 5),\n                x=None,\n                size=None,\n            )\n\n    inspector = Inspector(FromMethodCall3)\n    exclude = [0, 'size']\n    assert inspector.get_params_from_method_call('propagate', exclude) == {\n        'x': Parameter('x', Tensor, 'Tensor', inspect._empty),\n    }\n\n    class FromMethodCall4:\n        pass\n\n    inspector = Inspector(FromMethodCall4)\n    assert inspector.get_params_from_method_call('propagate') == {}\n"
  },
  {
    "path": "test/test_isinstance.py",
    "content": "import torch\n\nfrom torch_geometric import is_torch_instance\nfrom torch_geometric.testing import onlyLinux, withPackage\n\n\ndef test_basic():\n    assert is_torch_instance(torch.nn.Linear(1, 1), torch.nn.Linear)\n\n\n@onlyLinux\n@withPackage('torch>=2.0.0')\ndef test_compile():\n    model = torch.compile(torch.nn.Linear(1, 1))\n    assert not isinstance(model, torch.nn.Linear)\n    assert is_torch_instance(model, torch.nn.Linear)\n"
  },
  {
    "path": "test/test_onnx.py",
    "content": "import os\nimport tempfile\nimport warnings\nfrom typing import Any\nfrom unittest.mock import patch\n\nimport pytest\nimport torch\n\nfrom torch_geometric import is_in_onnx_export, safe_onnx_export\n\n# Global mock to prevent ANY real ONNX calls in tests\n# This ensures no deprecation warnings or real ONNX issues\npytestmark = pytest.mark.filterwarnings(\"ignore::DeprecationWarning\")\n\n\nclass SimpleModel(torch.nn.Module):\n    \"\"\"Simple model for testing ONNX export.\"\"\"\n    def __init__(self) -> None:\n        super().__init__()\n        self.linear = torch.nn.Linear(4, 2)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.linear(x)\n\n\ndef test_is_in_onnx_export() -> None:\n    \"\"\"Test is_in_onnx_export function.\"\"\"\n    assert not is_in_onnx_export()\n\n\ndef test_safe_onnx_export_ci_resilient() -> None:\n    \"\"\"Test safe_onnx_export handles CI environment issues gracefully.\"\"\"\n    model = SimpleModel()\n    x = torch.randn(3, 4)\n\n    # Use mocking to prevent real ONNX calls and deprecation warnings\n    with patch('torch.onnx.export', return_value=None) as mock_export:\n        with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:\n            try:\n                # Test with skip_on_error=True - should never fail\n                result = safe_onnx_export(model, (x, ), f.name,\n                                          skip_on_error=True)\n                # Should always succeed with mocking\n                assert result is True\n\n                # Verify the mock was called correctly\n                mock_export.assert_called_once()\n                call_args = mock_export.call_args[0]\n                assert call_args[0] is model\n                assert isinstance(call_args[1], tuple)\n                assert call_args[2] == f.name\n\n            finally:\n                if os.path.exists(f.name):\n                    try:\n                        os.unlink(f.name)\n                    except (PermissionError, OSError):\n                        pass  # Ignore file lock issues\n\n\ndef test_safe_onnx_export_success() -> None:\n    \"\"\"Test successful ONNX export with pure mocking.\"\"\"\n    model = SimpleModel()\n    x = torch.randn(3, 4)\n\n    # Use comprehensive mocking to avoid any real ONNX calls\n    with patch('torch.onnx.export', return_value=None) as mock_export:\n        with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:\n            try:\n                # Test with tuple args - should succeed with mock\n                result = safe_onnx_export(model, (x, ), f.name)\n                assert result is True\n\n                # Verify torch.onnx.export was called with correct args\n                mock_export.assert_called()\n                call_args = mock_export.call_args[0]\n                assert call_args[0] is model  # model\n                assert isinstance(call_args[1], tuple)  # args as tuple\n                assert call_args[2] == f.name  # file path\n\n                # Reset mock for second test\n                mock_export.reset_mock()\n\n                # Test with single tensor (should be converted to tuple)\n                result = safe_onnx_export(model, x, f.name)\n                assert result is True\n\n                # Verify single tensor was converted to tuple\n                call_args = mock_export.call_args[0]\n                assert isinstance(call_args[1], tuple)\n\n            finally:\n                if os.path.exists(f.name):\n                    try:\n                        try:\n\n                            os.unlink(f.name)\n\n                        except (PermissionError, OSError):\n\n                            pass\n                    except (PermissionError, OSError):\n                        pass\n\n\ndef test_safe_onnx_export_with_skip_on_error() -> None:\n    \"\"\"Test safe_onnx_export with skip_on_error=True.\"\"\"\n    model = SimpleModel()\n    x = torch.randn(3, 4)\n\n    # Mock torch.onnx.export to raise SerdeError\n    with patch('torch.onnx.export') as mock_export:\n        mock_export.side_effect = Exception(\n            \"onnx_ir.serde.SerdeError: allowzero\")\n\n        with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:\n            try:\n                # Should return False instead of raising\n                result = safe_onnx_export(model, (x, ), f.name,\n                                          skip_on_error=True)\n                assert result is False\n            finally:\n                if os.path.exists(f.name):\n                    try:\n\n                        os.unlink(f.name)\n\n                    except (PermissionError, OSError):\n\n                        pass\n\n\ndef test_serde_error_patterns() -> None:\n    \"\"\"Test detection of various SerdeError patterns.\"\"\"\n    model = SimpleModel()\n    x = torch.randn(3, 4)\n\n    error_patterns = [\n        \"onnx_ir.serde.SerdeError: allowzero attribute\",\n        \"ValueError: Value out of range: 1\", \"serialize_model_into failed\",\n        \"serialize_attribute_into failed\"\n    ]\n\n    for error_msg in error_patterns:\n        # Use multiple patch targets to ensure comprehensive mocking\n        with patch('torch.onnx.export') as mock_export, \\\n             patch('torch_geometric._onnx.torch.onnx.export') as mock_export2:\n\n            mock_export.side_effect = Exception(error_msg)\n            mock_export2.side_effect = Exception(error_msg)\n\n            with tempfile.NamedTemporaryFile(suffix='.onnx',\n                                             delete=False) as f:\n                try:\n                    result = safe_onnx_export(model, (x, ), f.name,\n                                              skip_on_error=True)\n                    assert result is False\n                finally:\n                    if os.path.exists(f.name):\n                        try:\n                            try:\n\n                                os.unlink(f.name)\n\n                            except (PermissionError, OSError):\n\n                                pass\n                        except (PermissionError, OSError):\n                            pass  # Ignore file lock issues\n\n\ndef test_non_serde_error_reraise() -> None:\n    \"\"\"Test that non-SerdeError exceptions are re-raised.\"\"\"\n    model = SimpleModel()\n    x = torch.randn(3, 4)\n\n    # Use comprehensive mocking to prevent real ONNX calls\n    with patch('torch.onnx.export') as mock_export:\n        mock_export.side_effect = ValueError(\"Some other error\")\n\n        with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:\n            try:\n                with pytest.raises(ValueError, match=\"Some other error\"):\n                    safe_onnx_export(model, (x, ), f.name)\n            finally:\n                if os.path.exists(f.name):\n                    try:\n\n                        os.unlink(f.name)\n\n                    except (PermissionError, OSError):\n\n                        pass\n\n\ndef test_dynamo_fallback() -> None:\n    \"\"\"Test dynamo=False fallback strategy.\"\"\"\n    model = SimpleModel()\n    x = torch.randn(3, 4)\n\n    call_count = 0\n\n    def mock_export_side_effect(*_args: Any, **kwargs: Any) -> None:\n        nonlocal call_count\n        call_count += 1\n        if call_count == 1:\n            # First call fails\n            raise Exception(\"onnx_ir.serde.SerdeError: allowzero\")\n        elif call_count == 2 and not kwargs.get('dynamo', True):\n            # Second call succeeds with dynamo=False\n            return None\n        else:\n            raise Exception(\"Unexpected call\")\n\n    with patch('torch.onnx.export', side_effect=mock_export_side_effect):\n        with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:\n            try:\n                result = safe_onnx_export(model, (x, ), f.name, dynamo=True)\n                assert result is True\n                assert call_count == 2\n            finally:\n                if os.path.exists(f.name):\n                    try:\n\n                        os.unlink(f.name)\n\n                    except (PermissionError, OSError):\n\n                        pass\n\n\ndef test_opset_fallback() -> None:\n    \"\"\"Test opset version fallback strategy.\"\"\"\n    model = SimpleModel()\n    x = torch.randn(3, 4)\n\n    call_count = 0\n\n    def mock_export_side_effect(*_args: Any, **kwargs: Any) -> None:\n        nonlocal call_count\n        call_count += 1\n        # Fail until we get to opset_version=17\n        if kwargs.get('opset_version') == 17:\n            # This call succeeds\n            return None\n        else:\n            # All other calls fail\n            raise Exception(\"onnx_ir.serde.SerdeError: allowzero\")\n\n    with patch('torch.onnx.export', side_effect=mock_export_side_effect):\n        with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:\n            try:\n                result = safe_onnx_export(model, (x, ), f.name,\n                                          opset_version=18)\n                # Should succeed when opset_version=17 is tried\n                assert result is True\n            finally:\n                if os.path.exists(f.name):\n                    try:\n                        try:\n\n                            os.unlink(f.name)\n\n                        except (PermissionError, OSError):\n\n                            pass\n                    except (PermissionError, OSError):\n                        pass\n\n\ndef test_all_strategies_fail() -> None:\n    \"\"\"Test when all workaround strategies fail.\"\"\"\n    model = SimpleModel()\n    x = torch.randn(3, 4)\n\n    with patch('torch.onnx.export') as mock_export:\n        mock_export.side_effect = Exception(\n            \"onnx_ir.serde.SerdeError: allowzero\")\n\n        with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:\n            try:\n                # Should raise RuntimeError when skip_on_error=False\n                with pytest.raises(RuntimeError,\n                                   match=\"Failed to export model to ONNX\"):\n                    safe_onnx_export(model, (x, ), f.name, skip_on_error=False)\n\n                # Should return False when skip_on_error=True\n                result = safe_onnx_export(model, (x, ), f.name,\n                                          skip_on_error=True)\n                assert result is False\n            finally:\n                if os.path.exists(f.name):\n                    try:\n\n                        os.unlink(f.name)\n\n                    except (PermissionError, OSError):\n\n                        pass\n\n\ndef test_pytest_environment_detection() -> None:\n    \"\"\"Test pytest environment detection for better error messages.\"\"\"\n    model = SimpleModel()\n    x = torch.randn(3, 4)\n\n    with patch('torch.onnx.export') as mock_export:\n        mock_export.side_effect = Exception(\n            \"onnx_ir.serde.SerdeError: allowzero\")\n\n        # Set pytest environment variable\n        with patch.dict(os.environ, {'PYTEST_CURRENT_TEST': 'test_something'}):\n            with tempfile.NamedTemporaryFile(suffix='.onnx',\n                                             delete=False) as f:\n                try:\n                    with pytest.raises(RuntimeError) as exc_info:\n                        safe_onnx_export(model, (x, ), f.name,\n                                         skip_on_error=False)\n\n                    # Should contain pytest-specific guidance\n                    assert \"pytest environments\" in str(exc_info.value)\n                    assert \"torch.jit.script()\" in str(exc_info.value)\n                finally:\n                    if os.path.exists(f.name):\n                        try:\n\n                            os.unlink(f.name)\n\n                        except (PermissionError, OSError):\n\n                            pass\n\n\ndef test_warnings_emitted() -> None:\n    \"\"\"Test that appropriate warnings are emitted during workarounds.\"\"\"\n    model = SimpleModel()\n    x = torch.randn(3, 4)\n\n    call_count = 0\n\n    def mock_export_side_effect(*_args: Any, **_kwargs: Any) -> None:\n        nonlocal call_count\n        call_count += 1\n        if call_count == 1:\n            raise Exception(\"onnx_ir.serde.SerdeError: allowzero\")\n        elif call_count == 2:\n            return None  # Success on dynamo fallback\n        else:\n            raise Exception(\"Unexpected call\")\n\n    with patch('torch.onnx.export', side_effect=mock_export_side_effect):\n        with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:\n            try:\n                with warnings.catch_warnings(record=True) as w:\n                    warnings.simplefilter(\"always\")\n                    result = safe_onnx_export(model, (x, ), f.name,\n                                              dynamo=True)\n\n                    assert result is True\n                    assert len(w) >= 2  # Initial error + dynamo fallback\n                    assert any(\"allowzero boolean attribute bug\" in str(\n                        warning.message) for warning in w)\n                    assert any(\n                        \"dynamo=False as workaround\" in str(warning.message)\n                        for warning in w)\n            finally:\n                if os.path.exists(f.name):\n                    try:\n\n                        os.unlink(f.name)\n\n                    except (PermissionError, OSError):\n\n                        pass\n\n\n@pytest.mark.parametrize(\n    \"args_input\",\n    [\n        torch.randn(3, 4),  # Single tensor\n        (torch.randn(3, 4), ),  # Tuple with one tensor\n        (torch.randn(3, 4), torch.randn(3, 2)),  # Tuple with multiple tensors\n    ])\ndef test_args_conversion(args_input: Any) -> None:\n    \"\"\"Test that args are properly converted to tuple format.\"\"\"\n    model = SimpleModel()\n\n    with patch('torch.onnx.export') as mock_export:\n        mock_export.return_value = None\n\n        with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f:\n            try:\n                result = safe_onnx_export(model, args_input, f.name)\n                assert result is True\n\n                # Check that torch.onnx.export was called with tuple args\n                mock_export.assert_called_once()\n                call_args = mock_export.call_args[0]\n                assert isinstance(call_args[1], tuple)  # args should be tuple\n            finally:\n                if os.path.exists(f.name):\n                    try:\n\n                        os.unlink(f.name)\n\n                    except (PermissionError, OSError):\n\n                        pass\n"
  },
  {
    "path": "test/test_seed.py",
    "content": "import random\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric import seed_everything\n\n\ndef test_seed_everything():\n    seed_everything(0)\n\n    assert random.randint(0, 100) == 49\n    assert random.randint(0, 100) == 97\n    assert np.random.randint(0, 100) == 44\n    assert np.random.randint(0, 100) == 47\n    assert int(torch.randint(0, 100, (1, ))) == 44\n    assert int(torch.randint(0, 100, (1, ))) == 39\n"
  },
  {
    "path": "test/test_typing.py",
    "content": "import pytest\n\nfrom torch_geometric.typing import EdgeTypeStr\n\n\ndef test_edge_type_str():\n    edge_type_str = EdgeTypeStr('a__links__b')\n    assert isinstance(edge_type_str, str)\n    assert edge_type_str == 'a__links__b'\n    assert edge_type_str.to_tuple() == ('a', 'links', 'b')\n\n    edge_type_str = EdgeTypeStr('a', 'b')\n    assert isinstance(edge_type_str, str)\n    assert edge_type_str == 'a__to__b'\n    assert edge_type_str.to_tuple() == ('a', 'to', 'b')\n\n    edge_type_str = EdgeTypeStr(('a', 'b'))\n    assert isinstance(edge_type_str, str)\n    assert edge_type_str == 'a__to__b'\n    assert edge_type_str.to_tuple() == ('a', 'to', 'b')\n\n    edge_type_str = EdgeTypeStr('a', 'links', 'b')\n    assert isinstance(edge_type_str, str)\n    assert edge_type_str == 'a__links__b'\n    assert edge_type_str.to_tuple() == ('a', 'links', 'b')\n\n    edge_type_str = EdgeTypeStr(('a', 'links', 'b'))\n    assert isinstance(edge_type_str, str)\n    assert edge_type_str == 'a__links__b'\n    assert edge_type_str.to_tuple() == ('a', 'links', 'b')\n\n    with pytest.raises(ValueError, match=\"invalid edge type\"):\n        EdgeTypeStr('a', 'b', 'c', 'd')\n\n    with pytest.raises(ValueError, match=\"Cannot convert the edge type\"):\n        EdgeTypeStr('a__b__c__d').to_tuple()\n"
  },
  {
    "path": "test/test_warnings.py",
    "content": "import warnings\nfrom unittest.mock import patch\n\nimport pytest\n\nfrom torch_geometric.warnings import WarningCache, warn\n\n\ndef test_warn():\n    with pytest.warns(UserWarning, match='test'):\n        warn('test')\n\n\n@patch('torch_geometric.is_compiling', return_value=True)\ndef test_no_warn_if_compiling(_):\n    \"\"\"No warning should be raised to avoid graph breaks when compiling.\"\"\"\n    with warnings.catch_warnings():\n        warnings.simplefilter('error')\n        warn('test')\n\n\ndef test_warning_cache():\n    cache = WarningCache()\n    assert len(cache) == 0\n\n    cache.warn('test')\n    assert len(cache) == 1\n    assert 'test' in cache\n\n    cache.warn('test')\n    assert len(cache) == 1\n\n    cache.warn('test2')\n    assert len(cache) == 2\n    assert 'test2' in cache\n"
  },
  {
    "path": "test/testing/test_decorators.py",
    "content": "import torch_geometric.typing\nfrom torch_geometric.testing import disableExtensions\n\n\ndef test_enable_extensions():\n    try:\n        import pyg_lib  # noqa\n        assert torch_geometric.typing.WITH_PYG_LIB\n    except (ImportError, OSError):\n        assert not torch_geometric.typing.WITH_PYG_LIB\n\n    try:\n        import torch_scatter  # noqa\n        assert torch_geometric.typing.WITH_TORCH_SCATTER\n    except (ImportError, OSError):\n        assert not torch_geometric.typing.WITH_TORCH_SCATTER\n\n    try:\n        import torch_sparse  # noqa\n        assert torch_geometric.typing.WITH_TORCH_SPARSE\n    except (ImportError, OSError):\n        assert not torch_geometric.typing.WITH_TORCH_SPARSE\n\n\n@disableExtensions\ndef test_disable_extensions():\n    assert not torch_geometric.typing.WITH_PYG_LIB\n    assert not torch_geometric.typing.WITH_TORCH_SCATTER\n    assert not torch_geometric.typing.WITH_TORCH_SPARSE\n"
  },
  {
    "path": "test/transforms/test_add_gpse.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.nn import GPSE\nfrom torch_geometric.nn.models.gpse import IdentityHead\nfrom torch_geometric.transforms import AddGPSE\n\nnum_nodes = 6\ngpse_inner_dim = 512\n\n\ndef test_gpse():\n    x = torch.randn(num_nodes, 4)\n    edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],\n                               [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]])\n    data = Data(x=x, edge_index=edge_index)\n\n    model = GPSE()\n    model.post_mp = IdentityHead()\n    transform = AddGPSE(model)\n\n    assert str(transform) == 'AddGPSE()'\n    out = transform(data)\n    assert out.pestat_GPSE.size() == (num_nodes, gpse_inner_dim)\n"
  },
  {
    "path": "test/transforms/test_add_metapaths.py",
    "content": "import torch\nfrom torch import tensor\n\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.transforms import AddMetaPaths, AddRandomMetaPaths\nfrom torch_geometric.utils import coalesce\n\n\ndef generate_data() -> HeteroData:\n    data = HeteroData()\n    data['p'].x = torch.ones(5)\n    data['a'].x = torch.ones(6)\n    data['c'].x = torch.ones(3)\n    data['p', 'p'].edge_index = tensor([[0, 1, 2, 3], [1, 2, 4, 2]])\n    data['p', 'a'].edge_index = tensor([[0, 1, 2, 3, 4], [2, 2, 5, 2, 5]])\n    data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0])\n    data['c', 'p'].edge_index = tensor([[0, 0, 1, 2, 2], [0, 1, 2, 3, 4]])\n    data['p', 'c'].edge_index = data['c', 'p'].edge_index.flip([0])\n    return data\n\n\ndef test_add_metapaths() -> None:\n    data = generate_data()\n    # Test transform options:\n    metapaths = [[('p', 'c'), ('c', 'p')]]\n\n    transform = AddMetaPaths(metapaths)\n    assert str(transform) == 'AddMetaPaths()'\n    meta1 = transform(data)\n\n    transform = AddMetaPaths(metapaths, drop_orig_edge_types=True)\n    assert str(transform) == 'AddMetaPaths()'\n    meta2 = transform(data)\n\n    transform = AddMetaPaths(metapaths, drop_orig_edge_types=True,\n                             keep_same_node_type=True)\n    assert str(transform) == 'AddMetaPaths()'\n    meta3 = transform(data)\n\n    transform = AddMetaPaths(metapaths, drop_orig_edge_types=True,\n                             keep_same_node_type=True,\n                             drop_unconnected_node_types=True)\n    assert str(transform) == 'AddMetaPaths()'\n    meta4 = transform(data)\n\n    assert meta1['metapath_0'].edge_index.size() == (2, 9)\n    assert meta2['metapath_0'].edge_index.size() == (2, 9)\n    assert meta3['metapath_0'].edge_index.size() == (2, 9)\n    assert meta4['metapath_0'].edge_index.size() == (2, 9)\n\n    assert all([i in meta1.edge_types for i in data.edge_types])\n    assert meta2.edge_types == [('p', 'metapath_0', 'p')]\n    assert meta3.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')]\n    assert meta4.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')]\n\n    assert meta3.node_types == ['p', 'a', 'c']\n    assert meta4.node_types == ['p']\n\n    # Test 4-hop metapath:\n    metapaths = [\n        [('a', 'p'), ('p', 'c')],\n        [('a', 'p'), ('p', 'c'), ('c', 'p'), ('p', 'a')],\n    ]\n    transform = AddMetaPaths(metapaths)\n    meta = transform(data)\n    new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')]\n    assert meta['metapath_0'].edge_index.size() == (2, 4)\n    assert meta['metapath_1'].edge_index.size() == (2, 4)\n\n    # Test `metapath_dict` information:\n    assert list(meta.metapath_dict.values()) == metapaths\n    assert list(meta.metapath_dict.keys()) == new_edge_types\n\n\ndef test_add_metapaths_max_sample() -> None:\n    torch.manual_seed(12345)\n\n    data = generate_data()\n\n    metapaths = [[('p', 'c'), ('c', 'p')]]\n    transform = AddMetaPaths(metapaths, max_sample=1)\n\n    meta = transform(data)\n    assert meta['metapath_0'].edge_index.size(1) < 9\n\n\ndef test_add_weighted_metapaths() -> None:\n    torch.manual_seed(12345)\n\n    data = HeteroData()\n    data['a'].num_nodes = 2\n    data['b'].num_nodes = 3\n    data['c'].num_nodes = 2\n    data['d'].num_nodes = 2\n    data['a', 'b'].edge_index = tensor([[0, 1, 1], [0, 1, 2]])\n    data['b', 'a'].edge_index = data['a', 'b'].edge_index.flip([0])\n    data['b', 'c'].edge_index = tensor([[0, 1, 2], [0, 1, 1]])\n    data['c', 'b'].edge_index = data['b', 'c'].edge_index.flip([0])\n    data['c', 'd'].edge_index = tensor([[0, 1], [0, 0]])\n    data['d', 'c'].edge_index = data['c', 'd'].edge_index.flip([0])\n\n    metapaths = [\n        [('a', 'b'), ('b', 'c')],\n        [('a', 'b'), ('b', 'c'), ('c', 'd')],\n        [('a', 'b'), ('b', 'c'), ('c', 'd'), ('d', 'c'), ('c', 'b'),\n         ('b', 'a')],\n    ]\n    transform = AddMetaPaths(metapaths, weighted=True)\n    out = transform(data)\n\n    # Make sure manually added metapaths compute the correct number of edges:\n    edge_index = out['a', 'a'].edge_index\n    edge_weight = out['a', 'a'].edge_weight\n    edge_index, edge_weight = coalesce(edge_index, edge_weight)\n    assert edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]]\n    assert edge_weight.tolist() == [1, 2, 2, 4]\n\n    edge_index = out['a', 'c'].edge_index\n    edge_weight = out['a', 'c'].edge_weight\n    edge_index, edge_weight = coalesce(edge_index, edge_weight)\n    assert edge_index.tolist() == [[0, 1], [0, 1]]\n    assert edge_weight.tolist() == [1, 2]\n\n    edge_index = out['a', 'd'].edge_index\n    edge_weight = out['a', 'd'].edge_weight\n    edge_index, edge_weight = coalesce(edge_index, edge_weight)\n    assert edge_index.tolist() == [[0, 1], [0, 0]]\n    assert edge_weight.tolist() == [1, 2]\n\n    # Compute intra-table metapaths efficiently:\n    metapaths = [[('a', 'b'), ('b', 'c'), ('c', 'd')]]\n    out = AddMetaPaths(metapaths, weighted=True)(data)\n    out['d', 'a'].edge_index = out['a', 'd'].edge_index.flip([0])\n    out['d', 'a'].edge_weight = out['a', 'd'].edge_weight\n    metapaths = [[('a', 'd'), ('d', 'a')]]\n    out = AddMetaPaths(metapaths, weighted=True)(out)\n\n    edge_index = out['a', 'a'].edge_index\n    edge_weight = out['a', 'a'].edge_weight\n    edge_index, edge_weight = coalesce(edge_index, edge_weight)\n    assert edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]]\n    assert edge_weight.tolist() == [1, 2, 2, 4]\n\n\ndef test_add_random_metapaths() -> None:\n    data = generate_data()\n\n    # Test transform options:\n    metapaths = [[('p', 'c'), ('c', 'p')]]\n    torch.manual_seed(12345)\n\n    transform = AddRandomMetaPaths(metapaths)\n    assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, '\n                              'walks_per_node=[1])')\n    meta1 = transform(data)\n\n    transform = AddRandomMetaPaths(metapaths, drop_orig_edge_types=True)\n    assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, '\n                              'walks_per_node=[1])')\n    meta2 = transform(data)\n\n    transform = AddRandomMetaPaths(metapaths, drop_orig_edge_types=True,\n                                   keep_same_node_type=True)\n    assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, '\n                              'walks_per_node=[1])')\n    meta3 = transform(data)\n\n    transform = AddRandomMetaPaths(metapaths, drop_orig_edge_types=True,\n                                   keep_same_node_type=True,\n                                   drop_unconnected_node_types=True)\n    assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, '\n                              'walks_per_node=[1])')\n    meta4 = transform(data)\n\n    transform = AddRandomMetaPaths(metapaths, sample_ratio=0.8,\n                                   drop_orig_edge_types=True,\n                                   keep_same_node_type=True,\n                                   drop_unconnected_node_types=True)\n    assert str(transform) == ('AddRandomMetaPaths(sample_ratio=0.8, '\n                              'walks_per_node=[1])')\n    meta5 = transform(data)\n\n    transform = AddRandomMetaPaths(metapaths, walks_per_node=5,\n                                   drop_orig_edge_types=True,\n                                   keep_same_node_type=True,\n                                   drop_unconnected_node_types=True)\n    assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, '\n                              'walks_per_node=[5])')\n    meta6 = transform(data)\n\n    assert meta1['metapath_0'].edge_index.size() == (2, 5)\n    assert meta2['metapath_0'].edge_index.size() == (2, 5)\n    assert meta3['metapath_0'].edge_index.size() == (2, 5)\n    assert meta4['metapath_0'].edge_index.size() == (2, 5)\n    assert meta5['metapath_0'].edge_index.size() == (2, 4)\n    assert meta6['metapath_0'].edge_index.size() == (2, 7)\n\n    assert all([i in meta1.edge_types for i in data.edge_types])\n    assert meta2.edge_types == [('p', 'metapath_0', 'p')]\n    assert meta3.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')]\n    assert meta4.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')]\n\n    assert meta3.node_types == ['p', 'a', 'c']\n    assert meta4.node_types == ['p']\n\n    # Test 4-hop metapath:\n    metapaths = [\n        [('a', 'p'), ('p', 'c')],\n        [('a', 'p'), ('p', 'c'), ('c', 'p'), ('p', 'a')],\n    ]\n    transform = AddRandomMetaPaths(metapaths)\n    assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, '\n                              'walks_per_node=[1, 1])')\n\n    meta1 = transform(data)\n    new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')]\n    assert meta1['metapath_0'].edge_index.size() == (2, 2)\n    assert meta1['metapath_1'].edge_index.size() == (2, 2)\n\n    # Test `metapath_dict` information:\n    assert list(meta1.metapath_dict.values()) == metapaths\n    assert list(meta1.metapath_dict.keys()) == new_edge_types\n\n    transform = AddRandomMetaPaths(metapaths, walks_per_node=[2, 5])\n    assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, '\n                              'walks_per_node=[2, 5])')\n\n    meta2 = transform(data)\n    new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')]\n    assert meta2['metapath_0'].edge_index.size() == (2, 2)\n    assert meta2['metapath_1'].edge_index.size() == (2, 3)\n"
  },
  {
    "path": "test/transforms/test_add_positional_encoding.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import (\n    AddLaplacianEigenvectorPE,\n    AddRandomWalkPE,\n)\n\n\n@withPackage('scipy')\ndef test_add_laplacian_eigenvector_pe():\n    x = torch.randn(6, 4)\n    edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],\n                               [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]])\n    data = Data(x=x, edge_index=edge_index)\n\n    transform = AddLaplacianEigenvectorPE(k=3)\n    assert str(transform) == 'AddLaplacianEigenvectorPE()'\n    out = transform(data)\n    assert out.laplacian_eigenvector_pe.size() == (6, 3)\n\n    transform = AddLaplacianEigenvectorPE(k=3, attr_name=None)\n    out = transform(data)\n    assert out.x.size() == (6, 4 + 3)\n\n    transform = AddLaplacianEigenvectorPE(k=3, attr_name='x')\n    out = transform(data)\n    assert out.x.size() == (6, 3)\n\n    # Output tests:\n    edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5, 2, 5],\n                               [1, 0, 4, 0, 4, 1, 3, 2, 5, 3, 5, 2]])\n    data = Data(x=x, edge_index=edge_index)\n\n    transform1 = AddLaplacianEigenvectorPE(k=1, is_undirected=True)\n    transform2 = AddLaplacianEigenvectorPE(k=1, is_undirected=False)\n\n    # Clustering test with first non-trivial eigenvector (Fiedler vector)\n    pe = transform1(data).laplacian_eigenvector_pe\n    pe_cluster_1 = pe[[0, 1, 4]]\n    pe_cluster_2 = pe[[2, 3, 5]]\n    assert not torch.allclose(pe_cluster_1, pe_cluster_2)\n    assert torch.allclose(pe_cluster_1, pe_cluster_1.mean())\n    assert torch.allclose(pe_cluster_2, pe_cluster_2.mean())\n\n    pe = transform2(data).laplacian_eigenvector_pe\n    pe_cluster_1 = pe[[0, 1, 4]]\n    pe_cluster_2 = pe[[2, 3, 5]]\n    assert not torch.allclose(pe_cluster_1, pe_cluster_2)\n    assert torch.allclose(pe_cluster_1, pe_cluster_1.mean())\n    assert torch.allclose(pe_cluster_2, pe_cluster_2.mean())\n\n\n@withPackage('scipy')\ndef test_eigenvector_permutation_invariance():\n    edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],\n                               [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]])\n    data = Data(edge_index=edge_index, num_nodes=6)\n\n    perm = torch.randperm(data.num_nodes)\n    transform = AddLaplacianEigenvectorPE(\n        k=2,\n        is_undirected=True,\n        attr_name='x',\n    )\n    out1 = transform(data)\n\n    transform = AddLaplacianEigenvectorPE(\n        k=2,\n        is_undirected=True,\n        attr_name='x',\n    )\n    out2 = transform(data.subgraph(perm))\n\n    assert torch.allclose(out1.x[perm].abs(), out2.x.abs(), atol=1e-6)\n\n\ndef test_add_random_walk_pe():\n    x = torch.randn(6, 4)\n    edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],\n                               [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]])\n    data = Data(x=x, edge_index=edge_index)\n\n    transform = AddRandomWalkPE(walk_length=3)\n    assert str(transform) == 'AddRandomWalkPE()'\n    out = transform(data)\n    assert out.random_walk_pe.size() == (6, 3)\n\n    transform = AddRandomWalkPE(walk_length=3, attr_name=None)\n    out = transform(data)\n    assert out.x.size() == (6, 4 + 3)\n\n    transform = AddRandomWalkPE(walk_length=3, attr_name='x')\n    out = transform(data)\n    assert out.x.size() == (6, 3)\n\n    # Output tests:\n    assert out.x.tolist() == [\n        [0.0, 0.5, 0.25],\n        [0.0, 0.5, 0.25],\n        [0.0, 0.5, 0.00],\n        [0.0, 1.0, 0.00],\n        [0.0, 0.5, 0.25],\n        [0.0, 0.5, 0.00],\n    ]\n\n    edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]])\n    data = Data(edge_index=edge_index, num_nodes=4)\n    out = transform(data)\n\n    assert out.x.tolist() == [\n        [1.0, 1.0, 1.0],\n        [1.0, 1.0, 1.0],\n        [1.0, 1.0, 1.0],\n        [0.0, 0.0, 0.0],\n    ]\n"
  },
  {
    "path": "test/transforms/test_add_remaining_self_loops.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import AddRemainingSelfLoops\n\n\ndef test_add_remaining_self_loops():\n    assert str(AddRemainingSelfLoops()) == 'AddRemainingSelfLoops()'\n\n    assert len(AddRemainingSelfLoops()(Data())) == 0\n\n    # No self-loops in `edge_index`.\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_weight = torch.tensor([1, 2, 3, 4])\n    edge_attr = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])\n\n    data = Data(edge_index=edge_index, num_nodes=3)\n    data = AddRemainingSelfLoops()(data)\n    assert len(data) == 2\n    assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2],\n                                        [1, 0, 2, 1, 0, 1, 2]]\n    assert data.num_nodes == 3\n\n    # Single self-loop in `edge_index`.\n    edge_index = torch.tensor([[0, 0, 1, 2], [1, 0, 2, 1]])\n    data = Data(edge_index=edge_index, num_nodes=3)\n    data = AddRemainingSelfLoops()(data)\n    assert len(data) == 2\n    assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [1, 2, 1, 0, 1, 2]]\n    assert data.num_nodes == 3\n\n    data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3)\n    data = AddRemainingSelfLoops(attr='edge_weight', fill_value=5)(data)\n    assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [1, 2, 1, 0, 1, 2]]\n    assert data.num_nodes == 3\n    assert data.edge_weight.tolist() == [1, 3, 4, 2, 5, 5]\n\n    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)\n    data = AddRemainingSelfLoops(attr='edge_attr', fill_value='add')(data)\n    assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [1, 2, 1, 0, 1, 2]]\n    assert data.num_nodes == 3\n    assert data.edge_attr.tolist() == [[1, 2], [5, 6], [7, 8], [3, 4], [8, 10],\n                                       [5, 6]]\n\n\ndef test_add_remaining_self_loops_all_loops_exist():\n    # All self-loops already exist in the data object.\n    edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]])\n    data = Data(edge_index=edge_index, num_nodes=3)\n    data = AddRemainingSelfLoops()(data)\n    assert data.edge_index.tolist() == edge_index.tolist()\n\n    # All self-loops already exist in the data object, some of them appear\n    # multiple times.\n    edge_index = torch.tensor([[0, 0, 1, 1, 2], [0, 0, 1, 1, 2]])\n    data = Data(edge_index=edge_index, num_nodes=3)\n    data = AddRemainingSelfLoops()(data)\n    assert data.edge_index.tolist() == [[0, 1, 2], [0, 1, 2]]\n\n\ndef test_hetero_add_remaining_self_loops():\n    edge_index = torch.tensor([[0, 0, 1, 2], [1, 0, 2, 1]])\n\n    data = HeteroData()\n    data['v'].num_nodes = 3\n    data['w'].num_nodes = 3\n    data['v', 'v'].edge_index = edge_index\n    data['v', 'w'].edge_index = edge_index\n    data = AddRemainingSelfLoops()(data)\n    assert data['v', 'v'].edge_index.tolist() == [[0, 1, 2, 0, 1, 2],\n                                                  [1, 2, 1, 0, 1, 2]]\n    assert data['v', 'w'].edge_index.tolist() == edge_index.tolist()\n"
  },
  {
    "path": "test/transforms/test_add_self_loops.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import AddSelfLoops\n\n\ndef test_add_self_loops():\n    assert str(AddSelfLoops()) == 'AddSelfLoops()'\n\n    assert len(AddSelfLoops()(Data())) == 0\n\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_weight = torch.tensor([1, 2, 3, 4])\n    edge_attr = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])\n\n    data = Data(edge_index=edge_index, num_nodes=3)\n    data = AddSelfLoops()(data)\n    assert len(data) == 2\n    assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2],\n                                        [1, 0, 2, 1, 0, 1, 2]]\n    assert data.num_nodes == 3\n\n    data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3)\n    data = AddSelfLoops(attr='edge_weight', fill_value=5)(data)\n    assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2],\n                                        [1, 0, 2, 1, 0, 1, 2]]\n    assert data.num_nodes == 3\n    assert data.edge_weight.tolist() == [1, 2, 3, 4, 5, 5, 5]\n\n    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)\n    data = AddSelfLoops(attr='edge_attr', fill_value='add')(data)\n    assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2],\n                                        [1, 0, 2, 1, 0, 1, 2]]\n    assert data.num_nodes == 3\n    assert data.edge_attr.tolist() == [[1, 2], [3, 4], [5, 6], [7, 8], [3, 4],\n                                       [8, 10], [5, 6]]\n\n\ndef test_add_self_loops_with_existing_self_loops():\n    edge_index = torch.tensor([[0, 1, 2], [0, 1, 2]])\n    data = Data(edge_index=edge_index, num_nodes=3)\n    data = AddSelfLoops()(data)\n    assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2]]\n    assert data.num_nodes == 3\n\n\ndef test_hetero_add_self_loops():\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    data = HeteroData()\n    data['v'].num_nodes = 3\n    data['w'].num_nodes = 3\n    data['v', 'v'].edge_index = edge_index\n    data['v', 'w'].edge_index = edge_index\n    data = AddSelfLoops()(data)\n    assert data['v', 'v'].edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2],\n                                                  [1, 0, 2, 1, 0, 1, 2]]\n    assert data['v', 'w'].edge_index.tolist() == edge_index.tolist()\n"
  },
  {
    "path": "test/transforms/test_cartesian.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import Cartesian\n\n\ndef test_cartesian():\n    assert str(Cartesian()) == 'Cartesian(norm=True, max_value=None)'\n\n    pos = torch.tensor([[-1.0, 0.0], [0.0, 0.0], [2.0, 0.0]])\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0])\n\n    data = Data(edge_index=edge_index, pos=pos)\n    data = Cartesian(norm=False)(data)\n    assert len(data) == 3\n    assert torch.equal(data.pos, pos)\n    assert torch.equal(data.edge_index, edge_index)\n    assert torch.allclose(\n        data.edge_attr,\n        torch.tensor([[-1.0, 0.0], [1.0, 0.0], [-2.0, 0.0], [2.0, 0.0]]),\n    )\n\n    data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr)\n    data = Cartesian(norm=True)(data)\n    assert len(data) == 3\n    assert torch.equal(data.pos, pos)\n    assert torch.equal(data.edge_index, edge_index)\n    assert torch.allclose(\n        data.edge_attr,\n        torch.tensor([\n            [1, 0.25, 0.5],\n            [2, 0.75, 0.5],\n            [3, 0, 0.5],\n            [4, 1, 0.5],\n        ]),\n    )\n"
  },
  {
    "path": "test/transforms/test_center.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import Center\n\n\ndef test_center():\n    transform = Center()\n    assert str(transform) == 'Center()'\n\n    pos = torch.tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]])\n    data = Data(pos=pos)\n\n    data = transform(data)\n    assert len(data) == 1\n    assert data.pos.tolist() == [[-2, 0], [0, 0], [2, 0]]\n"
  },
  {
    "path": "test/transforms/test_compose.py",
    "content": "import torch\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.data import Data\n\n\ndef test_compose():\n    transform = T.Compose([T.Center(), T.AddSelfLoops()])\n    assert str(transform) == ('Compose([\\n'\n                              '  Center(),\\n'\n                              '  AddSelfLoops()\\n'\n                              '])')\n\n    pos = torch.tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]])\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    data = Data(edge_index=edge_index, pos=pos)\n    data = transform(data)\n    assert len(data) == 2\n    assert data.pos.tolist() == [[-2.0, 0.0], [0.0, 0.0], [2.0, 0.0]]\n    assert data.edge_index.size() == (2, 7)\n\n\ndef test_compose_data_list():\n    transform = T.Compose([T.Center(), T.AddSelfLoops()])\n\n    pos = torch.tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]])\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    data_list = [Data(edge_index=edge_index, pos=pos) for _ in range(3)]\n    data_list = transform(data_list)\n    assert len(data_list) == 3\n    for data in data_list:\n        assert len(data) == 2\n        assert data.pos.tolist() == [[-2.0, 0.0], [0.0, 0.0], [2.0, 0.0]]\n        assert data.edge_index.size() == (2, 7)\n\n\ndef test_compose_filters():\n    filter_fn = T.ComposeFilters([\n        lambda d: d.num_nodes > 2,\n        lambda d: d.num_edges > 2,\n    ])\n    assert str(filter_fn)[:16] == 'ComposeFilters(['\n\n    data1 = Data(x=torch.arange(3))\n    assert not filter_fn(data1)\n\n    data2 = Data(x=torch.arange(2), edge_index=torch.tensor([\n        [0, 0, 1],\n        [0, 1, 1],\n    ]))\n    assert not filter_fn(data2)\n\n    data3 = Data(x=torch.arange(3), edge_index=torch.tensor([\n        [0, 0, 1],\n        [0, 1, 1],\n    ]))\n    assert filter_fn(data3)\n\n    # Test tuple of data objects:\n    assert filter_fn((data1, data2, data3)) is False\n"
  },
  {
    "path": "test/transforms/test_constant.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import Constant\n\n\ndef test_constant():\n    assert str(Constant()) == 'Constant(value=1.0)'\n\n    x = torch.tensor([[-1, 0], [0, 0], [2, 0]], dtype=torch.float)\n    edge_index = torch.tensor([[0, 1], [1, 2]])\n\n    data = Data(edge_index=edge_index, num_nodes=3)\n    data = Constant()(data)\n    assert len(data) == 3\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert data.x.tolist() == [[1], [1], [1]]\n    assert data.num_nodes == 3\n\n    data = Data(edge_index=edge_index, x=x)\n    data = Constant()(data)\n    assert len(data) == 2\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert data.x.tolist() == [[-1, 0, 1], [0, 0, 1], [2, 0, 1]]\n\n    data = HeteroData()\n    data['v'].x = x\n    data = Constant()(data)\n    assert len(data) == 1\n    assert data['v'].x.tolist() == [[-1, 0, 1], [0, 0, 1], [2, 0, 1]]\n\n    data = HeteroData()\n    data['v'].x = x\n    data['w'].x = x\n    data = Constant(node_types='w')(data)\n    assert len(data) == 1\n    assert data['v'].x.tolist() == x.tolist()\n    assert data['w'].x.tolist() == [[-1, 0, 1], [0, 0, 1], [2, 0, 1]]\n"
  },
  {
    "path": "test/transforms/test_delaunay.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import Delaunay\n\n\ndef assert_one_point(transform: Delaunay) -> None:\n    data = Data(pos=torch.rand(1, 2))\n    data = transform(data)\n    assert len(data) == 2\n    assert data.edge_index.tolist() == [[], []]\n\n\ndef assert_two_points(transform: Delaunay) -> None:\n    data = Data(pos=torch.rand(2, 2))\n    data = transform(data)\n    assert len(data) == 2\n    assert data.edge_index.tolist() == [[0, 1], [1, 0]]\n\n\ndef assert_three_points(transform: Delaunay) -> None:\n    data = Data(pos=torch.rand(3, 2))\n    data = transform(data)\n    assert len(data) == 2\n    assert data.face.tolist() == [[0], [1], [2]]\n\n\ndef assert_four_points(transform: Delaunay) -> None:\n    pos = torch.tensor([\n        [-1.0, -1.0],\n        [-1.0, 1.0],\n        [1.0, 1.0],\n        [0.5, -0.5],\n    ])\n    data = Data(pos=pos)\n    data = transform(data)\n    assert len(data) == 2\n\n    # The order of the simplices does not matter, therefore assert sets.\n    faces = set(map(tuple, data.face.tolist()))\n    assert faces == {(3, 1), (1, 3), (0, 2)}\n\n\n@withPackage('scipy')\ndef test_qhull_delaunay() -> None:\n    transform = Delaunay()\n\n    assert str(transform) == 'Delaunay()'\n    assert_one_point(transform)\n    assert_two_points(transform)\n    assert_three_points(transform)\n    assert_four_points(transform)\n\n\n@withPackage('torch_delaunay')\ndef test_shull_delaunay() -> None:\n    transform = Delaunay()\n\n    assert str(transform) == 'Delaunay()'\n    assert_one_point(transform)\n    assert_two_points(transform)\n    assert_three_points(transform)\n    assert_four_points(transform)\n"
  },
  {
    "path": "test/transforms/test_distance.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import Distance\n\n\ndef test_distance():\n    assert str(Distance()) == 'Distance(norm=True, max_value=None)'\n\n    pos = torch.tensor([[-1.0, 0.0], [0.0, 0.0], [2.0, 0.0]])\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_attr = torch.tensor([1.0, 1.0, 1.0, 1.0])\n\n    data = Data(edge_index=edge_index, pos=pos)\n    data = Distance(norm=False)(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert data.edge_attr.tolist() == [[1.0], [1.0], [2.0], [2.0]]\n\n    data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr)\n    data = Distance(norm=True)(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert data.edge_attr.tolist() == [\n        [1.0, 0.5],\n        [1.0, 0.5],\n        [1.0, 1.0],\n        [1.0, 1.0],\n    ]\n"
  },
  {
    "path": "test/transforms/test_face_to_edge.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import FaceToEdge\n\n\ndef test_2d_face_to_edge() -> None:\n    transform = FaceToEdge()\n    assert str(transform) == 'FaceToEdge()'\n\n    face = torch.tensor([[0, 0], [1, 1], [2, 3]])\n    data = Data(face=face, num_nodes=4)\n\n    data = transform(data)\n    assert len(data) == 2\n    assert data.edge_index.tolist() == [\n        [0, 0, 0, 1, 1, 1, 2, 2, 3, 3],\n        [1, 2, 3, 0, 2, 3, 0, 1, 0, 1],\n    ]\n    assert data.num_nodes == 4\n\n\ndef test_3d_face_to_edge() -> None:\n    transform = FaceToEdge()\n    assert str(transform) == 'FaceToEdge()'\n\n    face = torch.tensor([[0, 1, 2, 3], [1, 2, 3, 4]]).t()\n    data = Data(face=face, num_nodes=5)\n\n    data = transform(data)\n    assert data.edge_index.tolist() == [\n        [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4],\n        [1, 2, 3, 0, 2, 3, 4, 0, 1, 3, 4, 0, 1, 2, 4, 1, 2, 3],\n    ]\n    assert data.num_nodes == 5\n"
  },
  {
    "path": "test/transforms/test_feature_propagation.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import FeaturePropagation, ToSparseTensor\n\n\ndef test_feature_propagation():\n    x = torch.randn(6, 4)\n    x[0, 1] = float('nan')\n    x[2, 3] = float('nan')\n    missing_mask = torch.isnan(x)\n    edge_index = torch.tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5],\n                               [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]])\n\n    transform = FeaturePropagation(missing_mask)\n    assert str(transform) == ('FeaturePropagation(missing_features=8.3%, '\n                              'num_iterations=40)')\n\n    data1 = Data(x=x, edge_index=edge_index)\n    assert torch.isnan(data1.x).sum() == 2\n    data1 = FeaturePropagation(missing_mask)(data1)\n    assert torch.isnan(data1.x).sum() == 0\n    assert data1.x.size() == x.size()\n\n    data2 = Data(x=x, edge_index=edge_index)\n    assert torch.isnan(data2.x).sum() == 2\n    data2 = ToSparseTensor()(data2)\n    data2 = transform(data2)\n    assert torch.isnan(data2.x).sum() == 0\n    assert torch.allclose(data1.x, data2.x)\n"
  },
  {
    "path": "test/transforms/test_fixed_points.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import FixedPoints\n\n\ndef test_fixed_points():\n    assert str(FixedPoints(1024)) == 'FixedPoints(1024, replace=True)'\n\n    data = Data(\n        pos=torch.randn(100, 3),\n        x=torch.randn(100, 16),\n        y=torch.randn(1),\n        edge_attr=torch.randn(100, 3),\n        num_nodes=100,\n    )\n\n    out = FixedPoints(50, replace=True)(data)\n    assert len(out) == 5\n    assert out.pos.size() == (50, 3)\n    assert out.x.size() == (50, 16)\n    assert out.y.size() == (1, )\n    assert out.edge_attr.size() == (100, 3)\n    assert out.num_nodes == 50\n\n    out = FixedPoints(200, replace=True)(data)\n    assert len(out) == 5\n    assert out.pos.size() == (200, 3)\n    assert out.x.size() == (200, 16)\n    assert out.y.size() == (1, )\n    assert out.edge_attr.size() == (100, 3)\n    assert out.num_nodes == 200\n\n    out = FixedPoints(50, replace=False, allow_duplicates=False)(data)\n    assert len(out) == 5\n    assert out.pos.size() == (50, 3)\n    assert out.x.size() == (50, 16)\n    assert out.y.size() == (1, )\n    assert out.edge_attr.size() == (100, 3)\n    assert out.num_nodes == 50\n\n    out = FixedPoints(200, replace=False, allow_duplicates=False)(data)\n    assert len(out) == 5\n    assert out.pos.size() == (100, 3)\n    assert out.x.size() == (100, 16)\n    assert out.y.size() == (1, )\n    assert out.edge_attr.size() == (100, 3)\n    assert out.num_nodes == 100\n\n    out = FixedPoints(50, replace=False, allow_duplicates=True)(data)\n    assert len(out) == 5\n    assert out.pos.size() == (50, 3)\n    assert out.x.size() == (50, 16)\n    assert out.y.size() == (1, )\n    assert out.edge_attr.size() == (100, 3)\n    assert out.num_nodes == 50\n\n    out = FixedPoints(200, replace=False, allow_duplicates=True)(data)\n    assert len(out) == 5\n    assert out.pos.size() == (200, 3)\n    assert out.x.size() == (200, 16)\n    assert out.y.size() == (1, )\n    assert out.edge_attr.size() == (100, 3)\n    assert out.num_nodes == 200\n"
  },
  {
    "path": "test/transforms/test_gcn_norm.py",
    "content": "import torch\n\nimport torch_geometric.typing\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import GCNNorm\nfrom torch_geometric.typing import SparseTensor\n\n\ndef test_gcn_norm():\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_weight = torch.ones(edge_index.size(1))\n\n    transform = GCNNorm()\n    assert str(transform) == 'GCNNorm(add_self_loops=True)'\n\n    expected_edge_index = [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]]\n    expected_edge_weight = torch.tensor(\n        [0.4082, 0.4082, 0.4082, 0.4082, 0.5000, 0.3333, 0.5000])\n\n    data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3)\n    data = transform(data)\n    assert len(data) == 3\n    assert data.num_nodes == 3\n    assert data.edge_index.tolist() == expected_edge_index\n    assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)\n\n    data = Data(edge_index=edge_index, num_nodes=3)\n    data = transform(data)\n    assert len(data) == 3\n    assert data.num_nodes == 3\n    assert data.edge_index.tolist() == expected_edge_index\n    assert torch.allclose(data.edge_weight, expected_edge_weight, atol=1e-4)\n\n    # For `SparseTensor`, expected outputs will be sorted:\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]]\n        expected_edge_weight = torch.tensor(\n            [0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000])\n\n        adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t()\n        data = Data(adj_t=adj_t)\n        data = transform(data)\n        assert len(data) == 1\n        row, col, value = data.adj_t.coo()\n        assert row.tolist() == expected_edge_index[0]\n        assert col.tolist() == expected_edge_index[1]\n        assert torch.allclose(value, expected_edge_weight, atol=1e-4)\n"
  },
  {
    "path": "test/transforms/test_gdc.py",
    "content": "import torch\n\nfrom torch_geometric.datasets import KarateClub\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import GDC\nfrom torch_geometric.utils import to_dense_adj\n\n\n@withPackage('numba')\ndef test_gdc():\n    data = KarateClub()[0]\n\n    gdc = GDC(\n        self_loop_weight=1,\n        normalization_in='sym',\n        normalization_out='sym',\n        diffusion_kwargs=dict(method='ppr', alpha=0.15),\n        sparsification_kwargs=dict(method='threshold', avg_degree=2),\n        exact=True,\n    )\n    out = gdc(data)\n    mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze()\n    assert torch.all(mat >= -1e-8)\n    assert torch.allclose(mat, mat.t(), atol=1e-4)\n\n    gdc = GDC(\n        self_loop_weight=1,\n        normalization_in='sym',\n        normalization_out='sym',\n        diffusion_kwargs=dict(method='heat', t=10),\n        sparsification_kwargs=dict(method='threshold', avg_degree=2),\n        exact=True,\n    )\n    out = gdc(data)\n    mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze()\n    assert torch.all(mat >= -1e-8)\n    assert torch.allclose(mat, mat.t(), atol=1e-4)\n\n    gdc = GDC(\n        self_loop_weight=1,\n        normalization_in='col',\n        normalization_out='col',\n        diffusion_kwargs=dict(method='heat', t=10),\n        sparsification_kwargs=dict(method='topk', k=2, dim=0),\n        exact=True,\n    )\n    out = gdc(data)\n    mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze()\n    col_sum = mat.sum(0)\n    assert torch.all(mat >= -1e-8)\n    assert torch.all(\n        torch.isclose(col_sum, torch.tensor(1.0))\n        | torch.isclose(col_sum, torch.tensor(0.0)))\n    assert torch.all((~torch.isclose(mat, torch.tensor(0.0))).sum(0) == 2)\n\n    gdc = GDC(\n        self_loop_weight=1,\n        normalization_in='row',\n        normalization_out='row',\n        diffusion_kwargs=dict(method='heat', t=5),\n        sparsification_kwargs=dict(method='topk', k=2, dim=1),\n        exact=True,\n    )\n    out = gdc(data)\n    mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze()\n    row_sum = mat.sum(1)\n    assert torch.all(mat >= -1e-8)\n    assert torch.all(\n        torch.isclose(row_sum, torch.tensor(1.0))\n        | torch.isclose(row_sum, torch.tensor(0.0)))\n    assert torch.all((~torch.isclose(mat, torch.tensor(0.0))).sum(1) == 2)\n\n    gdc = GDC(\n        self_loop_weight=1,\n        normalization_in='row',\n        normalization_out='row',\n        diffusion_kwargs=dict(method='coeff', coeffs=[0.8, 0.3, 0.1]),\n        sparsification_kwargs=dict(method='threshold', eps=0.1),\n        exact=True,\n    )\n    out = gdc(data)\n    mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze()\n    row_sum = mat.sum(1)\n    assert torch.all(mat >= -1e-8)\n    assert torch.all(\n        torch.isclose(row_sum, torch.tensor(1.0))\n        | torch.isclose(row_sum, torch.tensor(0.0)))\n\n    gdc = GDC(\n        self_loop_weight=1,\n        normalization_in='sym',\n        normalization_out='col',\n        diffusion_kwargs=dict(method='ppr', alpha=0.15, eps=1e-4),\n        sparsification_kwargs=dict(method='threshold', avg_degree=2),\n        exact=False,\n    )\n    out = gdc(data)\n    mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze()\n    col_sum = mat.sum(0)\n    assert torch.all(mat >= -1e-8)\n    assert torch.all(\n        torch.isclose(col_sum, torch.tensor(1.0))\n        | torch.isclose(col_sum, torch.tensor(0.0)))\n"
  },
  {
    "path": "test/transforms/test_generate_mesh_normals.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import GenerateMeshNormals\n\n\ndef test_generate_mesh_normals():\n    transform = GenerateMeshNormals()\n    assert str(transform) == 'GenerateMeshNormals()'\n\n    pos = torch.tensor([\n        [0.0, 0.0, 0.0],\n        [-2.0, 1.0, 0.0],\n        [-1.0, 1.0, 0.0],\n        [0.0, 1.0, 0.0],\n        [1.0, 1.0, 0.0],\n        [2.0, 1.0, 0.0],\n    ])\n    face = torch.tensor([\n        [0, 0, 0, 0],\n        [1, 2, 3, 4],\n        [2, 3, 4, 5],\n    ])\n\n    data = transform(Data(pos=pos, face=face))\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.face.tolist() == face.tolist()\n    assert data.norm.tolist() == [[0.0, 0.0, -1.0]] * 6\n"
  },
  {
    "path": "test/transforms/test_grid_sampling.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import GridSampling\n\n\n@withPackage('torch_cluster')\ndef test_grid_sampling():\n    assert str(GridSampling(5)) == 'GridSampling(size=5)'\n\n    pos = torch.tensor([\n        [0.0, 2.0],\n        [3.0, 2.0],\n        [3.0, 2.0],\n        [2.0, 8.0],\n        [2.0, 6.0],\n    ])\n    y = torch.tensor([0, 1, 1, 2, 2])\n    batch = torch.tensor([0, 0, 0, 0, 0])\n\n    data = Data(pos=pos, y=y, batch=batch)\n    data = GridSampling(size=5, start=0)(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == [[2, 2], [2, 7]]\n    assert data.y.tolist() == [1, 2]\n    assert data.batch.tolist() == [0, 0]\n"
  },
  {
    "path": "test/transforms/test_half_hop.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import HalfHop\n\n\ndef test_half_hop():\n    edge_index = torch.tensor([[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]])\n    x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],\n                     dtype=torch.float)\n    data = Data(x=x, edge_index=edge_index)\n\n    transform = HalfHop()\n    assert str(transform) == 'HalfHop(alpha=0.5, p=1.0)'\n    data = transform(data)\n\n    expected_edge_index = [[0, 1, 2, 0, 1, 1, 2, 3, 4, 5, 6, 1, 0, 2, 1],\n                           [0, 1, 2, 3, 4, 5, 6, 1, 0, 2, 1, 3, 4, 5, 6]]\n    expected_x = torch.tensor(\n        [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [3, 4, 5, 6],\n         [3, 4, 5, 6], [7, 8, 9, 10], [7, 8, 9, 10]], dtype=torch.float)\n    assert len(data) == 3\n    assert data.num_nodes == 7\n    assert data.edge_index.tolist() == expected_edge_index\n    assert torch.allclose(data.x, expected_x, atol=1e-4)\n    assert data.slow_node_mask.tolist() == [\n        False, False, False, True, True, True, True\n    ]\n\n    torch.manual_seed(1)\n    data = Data(x=x, edge_index=edge_index)\n    transform = HalfHop(p=0.5)\n    assert str(transform) == 'HalfHop(alpha=0.5, p=0.5)'\n    data = transform(data)\n\n    expected_edge_index = [[1, 0, 1, 2, 0, 1, 2, 3, 4, 5, 1, 2, 1],\n                           [0, 0, 1, 2, 3, 4, 5, 1, 2, 1, 3, 4, 5]]\n    expected_x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12],\n                               [3, 4, 5, 6], [7, 8, 9, 10], [7, 8, 9, 10]],\n                              dtype=torch.float)\n    assert data.num_nodes == 6\n    assert data.edge_index.tolist() == expected_edge_index\n    assert torch.allclose(data.x, expected_x, atol=1e-4)\n    assert data.slow_node_mask.tolist() == [\n        False, False, False, True, True, True\n    ]\n"
  },
  {
    "path": "test/transforms/test_knn_graph.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import KNNGraph\n\n\n@withPackage('torch_cluster')\ndef test_knn_graph():\n    assert str(KNNGraph()) == 'KNNGraph(k=6)'\n\n    pos = torch.tensor([\n        [0.0, 0.0],\n        [1.0, 0.0],\n        [2.0, 0.0],\n        [0.0, 1.0],\n        [-2.0, 0.0],\n        [0.0, -2.0],\n    ])\n\n    expected_row = [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5]\n    expected_col = [1, 2, 3, 4, 5, 0, 2, 3, 5, 0, 1, 0, 1, 4, 0, 3, 0, 1]\n\n    data = Data(pos=pos)\n    data = KNNGraph(k=2, force_undirected=True)(data)\n    assert len(data) == 2\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index[0].tolist() == expected_row\n    assert data.edge_index[1].tolist() == expected_col\n"
  },
  {
    "path": "test/transforms/test_laplacian_lambda_max.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import LaplacianLambdaMax\n\n\n@withPackage('scipy')\ndef test_laplacian_lambda_max():\n    out = str(LaplacianLambdaMax())\n    assert out == 'LaplacianLambdaMax(normalization=None)'\n\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)\n    edge_attr = torch.tensor([1, 1, 2, 2], dtype=torch.float)\n\n    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)\n    out = LaplacianLambdaMax(normalization=None, is_undirected=True)(data)\n    assert len(out) == 4\n    assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(4.732049))\n\n    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)\n    out = LaplacianLambdaMax(normalization='sym', is_undirected=True)(data)\n    assert len(out) == 4\n    assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(2.0))\n\n    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)\n    out = LaplacianLambdaMax(normalization='rw', is_undirected=True)(data)\n    assert len(out) == 4\n    assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(2.0))\n\n    data = Data(edge_index=edge_index, edge_attr=torch.randn(4, 2),\n                num_nodes=3)\n    out = LaplacianLambdaMax(normalization=None)(data)\n    assert len(out) == 4\n    assert torch.allclose(torch.tensor(out.lambda_max), torch.tensor(3.0))\n"
  },
  {
    "path": "test/transforms/test_largest_connected_components.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import LargestConnectedComponents\n\n\n@withPackage('scipy')\ndef test_largest_connected_components():\n    assert str(LargestConnectedComponents()) == 'LargestConnectedComponents(1)'\n\n    edge_index = torch.tensor([\n        [0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 5, 6, 8, 9],\n        [1, 2, 0, 2, 0, 1, 3, 2, 4, 3, 6, 7, 9, 8],\n    ])\n    data = Data(edge_index=edge_index, num_nodes=10)\n\n    # Testing without `connection` specified:\n    transform = LargestConnectedComponents(num_components=2)\n    out = transform(data)\n    assert out.num_nodes == 8\n    assert out.edge_index.tolist() == data.edge_index[:, :12].tolist()\n\n    # Testing with `connection = strong`:\n    transform = LargestConnectedComponents(num_components=2,\n                                           connection='strong')\n    out = transform(data)\n    assert out.num_nodes == 7\n    assert out.edge_index.tolist() == [[0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 5, 6],\n                                       [1, 2, 0, 2, 0, 1, 3, 2, 4, 3, 6, 5]]\n\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 3, 4],\n        [1, 0, 3, 2, 4, 3],\n    ])\n    data = Data(edge_index=edge_index, num_nodes=5)\n\n    # Testing without `num_components` and `connection` specified:\n    transform = LargestConnectedComponents()\n    out = transform(data)\n    assert out.num_nodes == 3\n    assert out.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n\n    # Testing with larger `num_components` than actual number of components:\n    transform = LargestConnectedComponents(num_components=3)\n    out = transform(data)\n    assert out.num_nodes == 5\n    assert out.edge_index.tolist() == data.edge_index.tolist()\n"
  },
  {
    "path": "test/transforms/test_line_graph.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import LineGraph\n\n\ndef test_line_graph():\n    transform = LineGraph()\n    assert str(transform) == 'LineGraph()'\n\n    # Directed.\n    edge_index = torch.tensor([\n        [0, 1, 2, 2, 3],\n        [1, 2, 0, 3, 0],\n    ])\n    data = Data(edge_index=edge_index, num_nodes=4)\n    data = transform(data)\n    assert data.edge_index.tolist() == [[0, 1, 1, 2, 3, 4], [1, 2, 3, 0, 4, 0]]\n    assert data.num_nodes == data.edge_index.max().item() + 1\n\n    # Undirected.\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 4],\n                               [1, 2, 3, 0, 4, 0, 3, 0, 2, 4, 1, 3]])\n    edge_attr = torch.ones(edge_index.size(1))\n    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=5)\n    data = transform(data)\n    assert data.edge_index.max().item() + 1 == data.x.size(0)\n    assert data.edge_index.tolist() == [\n        [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5],\n        [1, 2, 3, 0, 2, 4, 0, 1, 4, 5, 0, 5, 1, 2, 5, 2, 3, 4],\n    ]\n    assert data.x.tolist() == [2, 2, 2, 2, 2, 2]\n    assert data.num_nodes == data.edge_index.max().item() + 1\n"
  },
  {
    "path": "test/transforms/test_linear_transformation.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import LinearTransformation\n\n\n@pytest.mark.parametrize('matrix', [\n    [[2.0, 0.0], [0.0, 2.0]],\n    torch.tensor([[2.0, 0.0], [0.0, 2.0]]),\n])\ndef test_linear_transformation(matrix):\n    pos = torch.tensor([[-1.0, 1.0], [-3.0, 0.0], [2.0, -1.0]])\n\n    transform = LinearTransformation(matrix)\n    assert str(transform) == ('LinearTransformation(\\n'\n                              '[[2. 0.]\\n'\n                              ' [0. 2.]]\\n'\n                              ')')\n\n    out = transform(Data(pos=pos))\n    assert len(out) == 1\n    assert torch.allclose(out.pos, 2 * pos)\n\n    out = transform(Data())\n    assert len(out) == 0\n"
  },
  {
    "path": "test/transforms/test_local_cartesian.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import LocalCartesian\n\n\ndef test_local_cartesian():\n    transform = LocalCartesian()\n    assert str(transform) == 'LocalCartesian()'\n\n    pos = torch.tensor([[-1.0, 0.0], [0.0, 0.0], [2.0, 0.0]])\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0])\n    data = Data(edge_index=edge_index, pos=pos)\n\n    data = transform(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert data.edge_attr.tolist() == [[0.25, 0.5], [1.0, 0.5], [0.0, 0.5],\n                                       [1.0, 0.5]]\n\n    data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr)\n    data = transform(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert data.edge_attr.tolist() == [[1, 0.25, 0.5], [2, 1.0, 0.5],\n                                       [3, 0.0, 0.5], [4, 1.0, 0.5]]\n"
  },
  {
    "path": "test/transforms/test_local_degree_profile.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import LocalDegreeProfile\n\n\ndef test_target_indegree():\n    assert str(LocalDegreeProfile()) == 'LocalDegreeProfile()'\n\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    x = torch.tensor([[1.0], [1.0], [1.0], [1.0]])  # One isolated node.\n\n    expected = torch.tensor([\n        [1, 2, 2, 2, 0],\n        [2, 1, 1, 1, 0],\n        [1, 2, 2, 2, 0],\n        [0, 0, 0, 0, 0],\n    ], dtype=torch.float)\n\n    data = Data(edge_index=edge_index, num_nodes=x.size(0))\n    data = LocalDegreeProfile()(data)\n    assert torch.allclose(data.x, expected, atol=1e-2)\n\n    data = Data(edge_index=edge_index, x=x)\n    data = LocalDegreeProfile()(data)\n    assert torch.allclose(data.x[:, :1], x)\n    assert torch.allclose(data.x[:, 1:], expected, atol=1e-2)\n"
  },
  {
    "path": "test/transforms/test_mask_transform.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import IndexToMask, MaskToIndex\n\n\ndef test_index_to_mask():\n    assert str(IndexToMask()) == ('IndexToMask(attrs=None, sizes=None, '\n                                  'replace=False)')\n\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4],\n                               [1, 0, 2, 1, 3, 2, 4, 3]])\n    train_index = torch.arange(0, 3)\n    test_index = torch.arange(3, 5)\n    data = Data(edge_index=edge_index, train_index=train_index,\n                test_index=test_index, num_nodes=5)\n\n    out = IndexToMask(replace=True)(data)\n    assert len(out) == len(data)\n    assert out.train_mask.tolist() == [True, True, True, False, False]\n    assert out.test_mask.tolist() == [False, False, False, True, True]\n\n    out = IndexToMask(replace=False)(data)\n    assert len(out) == len(data) + 2\n\n    out = IndexToMask(sizes=6, replace=True)(data)\n    assert out.train_mask.tolist() == [True, True, True, False, False, False]\n    assert out.test_mask.tolist() == [False, False, False, True, True, False]\n\n    out = IndexToMask(attrs='train_index')(data)\n    assert len(out) == len(data) + 1\n    assert 'train_index' in out\n    assert 'train_mask' in out\n    assert 'test_index' in out\n    assert 'test_mask' not in out\n\n\ndef test_mask_to_index():\n    assert str(MaskToIndex()) == 'MaskToIndex(attrs=None, replace=False)'\n\n    train_mask = torch.tensor([True, True, True, False, False])\n    test_mask = torch.tensor([False, False, False, True, True])\n    data = Data(train_mask=train_mask, test_mask=test_mask)\n\n    out = MaskToIndex(replace=True)(data)\n    assert len(out) == len(data)\n    assert out.train_index.tolist() == [0, 1, 2]\n    assert out.test_index.tolist() == [3, 4]\n\n    out = MaskToIndex(replace=False)(data)\n    assert len(out) == len(data) + 2\n\n    out = MaskToIndex(attrs='train_mask')(data)\n    assert len(out) == len(data) + 1\n    assert 'train_mask' in out\n    assert 'train_index' in out\n    assert 'test_mask' in out\n    assert 'test_index' not in out\n\n\ndef test_hetero_index_to_mask():\n    data = HeteroData()\n    data['u'].train_index = torch.arange(0, 3)\n    data['u'].test_index = torch.arange(3, 5)\n    data['u'].num_nodes = 5\n\n    data['v'].train_index = torch.arange(0, 3)\n    data['v'].test_index = torch.arange(3, 5)\n    data['v'].num_nodes = 5\n\n    out = IndexToMask()(data)\n    assert len(out) == len(data) + 2\n    assert 'train_mask' in out['u']\n    assert 'test_mask' in out['u']\n    assert 'train_mask' in out['v']\n    assert 'test_mask' in out['v']\n\n\ndef test_hetero_mask_to_index():\n    data = HeteroData()\n    data['u'].train_mask = torch.tensor([True, True, True, False, False])\n    data['u'].test_mask = torch.tensor([False, False, False, True, True])\n\n    data['v'].train_mask = torch.tensor([True, True, True, False, False])\n    data['v'].test_mask = torch.tensor([False, False, False, True, True])\n\n    out = MaskToIndex()(data)\n    assert len(out) == len(data) + 2\n    assert 'train_index' in out['u']\n    assert 'test_index' in out['u']\n    assert 'train_index' in out['v']\n    assert 'test_index' in out['v']\n"
  },
  {
    "path": "test/transforms/test_node_property_split.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.datasets import graph_generator\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import NodePropertySplit\n\n\n@withPackage('networkx', 'scipy')\n@pytest.mark.parametrize('property_name', [\n    'popularity',\n    'locality',\n    'density',\n])\ndef test_node_property_split(property_name):\n    ratios = [0.3, 0.1, 0.1, 0.2, 0.3]\n\n    transform = NodePropertySplit(property_name, ratios)\n    assert str(transform) == f'NodePropertySplit({property_name})'\n\n    data = graph_generator.ERGraph(num_nodes=100, edge_prob=0.4)()\n    data = transform(data)\n\n    node_ids = []\n    for name, ratio in zip([\n            'id_train_mask',\n            'id_val_mask',\n            'id_test_mask',\n            'ood_val_mask',\n            'ood_test_mask',\n    ], ratios):\n        assert data[name].dtype == torch.bool\n        assert data[name].size() == (100, )\n        assert int(data[name].sum()) == 100 * ratio\n        node_ids.extend(data[name].nonzero().view(-1).tolist())\n\n    # Check that masks are non-intersecting and cover all nodes:\n    node_ids = torch.tensor(node_ids)\n    assert node_ids.numel() == torch.unique(node_ids).numel() == 100\n"
  },
  {
    "path": "test/transforms/test_normalize_features.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import NormalizeFeatures\n\n\ndef test_normalize_scale():\n    transform = NormalizeFeatures()\n    assert str(transform) == 'NormalizeFeatures()'\n\n    x = torch.tensor([[1, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=torch.float)\n    data = Data(x=x)\n\n    data = transform(data)\n    assert len(data) == 1\n    assert data.x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]]\n\n\ndef test_hetero_normalize_scale():\n    x = torch.tensor([[1, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=torch.float)\n\n    data = HeteroData()\n    data['v'].x = x\n    data['w'].x = x\n    data = NormalizeFeatures()(data)\n    assert data['v'].x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]]\n    assert data['w'].x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]]\n"
  },
  {
    "path": "test/transforms/test_normalize_rotation.py",
    "content": "from math import sqrt\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import NormalizeRotation\n\n\ndef test_normalize_rotation():\n    assert str(NormalizeRotation()) == 'NormalizeRotation()'\n\n    pos = torch.tensor([\n        [-2.0, -2.0],\n        [-1.0, -1.0],\n        [0.0, 0.0],\n        [1.0, 1.0],\n        [2.0, 2.0],\n    ])\n    normal = torch.tensor([\n        [-1.0, 1.0],\n        [-1.0, 1.0],\n        [-1.0, 1.0],\n        [-1.0, 1.0],\n        [-1.0, 1.0],\n    ])\n    data = Data(pos=pos)\n    data.normal = normal\n    data = NormalizeRotation()(data)\n    assert len(data) == 2\n\n    expected_pos = torch.tensor([\n        [-2 * sqrt(2), 0.0],\n        [-sqrt(2), 0.0],\n        [0.0, 0.0],\n        [sqrt(2), 0.0],\n        [2 * sqrt(2), 0.0],\n    ])\n    expected_normal = [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]]\n\n    assert torch.allclose(data.pos, expected_pos, atol=1e-04)\n    assert data.normal.tolist() == expected_normal\n\n    data = Data(pos=pos)\n    data.normal = normal\n    data = NormalizeRotation(max_points=3)(data)\n    assert len(data) == 2\n\n    assert torch.allclose(data.pos, expected_pos, atol=1e-04)\n    assert data.normal.tolist() == expected_normal\n\n    data = Data(pos=pos)\n    data.normal = normal\n    data = NormalizeRotation(sort=True)(data)\n    assert len(data) == 2\n\n    assert torch.allclose(data.pos, expected_pos, atol=1e-04)\n    assert data.normal.tolist() == expected_normal\n"
  },
  {
    "path": "test/transforms/test_normalize_scale.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import NormalizeScale\n\n\ndef test_normalize_scale():\n    transform = NormalizeScale()\n    assert str(transform) == 'NormalizeScale()'\n\n    pos = torch.randn((10, 3))\n    data = Data(pos=pos)\n\n    data = transform(data)\n    assert len(data) == 1\n    assert data.pos.min().item() > -1\n    assert data.pos.max().item() < 1\n"
  },
  {
    "path": "test/transforms/test_one_hot_degree.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import OneHotDegree\n\n\ndef test_one_hot_degree():\n    assert str(OneHotDegree(max_degree=3)) == 'OneHotDegree(3)'\n\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    x = torch.tensor([1.0, 1.0, 1.0, 1.0])\n\n    data = Data(edge_index=edge_index, num_nodes=4)\n    data = OneHotDegree(max_degree=3)(data)\n    assert len(data) == 3\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert data.x.tolist() == [\n        [0.0, 0.0, 0.0, 1.0],\n        [0.0, 1.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0, 0.0],\n    ]\n    assert data.num_nodes == 4\n\n    data = Data(edge_index=edge_index, x=x)\n    data = OneHotDegree(max_degree=3)(data)\n    assert len(data) == 2\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert data.x.tolist() == [\n        [1.0, 0.0, 0.0, 0.0, 1.0],\n        [1.0, 0.0, 1.0, 0.0, 0.0],\n        [1.0, 0.0, 1.0, 0.0, 0.0],\n        [1.0, 0.0, 1.0, 0.0, 0.0],\n    ]\n"
  },
  {
    "path": "test/transforms/test_pad.py",
    "content": "import numbers\nfrom typing import Dict, Generator, List, Optional, Tuple, Union\n\nimport pytest\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.datasets import FakeDataset, FakeHeteroDataset\nfrom torch_geometric.transforms import Pad\nfrom torch_geometric.transforms.pad import (\n    AttrNamePadding,\n    EdgeTypePadding,\n    NodeTypePadding,\n    Padding,\n    UniformPadding,\n)\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\ndef fake_data() -> Data:\n    return FakeDataset(avg_num_nodes=10, avg_degree=5, edge_dim=2)[0]\n\n\ndef fake_hetero_data(node_types=2, edge_types=5) -> HeteroData:\n    return FakeHeteroDataset(num_node_types=node_types,\n                             num_edge_types=edge_types, avg_num_nodes=10,\n                             edge_dim=2)[0]\n\n\ndef _generate_homodata_node_attrs(data: Data) -> Generator[str, None, None]:\n    for attr in data.keys():\n        if data.is_node_attr(attr):\n            yield attr\n\n\ndef _generate_homodata_edge_attrs(data: Data) -> Generator[str, None, None]:\n    for attr in data.keys():\n        if data.is_edge_attr(attr):\n            yield attr\n\n\ndef _generate_heterodata_nodes(\n    data: HeteroData\n) -> Generator[Tuple[NodeType, str, torch.Tensor], None, None]:\n    for node_type, store in data.node_items():\n        for attr in store.keys():\n            yield node_type, attr\n\n\ndef _generate_heterodata_edges(\n    data: HeteroData\n) -> Generator[Tuple[EdgeType, str, torch.Tensor], None, None]:\n    for edge_type, store in data.edge_items():\n        for attr in store.keys():\n            yield edge_type, attr\n\n\ndef _check_homo_data_nodes(\n    original: Data,\n    padded: Data,\n    max_num_nodes: Union[int, Dict[NodeType, int]],\n    node_pad_value: Optional[Padding] = None,\n    is_mask_available: bool = False,\n    exclude_keys: Optional[List[str]] = None,\n):\n    assert padded.num_nodes == max_num_nodes\n\n    compare_pad_start_idx = original.num_nodes\n\n    if is_mask_available:\n        assert padded.pad_node_mask.numel() == padded.num_nodes\n        assert torch.all(padded.pad_node_mask[:compare_pad_start_idx])\n        assert not torch.any(padded.pad_node_mask[compare_pad_start_idx:])\n\n    for attr in _generate_homodata_node_attrs(original):\n        if attr in exclude_keys:\n            assert attr not in padded.keys()\n            continue\n\n        assert attr in padded.keys()\n\n        if not isinstance(padded[attr], torch.Tensor):\n            continue\n\n        assert padded[attr].shape[0] == max_num_nodes\n\n        # Check values in padded area.\n        pad_value = node_pad_value.get_value(\n            None, attr) if node_pad_value is not None else 0.0\n        assert all(\n            i == pad_value\n            for i in torch.flatten(padded[attr][compare_pad_start_idx:]))\n\n        # Check values in non-padded area.\n        assert torch.equal(original[attr],\n                           padded[attr][:compare_pad_start_idx])\n\n\ndef _check_homo_data_edges(\n    original: Data,\n    padded: Data,\n    max_num_edges: Optional[int] = None,\n    edge_pad_value: Optional[Padding] = None,\n    is_mask_available: bool = False,\n    exclude_keys: Optional[List[str]] = None,\n):\n    # Check edge index attribute.\n    if max_num_edges is None:\n        max_num_edges = padded.num_nodes**2\n    assert padded.num_edges == max_num_edges\n    assert padded.edge_index.shape[1] == max_num_edges\n    assert padded.edge_index.shape[1] == max_num_edges\n\n    compare_pad_start_idx = original.num_edges\n    expected_node = original.num_nodes\n\n    # Check values in padded area.\n    assert all(\n        padded.edge_index[0, i] == padded.edge_index[1, i] == expected_node\n        for i in range(compare_pad_start_idx, max_num_edges))\n    # Check values in non-padded area.\n    assert torch.equal(original.edge_index,\n                       padded.edge_index[:, :compare_pad_start_idx])\n\n    if is_mask_available:\n        assert padded.pad_edge_mask.numel() == padded.num_edges\n        assert torch.all(padded.pad_edge_mask[:compare_pad_start_idx])\n        assert not torch.any(padded.pad_edge_mask[compare_pad_start_idx:])\n\n    # Check other attributes.\n    for attr in _generate_homodata_edge_attrs(original):\n        if attr == 'edge_index':\n            continue\n        if attr in exclude_keys:\n            assert attr not in padded.keys()\n            continue\n\n        assert attr in padded.keys()\n\n        if not isinstance(padded[attr], torch.Tensor):\n            continue\n\n        assert padded[attr].shape[0] == max_num_edges\n\n        # Check values in padded area.\n        pad_value = edge_pad_value.get_value(\n            None, attr) if edge_pad_value is not None else 0.0\n        assert all(\n            i == pad_value\n            for i in torch.flatten(padded[attr][compare_pad_start_idx:, :]))\n\n        # Check values in non-padded area.\n        assert torch.equal(original[attr],\n                           padded[attr][:compare_pad_start_idx, :])\n\n\ndef _check_hetero_data_nodes(\n    original: HeteroData,\n    padded: HeteroData,\n    max_num_nodes: Union[int, Dict[NodeType, int]],\n    node_pad_value: Optional[Padding] = None,\n    is_mask_available: bool = False,\n    exclude_keys: Optional[List[str]] = None,\n):\n    if is_mask_available:\n        for store in padded.node_stores:\n            assert 'pad_node_mask' in store\n\n    expected_nodes = max_num_nodes\n\n    for node_type, attr in _generate_heterodata_nodes(original):\n        if attr in exclude_keys:\n            assert attr not in padded[node_type].keys()\n            continue\n\n        assert attr in padded[node_type].keys()\n\n        if not isinstance(padded[node_type][attr], torch.Tensor):\n            continue\n\n        compare_pad_start_idx = original[node_type].num_nodes\n        padded_tensor = padded[node_type][attr]\n\n        if attr == 'pad_node_mask':\n            assert padded_tensor.numel() == padded[node_type].num_nodes\n            assert torch.all(padded_tensor[:compare_pad_start_idx])\n            assert not torch.any(padded_tensor[compare_pad_start_idx:])\n            continue\n\n        original_tensor = original[node_type][attr]\n\n        # Check the number of nodes.\n        if isinstance(max_num_nodes, dict):\n            expected_nodes = max_num_nodes[node_type]\n\n        assert padded_tensor.shape[0] == expected_nodes\n\n        compare_pad_start_idx = original_tensor.shape[0]\n        pad_value = node_pad_value.get_value(\n            node_type, attr) if node_pad_value is not None else 0.0\n        assert all(\n            i == pad_value\n            for i in torch.flatten(padded_tensor[compare_pad_start_idx:]))\n        # Compare non-padded area with the original.\n        assert torch.equal(original_tensor,\n                           padded_tensor[:compare_pad_start_idx])\n\n\ndef _check_hetero_data_edges(\n    original: HeteroData,\n    padded: HeteroData,\n    max_num_edges: Optional[Union[int, Dict[EdgeType, int]]] = None,\n    edge_pad_value: Optional[Padding] = None,\n    is_mask_available: bool = False,\n    exclude_keys: Optional[List[str]] = None,\n):\n    if is_mask_available:\n        for store in padded.edge_stores:\n            assert 'pad_edge_mask' in store\n\n    for edge_type, attr in _generate_heterodata_edges(padded):\n        if attr in exclude_keys:\n            assert attr not in padded[edge_type].keys()\n            continue\n\n        assert attr in padded[edge_type].keys()\n\n        if not isinstance(padded[edge_type][attr], torch.Tensor):\n            continue\n\n        compare_pad_start_idx = original[edge_type].num_edges\n        padded_tensor = padded[edge_type][attr]\n\n        if attr == 'pad_edge_mask':\n            assert padded_tensor.numel() == padded[edge_type].num_edges\n            assert torch.all(padded_tensor[:compare_pad_start_idx])\n            assert not torch.any(padded_tensor[compare_pad_start_idx:])\n            continue\n\n        original_tensor = original[edge_type][attr]\n\n        if isinstance(max_num_edges, numbers.Number):\n            expected_num_edges = max_num_edges\n        elif max_num_edges is None or edge_type not in max_num_edges.keys():\n            v1, _, v2 = edge_type\n            expected_num_edges = padded[v1].num_nodes * padded[v2].num_nodes\n        else:\n            expected_num_edges = max_num_edges[edge_type]\n\n        if attr == 'edge_index':\n            # Check the number of edges.\n            assert padded_tensor.shape[1] == expected_num_edges\n\n            # Check padded area values.\n            src_nodes = original[edge_type[0]].num_nodes\n            assert all(\n                i == src_nodes\n                for i in torch.flatten(padded_tensor[0,\n                                                     compare_pad_start_idx:]))\n            dst_nodes = original[edge_type[2]].num_nodes\n            assert all(\n                i == dst_nodes\n                for i in torch.flatten(padded_tensor[1,\n                                                     compare_pad_start_idx:]))\n\n            # Compare non-padded area with the original.\n            assert torch.equal(original_tensor,\n                               padded_tensor[:, :compare_pad_start_idx])\n        else:\n            # Check padded area size.\n            assert padded_tensor.shape[0] == expected_num_edges\n\n            # Check padded area values.\n            pad_value = edge_pad_value.get_value(\n                edge_type, attr) if edge_pad_value is not None else 0.0\n            assert all(i == pad_value for i in torch.flatten(padded_tensor[\n                compare_pad_start_idx:, :]))\n\n            # Compare non-padded area with the original.\n            assert torch.equal(original_tensor,\n                               padded_tensor[:compare_pad_start_idx, :])\n\n\ndef _check_data(\n    original: Union[Data, HeteroData],\n    padded: Union[Data, HeteroData],\n    max_num_nodes: Union[int, Dict[NodeType, int]],\n    max_num_edges: Optional[Union[int, Dict[EdgeType, int]]] = None,\n    node_pad_value: Optional[Union[Padding, int, float]] = None,\n    edge_pad_value: Optional[Union[Padding, int, float]] = None,\n    is_mask_available: bool = False,\n    exclude_keys: Optional[List[str]] = None,\n):\n\n    if not isinstance(node_pad_value, Padding) and node_pad_value is not None:\n        node_pad_value = UniformPadding(node_pad_value)\n    if not isinstance(edge_pad_value, Padding) and edge_pad_value is not None:\n        edge_pad_value = UniformPadding(edge_pad_value)\n\n    if is_mask_available is None:\n        is_mask_available = False\n\n    if exclude_keys is None:\n        exclude_keys = []\n\n    if isinstance(original, Data):\n        _check_homo_data_nodes(original, padded, max_num_nodes, node_pad_value,\n                               is_mask_available, exclude_keys)\n        _check_homo_data_edges(original, padded, max_num_edges, edge_pad_value,\n                               is_mask_available, exclude_keys)\n    else:\n        _check_hetero_data_nodes(original, padded, max_num_nodes,\n                                 node_pad_value, is_mask_available,\n                                 exclude_keys)\n        _check_hetero_data_edges(original, padded, max_num_edges,\n                                 edge_pad_value, is_mask_available,\n                                 exclude_keys)\n\n\ndef test_pad_repr():\n    pad_str = 'Pad(max_num_nodes=10, max_num_edges=15, ' \\\n        'node_pad_value=UniformPadding(value=3.0), ' \\\n        'edge_pad_value=UniformPadding(value=1.5))'\n    assert str(eval(pad_str)) == pad_str\n\n\n@pytest.mark.parametrize('data', [fake_data(), fake_hetero_data()])\n@pytest.mark.parametrize('num_nodes', [32, 64])\n@pytest.mark.parametrize('add_pad_mask', [True, False])\ndef test_pad_auto_edges(data, num_nodes, add_pad_mask):\n    transform = Pad(max_num_nodes=num_nodes, add_pad_mask=add_pad_mask)\n\n    out = transform(data)\n    _check_data(data, out, num_nodes, is_mask_available=add_pad_mask)\n\n\n@pytest.mark.parametrize('num_nodes', [32, 64])\n@pytest.mark.parametrize('num_edges', [300, 411])\n@pytest.mark.parametrize('add_pad_mask', [True, False])\ndef test_pad_data_explicit_edges(num_nodes, num_edges, add_pad_mask):\n    data = fake_data()\n    transform = Pad(max_num_nodes=num_nodes, max_num_edges=num_edges,\n                    add_pad_mask=add_pad_mask)\n\n    out = transform(data)\n    _check_data(data, out, num_nodes, num_edges,\n                is_mask_available=add_pad_mask)\n\n\n@pytest.mark.parametrize('num_nodes', [32, {'v0': 64, 'v1': 36}])\n@pytest.mark.parametrize('num_edges', [300, {('v0', 'e0', 'v1'): 397}])\n@pytest.mark.parametrize('add_pad_mask', [True, False])\ndef test_pad_heterodata_explicit_edges(num_nodes, num_edges, add_pad_mask):\n    data = fake_hetero_data()\n    transform = Pad(max_num_nodes=num_nodes, max_num_edges=num_edges,\n                    add_pad_mask=add_pad_mask)\n\n    out = transform(data)\n    _check_data(data, out, num_nodes, num_edges,\n                is_mask_available=add_pad_mask)\n\n\n@pytest.mark.parametrize('node_pad_value', [10, AttrNamePadding({'x': 3.0})])\n@pytest.mark.parametrize('edge_pad_value',\n                         [11, AttrNamePadding({'edge_attr': 2.0})])\ndef test_pad_data_pad_values(node_pad_value, edge_pad_value):\n    data = fake_data()\n    num_nodes = 32\n    transform = Pad(max_num_nodes=num_nodes, node_pad_value=node_pad_value,\n                    edge_pad_value=edge_pad_value)\n    out = transform(data)\n    _check_data(data, out, num_nodes, node_pad_value=node_pad_value,\n                edge_pad_value=edge_pad_value)\n\n\n@pytest.mark.parametrize('node_pad_value', [\n    UniformPadding(12),\n    AttrNamePadding({'x': 0}),\n    NodeTypePadding({\n        'v0': UniformPadding(12),\n        'v1': AttrNamePadding({'x': 7})\n    })\n])\n@pytest.mark.parametrize('edge_pad_value', [\n    UniformPadding(13),\n    EdgeTypePadding({\n        ('v0', 'e0', 'v1'):\n        UniformPadding(13),\n        ('v1', 'e0', 'v0'):\n        AttrNamePadding({'edge_attr': UniformPadding(-1.0)})\n    })\n])\ndef test_pad_heterodata_pad_values(node_pad_value, edge_pad_value):\n    data = fake_hetero_data()\n    num_nodes = 32\n    transform = Pad(max_num_nodes=num_nodes, node_pad_value=node_pad_value,\n                    edge_pad_value=edge_pad_value)\n\n    out = transform(data)\n    _check_data(data, out, num_nodes, node_pad_value=node_pad_value,\n                edge_pad_value=edge_pad_value)\n\n\n@pytest.mark.parametrize('data', [fake_data(), fake_hetero_data()])\n@pytest.mark.parametrize('add_pad_mask', [True, False])\n@pytest.mark.parametrize('exclude_keys', [\n    ['y'],\n    ['edge_attr'],\n    ['y', 'edge_attr'],\n])\ndef test_pad_data_exclude_keys(data, add_pad_mask, exclude_keys):\n    num_nodes = 32\n    transform = Pad(max_num_nodes=num_nodes, add_pad_mask=add_pad_mask,\n                    exclude_keys=exclude_keys)\n\n    out = transform(data)\n    _check_data(data, out, num_nodes, is_mask_available=add_pad_mask,\n                exclude_keys=exclude_keys)\n\n\n@pytest.mark.parametrize('is_hetero', [False, True])\ndef test_pad_invalid_max_num_nodes(is_hetero):\n    if is_hetero:\n        data = fake_hetero_data(node_types=1)\n    else:\n        data = fake_data()\n\n    transform = Pad(max_num_nodes=data.num_nodes - 1)\n\n    with pytest.raises(AssertionError, match=\"after padding\"):\n        transform(data)\n\n\n@pytest.mark.parametrize('is_hetero', [False, True])\ndef test_pad_invalid_max_num_edges(is_hetero):\n    if is_hetero:\n        data = fake_hetero_data(node_types=1, edge_types=1)\n    else:\n        data = fake_data()\n\n    transform = Pad(max_num_nodes=data.num_nodes + 10,\n                    max_num_edges=data.num_edges - 1)\n\n    with pytest.raises(AssertionError, match=\"after padding\"):\n        transform(data)\n\n\ndef test_pad_num_nodes_not_complete():\n    data = fake_hetero_data(node_types=2, edge_types=1)\n    transform = Pad(max_num_nodes={'v0': 100})\n\n    with pytest.raises(KeyError):\n        transform(data)\n\n\ndef test_pad_invalid_padding_type():\n    with pytest.raises(ValueError, match=\"to be an integer or float\"):\n        Pad(max_num_nodes=100, node_pad_value='somestring')\n    with pytest.raises(ValueError, match=\"to be an integer or float\"):\n        Pad(max_num_nodes=100, edge_pad_value='somestring')\n\n\ndef test_pad_data_non_tensor_attr():\n    data = fake_data()\n    batch_size = 13\n    data.batch_size = batch_size\n\n    transform = Pad(max_num_nodes=100)\n    padded = transform(data)\n    assert padded.batch_size == batch_size\n\n    exclude_transform = Pad(max_num_nodes=101, exclude_keys=('batch_size', ))\n    padded = exclude_transform(data)\n    assert 'batch_size' not in padded.keys()\n\n\n@pytest.mark.parametrize('mask_pad_value', [True, False])\ndef test_pad_node_additional_attr_mask(mask_pad_value):\n    data = fake_data()\n    mask = torch.randn(data.num_nodes) > 0\n    mask_names = ['train_mask', 'test_mask', 'val_mask']\n    for mask_name in mask_names:\n        setattr(data, mask_name, mask)\n    padding_num = 20\n\n    max_num_nodes = int(data.num_nodes) + padding_num\n    max_num_edges = data.num_edges + padding_num\n\n    transform = Pad(max_num_nodes, max_num_edges, node_pad_value=0.1,\n                    mask_pad_value=mask_pad_value)\n    padded = transform(data)\n    padded_masks = [getattr(padded, mask_name) for mask_name in mask_names]\n\n    for padded_mask in padded_masks:\n        assert padded_mask.ndim == 1\n        assert padded_mask.size()[0] == max_num_nodes\n        assert torch.all(padded_mask[-padding_num:] == mask_pad_value)\n\n\ndef test_uniform_padding():\n    pad_val = 10.0\n    p = UniformPadding(pad_val)\n    assert p.get_value() == pad_val\n    assert p.get_value(\"v1\", \"x\") == pad_val\n\n    p = UniformPadding()\n    assert p.get_value() == 0.0\n\n    with pytest.raises(ValueError, match=\"to be an integer or float\"):\n        UniformPadding('')\n\n\ndef test_attr_name_padding():\n    x_val = 10.0\n    y_val = 15.0\n    default = 3.0\n    padding_dict = {'x': x_val, 'y': UniformPadding(y_val)}\n    padding = AttrNamePadding(padding_dict, default=default)\n\n    assert padding.get_value(attr_name='x') == x_val\n    assert padding.get_value('v1', 'x') == x_val\n    assert padding.get_value(attr_name='y') == y_val\n    assert padding.get_value('v1', 'y') == y_val\n    assert padding.get_value(attr_name='x2') == default\n\n    padding = AttrNamePadding({})\n    assert padding.get_value(attr_name='x') == 0.0\n\n\ndef test_attr_name_padding_invalid():\n    with pytest.raises(ValueError, match=\"to be a dictionary\"):\n        AttrNamePadding(10.0)\n\n    with pytest.raises(ValueError, match=\"to be a string\"):\n        AttrNamePadding({10: 10.0})\n\n    with pytest.raises(ValueError, match=\"to be of type\"):\n        AttrNamePadding({\"x\": {}})\n\n    with pytest.raises(ValueError, match=\"to be of type\"):\n        AttrNamePadding({\"x\": {}})\n\n    node_type_padding = NodeTypePadding({\"x\": 10.0})\n    with pytest.raises(ValueError, match=\"to be of type\"):\n        AttrNamePadding({'x': node_type_padding})\n\n\n@pytest.mark.parametrize('store_type', ['node', 'edge'])\ndef test_node_edge_type_padding(store_type):\n    if store_type == \"node\":\n        stores = ['v1', 'v2', 'v3', 'v4']\n        padding_cls = NodeTypePadding\n    else:\n        stores = [('v1', 'e1', 'v1'), ('v1', 'e2', 'v1'), ('v1', 'e1', 'v2'),\n                  ('v2', 'e1', 'v1')]\n        padding_cls = EdgeTypePadding\n\n    s0_default = 3.0\n    s0_padding_dict = {'x': 10.0, 'y': -12.0}\n    s0_padding = AttrNamePadding(s0_padding_dict, s0_default)\n    s1_default = 0.1\n    s1_padding_dict = {'y': 0.0, 'p': 13.0}\n    s1_padding = AttrNamePadding(s1_padding_dict, s1_default)\n\n    s2_default = 7.5\n    store_default = -11.0\n    padding_dict = {\n        stores[0]: s0_padding,\n        stores[1]: s1_padding,\n        stores[2]: s2_default\n    }\n    padding = padding_cls(padding_dict, store_default)\n\n    assert padding.get_value(stores[0], 'x') == s0_padding_dict['x']\n    assert padding.get_value(stores[0], 'y') == s0_padding_dict['y']\n    assert padding.get_value(stores[0], 'p') == s0_default\n    assert padding.get_value(stores[0], 'z') == s0_default\n\n    assert padding.get_value(stores[1], 'x') == s1_default\n    assert padding.get_value(stores[1], 'y') == s1_padding_dict['y']\n    assert padding.get_value(stores[1], 'p') == s1_padding_dict['p']\n    assert padding.get_value(stores[1], 'z') == s1_default\n\n    assert padding.get_value(stores[2], 'x') == s2_default\n    assert padding.get_value(stores[2], 'z') == s2_default\n\n    assert padding.get_value(stores[3], 'x') == store_default\n\n\ndef test_edge_padding_invalid():\n    with pytest.raises(ValueError, match=\"to be a tuple\"):\n        EdgeTypePadding({'v1': 10.0})\n\n    with pytest.raises(ValueError, match=\"got 1\"):\n        EdgeTypePadding({('v1', ): 10.0})\n\n    with pytest.raises(ValueError, match=\"got 2\"):\n        EdgeTypePadding({('v1', 'v2'): 10.0})\n\n    with pytest.raises(ValueError, match=\"got 4\"):\n        EdgeTypePadding({('v1', 'e2', 'v1', 'v2'): 10.0})\n"
  },
  {
    "path": "test/transforms/test_point_pair_features.py",
    "content": "from math import pi as PI\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import PointPairFeatures\n\n\ndef test_point_pair_features():\n    transform = PointPairFeatures()\n    assert str(transform) == 'PointPairFeatures()'\n\n    pos = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]])\n    edge_index = torch.tensor([[0, 1], [1, 0]])\n    norm = torch.tensor([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]])\n    edge_attr = torch.tensor([1.0, 1.0])\n    data = Data(edge_index=edge_index, pos=pos, norm=norm)\n\n    data = transform(data)\n    assert len(data) == 4\n    assert data.pos.tolist() == pos.tolist()\n    assert data.norm.tolist() == norm.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert torch.allclose(\n        data.edge_attr,\n        torch.tensor([[1.0, 0.0, 0.0, 0.0], [1.0, PI, PI, 0.0]]),\n        atol=1e-4,\n    )\n\n    data = Data(edge_index=edge_index, pos=pos, norm=norm, edge_attr=edge_attr)\n    data = transform(data)\n    assert len(data) == 4\n    assert data.pos.tolist() == pos.tolist()\n    assert data.norm.tolist() == norm.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert torch.allclose(\n        data.edge_attr,\n        torch.tensor([[1.0, 1.0, 0.0, 0.0, 0.0], [1.0, 1.0, PI, PI, 0.0]]),\n        atol=1e-4,\n    )\n"
  },
  {
    "path": "test/transforms/test_polar.py",
    "content": "from math import pi as PI\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import Polar\n\n\ndef test_polar():\n    assert str(Polar()) == 'Polar(norm=True, max_value=None)'\n\n    pos = torch.tensor([[0.0, 0.0], [1.0, 0.0]])\n    edge_index = torch.tensor([[0, 1], [1, 0]])\n    edge_attr = torch.tensor([1.0, 1.0])\n\n    data = Data(edge_index=edge_index, pos=pos)\n    data = Polar(norm=False)(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert torch.allclose(\n        data.edge_attr,\n        torch.tensor([[1.0, 0.0], [1.0, PI]]),\n        atol=1e-4,\n    )\n\n    data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr)\n    data = Polar(norm=True)(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert torch.allclose(\n        data.edge_attr,\n        torch.tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.5]]),\n        atol=1e-4,\n    )\n"
  },
  {
    "path": "test/transforms/test_radius_graph.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import RadiusGraph\nfrom torch_geometric.utils import coalesce\n\n\n@withPackage('torch_cluster')\ndef test_radius_graph():\n    assert str(RadiusGraph(r=1)) == 'RadiusGraph(r=1)'\n\n    pos = torch.tensor([\n        [0.0, 0.0],\n        [1.0, 0.0],\n        [2.0, 0.0],\n        [0.0, 1.0],\n        [-2.0, 0.0],\n        [0.0, -2.0],\n    ])\n\n    data = Data(pos=pos)\n    data = RadiusGraph(r=1.5)(data)\n    assert len(data) == 2\n    assert data.pos.tolist() == pos.tolist()\n    assert coalesce(data.edge_index).tolist() == [[0, 0, 1, 1, 1, 2, 3, 3],\n                                                  [1, 3, 0, 2, 3, 1, 0, 1]]\n"
  },
  {
    "path": "test/transforms/test_random_flip.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import RandomFlip\n\n\ndef test_random_flip():\n    assert str(RandomFlip(axis=0)) == 'RandomFlip(axis=0, p=0.5)'\n\n    pos = torch.tensor([[-1.0, 1.0], [-3.0, 0.0], [2.0, -1.0]])\n\n    data = Data(pos=pos)\n    data = RandomFlip(axis=0, p=1)(data)\n    assert len(data) == 1\n    assert data.pos.tolist() == [[1.0, 1.0], [3.0, 0.0], [-2.0, -1.0]]\n\n    data = Data(pos=pos)\n    data = RandomFlip(axis=1, p=1)(data)\n    assert len(data) == 1\n    assert data.pos.tolist() == [[-1.0, -1.0], [-3.0, 0.0], [2.0, 1.0]]\n"
  },
  {
    "path": "test/transforms/test_random_jitter.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import RandomJitter\n\n\ndef test_random_jitter():\n    assert str(RandomJitter(0.1)) == 'RandomJitter(0.1)'\n\n    pos = torch.tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])\n\n    data = Data(pos=pos)\n    data = RandomJitter(0)(data)\n    assert len(data) == 1\n    assert torch.allclose(data.pos, pos)\n\n    data = Data(pos=pos)\n    data = RandomJitter(0.1)(data)\n    assert len(data) == 1\n    assert data.pos.min() >= -0.1\n    assert data.pos.max() <= 0.1\n\n    data = Data(pos=pos)\n    data = RandomJitter([0.1, 1])(data)\n    assert len(data) == 1\n    assert data.pos[:, 0].min() >= -0.1\n    assert data.pos[:, 0].max() <= 0.1\n    assert data.pos[:, 1].min() >= -1\n    assert data.pos[:, 1].max() <= 1\n"
  },
  {
    "path": "test/transforms/test_random_link_split.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.testing import (\n    get_random_edge_index,\n    onlyFullTest,\n    onlyOnline,\n)\nfrom torch_geometric.transforms import RandomLinkSplit, ToSparseTensor\nfrom torch_geometric.utils import is_undirected, to_undirected\n\n\ndef test_random_link_split():\n    assert str(RandomLinkSplit()) == ('RandomLinkSplit('\n                                      'num_val=0.1, num_test=0.2)')\n\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5],\n                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]])\n    edge_attr = torch.randn(edge_index.size(1), 3)\n\n    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=100)\n\n    # No test split:\n    transform = RandomLinkSplit(num_val=2, num_test=0, is_undirected=True)\n    train_data, val_data, test_data = transform(data)\n\n    assert len(train_data) == 5\n    assert train_data.num_nodes == 100\n    assert train_data.edge_index.size() == (2, 6)\n    assert train_data.edge_attr.size() == (6, 3)\n    assert train_data.edge_label_index.size(1) == 6\n    assert train_data.edge_label.size(0) == 6\n\n    assert len(val_data) == 5\n    assert val_data.num_nodes == 100\n    assert val_data.edge_index.size() == (2, 6)\n    assert val_data.edge_attr.size() == (6, 3)\n    assert val_data.edge_label_index.size(1) == 4\n    assert val_data.edge_label.size(0) == 4\n\n    assert len(test_data) == 5\n    assert test_data.num_nodes == 100\n    assert test_data.edge_index.size() == (2, 10)\n    assert test_data.edge_attr.size() == (10, 3)\n    assert test_data.edge_label_index.size() == (2, 0)\n    assert test_data.edge_label.size() == (0, )\n\n    # Percentage split:\n    transform = RandomLinkSplit(num_val=0.2, num_test=0.2,\n                                neg_sampling_ratio=2.0, is_undirected=False)\n    train_data, val_data, test_data = transform(data)\n\n    assert len(train_data) == 5\n    assert train_data.num_nodes == 100\n    assert train_data.edge_index.size() == (2, 6)\n    assert train_data.edge_attr.size() == (6, 3)\n    assert train_data.edge_label_index.size(1) == 18\n    assert train_data.edge_label.size(0) == 18\n\n    assert len(val_data) == 5\n    assert val_data.num_nodes == 100\n    assert val_data.edge_index.size() == (2, 6)\n    assert val_data.edge_attr.size() == (6, 3)\n    assert val_data.edge_label_index.size(1) == 6\n    assert val_data.edge_label.size(0) == 6\n\n    assert len(test_data) == 5\n    assert test_data.num_nodes == 100\n    assert test_data.edge_index.size() == (2, 8)\n    assert test_data.edge_attr.size() == (8, 3)\n    assert test_data.edge_label_index.size(1) == 6\n    assert test_data.edge_label.size(0) == 6\n\n    # Disjoint training split:\n    transform = RandomLinkSplit(num_val=0.2, num_test=0.2, is_undirected=False,\n                                disjoint_train_ratio=0.5)\n    train_data, val_data, test_data = transform(data)\n\n    assert len(train_data) == 5\n    assert train_data.num_nodes == 100\n    assert train_data.edge_index.size() == (2, 3)\n    assert train_data.edge_attr.size() == (3, 3)\n    assert train_data.edge_label_index.size(1) == 6\n    assert train_data.edge_label.size(0) == 6\n\n\ndef test_random_link_split_with_to_sparse_tensor():\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5],\n                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]])\n    data = Data(edge_index=edge_index, num_nodes=6)\n\n    transform = RandomLinkSplit(num_val=2, num_test=2, neg_sampling_ratio=0.0)\n    train_data1, _, _ = transform(data)\n    assert train_data1.edge_index.size(1) == train_data1.edge_label.size(0)\n\n    train_data2 = ToSparseTensor()(train_data1)\n    assert train_data1.edge_label.equal(train_data2.edge_label)\n    assert train_data1.edge_label_index.equal(train_data2.edge_label_index)\n\n\ndef test_random_link_split_with_label():\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5],\n                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]])\n    edge_label = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])\n\n    data = Data(edge_index=edge_index, edge_label=edge_label, num_nodes=6)\n\n    transform = RandomLinkSplit(num_val=0.2, num_test=0.2,\n                                neg_sampling_ratio=0.0)\n    train_data, _, _ = transform(data)\n    assert len(train_data) == 4\n    assert train_data.num_nodes == 6\n    assert train_data.edge_index.size() == (2, 6)\n    assert train_data.edge_label_index.size() == (2, 6)\n    assert train_data.edge_label.size() == (6, )\n    assert train_data.edge_label.min() == 0\n    assert train_data.edge_label.max() == 1\n\n    transform = RandomLinkSplit(num_val=0.2, num_test=0.2,\n                                neg_sampling_ratio=1.0)\n    train_data, _, _ = transform(data)\n    assert len(train_data) == 4\n    assert train_data.num_nodes == 6\n    assert train_data.edge_index.size() == (2, 6)\n    assert train_data.edge_label_index.size() == (2, 12)\n    assert train_data.edge_label.size() == (12, )\n    assert train_data.edge_label.min() == 0\n    assert train_data.edge_label.max() == 2\n    assert train_data.edge_label[6:].sum() == 0\n\n\ndef test_random_link_split_increment_label():\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5],\n                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]])\n    edge_label = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])\n\n    data = Data(edge_index=edge_index, edge_label=edge_label, num_nodes=6)\n\n    transform = RandomLinkSplit(num_val=0, num_test=0, neg_sampling_ratio=0.0)\n    train_data, _, _ = transform(data)\n    assert train_data.edge_label.numel() == edge_index.size(1)\n    assert train_data.edge_label.min() == 0\n    assert train_data.edge_label.max() == 1\n\n    transform = RandomLinkSplit(num_val=0, num_test=0, neg_sampling_ratio=1.0)\n    train_data, _, _ = transform(data)\n    assert train_data.edge_label.numel() == 2 * edge_index.size(1)\n    assert train_data.edge_label.min() == 0\n    assert train_data.edge_label.max() == 2\n    assert train_data.edge_label[edge_index.size(1):].sum() == 0\n\n\ndef test_random_link_split_on_hetero_data():\n    data = HeteroData()\n\n    data['p'].x = torch.arange(100)\n    data['a'].x = torch.arange(100, 300)\n\n    data['p', 'p'].edge_index = get_random_edge_index(100, 100, 500)\n    data['p', 'p'].edge_index = to_undirected(data['p', 'p'].edge_index)\n    data['p', 'p'].edge_attr = torch.arange(data['p', 'p'].num_edges)\n    data['p', 'a'].edge_index = get_random_edge_index(100, 200, 1000)\n    data['p', 'a'].edge_attr = torch.arange(500, 1500)\n    data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0])\n    data['a', 'p'].edge_attr = torch.arange(1500, 2500)\n\n    transform = RandomLinkSplit(num_val=0.2, num_test=0.2, is_undirected=True,\n                                edge_types=('p', 'p'))\n    train_data, val_data, test_data = transform(data)\n\n    assert len(train_data['p']) == 1\n    assert len(train_data['a']) == 1\n    assert len(train_data['p', 'p']) == 4\n    assert len(train_data['p', 'a']) == 2\n    assert len(train_data['a', 'p']) == 2\n\n    assert is_undirected(train_data['p', 'p'].edge_index,\n                         train_data['p', 'p'].edge_attr)\n    assert is_undirected(val_data['p', 'p'].edge_index,\n                         val_data['p', 'p'].edge_attr)\n    assert is_undirected(test_data['p', 'p'].edge_index,\n                         test_data['p', 'p'].edge_attr)\n\n    transform = RandomLinkSplit(num_val=0.2, num_test=0.2,\n                                edge_types=('p', 'a'),\n                                rev_edge_types=('a', 'p'))\n    train_data, val_data, test_data = transform(data)\n\n    assert len(train_data['p']) == 1\n    assert len(train_data['a']) == 1\n    assert len(train_data['p', 'p']) == 2\n    assert len(train_data['p', 'a']) == 4\n    assert len(train_data['a', 'p']) == 2\n\n    assert train_data['p', 'a'].edge_index.size() == (2, 600)\n    assert train_data['p', 'a'].edge_attr.size() == (600, )\n    assert train_data['p', 'a'].edge_attr.min() >= 500\n    assert train_data['p', 'a'].edge_attr.max() <= 1500\n    assert train_data['a', 'p'].edge_index.size() == (2, 600)\n    assert train_data['a', 'p'].edge_attr.size() == (600, )\n    assert train_data['a', 'p'].edge_attr.min() >= 500\n    assert train_data['a', 'p'].edge_attr.max() <= 1500\n    assert train_data['p', 'a'].edge_label_index.size() == (2, 1200)\n    assert train_data['p', 'a'].edge_label.size() == (1200, )\n\n    assert val_data['p', 'a'].edge_index.size() == (2, 600)\n    assert val_data['p', 'a'].edge_attr.size() == (600, )\n    assert val_data['p', 'a'].edge_attr.min() >= 500\n    assert val_data['p', 'a'].edge_attr.max() <= 1500\n    assert val_data['a', 'p'].edge_index.size() == (2, 600)\n    assert val_data['a', 'p'].edge_attr.size() == (600, )\n    assert val_data['a', 'p'].edge_attr.min() >= 500\n    assert val_data['a', 'p'].edge_attr.max() <= 1500\n    assert val_data['p', 'a'].edge_label_index.size() == (2, 400)\n    assert val_data['p', 'a'].edge_label.size() == (400, )\n\n    assert test_data['p', 'a'].edge_index.size() == (2, 800)\n    assert test_data['p', 'a'].edge_attr.size() == (800, )\n    assert test_data['p', 'a'].edge_attr.min() >= 500\n    assert test_data['p', 'a'].edge_attr.max() <= 1500\n    assert test_data['a', 'p'].edge_index.size() == (2, 800)\n    assert test_data['a', 'p'].edge_attr.size() == (800, )\n    assert test_data['a', 'p'].edge_attr.min() >= 500\n    assert test_data['a', 'p'].edge_attr.max() <= 1500\n    assert test_data['p', 'a'].edge_label_index.size() == (2, 400)\n    assert test_data['p', 'a'].edge_label.size() == (400, )\n\n    transform = RandomLinkSplit(num_val=0.2, num_test=0.2, is_undirected=True,\n                                edge_types=[('p', 'p'), ('p', 'a')],\n                                rev_edge_types=[None, ('a', 'p')])\n    train_data, val_data, test_data = transform(data)\n\n    assert len(train_data['p']) == 1\n    assert len(train_data['a']) == 1\n    assert len(train_data['p', 'p']) == 4\n    assert len(train_data['p', 'a']) == 4\n    assert len(train_data['a', 'p']) == 2\n\n    assert is_undirected(train_data['p', 'p'].edge_index,\n                         train_data['p', 'p'].edge_attr)\n    assert train_data['p', 'a'].edge_index.size() == (2, 600)\n    assert train_data['a', 'p'].edge_index.size() == (2, 600)\n\n    # No reverse edge types specified:\n    transform = RandomLinkSplit(edge_types=[('p', 'p'), ('p', 'a')])\n    train_data, val_data, test_data = transform(data)\n    assert train_data['p', 'p'].num_edges < data['p', 'p'].num_edges\n    assert train_data['p', 'a'].num_edges < data['p', 'a'].num_edges\n    assert train_data['a', 'p'].num_edges == data['a', 'p'].num_edges\n\n\ndef test_random_link_split_on_undirected_hetero_data():\n    data = HeteroData()\n    data['p'].x = torch.arange(100)\n    data['p', 'p'].edge_index = get_random_edge_index(100, 100, 500)\n    data['p', 'p'].edge_index = to_undirected(data['p', 'p'].edge_index)\n\n    transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p'))\n    train_data, val_data, test_data = transform(data)\n    assert train_data['p', 'p'].is_undirected()\n\n    transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p'),\n                                rev_edge_types=('p', 'p'))\n    train_data, val_data, test_data = transform(data)\n    assert train_data['p', 'p'].is_undirected()\n\n    transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p'),\n                                rev_edge_types=('p', 'p'))\n    train_data, val_data, test_data = transform(data)\n    assert train_data['p', 'p'].is_undirected()\n\n\ndef test_random_link_split_insufficient_negative_edges():\n    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2], [1, 3, 0, 2, 0, 1]])\n    data = Data(edge_index=edge_index, num_nodes=4)\n\n    transform = RandomLinkSplit(num_val=0.34, num_test=0.34,\n                                is_undirected=False, neg_sampling_ratio=2,\n                                split_labels=True)\n\n    with pytest.warns(UserWarning, match=\"not enough negative edges\"):\n        train_data, val_data, test_data = transform(data)\n\n    assert train_data.neg_edge_label_index.size() == (2, 2)\n    assert val_data.neg_edge_label_index.size() == (2, 2)\n    assert test_data.neg_edge_label_index.size() == (2, 2)\n\n\ndef test_random_link_split_non_contiguous():\n    edge_index = get_random_edge_index(40, 40, num_edges=150)\n    edge_index = edge_index[:, :100]\n    assert not edge_index.is_contiguous()\n\n    data = Data(edge_index=edge_index, num_nodes=40)\n    transform = RandomLinkSplit(num_val=0.2, num_test=0.2)\n    train_data, val_data, test_data = transform(data)\n    assert train_data.num_edges == 60\n    assert train_data.edge_index.is_contiguous()\n\n    data = HeteroData()\n    data['p'].num_nodes = 40\n    data['p', 'p'].edge_index = edge_index\n    transform = RandomLinkSplit(num_val=0.2, num_test=0.2,\n                                edge_types=('p', 'p'))\n    train_data, val_data, test_data = transform(data)\n    assert train_data['p', 'p'].num_edges == 60\n    assert train_data['p', 'p'].edge_index.is_contiguous()\n\n\n@onlyOnline\n@onlyFullTest\ndef test_random_link_split_on_dataset(get_dataset):\n    dataset = get_dataset(name='MUTAG')\n\n    dataset.transform = RandomLinkSplit(\n        num_val=0.1,\n        num_test=0.1,\n        disjoint_train_ratio=0.3,\n        add_negative_train_samples=False,\n    )\n\n    train_dataset, val_dataset, test_dataset = zip(*dataset)\n    assert len(train_dataset) == len(dataset)\n    assert len(val_dataset) == len(dataset)\n    assert len(test_dataset) == len(dataset)\n\n    assert isinstance(train_dataset[0], Data)\n    assert train_dataset[0].edge_label.min() == 1.0\n    assert train_dataset[0].edge_label.max() == 1.0\n\n    assert isinstance(val_dataset[0], Data)\n    assert val_dataset[0].edge_label.min() == 0.0\n    assert val_dataset[0].edge_label.max() == 1.0\n\n    assert isinstance(test_dataset[0], Data)\n    assert test_dataset[0].edge_label.min() == 0.0\n    assert test_dataset[0].edge_label.max() == 1.0\n"
  },
  {
    "path": "test/transforms/test_random_node_split.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import RandomNodeSplit\n\n\n@pytest.mark.parametrize('num_splits', [1, 2])\ndef test_random_node_split(num_splits):\n    num_nodes, num_classes = 1000, 4\n    x = torch.randn(num_nodes, 16)\n    y = torch.randint(num_classes, (num_nodes, ), dtype=torch.long)\n    data = Data(x=x, y=y)\n\n    transform = RandomNodeSplit(split='train_rest', num_splits=num_splits,\n                                num_val=100, num_test=200)\n    assert str(transform) == 'RandomNodeSplit(split=train_rest)'\n    data = transform(data)\n    assert len(data) == 5\n\n    train_mask = data.train_mask\n    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask\n    assert train_mask.size() == (num_nodes, num_splits)\n    val_mask = data.val_mask\n    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask\n    assert val_mask.size() == (num_nodes, num_splits)\n    test_mask = data.test_mask\n    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask\n    assert test_mask.size() == (num_nodes, num_splits)\n\n    for i in range(train_mask.size(-1)):\n        assert train_mask[:, i].sum() == num_nodes - 100 - 200\n        assert val_mask[:, i].sum() == 100\n        assert test_mask[:, i].sum() == 200\n        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0\n        assert ((train_mask[:, i] | val_mask[:, i]\n                 | test_mask[:, i]).sum() == num_nodes)\n\n    transform = RandomNodeSplit(split='train_rest', num_splits=num_splits,\n                                num_val=0.1, num_test=0.2)\n    data = transform(data)\n\n    train_mask = data.train_mask\n    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask\n    val_mask = data.val_mask\n    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask\n    test_mask = data.test_mask\n    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask\n\n    for i in range(train_mask.size(-1)):\n        assert train_mask[:, i].sum() == num_nodes - 100 - 200\n        assert val_mask[:, i].sum() == 100\n        assert test_mask[:, i].sum() == 200\n        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0\n        assert ((train_mask[:, i] | val_mask[:, i]\n                 | test_mask[:, i]).sum() == num_nodes)\n\n    transform = RandomNodeSplit(split='test_rest', num_splits=num_splits,\n                                num_train_per_class=10, num_val=100)\n    assert str(transform) == 'RandomNodeSplit(split=test_rest)'\n    data = transform(data)\n    assert len(data) == 5\n\n    train_mask = data.train_mask\n    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask\n    val_mask = data.val_mask\n    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask\n    test_mask = data.test_mask\n    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask\n\n    for i in range(train_mask.size(-1)):\n        assert train_mask[:, i].sum() == 10 * num_classes\n        assert val_mask[:, i].sum() == 100\n        assert test_mask[:, i].sum() == num_nodes - 10 * num_classes - 100\n        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0\n        assert ((train_mask[:, i] | val_mask[:, i]\n                 | test_mask[:, i]).sum() == num_nodes)\n\n    transform = RandomNodeSplit(split='test_rest', num_splits=num_splits,\n                                num_train_per_class=10, num_val=0.1)\n    data = transform(data)\n\n    train_mask = data.train_mask\n    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask\n    val_mask = data.val_mask\n    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask\n    test_mask = data.test_mask\n    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask\n\n    for i in range(train_mask.size(-1)):\n        assert train_mask[:, i].sum() == 10 * num_classes\n        assert val_mask[:, i].sum() == 100\n        assert test_mask[:, i].sum() == num_nodes - 10 * num_classes - 100\n        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0\n        assert ((train_mask[:, i] | val_mask[:, i]\n                 | test_mask[:, i]).sum() == num_nodes)\n\n    transform = RandomNodeSplit(split='random', num_splits=num_splits,\n                                num_train_per_class=10, num_val=100,\n                                num_test=200)\n    assert str(transform) == 'RandomNodeSplit(split=random)'\n    data = transform(data)\n    assert len(data) == 5\n\n    train_mask = data.train_mask\n    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask\n    val_mask = data.val_mask\n    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask\n    test_mask = data.test_mask\n    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask\n\n    for i in range(train_mask.size(-1)):\n        assert train_mask[:, i].sum() == 10 * num_classes\n        assert val_mask[:, i].sum() == 100\n        assert test_mask[:, i].sum() == 200\n        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0\n        assert ((train_mask[:, i] | val_mask[:, i]\n                 | test_mask[:, i]).sum() == 10 * num_classes + 100 + 200)\n\n    transform = RandomNodeSplit(split='random', num_splits=num_splits,\n                                num_train_per_class=10, num_val=0.1,\n                                num_test=0.2)\n    assert str(transform) == 'RandomNodeSplit(split=random)'\n    data = transform(data)\n\n    train_mask = data.train_mask\n    train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask\n    val_mask = data.val_mask\n    val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask\n    test_mask = data.test_mask\n    test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask\n\n    for i in range(train_mask.size(-1)):\n        assert train_mask[:, i].sum() == 10 * num_classes\n        assert val_mask[:, i].sum() == 100\n        assert test_mask[:, i].sum() == 200\n        assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0\n        assert ((train_mask[:, i] | val_mask[:, i]\n                 | test_mask[:, i]).sum() == 10 * num_classes + 100 + 200)\n\n\ndef test_random_node_split_on_hetero_data():\n    data = HeteroData()\n\n    data['paper'].x = torch.randn(2000, 16)\n    data['paper'].y = torch.randint(4, (2000, ), dtype=torch.long)\n    data['author'].x = torch.randn(300, 16)\n\n    transform = RandomNodeSplit()\n    assert str(transform) == 'RandomNodeSplit(split=train_rest)'\n    data = transform(data)\n    assert len(data) == 5\n\n    assert len(data['author']) == 1\n    assert len(data['paper']) == 5\n\n    assert data['paper'].train_mask.sum() == 500\n    assert data['paper'].val_mask.sum() == 500\n    assert data['paper'].test_mask.sum() == 1000\n"
  },
  {
    "path": "test/transforms/test_random_rotate.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import RandomRotate\n\n\ndef test_random_rotate():\n    assert str(RandomRotate([-180, 180])) == ('RandomRotate('\n                                              '[-180, 180], axis=0)')\n\n    pos = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])\n\n    data = Data(pos=pos)\n    data = RandomRotate(0)(data)\n    assert len(data) == 1\n    assert data.pos.tolist() == pos.tolist()\n\n    data = Data(pos=pos)\n    data = RandomRotate([180, 180])(data)\n    assert len(data) == 1\n    assert data.pos.tolist() == [[1, 1], [1, -1], [-1, 1], [-1, -1]]\n\n    pos = torch.tensor([\n        [-1.0, -1.0, 1.0],\n        [-1.0, 1.0, 1.0],\n        [1.0, -1.0, -1.0],\n        [1.0, 1.0, -1.0],\n    ])\n\n    data = Data(pos=pos)\n    data = RandomRotate([180, 180], axis=0)(data)\n    assert len(data) == 1\n    assert data.pos.tolist() == [[-1, 1, -1], [-1, -1, -1], [1, 1, 1],\n                                 [1, -1, 1]]\n\n    data = Data(pos=pos)\n    data = RandomRotate([180, 180], axis=1)(data)\n    assert len(data) == 1\n    assert data.pos.tolist() == [[1, -1, -1], [1, 1, -1], [-1, -1, 1],\n                                 [-1, 1, 1]]\n\n    data = Data(pos=pos)\n    data = RandomRotate([180, 180], axis=2)(data)\n    assert len(data) == 1\n    assert data.pos.tolist() == [[1, 1, 1], [1, -1, 1], [-1, 1, -1],\n                                 [-1, -1, -1]]\n"
  },
  {
    "path": "test/transforms/test_random_scale.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import RandomScale\n\n\ndef test_random_scale():\n    assert str(RandomScale([1, 2])) == 'RandomScale([1, 2])'\n\n    pos = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])\n\n    data = Data(pos=pos)\n    data = RandomScale([1, 1])(data)\n    assert len(data) == 1\n    assert data.pos.tolist() == pos.tolist()\n\n    data = Data(pos=pos)\n    data = RandomScale([2, 2])(data)\n    assert len(data) == 1\n    assert data.pos.tolist() == [[-2, -2], [-2, 2], [2, -2], [2, 2]]\n"
  },
  {
    "path": "test/transforms/test_random_shear.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import RandomShear\n\n\ndef test_random_shear():\n    assert str(RandomShear(0.1)) == 'RandomShear(0.1)'\n\n    pos = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])\n\n    data = Data(pos=pos)\n    data = RandomShear(0)(data)\n    assert len(data) == 1\n    assert torch.allclose(data.pos, pos)\n\n    data = Data(pos=pos)\n    data = RandomShear(0.1)(data)\n    assert len(data) == 1\n    assert not torch.allclose(data.pos, pos)\n"
  },
  {
    "path": "test/transforms/test_remove_duplicated_edges.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import RemoveDuplicatedEdges\n\n\ndef test_remove_duplicated_edges():\n    edge_index = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1],\n                               [0, 0, 1, 1, 0, 0, 1, 1]])\n    edge_weight = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])\n    data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=2)\n\n    transform = RemoveDuplicatedEdges()\n    assert str(transform) == 'RemoveDuplicatedEdges()'\n\n    out = transform(data)\n    assert len(out) == 3\n    assert out.num_nodes == 2\n    assert out.edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]]\n    assert out.edge_weight.tolist() == [3, 7, 11, 15]\n"
  },
  {
    "path": "test/transforms/test_remove_isolated_nodes.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import RemoveIsolatedNodes\n\n\ndef test_remove_isolated_nodes():\n    assert str(RemoveIsolatedNodes()) == 'RemoveIsolatedNodes()'\n\n    data = Data()\n    data.x = torch.arange(3)\n    data.edge_index = torch.tensor([[0, 2], [2, 0]])\n    data.edge_attr = torch.arange(2)\n\n    data = RemoveIsolatedNodes()(data)\n\n    assert len(data) == 3\n    assert data.x.tolist() == [0, 2]\n    assert data.edge_index.tolist() == [[0, 1], [1, 0]]\n    assert data.edge_attr.tolist() == [0, 1]\n\n\ndef test_remove_isolated_nodes_in_hetero_data():\n    data = HeteroData()\n\n    data['p'].x = torch.arange(6)\n    data['a'].x = torch.arange(6)\n    data['i'].num_nodes = 4\n\n    # isolated paper nodes: {4}\n    # isolated author nodes: {3, 4, 5}\n    # isolated institution nodes: {0, 1, 2, 3}\n    data['p', '1', 'p'].edge_index = torch.tensor([[0, 1, 2], [0, 1, 3]])\n    data['p', '2', 'a'].edge_index = torch.tensor([[1, 3, 5], [0, 1, 2]])\n    data['p', '2', 'a'].edge_attr = torch.arange(3)\n    data['p', '3', 'a'].edge_index = torch.tensor([[5], [2]])\n\n    data = RemoveIsolatedNodes()(data)\n\n    assert len(data) == 4\n    assert data['p'].num_nodes == 5\n    assert data['a'].num_nodes == 3\n    assert data['i'].num_nodes == 0\n\n    assert data['p'].x.tolist() == [0, 1, 2, 3, 5]\n    assert data['a'].x.tolist() == [0, 1, 2]\n\n    assert data['1'].edge_index.tolist() == [[0, 1, 2], [0, 1, 3]]\n    assert data['2'].edge_index.tolist() == [[1, 3, 4], [0, 1, 2]]\n    assert data['2'].edge_attr.tolist() == [0, 1, 2]\n    assert data['3'].edge_index.tolist() == [[4], [2]]\n"
  },
  {
    "path": "test/transforms/test_remove_self_loops.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import RemoveSelfLoops\n\n\ndef test_remove_self_loops():\n    assert str(RemoveSelfLoops()) == 'RemoveSelfLoops()'\n\n    assert len(RemoveSelfLoops()(Data())) == 0\n\n    edge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]])\n    edge_weight = torch.tensor([1, 2, 3, 4])\n    edge_attr = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])\n\n    data = Data(edge_index=edge_index, num_nodes=3)\n    data = RemoveSelfLoops()(data)\n    assert len(data) == 2\n    assert data.edge_index.tolist() == [[1, 2], [0, 1]]\n    assert data.num_nodes == 3\n\n    data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3)\n    data = RemoveSelfLoops(attr='edge_weight')(data)\n    assert data.edge_index.tolist() == [[1, 2], [0, 1]]\n    assert data.num_nodes == 3\n    assert data.edge_weight.tolist() == [2, 4]\n\n    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)\n    data = RemoveSelfLoops(attr='edge_attr')(data)\n    assert data.edge_index.tolist() == [[1, 2], [0, 1]]\n    assert data.num_nodes == 3\n    assert data.edge_attr.tolist() == [[3, 4], [7, 8]]\n\n\ndef test_hetero_remove_self_loops():\n    edge_index = torch.tensor([[0, 1, 1, 2], [0, 0, 1, 1]])\n\n    data = HeteroData()\n    data['v'].num_nodes = 3\n    data['w'].num_nodes = 3\n    data['v', 'v'].edge_index = edge_index\n    data['v', 'w'].edge_index = edge_index\n    data = RemoveSelfLoops()(data)\n    assert data['v', 'v'].edge_index.tolist() == [[1, 2], [0, 1]]\n    assert data['v', 'w'].edge_index.tolist() == edge_index.tolist()\n"
  },
  {
    "path": "test/transforms/test_remove_training_classes.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import RemoveTrainingClasses\n\n\ndef test_remove_training_classes():\n    y = torch.tensor([1, 0, 0, 2, 1, 3])\n    train_mask = torch.tensor([False, False, True, True, True, True])\n\n    data = Data(y=y, train_mask=train_mask)\n\n    transform = RemoveTrainingClasses(classes=[0, 1])\n    assert str(transform) == 'RemoveTrainingClasses([0, 1])'\n\n    data = transform(data)\n    assert len(data) == 2\n    assert torch.equal(data.y, y)\n    assert data.train_mask.tolist() == [False, False, False, True, False, True]\n"
  },
  {
    "path": "test/transforms/test_rooted_subgraph.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.transforms import RootedEgoNets, RootedRWSubgraph\n\n\ndef test_rooted_ego_nets():\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_attr = torch.randn(4, 8)\n    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)\n\n    transform = RootedEgoNets(num_hops=1)\n    assert str(transform) == 'RootedEgoNets(num_hops=1)'\n\n    out = transform(data)\n    assert len(out) == 8\n\n    assert torch.equal(out.x, data.x)\n    assert torch.equal(out.edge_index, data.edge_index)\n    assert torch.equal(out.edge_attr, data.edge_attr)\n\n    assert out.sub_edge_index.tolist() == [[0, 1, 2, 3, 3, 4, 5, 6],\n                                           [1, 0, 3, 2, 4, 3, 6, 5]]\n    assert out.n_id.tolist() == [0, 1, 0, 1, 2, 1, 2]\n    assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 1, 2, 2]\n    assert out.e_id.tolist() == [0, 1, 0, 1, 2, 3, 2, 3]\n    assert out.e_sub_batch.tolist() == [0, 0, 1, 1, 1, 1, 2, 2]\n\n    out = out.map_data()\n    assert len(out) == 4\n\n    assert torch.allclose(out.x, x[[0, 1, 0, 1, 2, 1, 2]])\n    assert out.edge_index.tolist() == [[0, 1, 2, 3, 3, 4, 5, 6],\n                                       [1, 0, 3, 2, 4, 3, 6, 5]]\n    assert torch.allclose(out.edge_attr, edge_attr[[0, 1, 0, 1, 2, 3, 2, 3]])\n    assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 1, 2, 2]\n\n\n@withPackage('torch_cluster')\ndef test_rooted_rw_subgraph():\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    data = Data(edge_index=edge_index, num_nodes=3)\n\n    transform = RootedRWSubgraph(walk_length=1)\n    assert str(transform) == 'RootedRWSubgraph(walk_length=1)'\n\n    out = transform(data)\n    assert len(out) == 7\n\n    assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 2, 2]\n    assert out.sub_edge_index.size() == (2, 6)\n\n    out = out.map_data()\n    assert len(out) == 3\n\n    assert out.edge_index.size() == (2, 6)\n    assert out.num_nodes == 6\n    assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 2, 2]\n\n\ndef test_rooted_subgraph_minibatch():\n    x = torch.randn(3, 8)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_attr = torch.randn(4, 8)\n    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)\n\n    transform = RootedEgoNets(num_hops=1)\n    data = transform(data)\n\n    loader = DataLoader([data, data], batch_size=2)\n    batch = next(iter(loader))\n    batch = batch.map_data()\n    assert batch.num_graphs == len(batch) == 2\n\n    assert batch.x.size() == (14, 8)\n    assert batch.edge_index.size() == (2, 16)\n    assert batch.edge_attr.size() == (16, 8)\n    assert batch.n_sub_batch.size() == (14, )\n    assert batch.batch.size() == (14, )\n    assert batch.ptr.size() == (3, )\n\n    assert batch.edge_index.min() == 0\n    assert batch.edge_index.max() == 13\n\n    assert batch.n_sub_batch.min() == 0\n    assert batch.n_sub_batch.max() == 5\n"
  },
  {
    "path": "test/transforms/test_sample_points.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import SamplePoints\n\n\ndef test_sample_points():\n    assert str(SamplePoints(1024)) == 'SamplePoints(1024)'\n\n    pos = torch.tensor([\n        [0.0, 0.0, 0.0],\n        [1.0, 0.0, 0.0],\n        [0.0, 1.0, 0.0],\n        [1.0, 1.0, 0.0],\n    ])\n    face = torch.tensor([[0, 1], [1, 2], [2, 3]])\n\n    data = Data(pos=pos)\n    data.face = face\n    data = SamplePoints(8)(data)\n    assert len(data) == 1\n    assert pos[:, 0].min() >= 0 and pos[:, 0].max() <= 1\n    assert pos[:, 1].min() >= 0 and pos[:, 1].max() <= 1\n    assert pos[:, 2].abs().sum() == 0\n\n    data = Data(pos=pos)\n    data.face = face\n    data = SamplePoints(8, include_normals=True)(data)\n    assert len(data) == 2\n    assert data.normal[:, :2].abs().sum() == 0\n    assert data.normal[:, 2].abs().sum() == 8\n"
  },
  {
    "path": "test/transforms/test_sign.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import SIGN\n\n\ndef test_sign():\n    x = torch.ones(5, 3)\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 3, 4],\n        [1, 0, 3, 2, 4, 3],\n    ])\n    data = Data(x=x, edge_index=edge_index)\n\n    transform = SIGN(K=2)\n    assert str(transform) == 'SIGN(K=2)'\n\n    expected_x1 = torch.tensor([\n        [1, 1, 1],\n        [1, 1, 1],\n        [0.7071, 0.7071, 0.7071],\n        [1.4142, 1.4142, 1.4142],\n        [0.7071, 0.7071, 0.7071],\n    ])\n    expected_x2 = torch.ones(5, 3)\n\n    out = transform(data)\n    assert len(out) == 4\n    assert torch.equal(out.edge_index, edge_index)\n    assert torch.allclose(out.x, x)\n    assert torch.allclose(out.x1, expected_x1, atol=1e-4)\n    assert torch.allclose(out.x2, expected_x2)\n"
  },
  {
    "path": "test/transforms/test_spherical.py",
    "content": "from math import pi as PI\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import Spherical\n\n\ndef test_spherical():\n    assert str(Spherical()) == 'Spherical(norm=True, max_value=None)'\n\n    pos = torch.tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]])\n    edge_index = torch.tensor([[0, 1], [1, 0]])\n    edge_attr = torch.tensor([1.0, 1.0])\n\n    data = Data(edge_index=edge_index, pos=pos)\n    data = Spherical(norm=False)(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert torch.allclose(\n        data.edge_attr,\n        torch.tensor([[1.0, 0.0, PI / 2.0], [1.0, PI, PI / 2.0]]),\n        atol=1e-4,\n    )\n\n    data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr)\n    data = Spherical(norm=True)(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert torch.allclose(\n        data.edge_attr,\n        torch.tensor([[1.0, 1.0, 0.0, 0.5], [1.0, 1.0, 0.5, 0.5]]),\n        atol=1e-4,\n    )\n\n    pos = torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]])\n    edge_index = torch.tensor([[0, 1], [1, 0]])\n\n    data = Data(edge_index=edge_index, pos=pos)\n    data = Spherical(norm=False)(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert torch.allclose(\n        data.edge_attr,\n        torch.tensor([[1.0, 0.0, 0.0], [1.0, 0.0, PI]]),\n        atol=1e-4,\n    )\n\n    data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr)\n    data = Spherical(norm=True)(data)\n    assert len(data) == 3\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert torch.allclose(\n        data.edge_attr,\n        torch.tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 1.0]]),\n        atol=1e-4,\n    )\n"
  },
  {
    "path": "test/transforms/test_svd_feature_reduction.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import SVDFeatureReduction\n\n\ndef test_svd_feature_reduction():\n    assert str(SVDFeatureReduction(10)) == 'SVDFeatureReduction(10)'\n\n    x = torch.randn(4, 16)\n    U, S, _ = torch.linalg.svd(x)\n    data = Data(x=x)\n    data = SVDFeatureReduction(10)(data)\n    assert torch.allclose(data.x, torch.mm(U[:, :10], torch.diag(S[:10])))\n\n    x = torch.randn(4, 8)\n    data.x = x\n    data = SVDFeatureReduction(10)(Data(x=x))\n    assert torch.allclose(data.x, x)\n"
  },
  {
    "path": "test/transforms/test_target_indegree.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import TargetIndegree\n\n\ndef test_target_indegree():\n    assert str(TargetIndegree()) == 'TargetIndegree(norm=True, max_value=None)'\n\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_attr = torch.tensor([1.0, 1.0, 1.0, 1.0])\n\n    data = Data(edge_index=edge_index, num_nodes=3)\n    data = TargetIndegree(norm=False)(data)\n    assert len(data) == 3\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert data.edge_attr.tolist() == [[2], [1], [1], [2]]\n    assert data.num_nodes == 3\n\n    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3)\n    data = TargetIndegree(norm=True)(data)\n    assert len(data) == 3\n    assert data.edge_index.tolist() == edge_index.tolist()\n    assert data.edge_attr.tolist() == [[1, 1], [1, 0.5], [1, 0.5], [1, 1]]\n    assert data.num_nodes == 3\n"
  },
  {
    "path": "test/transforms/test_to_dense.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import ToDense\n\n\ndef test_to_dense():\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])\n    num_nodes = edge_index.max().item() + 1\n    x = torch.randn((num_nodes, 4))\n    pos = torch.randn((num_nodes, 3))\n    y = torch.randint(0, 4, (num_nodes, ), dtype=torch.long)\n\n    transform = ToDense()\n    assert str(transform) == 'ToDense()'\n    data = Data(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, y=y)\n    data = transform(data)\n    assert len(data) == 5\n    assert data.x.tolist() == x.tolist()\n    assert data.pos.tolist() == pos.tolist()\n    assert data.y.tolist() == y.tolist()\n    assert data.adj.size() == (num_nodes, num_nodes)\n    assert data.adj.tolist() == [\n        [0, 1, 2, 3],\n        [4, 0, 0, 0],\n        [5, 0, 0, 0],\n        [6, 0, 0, 0],\n    ]\n    assert data.mask.tolist() == [1, 1, 1, 1]\n\n    transform = ToDense(num_nodes=5)\n    assert str(transform) == 'ToDense(num_nodes=5)'\n    data = Data(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr, y=y)\n    data = transform(data)\n    assert len(data) == 5\n    assert data.x.size() == (5, 4)\n    assert data.x[:4].tolist() == x.tolist()\n    assert data.x[4].tolist() == [0, 0, 0, 0]\n    assert data.pos.size() == (5, 3)\n    assert data.pos[:4].tolist() == pos.tolist()\n    assert data.pos[4].tolist() == [0, 0, 0]\n    assert data.y.size() == (5, )\n    assert data.y[:4].tolist() == y.tolist()\n    assert data.y[4].tolist() == 0\n    assert data.adj.size() == (5, 5)\n    assert data.adj.tolist() == [\n        [0, 1, 2, 3, 0],\n        [4, 0, 0, 0, 0],\n        [5, 0, 0, 0, 0],\n        [6, 0, 0, 0, 0],\n        [0, 0, 0, 0, 0],\n    ]\n    assert data.mask.tolist() == [1, 1, 1, 1, 0]\n"
  },
  {
    "path": "test/transforms/test_to_device.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.testing import withDevice\nfrom torch_geometric.transforms import ToDevice\n\n\n@withDevice\ndef test_to_device(device):\n    x = torch.randn(3, 4)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_weight = torch.randn(edge_index.size(1))\n\n    data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight)\n\n    transform = ToDevice(device)\n    assert str(transform) == f'ToDevice({device})'\n\n    data = transform(data)\n    for _, value in data:\n        assert value.device == device\n"
  },
  {
    "path": "test/transforms/test_to_sparse_tensor.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import ToSparseTensor\n\n\n@pytest.mark.parametrize('layout', [None, torch.sparse_coo, torch.sparse_csr])\ndef test_to_sparse_tensor_basic(layout):\n    transform = ToSparseTensor(layout=layout)\n    assert str(transform) == (f'ToSparseTensor(attr=edge_weight, '\n                              f'layout={layout})')\n\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_weight = torch.randn(edge_index.size(1))\n    edge_attr = torch.randn(edge_index.size(1), 8)\n\n    perm = torch.tensor([1, 0, 3, 2])\n\n    data = Data(edge_index=edge_index, edge_weight=edge_weight,\n                edge_attr=edge_attr, num_nodes=3)\n    data = transform(data)\n\n    assert len(data) == 3\n    assert data.num_nodes == 3\n    assert torch.equal(data.edge_attr, edge_attr[perm])\n    assert 'adj_t' in data\n\n    if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE:\n        row, col, value = data.adj_t.coo()\n        assert row.tolist() == [0, 1, 1, 2]\n        assert col.tolist() == [1, 0, 2, 1]\n        assert torch.equal(value, edge_weight[perm])\n    else:\n        adj_t = data.adj_t\n        assert adj_t.layout == layout or torch.sparse_csr\n        if layout != torch.sparse_coo:\n            adj_t = adj_t.to_sparse_coo()\n        assert adj_t.coalesce().indices().tolist() == [\n            [0, 1, 1, 2],\n            [1, 0, 2, 1],\n        ]\n        assert torch.equal(adj_t.coalesce().values(), edge_weight[perm])\n\n\ndef test_to_sparse_tensor_and_keep_edge_index():\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_weight = torch.randn(edge_index.size(1))\n    edge_attr = torch.randn(edge_index.size(1), 8)\n\n    perm = torch.tensor([1, 0, 3, 2])\n\n    data = Data(edge_index=edge_index, edge_weight=edge_weight,\n                edge_attr=edge_attr, num_nodes=3)\n    data = ToSparseTensor(remove_edge_index=False)(data)\n\n    assert len(data) == 5\n    assert torch.equal(data.edge_index, edge_index[:, perm])\n    assert torch.equal(data.edge_weight, edge_weight[perm])\n    assert torch.equal(data.edge_attr, edge_attr[perm])\n\n\n@pytest.mark.parametrize('layout', [None, torch.sparse_coo, torch.sparse_csr])\ndef test_hetero_to_sparse_tensor(layout):\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n\n    data = HeteroData()\n    data['v'].num_nodes = 3\n    data['w'].num_nodes = 3\n    data['v', 'v'].edge_index = edge_index\n    data['v', 'w'].edge_index = edge_index\n\n    data = ToSparseTensor(layout=layout)(data)\n\n    if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE:\n        row, col, value = data['v', 'v'].adj_t.coo()\n        assert row.tolist() == [0, 1, 1, 2]\n        assert col.tolist() == [1, 0, 2, 1]\n        assert value is None\n\n        row, col, value = data['v', 'w'].adj_t.coo()\n        assert row.tolist() == [0, 1, 1, 2]\n        assert col.tolist() == [1, 0, 2, 1]\n        assert value is None\n    else:\n        adj_t = data['v', 'v'].adj_t\n        assert adj_t.layout == layout or torch.sparse_csr\n        if layout != torch.sparse_coo:\n            adj_t = adj_t.to_sparse_coo()\n        assert adj_t.coalesce().indices().tolist() == [\n            [0, 1, 1, 2],\n            [1, 0, 2, 1],\n        ]\n        assert adj_t.coalesce().values().tolist() == [1., 1., 1., 1.]\n\n        adj_t = data['v', 'w'].adj_t\n        assert adj_t.layout == layout or torch.sparse_csr\n        if layout != torch.sparse_coo:\n            adj_t = adj_t.to_sparse_coo()\n        assert adj_t.coalesce().indices().tolist() == [\n            [0, 1, 1, 2],\n            [1, 0, 2, 1],\n        ]\n        assert adj_t.coalesce().values().tolist() == [1., 1., 1., 1.]\n\n\ndef test_to_sparse_tensor_num_nodes_equals_num_edges():\n    x = torch.arange(4)\n    y = torch.arange(4)\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    edge_weight = torch.randn(edge_index.size(1))\n    edge_attr = torch.randn(edge_index.size(1), 8)\n\n    perm = torch.tensor([1, 0, 3, 2])\n\n    data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight,\n                edge_attr=edge_attr, y=y)\n    data = ToSparseTensor()(data)\n\n    assert len(data) == 4\n    assert torch.equal(data.x, x)\n    assert torch.equal(data.y, y)\n    assert torch.equal(data.edge_attr, edge_attr[perm])\n"
  },
  {
    "path": "test/transforms/test_to_superpixels.py",
    "content": "import os\nimport os.path as osp\n\nimport torch\n\nfrom torch_geometric.data import download_url, extract_gz\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.testing import onlyOnline, withPackage\nfrom torch_geometric.transforms import ToSLIC\n\nresources = [\n    'https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz',\n    'https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz',\n]\n\n\n@onlyOnline\n@withPackage('torchvision', 'skimage')\ndef test_to_superpixels(tmp_path):\n    import torchvision.transforms as T\n    from torchvision.datasets.mnist import (\n        MNIST,\n        read_image_file,\n        read_label_file,\n    )\n\n    raw_folder = osp.join(tmp_path, 'MNIST', 'raw')\n    processed_folder = osp.join(tmp_path, 'MNIST', 'processed')\n\n    os.makedirs(raw_folder, exist_ok=True)\n    os.makedirs(processed_folder, exist_ok=True)\n    for resource in resources:\n        path = download_url(resource, raw_folder)\n        extract_gz(path, osp.join(tmp_path, raw_folder))\n\n    test_set = (\n        read_image_file(osp.join(raw_folder, 't10k-images-idx3-ubyte')),\n        read_label_file(osp.join(raw_folder, 't10k-labels-idx1-ubyte')),\n    )\n\n    torch.save(test_set, osp.join(processed_folder, 'training.pt'))\n    torch.save(test_set, osp.join(processed_folder, 'test.pt'))\n\n    dataset = MNIST(tmp_path, download=False)\n\n    dataset.transform = T.Compose([T.ToTensor(), ToSLIC()])\n\n    data, y = dataset[0]\n    assert len(data) == 2\n    assert data.pos.dim() == 2 and data.pos.size(1) == 2\n    assert data.x.dim() == 2 and data.x.size(1) == 1\n    assert data.pos.size(0) == data.x.size(0)\n    assert y == 7\n\n    loader = DataLoader(dataset, batch_size=2, shuffle=False)\n    for batch, y in loader:\n        assert batch.num_graphs == len(batch) == 2\n        assert batch.pos.dim() == 2 and batch.pos.size(1) == 2\n        assert batch.x.dim() == 2 and batch.x.size(1) == 1\n        assert batch.batch.dim() == 1\n        assert batch.ptr.dim() == 1\n        assert batch.pos.size(0) == batch.x.size(0) == batch.batch.size(0)\n        assert y.tolist() == [7, 2]\n        break\n\n    dataset.transform = T.Compose(\n        [T.ToTensor(), ToSLIC(add_seg=True, add_img=True)])\n\n    data, y = dataset[0]\n    assert len(data) == 4\n    assert data.pos.dim() == 2 and data.pos.size(1) == 2\n    assert data.x.dim() == 2 and data.x.size(1) == 1\n    assert data.pos.size(0) == data.x.size(0)\n    assert data.seg.size() == (1, 28, 28)\n    assert data.img.size() == (1, 1, 28, 28)\n    assert data.seg.max().item() + 1 == data.x.size(0)\n    assert y == 7\n\n    loader = DataLoader(dataset, batch_size=2, shuffle=False)\n    for batch, y in loader:\n        assert batch.num_graphs == len(batch) == 2\n        assert batch.pos.dim() == 2 and batch.pos.size(1) == 2\n        assert batch.x.dim() == 2 and batch.x.size(1) == 1\n        assert batch.batch.dim() == 1\n        assert batch.ptr.dim() == 1\n        assert batch.pos.size(0) == batch.x.size(0) == batch.batch.size(0)\n        assert batch.seg.size() == (2, 28, 28)\n        assert batch.img.size() == (2, 1, 28, 28)\n        assert y.tolist() == [7, 2]\n        break\n"
  },
  {
    "path": "test/transforms/test_to_undirected.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import ToUndirected\n\n\ndef test_to_undirected():\n    assert str(ToUndirected()) == 'ToUndirected()'\n\n    edge_index = torch.tensor([[2, 0, 2], [3, 1, 0]])\n    edge_weight = torch.randn(edge_index.size(1))\n    edge_attr = torch.randn(edge_index.size(1), 8)\n\n    perm = torch.tensor([1, 2, 1, 2, 0, 0])\n\n    data = Data(edge_index=edge_index, edge_weight=edge_weight,\n                edge_attr=edge_attr, num_nodes=4)\n    data = ToUndirected()(data)\n    assert len(data) == 4\n    assert data.edge_index.tolist() == [[0, 0, 1, 2, 2, 3], [1, 2, 0, 0, 3, 2]]\n    assert data.edge_weight.tolist() == edge_weight[perm].tolist()\n    assert data.edge_attr.tolist() == edge_attr[perm].tolist()\n    assert data.num_nodes == 4\n\n\ndef test_to_undirected_with_duplicates():\n    edge_index = torch.tensor([[0, 0, 1, 1], [0, 1, 0, 2]])\n    edge_weight = torch.ones(4)\n\n    data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3)\n    data = ToUndirected()(data)\n    assert len(data) == 3\n    assert data.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 2, 1]]\n    assert data.edge_weight.tolist() == [2, 2, 2, 1, 1]\n    assert data.num_nodes == 3\n\n\ndef test_hetero_to_undirected():\n    edge_index = torch.tensor([[2, 0], [3, 1]])\n    edge_weight = torch.randn(edge_index.size(1))\n    edge_attr = torch.randn(edge_index.size(1), 8)\n\n    perm = torch.tensor([1, 1, 0, 0])\n\n    data = HeteroData()\n    data['v'].num_nodes = 4\n    data['w'].num_nodes = 4\n    data['v', 'v'].edge_index = edge_index\n    data['v', 'v'].edge_weight = edge_weight\n    data['v', 'v'].edge_attr = edge_attr\n    data['v', 'w'].edge_index = edge_index\n    data['v', 'w'].edge_weight = edge_weight\n    data['v', 'w'].edge_attr = edge_attr\n\n    from torch_geometric.transforms import ToUndirected\n\n    assert not data.is_undirected()\n    data = ToUndirected()(data)\n    assert data.is_undirected()\n\n    assert data['v', 'v'].edge_index.tolist() == [[0, 1, 2, 3], [1, 0, 3, 2]]\n    assert data['v', 'v'].edge_weight.tolist() == edge_weight[perm].tolist()\n    assert data['v', 'v'].edge_attr.tolist() == edge_attr[perm].tolist()\n    assert data['v', 'w'].edge_index.tolist() == edge_index.tolist()\n    assert data['v', 'w'].edge_weight.tolist() == edge_weight.tolist()\n    assert data['v', 'w'].edge_attr.tolist() == edge_attr.tolist()\n    assert data['w', 'v'].edge_index.tolist() == [[3, 1], [2, 0]]\n    assert data['w', 'v'].edge_weight.tolist() == edge_weight.tolist()\n    assert data['w', 'v'].edge_attr.tolist() == edge_attr.tolist()\n"
  },
  {
    "path": "test/transforms/test_two_hop.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import TwoHop\n\n\ndef test_two_hop():\n    transform = TwoHop()\n    assert str(transform) == 'TwoHop()'\n\n    edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])\n    edge_attr = torch.tensor([1, 2, 3, 1, 2, 3], dtype=torch.float)\n    data = Data(edge_index=edge_index, edge_attr=edge_attr, num_nodes=4)\n\n    data = transform(data)\n    assert len(data) == 3\n    assert data.edge_index.equal(\n        torch.tensor([\n            [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],\n            [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2],\n        ]))\n    assert data.edge_attr.equal(\n        torch.tensor([1, 2, 3, 1, 0, 0, 2, 0, 0, 3, 0, 0]))\n    assert data.num_nodes == 4\n\n    data = Data(edge_index=edge_index, num_nodes=4)\n    data = transform(data)\n    assert len(data) == 2\n    assert data.edge_index.equal(\n        torch.tensor([\n            [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],\n            [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2],\n        ]))\n    assert data.num_nodes == 4\n"
  },
  {
    "path": "test/transforms/test_virtual_node.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import VirtualNode\n\n\ndef test_virtual_node():\n    assert str(VirtualNode()) == 'VirtualNode()'\n\n    x = torch.randn(4, 16)\n    edge_index = torch.tensor([[2, 0, 2], [3, 1, 0]])\n    edge_weight = torch.rand(edge_index.size(1))\n    edge_attr = torch.randn(edge_index.size(1), 8)\n\n    data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight,\n                edge_attr=edge_attr, num_nodes=x.size(0))\n\n    data = VirtualNode()(data)\n    assert len(data) == 6\n\n    assert data.x.size() == (5, 16)\n    assert torch.allclose(data.x[:4], x)\n    assert data.x[4:].abs().sum() == 0\n\n    assert data.edge_index.tolist() == [[2, 0, 2, 0, 1, 2, 3, 4, 4, 4, 4],\n                                        [3, 1, 0, 4, 4, 4, 4, 0, 1, 2, 3]]\n\n    assert data.edge_weight.size() == (11, )\n    assert torch.allclose(data.edge_weight[:3], edge_weight)\n    assert data.edge_weight[3:].abs().sum() == 8\n\n    assert data.edge_attr.size() == (11, 8)\n    assert torch.allclose(data.edge_attr[:3], edge_attr)\n    assert data.edge_attr[3:].abs().sum() == 0\n\n    assert data.num_nodes == 5\n\n    assert data.edge_type.tolist() == [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]\n\n    data = Data(x=x, edge_index=torch.empty(2, 0, dtype=torch.long))\n    data = VirtualNode()(data)\n    assert len(data) == 3\n\n    assert data.x.size() == (5, 16)\n    assert torch.allclose(data.x[:4], x)\n    assert data.x[4:].abs().sum() == 0\n\n    assert data.edge_index.tolist() == [\n        [0, 1, 2, 3, 4, 4, 4, 4],\n        [4, 4, 4, 4, 0, 1, 2, 3],\n    ]\n    assert data.edge_type.tolist() == [1, 1, 1, 1, 2, 2, 2, 2]\n"
  },
  {
    "path": "test/utils/conftest.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.explain.config import ModelMode, ModelReturnType\nfrom torch_geometric.nn import SAGEConv, to_hetero\nfrom torch_geometric.testing import get_random_edge_index\n\n\nclass GraphSAGE(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = SAGEConv((-1, -1), 32)\n        self.conv2 = SAGEConv((-1, -1), 32)\n\n    def forward(self, x, edge_index):\n        x = self.conv1(x, edge_index).relu()\n        return self.conv2(x, edge_index)\n\n\nclass HeteroSAGE(torch.nn.Module):\n    def __init__(self, metadata, model_config=None):\n        super().__init__()\n        self.model_config = model_config\n        self.graph_sage = to_hetero(GraphSAGE(), metadata, debug=False)\n\n        # Determine output channels based on model_config\n        out_channels = 1\n        if (model_config\n                and model_config.mode == ModelMode.multiclass_classification):\n            out_channels = 7\n\n        self.lin = torch.nn.Linear(32, out_channels)\n\n    def forward(self, x_dict, edge_index_dict,\n                additonal_arg=None) -> torch.Tensor:\n        x = self.lin(self.graph_sage(x_dict, edge_index_dict)['paper'])\n\n        # Apply transformations based on model_config if available\n        if hasattr(self, 'model_config') and self.model_config:\n            if self.model_config.mode == ModelMode.binary_classification:\n                if self.model_config.return_type == ModelReturnType.probs:\n                    x = x.sigmoid()\n            elif self.model_config.mode == ModelMode.multiclass_classification:\n                if self.model_config.return_type == ModelReturnType.probs:\n                    x = x.softmax(dim=-1)\n                elif (self.model_config.return_type ==\n                      ModelReturnType.log_probs):\n                    x = x.log_softmax(dim=-1)\n\n        return x\n\n\n@pytest.fixture()\ndef hetero_data():\n    data = HeteroData()\n    data['paper'].x = torch.randn(8, 16)\n    data['author'].x = torch.randn(10, 8)\n\n    data['paper', 'paper'].edge_index = get_random_edge_index(8, 8, 10)\n    data['paper', 'paper'].edge_attr = torch.randn(10, 16)\n    data['paper', 'author'].edge_index = get_random_edge_index(8, 10, 10)\n    data['paper', 'author'].edge_attr = torch.randn(10, 8)\n    data['author', 'paper'].edge_index = get_random_edge_index(10, 8, 10)\n    data['author', 'paper'].edge_attr = torch.randn(10, 8)\n\n    return data\n\n\n@pytest.fixture()\ndef hetero_model():\n    return HeteroSAGE\n"
  },
  {
    "path": "test/utils/test_assortativity.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import assortativity\n\n\ndef test_assortativity():\n    # Completely assortative graph:\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5],\n                               [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]])\n    out = assortativity(edge_index)\n    assert pytest.approx(out, abs=1e-5) == 1.0\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[6, 6])\n        out = assortativity(adj)\n        assert pytest.approx(out, abs=1e-5) == 1.0\n\n    # Completely disassortative graph:\n    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 5, 5, 5, 5],\n                               [5, 5, 5, 5, 5, 0, 1, 2, 3, 4]])\n    out = assortativity(edge_index)\n    assert pytest.approx(out, abs=1e-5) == -1.0\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[6, 6])\n        out = assortativity(adj)\n        assert pytest.approx(out, abs=1e-5) == -1.0\n"
  },
  {
    "path": "test/utils/test_augmentation.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric import seed_everything\nfrom torch_geometric.utils import (\n    add_random_edge,\n    is_undirected,\n    mask_feature,\n    shuffle_node,\n)\n\n\ndef test_shuffle_node():\n    x = torch.tensor([[0, 1, 2], [3, 4, 5]], dtype=torch.float)\n\n    out = shuffle_node(x, training=False)\n    assert out[0].tolist() == x.tolist()\n    assert out[1].tolist() == list(range(len(x)))\n\n    torch.manual_seed(5)\n    out = shuffle_node(x)\n    assert out[0].tolist() == [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]\n    assert out[1].tolist() == [1, 0]\n\n    torch.manual_seed(10)\n    x = torch.arange(21).view(7, 3).to(torch.float)\n    batch = torch.tensor([0, 0, 1, 1, 2, 2, 2])\n    out = shuffle_node(x, batch)\n    assert out[0].tolist() == [[3.0, 4.0, 5.0], [0.0, 1.0, 2.0],\n                               [9.0, 10.0, 11.0], [6.0, 7.0, 8.0],\n                               [12.0, 13.0, 14.0], [18.0, 19.0, 20.0],\n                               [15.0, 16.0, 17.0]]\n    assert out[1].tolist() == [1, 0, 3, 2, 4, 6, 5]\n\n\ndef test_mask_feature():\n    x = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],\n                     dtype=torch.float)\n\n    out = mask_feature(x, training=False)\n    assert out[0].tolist() == x.tolist()\n    assert torch.all(out[1])\n\n    torch.manual_seed(4)\n    out = mask_feature(x)\n    assert out[0].tolist() == [[1.0, 2.0, 0.0, 0.0], [5.0, 6.0, 0.0, 0.0],\n                               [9.0, 10.0, 0.0, 0.0]]\n    assert out[1].tolist() == [[True, True, False, False]]\n\n    torch.manual_seed(5)\n    out = mask_feature(x, mode='row')\n    assert out[0].tolist() == [[1.0, 2.0, 3.0, 4.0], [0.0, 0.0, 0.0, 0.0],\n                               [9.0, 10.0, 11.0, 12.0]]\n    assert out[1].tolist() == [[True], [False], [True]]\n\n    torch.manual_seed(7)\n    out = mask_feature(x, mode='all')\n    assert out[0].tolist() == [[1.0, 0.0, 3.0, 4.0], [0.0, 0.0, 0.0, 8.0],\n                               [0.0, 10.0, 11.0, 12.0]]\n\n    assert out[1].tolist() == [[True, False, True, True],\n                               [False, False, False, True],\n                               [False, True, True, True]]\n\n    torch.manual_seed(7)\n    out = mask_feature(x, mode='all', fill_value=-1)\n    assert out[0].tolist() == [[1.0, -1.0, 3.0, 4.0], [-1.0, -1.0, -1.0, 8.0],\n                               [-1.0, 10.0, 11.0, 12.0]]\n\n    assert out[1].tolist() == [[True, False, True, True],\n                               [False, False, False, True],\n                               [False, True, True, True]]\n\n\ndef test_add_random_edge():\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])\n    out = add_random_edge(edge_index, p=0.5, training=False)\n    assert out[0].tolist() == edge_index.tolist()\n    assert out[1].tolist() == [[], []]\n\n    seed_everything(5)\n    out = add_random_edge(edge_index, p=0.5)\n    assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 3, 1, 2],\n                               [1, 0, 2, 1, 3, 2, 0, 3, 0]]\n    assert out[1].tolist() == [[3, 1, 2], [0, 3, 0]]\n\n    seed_everything(6)\n    out = add_random_edge(edge_index, p=0.5, force_undirected=True)\n    assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 1, 3],\n                               [1, 0, 2, 1, 3, 2, 3, 1]]\n    assert out[1].tolist() == [[1, 3], [3, 1]]\n    assert is_undirected(out[0])\n    assert is_undirected(out[1])\n\n    # Test for bipartite graph:\n    seed_everything(7)\n    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], [2, 3, 1, 4, 2, 1]])\n    with pytest.raises(RuntimeError, match=\"not supported for bipartite\"):\n        add_random_edge(edge_index, force_undirected=True, num_nodes=(6, 5))\n    out = add_random_edge(edge_index, p=0.5, num_nodes=(6, 5))\n    assert out[0].tolist() == [[0, 1, 2, 3, 4, 5, 2, 0, 2],\n                               [2, 3, 1, 4, 2, 1, 0, 4, 2]]\n    assert out[1].tolist() == [[2, 0, 2], [0, 4, 2]]\n"
  },
  {
    "path": "test/utils/test_coalesce.py",
    "content": "from typing import List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import coalesce\n\n\ndef test_coalesce():\n    edge_index = torch.tensor([[2, 1, 1, 0, 2], [1, 2, 0, 1, 1]])\n    edge_attr = torch.tensor([[1], [2], [3], [4], [5]])\n\n    out = coalesce(edge_index)\n    assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n\n    out = coalesce(edge_index, None)\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert out[1] is None\n\n    out = coalesce(edge_index, edge_attr)\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert out[1].tolist() == [[4], [3], [2], [6]]\n\n    out = coalesce(edge_index, [edge_attr, edge_attr.view(-1)])\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert out[1][0].tolist() == [[4], [3], [2], [6]]\n    assert out[1][1].tolist() == [4, 3, 2, 6]\n\n    out = coalesce((edge_index[0], edge_index[1]))\n    assert isinstance(out, tuple)\n    assert out[0].tolist() == [0, 1, 1, 2]\n    assert out[1].tolist() == [1, 0, 2, 1]\n\n\ndef test_coalesce_without_duplicates():\n    edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]])\n    edge_attr = torch.tensor([[1], [2], [3], [4]])\n\n    out = coalesce(edge_index)\n    assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n\n    out = coalesce(edge_index, None)\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert out[1] is None\n\n    out = coalesce(edge_index, edge_attr)\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert out[1].tolist() == [[4], [3], [2], [1]]\n\n    out = coalesce(edge_index, [edge_attr, edge_attr.view(-1)])\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert out[1][0].tolist() == [[4], [3], [2], [1]]\n    assert out[1][1].tolist() == [4, 3, 2, 1]\n\n\ndef test_coalesce_jit():\n    @torch.jit.script\n    def wrapper1(edge_index: Tensor) -> Tensor:\n        return coalesce(edge_index)\n\n    @torch.jit.script\n    def wrapper2(\n        edge_index: Tensor,\n        edge_attr: Optional[Tensor],\n    ) -> Tuple[Tensor, Optional[Tensor]]:\n        return coalesce(edge_index, edge_attr)\n\n    @torch.jit.script\n    def wrapper3(\n        edge_index: Tensor,\n        edge_attr: List[Tensor],\n    ) -> Tuple[Tensor, List[Tensor]]:\n        return coalesce(edge_index, edge_attr)\n\n    edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]])\n    edge_attr = torch.tensor([[1], [2], [3], [4]])\n\n    out = wrapper1(edge_index)\n    assert out.size() == edge_index.size()\n\n    out = wrapper2(edge_index, None)\n    assert out[0].size() == edge_index.size()\n    assert out[1] is None\n\n    out = wrapper2(edge_index, edge_attr)\n    assert out[0].size() == edge_index.size()\n    assert out[1].size() == edge_attr.size()\n\n    out = wrapper3(edge_index, [edge_attr, edge_attr.view(-1)])\n    assert out[0].size() == edge_index.size()\n    assert len(out[1]) == 2\n    assert out[1][0].size() == edge_attr.size()\n    assert out[1][1].size() == edge_attr.view(-1).size()\n"
  },
  {
    "path": "test/utils/test_convert.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.testing import get_random_edge_index, withPackage\nfrom torch_geometric.utils import (\n    from_cugraph,\n    from_dgl,\n    from_networkit,\n    from_networkx,\n    from_scipy_sparse_matrix,\n    from_trimesh,\n    sort_edge_index,\n    subgraph,\n    to_cugraph,\n    to_dgl,\n    to_networkit,\n    to_networkx,\n    to_scipy_sparse_matrix,\n    to_trimesh,\n)\n\n\n@withPackage('scipy')\ndef test_to_scipy_sparse_matrix():\n    import scipy.sparse as sp\n\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n\n    adj = to_scipy_sparse_matrix(edge_index)\n    assert isinstance(adj, sp.coo_matrix)\n    assert adj.shape == (2, 2)\n    assert adj.row.tolist() == edge_index[0].tolist()\n    assert adj.col.tolist() == edge_index[1].tolist()\n    assert adj.data.tolist() == [1, 1, 1]\n\n    edge_attr = torch.tensor([1.0, 2.0, 3.0])\n    adj = to_scipy_sparse_matrix(edge_index, edge_attr)\n    assert isinstance(adj, sp.coo_matrix)\n    assert adj.shape == (2, 2)\n    assert adj.row.tolist() == edge_index[0].tolist()\n    assert adj.col.tolist() == edge_index[1].tolist()\n    assert adj.data.tolist() == edge_attr.tolist()\n\n\n@withPackage('scipy')\ndef test_from_scipy_sparse_matrix():\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    adj = to_scipy_sparse_matrix(edge_index)\n\n    out = from_scipy_sparse_matrix(adj)\n    assert out[0].tolist() == edge_index.tolist()\n    assert out[1].tolist() == [1, 1, 1]\n\n\n@withPackage('networkx')\ndef test_to_networkx():\n    import networkx as nx\n\n    x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])\n    pos = torch.tensor([[0.0, 0.0], [1.0, 1.0]])\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    edge_attr = torch.tensor([1.0, 2.0, 3.0])\n    data = Data(x=x, pos=pos, edge_index=edge_index, weight=edge_attr)\n\n    for remove_self_loops in [True, False]:\n        G = to_networkx(data, node_attrs=['x', 'pos'], edge_attrs=['weight'],\n                        remove_self_loops=remove_self_loops)\n\n        assert G.nodes[0]['x'] == [1.0, 2.0]\n        assert G.nodes[1]['x'] == [3.0, 4.0]\n        assert G.nodes[0]['pos'] == [0.0, 0.0]\n        assert G.nodes[1]['pos'] == [1.0, 1.0]\n\n        if remove_self_loops:\n            assert nx.to_numpy_array(G).tolist() == [[0.0, 1.0], [2.0, 0.0]]\n        else:\n            assert nx.to_numpy_array(G).tolist() == [[3.0, 1.0], [2.0, 0.0]]\n\n\n@withPackage('networkx')\ndef test_from_networkx_set_node_attributes():\n    import networkx as nx\n\n    G = nx.path_graph(3)\n    attrs = {\n        0: {\n            'x': torch.tensor([1, 0, 0])\n        },\n        1: {\n            'x': torch.tensor([0, 1, 0])\n        },\n        2: {\n            'x': torch.tensor([0, 0, 1])\n        },\n    }\n    nx.set_node_attributes(G, attrs)\n\n    assert from_networkx(G).x.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]\n\n\n@withPackage('networkx')\ndef test_to_networkx_undirected():\n    import networkx as nx\n\n    x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])\n    pos = torch.tensor([[0.0, 0.0], [1.0, 1.0]])\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    edge_attr = torch.tensor([1.0, 2.0, 3.0])\n    data = Data(x=x, pos=pos, edge_index=edge_index, weight=edge_attr)\n\n    for remove_self_loops in [True, False]:\n        G = to_networkx(\n            data,\n            node_attrs=['x', 'pos'],\n            edge_attrs=['weight'],\n            remove_self_loops=remove_self_loops,\n            to_undirected=True,\n        )\n\n        assert G.nodes[0]['x'] == [1, 2]\n        assert G.nodes[1]['x'] == [3, 4]\n        assert G.nodes[0]['pos'] == [0, 0]\n        assert G.nodes[1]['pos'] == [1, 1]\n\n        if remove_self_loops:\n            assert nx.to_numpy_array(G).tolist() == [[0, 2], [2, 0]]\n        else:\n            assert nx.to_numpy_array(G).tolist() == [[3, 2], [2, 0]]\n\n    G = to_networkx(data, edge_attrs=['weight'], to_undirected=False)\n    assert nx.to_numpy_array(G).tolist() == [[3, 1], [2, 0]]\n\n    G = to_networkx(data, edge_attrs=['weight'], to_undirected='upper')\n    assert nx.to_numpy_array(G).tolist() == [[3, 1], [1, 0]]\n\n    G = to_networkx(data, edge_attrs=['weight'], to_undirected='lower')\n    assert nx.to_numpy_array(G).tolist() == [[3, 2], [2, 0]]\n\n\n@withPackage('networkx')\ndef test_to_networkx_undirected_options():\n    import networkx as nx\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 0]])\n    data = Data(edge_index=edge_index, num_nodes=3)\n\n    G = to_networkx(data, to_undirected=True)\n    assert nx.to_numpy_array(G).tolist() == [[0, 1, 1], [1, 0, 1], [1, 1, 0]]\n\n    G = to_networkx(data, to_undirected='upper')\n    assert nx.to_numpy_array(G).tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]]\n\n    G = to_networkx(data, to_undirected='lower')\n    assert nx.to_numpy_array(G).tolist() == [[0, 1, 1], [1, 0, 0], [1, 0, 0]]\n\n\n@withPackage('networkx')\ndef test_to_networkx_hetero():\n    edge_index = get_random_edge_index(5, 10, 20, coalesce=True)\n\n    data = HeteroData()\n    data['global_id'] = 0\n    data['author'].x = torch.arange(5)\n    data['paper'].x = torch.arange(10)\n    data['author', 'paper'].edge_index = edge_index\n    data['author', 'paper'].edge_attr = torch.arange(edge_index.size(1))\n\n    G = to_networkx(data, node_attrs=['x'], edge_attrs=['edge_attr'],\n                    graph_attrs=['global_id'])\n\n    assert G.number_of_nodes() == 15\n    assert G.number_of_edges() == edge_index.size(1)\n\n    assert G.graph == {'global_id': 0}\n\n    for i, (v, data) in enumerate(G.nodes(data=True)):\n        assert i == v\n        assert len(data) == 2\n        if i < 5:\n            assert data['x'] == i\n            assert data['type'] == 'author'\n        else:\n            assert data['x'] == i - 5\n            assert data['type'] == 'paper'\n\n    for i, (v, w, data) in enumerate(G.edges(data=True)):\n        assert v == int(edge_index[0, i])\n        assert w == int(edge_index[1, i]) + 5\n        assert len(data) == 2\n        assert data['type'] == ('author', 'to', 'paper')\n        assert data['edge_attr'] == i\n\n\n@withPackage('networkx')\ndef test_from_networkx():\n    x = torch.randn(2, 8)\n    pos = torch.randn(2, 3)\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    edge_attr = torch.randn(edge_index.size(1))\n    perm = torch.tensor([0, 2, 1])\n    data = Data(x=x, pos=pos, edge_index=edge_index, edge_attr=edge_attr)\n    G = to_networkx(data, node_attrs=['x', 'pos'], edge_attrs=['edge_attr'])\n    data = from_networkx(G)\n    assert len(data) == 4\n    assert data.x.tolist() == x.tolist()\n    assert data.pos.tolist() == pos.tolist()\n    assert data.edge_index.tolist() == edge_index[:, perm].tolist()\n    assert data.edge_attr.tolist() == edge_attr[perm].tolist()\n\n\n@withPackage('networkx')\ndef test_from_networkx_group_attrs():\n    x = torch.randn(2, 2)\n    x1 = torch.randn(2, 4)\n    x2 = torch.randn(2, 8)\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    edge_attr1 = torch.randn(edge_index.size(1))\n    edge_attr2 = torch.randn(edge_index.size(1))\n    perm = torch.tensor([0, 2, 1])\n    data = Data(x=x, x1=x1, x2=x2, edge_index=edge_index,\n                edge_attr1=edge_attr1, edge_attr2=edge_attr2)\n    G = to_networkx(data, node_attrs=['x', 'x1', 'x2'],\n                    edge_attrs=['edge_attr1', 'edge_attr2'])\n    data = from_networkx(G, group_node_attrs=['x', 'x2'], group_edge_attrs=all)\n    assert len(data) == 4\n    assert data.x.tolist() == torch.cat([x, x2], dim=-1).tolist()\n    assert data.x1.tolist() == x1.tolist()\n    assert data.edge_index.tolist() == edge_index[:, perm].tolist()\n    assert data.edge_attr.tolist() == torch.stack([edge_attr1, edge_attr2],\n                                                  dim=-1)[perm].tolist()\n\n\n@withPackage('networkx')\ndef test_networkx_vice_versa_convert():\n    import networkx as nx\n\n    G = nx.complete_graph(5)\n    assert G.is_directed() is False\n    data = from_networkx(G)\n    assert data.is_directed() is False\n    G = to_networkx(data)\n    assert G.is_directed() is True\n    G = nx.to_undirected(G)\n    assert G.is_directed() is False\n\n\n@withPackage('networkx')\ndef test_from_networkx_non_consecutive():\n    import networkx as nx\n\n    graph = nx.Graph()\n    graph.add_node(4)\n    graph.add_node(2)\n    graph.add_edge(4, 2)\n    for node in graph.nodes():\n        graph.nodes[node]['x'] = node\n\n    data = from_networkx(graph)\n    assert len(data) == 2\n    assert data.x.tolist() == [4, 2]\n    assert data.edge_index.tolist() == [[0, 1], [1, 0]]\n\n\n@withPackage('networkx')\ndef test_from_networkx_inverse():\n    import networkx as nx\n\n    graph = nx.Graph()\n    graph.add_node(3)\n    graph.add_node(2)\n    graph.add_node(1)\n    graph.add_node(0)\n    graph.add_edge(3, 1)\n    graph.add_edge(2, 1)\n    graph.add_edge(1, 0)\n\n    data = from_networkx(graph)\n    assert len(data) == 2\n    assert data.edge_index.tolist() == [[0, 1, 2, 2, 2, 3], [2, 2, 0, 1, 3, 2]]\n    assert data.num_nodes == 4\n\n\n@withPackage('networkx')\ndef test_from_networkx_non_numeric_labels():\n    import networkx as nx\n\n    graph = nx.Graph()\n    graph.add_node('4')\n    graph.add_node('2')\n    graph.add_edge('4', '2')\n    for node in graph.nodes():\n        graph.nodes[node]['x'] = node\n    data = from_networkx(graph)\n    assert len(data) == 2\n    assert data.x == ['4', '2']\n    assert data.edge_index.tolist() == [[0, 1], [1, 0]]\n\n\n@withPackage('networkx')\ndef test_from_networkx_without_edges():\n    import networkx as nx\n\n    graph = nx.Graph()\n    graph.add_node(1)\n    graph.add_node(2)\n    data = from_networkx(graph)\n    assert len(data) == 2\n    assert data.edge_index.size() == (2, 0)\n    assert data.num_nodes == 2\n\n\n@withPackage('networkx')\ndef test_from_networkx_with_same_node_and_edge_attributes():\n    import networkx as nx\n\n    G = nx.Graph()\n    G.add_nodes_from([(0, {'age': 1}), (1, {'age': 6}), (2, {'age': 5})])\n    G.add_edges_from([(0, 1, {'age': 2}), (1, 2, {'age': 7})])\n\n    data = from_networkx(G)\n    assert len(data) == 4\n    assert data.age.tolist() == [1, 6, 5]\n    assert data.num_nodes == 3\n    assert data.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert data.edge_age.tolist() == [2, 2, 7, 7]\n\n    data = from_networkx(G, group_node_attrs=all, group_edge_attrs=all)\n    assert len(data) == 3\n    assert data.x.tolist() == [[1], [6], [5]]\n    assert data.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert data.edge_attr.tolist() == [[2], [2], [7], [7]]\n\n\n@withPackage('networkx')\ndef test_from_networkx_subgraph_convert():\n    import networkx as nx\n\n    G = nx.complete_graph(5)\n\n    edge_index = from_networkx(G).edge_index\n    sub_edge_index_1, _ = subgraph([0, 1, 3, 4], edge_index,\n                                   relabel_nodes=True)\n\n    sub_edge_index_2 = from_networkx(G.subgraph([0, 1, 3, 4])).edge_index\n\n    assert sub_edge_index_1.tolist() == sub_edge_index_2.tolist()\n\n\n@withPackage('networkx')\n@pytest.mark.parametrize('n', [100])\n@pytest.mark.parametrize('p', [0.8])\n@pytest.mark.parametrize('q', [0.2])\ndef test_from_networkx_sbm(n, p, q):\n    import networkx as nx\n    G = nx.stochastic_block_model(\n        sizes=[n // 2, n // 2],\n        p=[[p, q], [q, p]],\n        seed=0,\n        directed=False,\n    )\n\n    data = from_networkx(G)\n    assert data.num_nodes == 100\n    assert torch.equal(data.block[:50], data.block.new_zeros(50))\n    assert torch.equal(data.block[50:], data.block.new_ones(50))\n\n\n@withPackage('networkit')\ndef test_to_networkit_vice_versa():\n    edge_index = torch.tensor([[0, 1], [1, 0]])\n\n    g = to_networkit(edge_index, directed=False)\n    assert not g.isDirected()\n    assert not g.isWeighted()\n\n    edge_index, edge_weight = from_networkit(g)\n    assert edge_index.tolist() == [[0, 1], [1, 0]]\n    assert edge_weight is None\n\n\n@withPackage('networkit')\n@pytest.mark.parametrize('directed', [True, False])\n@pytest.mark.parametrize('num_nodes', [None, 3])\n@pytest.mark.parametrize('edge_weight', [None, torch.rand(3)])\ndef test_to_networkit(directed, edge_weight, num_nodes):\n    import networkit\n\n    edge_index = torch.tensor([[0, 1, 1], [1, 0, 2]], dtype=torch.long)\n    g = to_networkit(edge_index, edge_weight, num_nodes, directed)\n\n    assert isinstance(g, networkit.Graph)\n    assert g.isDirected() == directed\n    assert g.numberOfNodes() == 3\n\n    if edge_weight is None:\n        edge_weight = torch.tensor([1., 1., 1.])\n\n    assert g.weight(0, 1) == float(edge_weight[0])\n    assert g.weight(1, 2) == float(edge_weight[2])\n\n    if directed:\n        assert g.numberOfEdges() == 3\n        assert g.weight(1, 0) == float(edge_weight[1])\n    else:\n        assert g.numberOfEdges() == 2\n\n\n@pytest.mark.parametrize('directed', [True, False])\n@pytest.mark.parametrize('weighted', [True, False])\n@withPackage('networkit')\ndef test_from_networkit(directed, weighted):\n    import networkit\n\n    g = networkit.Graph(3, weighted=weighted, directed=directed)\n    g.addEdge(0, 1)\n    g.addEdge(1, 2)\n    if directed:\n        g.addEdge(1, 0)\n\n    if weighted:\n        for i, (u, v) in enumerate(g.iterEdges()):\n            g.setWeight(u, v, i + 1)\n\n    edge_index, edge_weight = from_networkit(g)\n\n    if directed:\n        assert edge_index.tolist() == [[0, 1, 1], [1, 2, 0]]\n        if weighted:\n            assert edge_weight.tolist() == [1, 2, 3]\n        else:\n            assert edge_weight is None\n    else:\n        assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n        if weighted:\n            assert edge_weight.tolist() == [1, 1, 2, 2]\n        else:\n            assert edge_weight is None\n\n\n@withPackage('trimesh')\ndef test_trimesh_vice_versa():\n    pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]],\n                       dtype=torch.float)\n    face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()\n\n    data = Data(pos=pos, face=face)\n    mesh = to_trimesh(data)\n    data = from_trimesh(mesh)\n\n    assert pos.tolist() == data.pos.tolist()\n    assert face.tolist() == data.face.tolist()\n\n\n@withPackage('trimesh')\ndef test_to_trimesh():\n    import trimesh\n\n    pos = torch.tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]])\n    face = torch.tensor([[0, 1, 2], [2, 1, 3]]).t()\n    data = Data(pos=pos, face=face)\n\n    obj = to_trimesh(data)\n\n    assert isinstance(obj, trimesh.Trimesh)\n    assert obj.vertices.shape == (4, 3)\n    assert obj.faces.shape == (2, 3)\n    assert obj.vertices.tolist() == data.pos.tolist()\n    assert obj.faces.tolist() == data.face.t().contiguous().tolist()\n\n\n@withPackage('trimesh')\ndef test_from_trimesh():\n    import trimesh\n\n    vertices = [[0, 0, 0], [1, 0, 0], [0, 1, 0]]\n    faces = [[0, 1, 2]]\n    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)\n\n    data = from_trimesh(mesh)\n\n    assert data.pos.tolist() == vertices\n    assert data.face.t().contiguous().tolist() == faces\n\n\n@withPackage('cudf', 'cugraph')\n@pytest.mark.parametrize('edge_weight', [None, torch.rand(4)])\n@pytest.mark.parametrize('relabel_nodes', [True, False])\n@pytest.mark.parametrize('directed', [True, False])\ndef test_to_cugraph(edge_weight, directed, relabel_nodes):\n    import cugraph\n\n    if directed:\n        edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    else:\n        edge_index = torch.tensor([[0, 1], [1, 2]])\n\n    if edge_weight is not None:\n        edge_weight = edge_weight[:edge_index.size(1)]\n\n    graph = to_cugraph(edge_index, edge_weight, relabel_nodes, directed)\n    assert isinstance(graph, cugraph.Graph)\n    assert graph.number_of_nodes() == 3\n\n    edge_list = graph.view_edge_list()\n    assert edge_list is not None\n\n    edge_list = edge_list.sort_values(\n        by=[graph.source_columns, graph.destination_columns])\n\n    cu_edge_index = edge_list[[\n        graph.source_columns, graph.destination_columns\n    ]].to_pandas().values\n    cu_edge_index = torch.from_numpy(cu_edge_index).t()\n    cu_edge_weight = None\n    if edge_weight is not None:\n        cu_edge_weight = edge_list[graph.weight_column].to_pandas().values\n        cu_edge_weight = torch.from_numpy(cu_edge_weight)\n\n    cu_edge_index, cu_edge_weight = sort_edge_index(cu_edge_index,\n                                                    cu_edge_weight)\n\n    assert torch.equal(edge_index, cu_edge_index.cpu())\n    if edge_weight is not None:\n        assert torch.allclose(edge_weight, cu_edge_weight.cpu())\n\n\n@withPackage('cudf', 'cugraph')\n@pytest.mark.parametrize('edge_weight', [None, torch.randn(4)])\n@pytest.mark.parametrize('directed', [True, False])\n@pytest.mark.parametrize('relabel_nodes', [True, False])\ndef test_from_cugraph(edge_weight, directed, relabel_nodes):\n    import cudf\n    import cugraph\n    from torch.utils.dlpack import to_dlpack\n\n    if directed:\n        edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]])\n    else:\n        edge_index = torch.tensor([[0, 1], [1, 2]])\n\n    if edge_weight is not None:\n        edge_weight = edge_weight[:edge_index.size(1)]\n\n    G = cugraph.Graph(directed=directed)\n    df = cudf.DataFrame({\n        'source':\n        cudf.from_dlpack(to_dlpack(edge_index[0])),\n        'destination':\n        cudf.from_dlpack(to_dlpack(edge_index[1])),\n    })\n    if edge_weight is not None:\n        df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight))\n\n    G.from_cudf_edgelist(\n        df,\n        source='source',\n        destination='destination',\n        edge_attr='weight' if edge_weight is not None else None,\n        renumber=relabel_nodes,\n    )\n\n    cu_edge_index, cu_edge_weight = from_cugraph(G)\n    cu_edge_index, cu_edge_weight = sort_edge_index(cu_edge_index,\n                                                    cu_edge_weight)\n\n    assert torch.equal(edge_index, cu_edge_index.cpu())\n    if edge_weight is not None:\n        assert torch.allclose(edge_weight, cu_edge_weight.cpu())\n    else:\n        assert cu_edge_weight is None\n\n\n@withPackage('dgl')\ndef test_to_dgl_graph():\n    x = torch.randn(5, 3)\n    edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]])\n    edge_attr = torch.randn(edge_index.size(1), 2)\n    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)\n\n    g = to_dgl(data)\n\n    assert torch.equal(data.x, g.ndata['x'])\n    row, col = g.edges()\n    assert torch.equal(row, edge_index[0])\n    assert torch.equal(col, edge_index[1])\n    assert torch.equal(data.edge_attr, g.edata['edge_attr'])\n\n\n@withPackage('dgl')\ndef test_to_dgl_hetero_graph():\n    data = HeteroData()\n    data['v1'].x = torch.randn(4, 3)\n    data['v2'].x = torch.randn(4, 3)\n    data['v1', 'v2'].edge_index = torch.tensor([[0, 1, 2, 3], [0, 1, 2, 3]])\n    data['v1', 'v2'].edge_attr = torch.randn(4, 2)\n\n    g = to_dgl(data)\n\n    assert data['v1', 'v2'].num_edges == g.num_edges(('v1', 'to', 'v2'))\n    assert data['v1'].num_nodes == g.num_nodes('v1')\n    assert data['v2'].num_nodes == g.num_nodes('v2')\n    assert torch.equal(data['v1'].x, g.nodes['v1'].data['x'])\n    assert torch.equal(data['v2'].x, g.nodes['v2'].data['x'])\n    row, col = g.edges()\n    assert torch.equal(row, data['v1', 'v2'].edge_index[0])\n    assert torch.equal(col, data['v1', 'v2'].edge_index[1])\n    assert torch.equal(g.edata['edge_attr'], data['v1', 'v2'].edge_attr)\n\n\n@withPackage('dgl', 'torch_sparse')\ndef test_to_dgl_sparse():\n    from torch_geometric.transforms import ToSparseTensor\n    x = torch.randn(5, 3)\n    edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]])\n    edge_attr = torch.randn(edge_index.size(1), 2)\n    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)\n    data = ToSparseTensor()(data)\n\n    g = to_dgl(data)\n\n    assert torch.equal(data.x, g.ndata[\"x\"])\n    pyg_row, pyg_col, _ = data.adj_t.t().coo()\n    dgl_row, dgl_col = g.edges()\n    assert torch.equal(pyg_row, dgl_row)\n    assert torch.equal(pyg_col, dgl_col)\n    assert torch.equal(data.edge_attr, g.edata['edge_attr'])\n\n\n@withPackage('dgl')\ndef test_from_dgl_graph():\n    import dgl\n    g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0]))\n    g.ndata['x'] = torch.randn(g.num_nodes(), 3)\n    g.edata['edge_attr'] = torch.randn(g.num_edges())\n\n    data = from_dgl(g)\n\n    assert torch.equal(data.x, g.ndata['x'])\n    row, col = g.edges()\n    assert torch.equal(data.edge_index[0], row)\n    assert torch.equal(data.edge_index[1], col)\n    assert torch.equal(data.edge_attr, g.edata['edge_attr'])\n\n\n@withPackage('dgl')\ndef test_from_dgl_hetero_graph():\n    import dgl\n    g = dgl.heterograph({\n        ('v1', 'to', 'v2'): (\n            [0, 1, 1, 2, 3, 3, 4],\n            [0, 0, 1, 1, 1, 2, 2],\n        )\n    })\n    g.nodes['v1'].data['x'] = torch.randn(5, 3)\n    g.nodes['v2'].data['x'] = torch.randn(3, 3)\n\n    data = from_dgl(g)\n\n    assert data['v1', 'v2'].num_edges == g.num_edges(('v1', 'to', 'v2'))\n    assert data['v1'].num_nodes == g.num_nodes('v1')\n    assert data['v2'].num_nodes == g.num_nodes('v2')\n    assert torch.equal(data['v1'].x, g.nodes['v1'].data['x'])\n    assert torch.equal(data['v2'].x, g.nodes['v2'].data['x'])\n"
  },
  {
    "path": "test/utils/test_cross_entropy.py",
    "content": "import pytest\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.testing import withCUDA\nfrom torch_geometric.utils.cross_entropy import sparse_cross_entropy\n\n\n@withCUDA\n@pytest.mark.parametrize('with_edge_label_weight', [False, True])\ndef test_sparse_cross_entropy_multiclass(\n    with_edge_label_weight: bool,\n    device: torch.device,\n) -> None:\n    x = torch.randn(5, 5, device=device, requires_grad=True)\n    y = torch.eye(5, device=device)\n\n    edge_label_index = y.nonzero().t()\n    edge_label_weight = None\n    if with_edge_label_weight:\n        edge_label_weight = torch.rand(edge_label_index.size(1), device=device)\n        y[y == 1.0] = edge_label_weight\n\n    expected = F.cross_entropy(x, y)\n    expected.backward()\n    expected_grad = x.grad\n\n    x.grad = None\n    out = sparse_cross_entropy(x, edge_label_index, edge_label_weight)\n    out.backward()\n\n    assert torch.allclose(expected, out)\n    assert torch.allclose(expected_grad, x.grad)\n\n\n@withCUDA\n@pytest.mark.parametrize('with_edge_label_weight', [False, True])\ndef test_sparse_cross_entropy_multilabel(\n    with_edge_label_weight: bool,\n    device: torch.device,\n) -> None:\n    x = torch.randn(4, 4, device=device, requires_grad=True)\n    y = torch.randint_like(x, 0, 2)\n\n    edge_label_index = y.nonzero().t()\n    edge_label_weight = None\n    if with_edge_label_weight:\n        edge_label_weight = torch.rand(edge_label_index.size(1), device=device)\n        y[y == 1.0] = edge_label_weight\n\n    expected = F.cross_entropy(x, y)\n    expected.backward()\n    expected_grad = x.grad\n\n    x.grad = None\n    out = sparse_cross_entropy(x, edge_label_index, edge_label_weight)\n    out.backward()\n\n    assert torch.allclose(expected, out)\n    assert torch.allclose(expected_grad, x.grad)\n\n\n@withCUDA\n@pytest.mark.parametrize('edge_label_weights', [\n    [2.0, -10.0, 1.0, -5.0, 4.0, 0.0, -1.0],\n    [-2.0, -1.0, -1.0, -3.0, -4.0, -10.0, -1.0],\n    [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n])\ndef test_sparse_cross_entropy_negative_weight(\n    edge_label_weights: list[float],\n    device: torch.device,\n) -> None:\n    x = torch.randn(4, 8, device=device, requires_grad=True)\n    edge_label_index = torch.tensor([\n        [0, 0, 1, 2, 2, 2, 3],\n        [2, 4, 6, 5, 3, 1, 1],\n    ], device=device)\n    edge_label_weight = torch.tensor(edge_label_weights, device=device)\n    pos_mask = edge_label_weight >= 0\n\n    y = torch.zeros_like(x)\n    y[\n        edge_label_index[0, pos_mask],\n        edge_label_index[1, pos_mask],\n    ] = edge_label_weight[pos_mask]\n\n    _x = x.clone()\n    _x[\n        edge_label_index[0, ~pos_mask],\n        edge_label_index[1, ~pos_mask],\n    ] += edge_label_weight[~pos_mask].abs().log()\n    expected = F.cross_entropy(_x, y)\n    expected.backward()\n    expected_grad = x.grad\n\n    x.grad = None\n    out = sparse_cross_entropy(x, edge_label_index, edge_label_weight)\n    out.backward()\n\n    assert torch.allclose(expected, out)\n    assert torch.allclose(expected_grad, x.grad)\n"
  },
  {
    "path": "test/utils/test_degree.py",
    "content": "import torch\n\nfrom torch_geometric.utils import degree\n\n\ndef test_degree():\n    row = torch.tensor([0, 1, 0, 2, 0])\n    deg = degree(row, dtype=torch.long)\n    assert deg.dtype == torch.long\n    assert deg.tolist() == [3, 1, 1]\n"
  },
  {
    "path": "test/utils/test_dropout.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.utils import (\n    dropout_adj,\n    dropout_edge,\n    dropout_node,\n    dropout_path,\n)\n\n\ndef test_dropout_adj():\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3],\n        [1, 0, 2, 1, 3, 2],\n    ])\n    edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])\n\n    with pytest.warns(UserWarning, match=\"'dropout_adj' is deprecated\"):\n        out = dropout_adj(edge_index, edge_attr, training=False)\n    assert edge_index.tolist() == out[0].tolist()\n    assert edge_attr.tolist() == out[1].tolist()\n\n    torch.manual_seed(5)\n    with pytest.warns(UserWarning, match=\"'dropout_adj' is deprecated\"):\n        out = dropout_adj(edge_index, edge_attr)\n    assert out[0].tolist() == [[0, 1, 2, 2], [1, 2, 1, 3]]\n    assert out[1].tolist() == [1, 3, 4, 5]\n\n    torch.manual_seed(6)\n    with pytest.warns(UserWarning, match=\"'dropout_adj' is deprecated\"):\n        out = dropout_adj(edge_index, edge_attr, force_undirected=True)\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 2, 0, 1]]\n    assert out[1].tolist() == [1, 3, 1, 3]\n\n\ndef test_dropout_node():\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3],\n        [1, 0, 2, 1, 3, 2],\n    ])\n\n    out = dropout_node(edge_index, training=False)\n    assert edge_index.tolist() == out[0].tolist()\n    assert out[1].tolist() == [True, True, True, True, True, True]\n    assert out[2].tolist() == [True, True, True, True]\n\n    torch.manual_seed(5)\n    out = dropout_node(edge_index)\n    assert out[0].tolist() == [[2, 3], [3, 2]]\n    assert out[1].tolist() == [False, False, False, False, True, True]\n    assert out[2].tolist() == [True, False, True, True]\n\n\ndef test_dropout_edge():\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])\n\n    out = dropout_edge(edge_index, training=False)\n    assert edge_index.tolist() == out[0].tolist()\n    assert out[1].tolist() == [True, True, True, True, True, True]\n\n    torch.manual_seed(5)\n    out = dropout_edge(edge_index)\n    assert out[0].tolist() == [[0, 1, 2, 2], [1, 2, 1, 3]]\n    assert out[1].tolist() == [True, False, True, True, True, False]\n\n    torch.manual_seed(6)\n    out = dropout_edge(edge_index, force_undirected=True)\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 2, 0, 1]]\n    assert out[1].tolist() == [0, 2, 0, 2]\n\n\n@withPackage('torch_cluster')\ndef test_dropout_path():\n    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])\n\n    out = dropout_path(edge_index, training=False)\n    assert edge_index.tolist() == out[0].tolist()\n    assert out[1].tolist() == [True, True, True, True, True, True]\n\n    torch.manual_seed(4)\n    out = dropout_path(edge_index, p=0.2)\n    assert out[0].tolist() == [[0, 1], [1, 0]]\n    assert out[1].tolist() == [True, True, False, False, False, False]\n    assert edge_index[:, out[1]].tolist() == out[0].tolist()\n\n    # test with unsorted edges\n    torch.manual_seed(5)\n    edge_index = torch.tensor([[3, 5, 2, 2, 2, 1], [1, 0, 0, 1, 3, 2]])\n    out = dropout_path(edge_index, p=0.2)\n    assert out[0].tolist() == [[3, 2, 2, 1], [1, 1, 3, 2]]\n    assert out[1].tolist() == [True, False, False, True, True, True]\n    assert edge_index[:, out[1]].tolist() == out[0].tolist()\n\n    # test with isolated nodes\n    torch.manual_seed(7)\n    edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 2, 4]])\n    out = dropout_path(edge_index, p=0.2)\n    assert out[0].tolist() == [[2, 3], [2, 4]]\n    assert out[1].tolist() == [False, False, True, True]\n    assert edge_index[:, out[1]].tolist() == out[0].tolist()\n"
  },
  {
    "path": "test/utils/test_embedding.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.nn import GCNConv, Linear\nfrom torch_geometric.utils import get_embeddings\nfrom torch_geometric.utils.embedding import get_embeddings_hetero\n\n\nclass GNN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GCNConv(5, 6)\n        self.conv2 = GCNConv(6, 7)\n\n    def forward(self, x0, edge_index):\n        x1 = self.conv1(x0, edge_index)\n        x2 = self.conv2(x1, edge_index)\n        return [x1, x2]\n\n\ndef test_get_embeddings():\n    x = torch.randn(6, 5)\n    edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]])\n\n    with pytest.warns(UserWarning, match=\"any 'MessagePassing' layers\"):\n        intermediate_outs = get_embeddings(Linear(5, 5), x)\n    assert len(intermediate_outs) == 0\n\n    model = GNN()\n    expected_embeddings = model(x, edge_index)\n\n    embeddings = get_embeddings(model, x, edge_index)\n    assert len(embeddings) == 2\n    for expected, out in zip(expected_embeddings, embeddings):\n        assert torch.allclose(expected, out)\n\n\ndef test_get_embeddings_hetero(hetero_data, hetero_model):\n    # Create model using the metadata from hetero_data\n    metadata = hetero_data.metadata()\n    model = hetero_model(metadata)\n\n    # Get heterogeneous embeddings\n    embeddings_dict = get_embeddings_hetero(model, None, hetero_data.x_dict,\n                                            hetero_data.edge_index_dict)\n\n    # Verify the structure of the returned embeddings\n    assert isinstance(embeddings_dict, dict)\n    assert 'paper' in embeddings_dict\n    assert 'author' in embeddings_dict\n\n    # Verify that we have embeddings for both node types\n    assert len(embeddings_dict['paper']) > 0\n    assert len(embeddings_dict['author']) > 0\n\n    # Check that the embeddings have the right shape\n    num_paper_nodes = hetero_data['paper'].num_nodes\n    num_author_nodes = hetero_data['author'].num_nodes\n\n    # Verify dimensions of embeddings\n    assert embeddings_dict['paper'][0].shape == (num_paper_nodes, 32\n                                                 )  # First layer\n    assert embeddings_dict['author'][0].shape == (num_author_nodes, 32\n                                                  )  # First layer\n"
  },
  {
    "path": "test/utils/test_functions.py",
    "content": "import torch\n\nfrom torch_geometric.utils import cumsum\n\n\ndef test_cumsum():\n    x = torch.tensor([2, 4, 1])\n    assert cumsum(x).tolist() == [0, 2, 6, 7]\n\n    x = torch.tensor([[2, 4], [3, 6]])\n    assert cumsum(x, dim=1).tolist() == [[0, 2, 6], [0, 3, 9]]\n"
  },
  {
    "path": "test/utils/test_geodesic.py",
    "content": "from math import sqrt\n\nimport torch\n\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.utils import geodesic_distance\n\n\n@withPackage('gdist')\ndef test_geodesic_distance():\n    pos = torch.tensor([\n        [0.0, 0.0, 0.0],\n        [2.0, 0.0, 0.0],\n        [0.0, 2.0, 0.0],\n        [2.0, 2.0, 0.0],\n    ])\n    face = torch.tensor([[0, 1, 3], [0, 2, 3]]).t()\n\n    out = geodesic_distance(pos, face)\n    expected = torch.tensor([\n        [0.0, 1.0, 1.0, sqrt(2)],\n        [1.0, 0.0, sqrt(2), 1.0],\n        [1.0, sqrt(2), 0.0, 1.0],\n        [sqrt(2), 1.0, 1.0, 0.0],\n    ])\n    assert torch.allclose(out, expected)\n    assert torch.allclose(out, geodesic_distance(pos, face, num_workers=-1))\n\n    out = geodesic_distance(pos, face, norm=False)\n    expected = torch.tensor([\n        [0, 2, 2, 2 * sqrt(2)],\n        [2, 0, 2 * sqrt(2), 2],\n        [2, 2 * sqrt(2), 0, 2],\n        [2 * sqrt(2), 2, 2, 0],\n    ])\n    assert torch.allclose(out, expected)\n\n    src = torch.tensor([0, 0, 0, 0])\n    dst = torch.tensor([0, 1, 2, 3])\n    out = geodesic_distance(pos, face, src=src, dst=dst)\n    expected = torch.tensor([0.0, 1.0, 1.0, sqrt(2)])\n    assert torch.allclose(out, expected)\n\n    out = geodesic_distance(pos, face, dst=dst)\n    expected = torch.tensor([0.0, 0.0, 0.0, 0.0])\n    assert torch.allclose(out, expected)\n"
  },
  {
    "path": "test/utils/test_grid.py",
    "content": "import torch\n\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.utils import grid\n\n\ndef test_grid():\n    (row, col), pos = grid(height=3, width=2)\n\n    expected_row = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2]\n    expected_col = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5]\n    expected_row += [3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5]\n    expected_col += [0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5]\n\n    expected_pos = [[0, 2], [1, 2], [0, 1], [1, 1], [0, 0], [1, 0]]\n\n    assert row.tolist() == expected_row\n    assert col.tolist() == expected_col\n    assert pos.tolist() == expected_pos\n\n    if is_full_test():\n        jit = torch.jit.script(grid)\n        (row, col), pos = jit(height=3, width=2)\n        assert row.tolist() == expected_row\n        assert col.tolist() == expected_col\n        assert pos.tolist() == expected_pos\n"
  },
  {
    "path": "test/utils/test_hetero.py",
    "content": "import torch\n\nfrom torch_geometric.testing import get_random_edge_index\nfrom torch_geometric.utils.hetero import construct_bipartite_edge_index\n\n\ndef test_construct_bipartite_edge_index():\n    edge_index = get_random_edge_index(4, 6, num_edges=20)\n\n    edge_index_dict = {\n        ('author', 'paper'): edge_index,\n        ('paper', 'author'): edge_index.flip([0]),\n    }\n    edge_attr_dict = {\n        ('author', 'paper'): torch.randn(edge_index.size(1), 16),\n        ('paper', 'author'): torch.randn(edge_index.size(1), 16)\n    }\n\n    edge_index, edge_attr = construct_bipartite_edge_index(\n        edge_index_dict,\n        src_offset_dict={\n            ('author', 'paper'): 0,\n            ('paper', 'author'): 4\n        },\n        dst_offset_dict={\n            'author': 0,\n            'paper': 4\n        },\n        edge_attr_dict=edge_attr_dict,\n    )\n\n    assert edge_index.size() == (2, 40)\n    assert edge_index.min() >= 0\n    assert edge_index[0].max() > 4 and edge_index[1].max() > 6\n    assert edge_index.max() <= 10\n    assert edge_attr.size() == (40, 16)\n    assert torch.equal(edge_attr[:20], edge_attr_dict['author', 'paper'])\n    assert torch.equal(edge_attr[20:], edge_attr_dict['paper', 'author'])\n"
  },
  {
    "path": "test/utils/test_homophily.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import homophily\n\n\ndef test_homophily():\n    edge_index = torch.tensor([[0, 1, 2, 3], [1, 2, 0, 4]])\n    y = torch.tensor([0, 0, 0, 0, 1])\n    batch = torch.tensor([0, 0, 0, 1, 1])\n    row, col = edge_index\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj = SparseTensor(row=row, col=col, sparse_sizes=(5, 5))\n\n    method = 'edge'\n    assert pytest.approx(homophily(edge_index, y, method=method)) == 0.75\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert pytest.approx(homophily(adj, y, method=method)) == 0.75\n    assert homophily(edge_index, y, batch, method).tolist() == [1., 0.]\n\n    method = 'node'\n    assert pytest.approx(homophily(edge_index, y, method=method)) == 0.6\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert pytest.approx(homophily(adj, y, method=method)) == 0.6\n    assert homophily(edge_index, y, batch, method).tolist() == [1., 0.]\n\n    method = 'edge_insensitive'\n    assert pytest.approx(homophily(edge_index, y, method=method)) == 0.1999999\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert pytest.approx(homophily(adj, y, method=method)) == 0.1999999\n    assert homophily(edge_index, y, batch, method).tolist() == [0., 0.]\n"
  },
  {
    "path": "test/utils/test_index_sort.py",
    "content": "import torch\n\nfrom torch_geometric.testing import withDevice\nfrom torch_geometric.utils import index_sort\n\n\n@withDevice\ndef test_index_sort_stable(device):\n    for _ in range(100):\n        inputs = torch.randint(0, 4, size=(10, ), device=device)\n\n        out = index_sort(inputs, stable=True)\n        expected = torch.sort(inputs, stable=True)\n\n        assert torch.equal(out[0], expected[0])\n        assert torch.equal(out[1], expected[1])\n"
  },
  {
    "path": "test/utils/test_isolated.py",
    "content": "import torch\n\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.utils import (\n    contains_isolated_nodes,\n    remove_isolated_nodes,\n)\n\n\ndef test_contains_isolated_nodes():\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    assert not contains_isolated_nodes(edge_index)\n    assert contains_isolated_nodes(edge_index, num_nodes=3)\n\n    if is_full_test():\n        jit = torch.jit.script(contains_isolated_nodes)\n        assert not jit(edge_index)\n        assert jit(edge_index, num_nodes=3)\n\n    edge_index = torch.tensor([[0, 1, 2, 0], [1, 0, 2, 0]])\n    assert contains_isolated_nodes(edge_index)\n\n\ndef test_remove_isolated_nodes():\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n\n    out, _, mask = remove_isolated_nodes(edge_index)\n    assert out.tolist() == [[0, 1, 0], [1, 0, 0]]\n    assert mask.tolist() == [1, 1]\n\n    if is_full_test():\n        jit = torch.jit.script(remove_isolated_nodes)\n        out, _, mask = jit(edge_index)\n        assert out.tolist() == [[0, 1, 0], [1, 0, 0]]\n        assert mask.tolist() == [1, 1]\n\n    out, _, mask = remove_isolated_nodes(edge_index, num_nodes=3)\n    assert out.tolist() == [[0, 1, 0], [1, 0, 0]]\n    assert mask.tolist() == [1, 1, 0]\n\n    edge_index = torch.tensor([[0, 2, 1, 0, 2], [2, 0, 1, 0, 2]])\n    edge_attr = torch.tensor([1, 2, 3, 4, 5])\n    out1, out2, mask = remove_isolated_nodes(edge_index, edge_attr)\n    assert out1.tolist() == [[0, 1, 0, 1], [1, 0, 0, 1]]\n    assert out2.tolist() == [1, 2, 4, 5]\n    assert mask.tolist() == [1, 0, 1]\n"
  },
  {
    "path": "test/utils/test_laplacian.py",
    "content": "import torch\n\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.utils import get_laplacian\n\n\ndef test_get_laplacian():\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)\n    edge_weight = torch.tensor([1, 2, 2, 4], dtype=torch.float)\n\n    lap = get_laplacian(edge_index, edge_weight)\n    assert lap[0].tolist() == [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]]\n    assert lap[1].tolist() == [-1, -2, -2, -4, 1, 4, 4]\n\n    if is_full_test():\n        jit = torch.jit.script(get_laplacian)\n        lap = jit(edge_index, edge_weight)\n        assert lap[0].tolist() == [[0, 1, 1, 2, 0, 1, 2],\n                                   [1, 0, 2, 1, 0, 1, 2]]\n        assert lap[1].tolist() == [-1, -2, -2, -4, 1, 4, 4]\n\n    lap_sym = get_laplacian(edge_index, edge_weight, normalization='sym')\n    assert lap_sym[0].tolist() == lap[0].tolist()\n    assert lap_sym[1].tolist() == [-0.5, -1, -0.5, -1, 1, 1, 1]\n\n    lap_rw = get_laplacian(edge_index, edge_weight, normalization='rw')\n    assert lap_rw[0].tolist() == lap[0].tolist()\n    assert lap_rw[1].tolist() == [-1, -0.5, -0.5, -1, 1, 1, 1]\n"
  },
  {
    "path": "test/utils/test_lexsort.py",
    "content": "import numpy as np\nimport torch\n\nfrom torch_geometric.utils import lexsort\n\n\ndef test_lexsort():\n    keys = [torch.randn(100) for _ in range(3)]\n\n    expected = np.lexsort([key.numpy() for key in keys])\n    assert torch.equal(lexsort(keys), torch.from_numpy(expected))\n"
  },
  {
    "path": "test/utils/test_loop.py",
    "content": "import torch\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.utils import (\n    add_remaining_self_loops,\n    add_self_loops,\n    contains_self_loops,\n    get_self_loop_attr,\n    remove_self_loops,\n    segregate_self_loops,\n    to_torch_coo_tensor,\n)\n\n\ndef test_contains_self_loops():\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    assert contains_self_loops(edge_index)\n\n    edge_index = torch.tensor([[0, 1, 1], [1, 0, 2]])\n    assert not contains_self_loops(edge_index)\n\n\ndef test_remove_self_loops():\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    edge_attr = torch.tensor([[1, 2], [3, 4], [5, 6]])\n\n    expected = torch.tensor([[0, 1], [1, 0]])\n\n    out = remove_self_loops(edge_index)\n    assert out[0].equal(expected)\n    assert out[1] is None\n\n    out = remove_self_loops(edge_index, edge_attr)\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([[1, 2], [3, 4]]))\n\n    adj = to_torch_coo_tensor(edge_index)\n    adj, _ = remove_self_loops(adj)\n    assert torch.diag(adj.to_dense()).tolist() == [0, 0]\n\n    edge_index = EdgeIndex(\n        edge_index,\n        sparse_size=(2, 2),\n        sort_order='row',\n        is_undirected=True,\n    )\n    out = remove_self_loops(edge_index)\n    assert out[0].equal(expected)\n    assert out[0].sparse_size() == (2, 2)\n    assert out[0].sort_order == 'row'\n    assert out[0].is_undirected\n    assert out[1] is None\n\n    out = remove_self_loops(edge_index, edge_attr)\n    assert out[0].equal(expected)\n    assert out[0].sparse_size() == (2, 2)\n    assert out[0].sort_order == 'row'\n    assert out[0].is_undirected\n    assert out[1].equal(torch.tensor([[1, 2], [3, 4]]))\n\n\ndef test_segregate_self_loops():\n    edge_index = torch.tensor([[0, 0, 1], [0, 1, 0]])\n\n    out = segregate_self_loops(edge_index)\n    assert out[0].equal(torch.tensor([[0, 1], [1, 0]]))\n    assert out[1] is None\n    assert out[2].equal(torch.tensor([[0], [0]]))\n    assert out[3] is None\n\n    edge_attr = torch.tensor([1, 2, 3])\n    out = segregate_self_loops(edge_index, edge_attr)\n    assert out[0].equal(torch.tensor([[0, 1], [1, 0]]))\n    assert out[1].equal(torch.tensor([2, 3]))\n    assert out[2].equal(torch.tensor([[0], [0]]))\n    assert out[3].equal(torch.tensor([1]))\n\n    edge_index = EdgeIndex(\n        edge_index,\n        sparse_size=(2, 2),\n        sort_order='row',\n        is_undirected=True,\n    )\n    out = segregate_self_loops(edge_index)\n    assert out[0].equal(torch.tensor([[0, 1], [1, 0]]))\n    assert out[0].sparse_size() == (2, 2)\n    assert out[0].sort_order == 'row'\n    assert out[0].is_undirected\n    assert out[1] is None\n    assert out[2].equal(torch.tensor([[0], [0]]))\n    assert out[2].sparse_size() == (2, 2)\n    assert out[2].sort_order == 'row'\n    assert out[2].is_undirected\n    assert out[3] is None\n\n    out = segregate_self_loops(edge_index, edge_attr)\n    assert out[0].equal(torch.tensor([[0, 1], [1, 0]]))\n    assert out[0].sparse_size() == (2, 2)\n    assert out[0].sort_order == 'row'\n    assert out[0].is_undirected\n    assert out[1].equal(torch.tensor([2, 3]))\n    assert out[2].equal(torch.tensor([[0], [0]]))\n    assert out[2].sparse_size() == (2, 2)\n    assert out[2].sort_order == 'row'\n    assert out[2].is_undirected\n    assert out[3].equal(torch.tensor([1]))\n\n\ndef test_add_self_loops():\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    edge_weight = torch.tensor([0.5, 0.5, 0.5])\n    edge_attr = torch.eye(3)\n    adj = to_torch_coo_tensor(edge_index, edge_weight)\n\n    expected = torch.tensor([[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]])\n    assert add_self_loops(edge_index)[0].equal(expected)\n\n    out = add_self_loops(edge_index, edge_weight)\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 1., 1.]))\n\n    out = add_self_loops(adj)[0]\n    assert out._indices().equal(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]]))\n    assert out._values().equal(torch.tensor([1.5, 0.5, 0.5, 1.0]))\n\n    out = add_self_loops(edge_index, edge_weight, fill_value=5)\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 5.0, 5.0]))\n\n    out = add_self_loops(adj, fill_value=5)[0]\n    assert out._indices().equal(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]]))\n    assert out._values().equal(torch.tensor([5.5, 0.5, 0.5, 5.0]))\n\n    out = add_self_loops(edge_index, edge_weight, fill_value=torch.tensor(2.))\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 2., 2.]))\n\n    out = add_self_loops(adj, fill_value=torch.tensor(2.))[0]\n    assert out._indices().equal(torch.tensor([[0, 0, 1, 1], [0, 1, 0, 1]]))\n    assert out._values().equal(torch.tensor([2.5, 0.5, 0.5, 2.0]))\n\n    out = add_self_loops(edge_index, edge_weight, fill_value='add')\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 1, 0.5]))\n\n    # Tests with `edge_attr`:\n    out = add_self_loops(edge_index, edge_attr)\n    assert out[0].equal(expected)\n    assert out[1].equal(\n        torch.tensor([\n            [1., 0., 0.],\n            [0., 1., 0.],\n            [0., 0., 1.],\n            [1., 1., 1.],\n            [1., 1., 1.],\n        ]))\n\n    out = add_self_loops(edge_index, edge_attr,\n                         fill_value=torch.tensor([0., 1., 0.]))\n    assert out[0].equal(expected)\n    assert out[1].equal(\n        torch.tensor([\n            [1., 0., 0.],\n            [0., 1., 0.],\n            [0., 0., 1.],\n            [0., 1., 0.],\n            [0., 1., 0.],\n        ]))\n\n    out = add_self_loops(edge_index, edge_attr, fill_value='add')\n    assert out[0].equal(expected)\n    assert out[1].equal(\n        torch.tensor([\n            [1., 0., 0.],\n            [0., 1., 0.],\n            [0., 0., 1.],\n            [0., 1., 1.],\n            [1., 0., 0.],\n        ]))\n\n    edge_index = EdgeIndex(\n        edge_index,\n        sparse_size=(2, 2),\n        sort_order='row',\n        is_undirected=True,\n    )\n    out, _ = add_self_loops(edge_index)\n    assert out.equal(expected)\n    assert out.sparse_size() == (2, 2)\n    assert out.sort_order is None\n    assert out.is_undirected\n\n    # Test empty `edge_index` and `edge_weight`:\n    edge_index = torch.empty(2, 0, dtype=torch.long)\n    edge_weight = torch.empty(0)\n    out = add_self_loops(edge_index, edge_weight, num_nodes=1)\n    assert out[0].equal(torch.tensor([[0], [0]]))\n    assert out[1].equal(torch.tensor([1.]))\n\n\ndef test_add_self_loops_bipartite():\n    edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])\n    adj = to_torch_coo_tensor(edge_index, size=(4, 2))\n\n    edge_index, _ = add_self_loops(edge_index, num_nodes=(4, 2))\n    assert edge_index.equal(\n        torch.tensor([\n            [0, 1, 2, 3, 0, 1],\n            [0, 0, 1, 1, 0, 1],\n        ]))\n\n    adj, _ = add_self_loops(adj)\n    assert adj._indices().equal(\n        torch.tensor([\n            [0, 1, 1, 2, 3],\n            [0, 0, 1, 1, 1],\n        ]))\n\n\ndef test_add_remaining_self_loops():\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    edge_weight = torch.tensor([0.5, 0.5, 0.5])\n    edge_attr = torch.eye(3)\n\n    expected = torch.tensor([[0, 1, 0, 1], [1, 0, 0, 1]])\n\n    out = add_remaining_self_loops(edge_index, edge_weight)\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 1]))\n\n    out = add_remaining_self_loops(edge_index, edge_weight, fill_value=5)\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 5.0]))\n\n    out = add_remaining_self_loops(edge_index, edge_weight,\n                                   fill_value=torch.tensor(2.))\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 2.0]))\n\n    out = add_remaining_self_loops(edge_index, edge_weight, fill_value='add')\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 0.5]))\n\n    # Test with `edge_attr`:\n    out = add_remaining_self_loops(edge_index, edge_attr,\n                                   fill_value=torch.tensor([0., 1., 0.]))\n    assert out[0].equal(expected)\n    assert out[1].equal(\n        torch.tensor([\n            [1., 0., 0.],\n            [0., 1., 0.],\n            [0., 0., 1.],\n            [0., 1., 0.],\n        ]))\n\n    edge_index = EdgeIndex(\n        edge_index,\n        sparse_size=(2, 2),\n        sort_order='row',\n        is_undirected=True,\n    )\n    out, _ = add_remaining_self_loops(edge_index)\n    assert out.equal(expected)\n    assert out.sparse_size() == (2, 2)\n    assert out.sort_order is None\n    assert out.is_undirected\n\n\ndef test_add_remaining_self_loops_without_initial_loops():\n    edge_index = torch.tensor([[0, 1], [1, 0]])\n    edge_weight = torch.tensor([0.5, 0.5])\n\n    expected = torch.tensor([[0, 1, 0, 1], [1, 0, 0, 1]])\n\n    out = add_remaining_self_loops(edge_index, edge_weight)\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 1, 1]))\n\n    out = add_remaining_self_loops(edge_index, edge_weight, fill_value=5)\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 5.0, 5.0]))\n\n    out = add_remaining_self_loops(edge_index, edge_weight,\n                                   fill_value=torch.tensor(2.0))\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 2.0, 2.0]))\n\n    # Test string `fill_value`:\n    out = add_remaining_self_loops(edge_index, edge_weight, fill_value='add')\n    assert out[0].equal(expected)\n    assert out[1].equal(torch.tensor([0.5, 0.5, 0.5, 0.5]))\n\n\ndef test_get_self_loop_attr():\n    edge_index = torch.tensor([[0, 1, 0], [1, 0, 0]])\n    edge_weight = torch.tensor([0.2, 0.3, 0.5])\n\n    full_loop_weight = get_self_loop_attr(edge_index, edge_weight)\n    assert full_loop_weight.equal(torch.tensor([0.5, 0.0]))\n\n    full_loop_weight = get_self_loop_attr(edge_index, edge_weight, num_nodes=4)\n    assert full_loop_weight.equal(torch.tensor([0.5, 0.0, 0.0, 0.0]))\n\n    full_loop_weight = get_self_loop_attr(edge_index)\n    assert full_loop_weight.equal(torch.tensor([1.0, 0.0]))\n\n    edge_attr = torch.tensor([[1.0, 0.0], [0.0, 1.0], [0.5, 1.0]])\n    full_loop_attr = get_self_loop_attr(edge_index, edge_attr)\n    assert full_loop_attr.equal(torch.tensor([[0.5, 1.0], [0.0, 0.0]]))\n"
  },
  {
    "path": "test/utils/test_map.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import withDevice, withPackage\nfrom torch_geometric.utils.map import map_index\n\n\n@withDevice\n@withPackage('pandas')\n@pytest.mark.parametrize('max_index', [3, 100_000_000])\ndef test_map_index(device, max_index):\n    src = torch.tensor([2, 0, 1, 0, max_index], device=device)\n    index = torch.tensor([max_index, 2, 0, 1], device=device)\n\n    out, mask = map_index(src, index, inclusive=True)\n    assert out.device == device\n    assert mask is None\n    assert out.tolist() == [1, 2, 3, 2, 0]\n\n\n@withDevice\n@withPackage('pandas')\n@pytest.mark.parametrize('max_index', [3, 100_000_000])\ndef test_map_index_na(device, max_index):\n    src = torch.tensor([2, 0, 1, 0, max_index], device=device)\n    index = torch.tensor([max_index, 2, 0], device=device)\n\n    out, mask = map_index(src, index, inclusive=False)\n    assert out.device == device\n    assert mask.device == device\n    assert out.tolist() == [1, 2, 2, 0]\n    assert mask.tolist() == [True, True, False, True, True]\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    args = parser.parse_args()\n\n    src = torch.randint(0, 100_000_000, (100_000, ), device=args.device)\n    index = src.unique()\n\n    def trivial_map(src, index, max_index, inclusive):\n        if max_index is None:\n            max_index = max(src.max(), index.max())\n\n        if inclusive:\n            assoc = src.new_empty(max_index + 1)\n        else:\n            assoc = src.new_full((max_index + 1, ), -1)\n        assoc[index] = torch.arange(index.numel(), device=index.device)\n        out = assoc[src]\n\n        if inclusive:\n            return out, None\n        else:\n            mask = out != -1\n            return out[mask], mask\n\n    print('Inclusive:')\n    benchmark(\n        funcs=[trivial_map, map_index],\n        func_names=['trivial', 'map_index'],\n        args=(src, index, None, True),\n        num_steps=100,\n        num_warmups=50,\n    )\n\n    print('Exclusive:')\n    benchmark(\n        funcs=[trivial_map, map_index],\n        func_names=['trivial', 'map_index'],\n        args=(src, index[:50_000], None, False),\n        num_steps=100,\n        num_warmups=50,\n    )\n"
  },
  {
    "path": "test/utils/test_mask.py",
    "content": "import torch\n\nfrom torch_geometric.utils import index_to_mask, mask_select, mask_to_index\n\n\ndef test_mask_select():\n    src = torch.randn(6, 8)\n    mask = torch.tensor([False, True, False, True, False, True])\n\n    out = mask_select(src, 0, mask)\n    assert out.size() == (3, 8)\n    assert torch.equal(src[torch.tensor([1, 3, 5])], out)\n\n    jit = torch.jit.script(mask_select)\n    assert torch.equal(jit(src, 0, mask), out)\n\n\ndef test_index_to_mask():\n    index = torch.tensor([1, 3, 5])\n\n    mask = index_to_mask(index)\n    assert mask.tolist() == [False, True, False, True, False, True]\n\n    mask = index_to_mask(index, size=7)\n    assert mask.tolist() == [False, True, False, True, False, True, False]\n\n\ndef test_mask_to_index():\n    mask = torch.tensor([False, True, False, True, False, True])\n\n    index = mask_to_index(mask)\n    assert index.tolist() == [1, 3, 5]\n"
  },
  {
    "path": "test/utils/test_mesh_laplacian.py",
    "content": "import torch\n\nfrom torch_geometric.utils import get_mesh_laplacian\n\n\ndef test_get_mesh_laplacian_of_cube():\n    pos = torch.tensor([\n        [1.0, 1.0, 1.0],\n        [1.0, -1.0, 1.0],\n        [-1.0, -1.0, 1.0],\n        [-1.0, 1.0, 1.0],\n        [1.0, 1.0, -1.0],\n        [1.0, -1.0, -1.0],\n        [-1.0, -1.0, -1.0],\n        [-1.0, 1.0, -1.0],\n    ])\n\n    face = torch.tensor([\n        [0, 1, 2],\n        [0, 3, 2],\n        [4, 5, 1],\n        [4, 0, 1],\n        [7, 6, 5],\n        [7, 4, 5],\n        [3, 2, 6],\n        [3, 7, 6],\n        [4, 0, 3],\n        [4, 7, 3],\n        [1, 5, 6],\n        [1, 2, 6],\n    ])\n\n    edge_index, edge_weight = get_mesh_laplacian(pos, face.t(),\n                                                 normalization='rw')\n\n    assert edge_index.tolist() == [\n        [\n            0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4,\n            4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 0, 1, 2, 3, 4, 5, 6, 7\n        ],\n        [\n            1, 2, 3, 4, 0, 2, 4, 5, 6, 0, 1, 3, 6, 0, 2, 4, 6, 7, 0, 1, 3, 5,\n            7, 1, 4, 6, 7, 1, 2, 3, 5, 7, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7\n        ],\n    ]\n\n    assert torch.allclose(\n        edge_weight,\n        torch.tensor([\n            0.375, 0.0, 0.375, 0.375, 0.3, 0.3, 0.0, 0.3, 0.0, 0.0, 0.375,\n            0.375, 0.375, 0.3, 0.3, 0.0, 0.0, 0.3, 0.3, 0.0, 0.0, 0.3, 0.3,\n            0.375, 0.375, 0.375, 0.0, 0.0, 0.3, 0.0, 0.3, 0.3, 0.375, 0.375,\n            0.0, 0.375, -1.125, -0.9, -1.125, -0.9, -0.9, -1.125, -0.9, -1.125\n        ]))\n\n\ndef test_get_mesh_laplacian_of_irregular_triangular_prism():\n    pos = torch.tensor([\n        [0.0, 0.0, 0.0],\n        [4.0, 0.0, 0.0],\n        [0.0, 0.0, -3.0],\n        [1.0, 5.0, -1.0],\n        [3.0, 5.0, -1.0],\n        [2.0, 5.0, -2.0],\n    ])\n\n    face = torch.tensor([\n        [0, 1, 2],\n        [3, 4, 5],\n        [0, 1, 4],\n        [0, 3, 4],\n        [1, 2, 5],\n        [1, 4, 5],\n        [2, 0, 3],\n        [2, 5, 3],\n    ])\n\n    edge_index, edge_weight = get_mesh_laplacian(pos, face.t(),\n                                                 normalization='rw')\n\n    assert edge_index.tolist() == [\n        [\n            0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5,\n            5, 5, 0, 1, 2, 3, 4, 5\n        ],\n        [\n            1, 2, 3, 4, 0, 2, 4, 5, 0, 1, 3, 5, 0, 2, 4, 5, 0, 1, 3, 5, 1, 2,\n            3, 4, 0, 1, 2, 3, 4, 5\n        ],\n    ]\n\n    assert torch.allclose(\n        edge_weight,\n        torch.tensor([\n            0.09730332, 0.15039921, 0.05081503, 0.00000000, 0.08726977,\n            0.03521059, 0.05363689, 0.00723919, 0.14497279, 0.03784235,\n            0.01629947, 0.03438699, 0.08362866, 0.02782887, 0.24252312,\n            0.40727590, 0.00000000, 0.08728313, 0.21507657, 0.38582093,\n            0.01117009, 0.04936920, 0.34247482, 0.36583540, -0.29851755,\n            -0.18335645, -0.23350160, -0.76125660, -0.68818060, -0.76884955\n        ]))\n"
  },
  {
    "path": "test/utils/test_negative_sampling.py",
    "content": "import torch\n\nfrom torch_geometric.utils import (\n    batched_negative_sampling,\n    contains_self_loops,\n    is_undirected,\n    negative_sampling,\n    structured_negative_sampling,\n    structured_negative_sampling_feasible,\n    to_undirected,\n)\nfrom torch_geometric.utils._negative_sampling import (\n    edge_index_to_vector,\n    vector_to_edge_index,\n)\n\n\ndef is_negative(edge_index, neg_edge_index, size, bipartite):\n    adj = torch.zeros(size, dtype=torch.bool)\n    neg_adj = torch.zeros(size, dtype=torch.bool)\n\n    adj[edge_index[0], edge_index[1]] = True\n    neg_adj[neg_edge_index[0], neg_edge_index[1]] = True\n\n    if not bipartite:\n        arange = torch.arange(size[0])\n        assert neg_adj[arange, arange].sum() == 0\n\n    return (adj & neg_adj).sum() == 0\n\n\ndef test_edge_index_to_vector_and_vice_versa():\n    # Create a fully-connected graph:\n    N = 10\n    row = torch.arange(N).view(-1, 1).repeat(1, N).view(-1)\n    col = torch.arange(N).view(1, -1).repeat(N, 1).view(-1)\n    edge_index = torch.stack([row, col], dim=0)\n\n    idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True)\n    assert population == N * N\n    assert idx.tolist() == list(range(population))\n    edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True)\n    assert is_undirected(edge_index2)\n    assert edge_index.tolist() == edge_index2.tolist()\n\n    idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False)\n    assert population == N * N - N\n    assert idx.tolist() == list(range(population))\n    mask = edge_index[0] != edge_index[1]  # Remove self-loops.\n    edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False)\n    assert is_undirected(edge_index2)\n    assert edge_index[:, mask].tolist() == edge_index2.tolist()\n\n    idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False,\n                                           force_undirected=True)\n    assert population == (N * (N + 1)) / 2 - N\n    assert idx.tolist() == list(range(population))\n    mask = edge_index[0] != edge_index[1]  # Remove self-loops.\n    edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False,\n                                       force_undirected=True)\n    assert is_undirected(edge_index2)\n    assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist()\n\n\ndef test_negative_sampling():\n    edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]])\n\n    neg_edge_index = negative_sampling(edge_index)\n    assert neg_edge_index.size(1) == edge_index.size(1)\n    assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False)\n\n    neg_edge_index = negative_sampling(edge_index, method='dense')\n    assert neg_edge_index.size(1) == edge_index.size(1)\n    assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False)\n\n    neg_edge_index = negative_sampling(edge_index, num_neg_samples=2)\n    assert neg_edge_index.size(1) == 2\n    assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False)\n\n    # Test with float multiplier less than 1\n    neg_edge_index = negative_sampling(edge_index, num_neg_samples=0.5)\n    assert neg_edge_index.size(1) == 2  # 50% of 4 edges = 2 edges\n    assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False)\n\n    # Test with float multiplier greater than 1\n    neg_edge_index = negative_sampling(edge_index, num_neg_samples=1.5)\n    assert neg_edge_index.size(1) == 6  # 150% of 4 edges = 6 edges\n    assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False)\n\n    edge_index = to_undirected(edge_index)\n    neg_edge_index = negative_sampling(edge_index, force_undirected=True)\n    assert neg_edge_index.size(1) == edge_index.size(1) - 1\n    assert is_undirected(neg_edge_index)\n    assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False)\n\n\ndef test_bipartite_negative_sampling():\n    edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]])\n\n    neg_edge_index = negative_sampling(edge_index, num_nodes=(3, 4))\n    assert neg_edge_index.size(1) == edge_index.size(1)\n    assert is_negative(edge_index, neg_edge_index, (3, 4), bipartite=True)\n\n    neg_edge_index = negative_sampling(edge_index, num_nodes=(3, 4),\n                                       num_neg_samples=2)\n    assert neg_edge_index.size(1) == 2\n    assert is_negative(edge_index, neg_edge_index, (3, 4), bipartite=True)\n\n\ndef test_batched_negative_sampling():\n    edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]])\n    edge_index = torch.cat([edge_index, edge_index + 4], dim=1)\n    batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])\n\n    neg_edge_index = batched_negative_sampling(edge_index, batch)\n    assert neg_edge_index.size(1) <= edge_index.size(1)\n\n    # Test with float multiplier less than 1\n    neg_edge_index = batched_negative_sampling(edge_index, batch,\n                                               num_neg_samples=0.5)\n    assert neg_edge_index.size(1) <= 4  # 50% of 8 edges = 4 edges\n\n    # Test with float multiplier greater than 1\n    neg_edge_index = batched_negative_sampling(edge_index, batch,\n                                               num_neg_samples=1.5)\n    assert neg_edge_index.size(1) <= 12  # 150% of 8 edges = 12 edges\n\n    adj = torch.zeros(8, 8, dtype=torch.bool)\n    adj[edge_index[0], edge_index[1]] = True\n    neg_adj = torch.zeros(8, 8, dtype=torch.bool)\n    neg_adj[neg_edge_index[0], neg_edge_index[1]] = True\n\n    assert (adj & neg_adj).sum() == 0\n    assert (adj | neg_adj).sum() == edge_index.size(1) + neg_edge_index.size(1)\n    assert neg_adj[:4, 4:].sum() == 0\n    assert neg_adj[4:, :4].sum() == 0\n\n\ndef test_bipartite_batched_negative_sampling():\n    edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]])\n    edge_index2 = edge_index1 + torch.tensor([[2], [4]])\n    edge_index3 = edge_index2 + torch.tensor([[2], [4]])\n    edge_index = torch.cat([edge_index1, edge_index2, edge_index3], dim=1)\n    src_batch = torch.tensor([0, 0, 1, 1, 2, 2])\n    dst_batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])\n\n    neg_edge_index = batched_negative_sampling(edge_index,\n                                               (src_batch, dst_batch))\n    assert neg_edge_index.size(1) <= edge_index.size(1)\n\n    adj = torch.zeros(6, 12, dtype=torch.bool)\n    adj[edge_index[0], edge_index[1]] = True\n    neg_adj = torch.zeros(6, 12, dtype=torch.bool)\n    neg_adj[neg_edge_index[0], neg_edge_index[1]] = True\n\n    assert (adj & neg_adj).sum() == 0\n    assert (adj | neg_adj).sum() == edge_index.size(1) + neg_edge_index.size(1)\n\n\ndef test_structured_negative_sampling():\n    edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]])\n\n    i, j, k = structured_negative_sampling(edge_index)\n    assert i.size(0) == edge_index.size(1)\n    assert j.size(0) == edge_index.size(1)\n    assert k.size(0) == edge_index.size(1)\n\n    adj = torch.zeros(4, 4, dtype=torch.bool)\n    adj[i, j] = 1\n\n    neg_adj = torch.zeros(4, 4, dtype=torch.bool)\n    neg_adj[i, k] = 1\n    assert (adj & neg_adj).sum() == 0\n\n    # Test with no self-loops:\n    edge_index = torch.LongTensor([[0, 0, 1, 1, 2], [1, 2, 0, 2, 1]])\n    i, j, k = structured_negative_sampling(edge_index, num_nodes=4,\n                                           contains_neg_self_loops=False)\n    neg_edge_index = torch.vstack([i, k])\n    assert not contains_self_loops(neg_edge_index)\n\n\ndef test_structured_negative_sampling_feasible():\n    edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2],\n                                   [1, 2, 0, 2, 0, 1, 1]])\n    assert not structured_negative_sampling_feasible(edge_index, 3, False)\n    assert structured_negative_sampling_feasible(edge_index, 3, True)\n    assert structured_negative_sampling_feasible(edge_index, 4, False)\n"
  },
  {
    "path": "test/utils/test_nested.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.utils import from_nested_tensor, to_nested_tensor\n\n\ndef test_to_nested_tensor():\n    x = torch.randn(5, 4, 3)\n\n    out = to_nested_tensor(x, batch=torch.tensor([0, 0, 1, 1, 1]))\n    out = out.to_padded_tensor(padding=0)\n    assert out.size() == (2, 3, 4, 3)\n    assert torch.allclose(out[0, :2], x[0:2])\n    assert torch.allclose(out[1, :3], x[2:5])\n\n    out = to_nested_tensor(x, ptr=torch.tensor([0, 2, 5]))\n    out = out.to_padded_tensor(padding=0)\n    assert out.size() == (2, 3, 4, 3)\n    assert torch.allclose(out[0, :2], x[0:2])\n    assert torch.allclose(out[1, :3], x[2:5])\n\n    out = to_nested_tensor(x)\n    out = out.to_padded_tensor(padding=0)\n    assert out.size() == (1, 5, 4, 3)\n    assert torch.allclose(out[0], x)\n\n\ndef test_from_nested_tensor():\n    x = torch.randn(5, 4, 3)\n\n    nested = to_nested_tensor(x, batch=torch.tensor([0, 0, 1, 1, 1]))\n    out, batch = from_nested_tensor(nested, return_batch=True)\n\n    assert torch.equal(x, out)\n    assert batch.tolist() == [0, 0, 1, 1, 1]\n\n    nested = torch.nested.nested_tensor([torch.randn(4, 3), torch.randn(5, 4)])\n    with pytest.raises(ValueError, match=\"have the same size in dimension 1\"):\n        from_nested_tensor(nested)\n\n    # Test zero-copy:\n    nested = to_nested_tensor(x, batch=torch.tensor([0, 0, 1, 1, 1]))\n    out = from_nested_tensor(nested)\n    out += 1  # Increment in-place (which should increment `nested` as well).\n    assert torch.equal(nested.to_padded_tensor(padding=0)[0, :2], out[0:2])\n    assert torch.equal(nested.to_padded_tensor(padding=0)[1, :3], out[2:5])\n\n\ndef test_to_and_from_nested_tensor_autograd():\n    x = torch.randn(5, 4, 3, requires_grad=True)\n    grad = torch.randn_like(x)\n\n    out = to_nested_tensor(x, batch=torch.tensor([0, 0, 1, 1, 1]))\n    out = from_nested_tensor(out)\n    out.backward(grad)\n    assert torch.equal(x.grad, grad)\n"
  },
  {
    "path": "test/utils/test_noise_scheduler.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.utils.noise_scheduler import (\n    get_diffusion_beta_schedule,\n    get_smld_sigma_schedule,\n)\n\n\ndef test_get_smld_sigma_schedule():\n    expected = torch.tensor([\n        1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637,\n        0.04641589, 0.02782559, 0.01668101, 0.01\n    ])\n    out = get_smld_sigma_schedule(\n        sigma_min=0.01,\n        sigma_max=1.0,\n        num_scales=10,\n    )\n    assert torch.allclose(out, expected)\n\n\n@pytest.mark.parametrize(\n    'schedule_type',\n    ['linear', 'quadratic', 'constant', 'sigmoid'],\n)\ndef test_get_diffusion_beta_schedule(schedule_type):\n    out = get_diffusion_beta_schedule(\n        schedule_type,\n        beta_start=0.1,\n        beta_end=0.2,\n        num_diffusion_timesteps=10,\n    )\n    assert out.size() == (10, )\n"
  },
  {
    "path": "test/utils/test_normalize_edge_index.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.utils import normalize_edge_index\n\n\n@pytest.mark.parametrize('add_self_loops', [False, True])\n@pytest.mark.parametrize('symmetric', [False, True])\ndef test_normalize_edge_index(add_self_loops: bool, symmetric: bool):\n    edge_index = torch.tensor([[0, 2, 2, 3], [2, 0, 3, 0]])\n\n    out = normalize_edge_index(\n        edge_index,\n        add_self_loops=add_self_loops,\n        symmetric=symmetric,\n    )\n    assert isinstance(out, tuple) and len(out) == 2\n    if not add_self_loops:\n        assert out[0].equal(edge_index)\n    else:\n        assert out[0].tolist() == [\n            [0, 2, 2, 3, 0, 1, 2, 3],\n            [2, 0, 3, 0, 0, 1, 2, 3],\n        ]\n\n    assert out[1].min() >= 0.0\n    assert out[1].min() <= 1.0\n"
  },
  {
    "path": "test/utils/test_normalized_cut.py",
    "content": "import torch\n\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.utils import normalized_cut\n\n\ndef test_normalized_cut():\n    row = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4])\n    col = torch.tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3])\n    edge_attr = torch.tensor(\n        [3.0, 3.0, 6.0, 3.0, 6.0, 1.0, 3.0, 2.0, 1.0, 2.0])\n    expected = torch.tensor([4.0, 4.0, 5.0, 2.5, 5.0, 1.0, 2.5, 2.0, 1.0, 2.0])\n\n    out = normalized_cut(torch.stack([row, col], dim=0), edge_attr)\n    assert torch.allclose(out, expected)\n\n    if is_full_test():\n        jit = torch.jit.script(normalized_cut)\n        out = jit(torch.stack([row, col], dim=0), edge_attr)\n        assert torch.allclose(out, expected)\n"
  },
  {
    "path": "test/utils/test_num_nodes.py",
    "content": "import torch\n\nfrom torch_geometric.utils import to_torch_coo_tensor\nfrom torch_geometric.utils.num_nodes import (\n    maybe_num_nodes,\n    maybe_num_nodes_dict,\n)\n\n\ndef test_maybe_num_nodes():\n    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]])\n\n    assert maybe_num_nodes(edge_index, 4) == 4\n    assert maybe_num_nodes(edge_index) == 3\n\n    adj = to_torch_coo_tensor(edge_index)\n    assert maybe_num_nodes(adj, 4) == 4\n    assert maybe_num_nodes(adj) == 3\n\n\ndef test_maybe_num_nodes_dict():\n    edge_index_dict = {\n        '1': torch.tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]),\n        '2': torch.tensor([[0, 0, 1, 3], [1, 2, 0, 4]])\n    }\n    num_nodes_dict = {'2': 6}\n\n    assert maybe_num_nodes_dict(edge_index_dict) == {'1': 3, '2': 5}\n    assert maybe_num_nodes_dict(edge_index_dict, num_nodes_dict) == {\n        '1': 3,\n        '2': 6,\n    }\n"
  },
  {
    "path": "test/utils/test_one_hot.py",
    "content": "import torch\n\nfrom torch_geometric.utils import one_hot\n\n\ndef test_one_hot():\n    index = torch.tensor([0, 1, 2])\n\n    out = one_hot(index)\n    assert out.size() == (3, 3)\n    assert out.dtype == torch.float\n    assert out.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]\n\n    out = one_hot(index, num_classes=4, dtype=torch.long)\n    assert out.size() == (3, 4)\n    assert out.dtype == torch.long\n    assert out.tolist() == [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]\n"
  },
  {
    "path": "test/utils/test_ppr.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.datasets import KarateClub\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.utils import get_ppr\n\n\n@withPackage('numba')\n@pytest.mark.parametrize('target', [None, torch.tensor([0, 4, 5, 6])])\ndef test_get_ppr(target):\n    data = KarateClub()[0]\n\n    edge_index, edge_weight = get_ppr(\n        data.edge_index,\n        alpha=0.1,\n        eps=1e-5,\n        target=target,\n    )\n\n    assert edge_index.size(0) == 2\n    assert edge_index.size(1) == edge_weight.numel()\n\n    min_row = 0 if target is None else target.min()\n    max_row = data.num_nodes - 1 if target is None else target.max()\n    assert edge_index[0].min() == min_row and edge_index[0].max() == max_row\n    assert edge_index[1].min() >= 0 and edge_index[1].max() < data.num_nodes\n    assert edge_weight.min() >= 0.0 and edge_weight.max() <= 1.0\n"
  },
  {
    "path": "test/utils/test_random.py",
    "content": "import numpy as np\nimport torch\n\nfrom torch_geometric.utils import (\n    barabasi_albert_graph,\n    erdos_renyi_graph,\n    stochastic_blockmodel_graph,\n)\n\n\ndef test_erdos_renyi_graph():\n    torch.manual_seed(1234)\n    edge_index = erdos_renyi_graph(5, 0.2, directed=False)\n    assert edge_index.tolist() == [\n        [0, 1, 1, 1, 2, 4],\n        [1, 0, 2, 4, 1, 1],\n    ]\n\n    edge_index = erdos_renyi_graph(5, 0.5, directed=True)\n    assert edge_index.tolist() == [\n        [1, 1, 2, 2, 3, 4, 4, 4],\n        [0, 3, 0, 4, 0, 0, 1, 3],\n    ]\n\n\ndef test_stochastic_blockmodel_graph():\n    torch.manual_seed(12345)\n\n    block_sizes = [2, 2, 4]\n    edge_probs = [\n        [0.25, 0.05, 0.02],\n        [0.05, 0.35, 0.07],\n        [0.02, 0.07, 0.40],\n    ]\n\n    edge_index = stochastic_blockmodel_graph(block_sizes, edge_probs,\n                                             directed=False)\n    assert edge_index.tolist() == [\n        [2, 3, 4, 4, 5, 5, 6, 7, 7, 7],\n        [3, 2, 5, 7, 4, 7, 7, 4, 5, 6],\n    ]\n\n    edge_index = stochastic_blockmodel_graph(block_sizes, edge_probs,\n                                             directed=True)\n    assert edge_index.tolist() == [\n        [0, 1, 3, 5, 6, 6, 7, 7],\n        [3, 3, 2, 4, 4, 7, 5, 6],\n    ]\n\n\ndef test_barabasi_albert_graph():\n    torch.manual_seed(12345)\n    np.random.seed(12345)\n\n    edge_index = barabasi_albert_graph(num_nodes=8, num_edges=3)\n    assert edge_index.size() == (2, 26)\n"
  },
  {
    "path": "test/utils/test_repeat.py",
    "content": "from torch_geometric.utils.repeat import repeat\n\n\ndef test_repeat():\n    assert repeat(None, length=4) is None\n    assert repeat(4, length=4) == [4, 4, 4, 4]\n    assert repeat([2, 3, 4], length=4) == [2, 3, 4, 4]\n    assert repeat([1, 2, 3, 4], length=4) == [1, 2, 3, 4]\n    assert repeat([1, 2, 3, 4, 5], length=4) == [1, 2, 3, 4]\n"
  },
  {
    "path": "test/utils/test_scatter.py",
    "content": "from itertools import product\n\nimport pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import withCUDA, withDevice, withPackage\nfrom torch_geometric.utils import group_argsort, group_cat, scatter\nfrom torch_geometric.utils._scatter import scatter_argmax\n\n\ndef test_scatter_validate():\n    src = torch.randn(100, 32)\n    index = torch.randint(0, 10, (100, ), dtype=torch.long)\n\n    with pytest.raises(ValueError, match=\"must be one-dimensional\"):\n        scatter(src, index.view(-1, 1))\n\n    with pytest.raises(ValueError, match=\"must lay between 0 and 1\"):\n        scatter(src, index, dim=2)\n\n    with pytest.raises(ValueError, match=\"invalid `reduce` argument 'std'\"):\n        scatter(src, index, reduce='std')\n\n\n@withDevice\n@withPackage('torch_scatter')\n@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max'])\ndef test_scatter(reduce, device):\n    import torch_scatter\n\n    src = torch.randn(100, 16, device=device)\n    index = torch.randint(0, 8, (100, ), device=device)\n\n    if device.type == 'mps' and reduce in ['min', 'max']:\n        with pytest.raises(NotImplementedError, match=\"for the MPS device\"):\n            scatter(src, index, dim=0, reduce=reduce)\n        return\n\n    out1 = scatter(src, index, dim=0, reduce=reduce)\n    out2 = torch_scatter.scatter(src, index, dim=0, reduce=reduce)\n    assert out1.device == device\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n    jit = torch.jit.script(scatter)\n    out3 = jit(src, index, dim=0, reduce=reduce)\n    assert torch.allclose(out1, out3, atol=1e-6)\n\n    src = torch.randn(8, 100, 16, device=device)\n    out1 = scatter(src, index, dim=1, reduce=reduce)\n    out2 = torch_scatter.scatter(src, index, dim=1, reduce=reduce)\n    assert out1.device == device\n    assert torch.allclose(out1, out2, atol=1e-6)\n\n\n@withDevice\n@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max'])\ndef test_scatter_backward(reduce, device):\n    src = torch.randn(8, 100, 16, device=device, requires_grad=True)\n    index = torch.randint(0, 8, (100, ), device=device)\n\n    if device.type == 'mps' and reduce in ['min', 'max']:\n        with pytest.raises(NotImplementedError, match=\"for the MPS device\"):\n            scatter(src, index, dim=1, reduce=reduce)\n        return\n\n    out = scatter(src, index, dim=1, reduce=reduce)\n\n    assert src.grad is None\n    out.mean().backward()\n    assert src.grad is not None\n\n\n@withDevice\ndef test_scatter_any(device):\n    src = torch.randn(6, 4, device=device)\n    index = torch.tensor([0, 0, 1, 1, 2, 2], device=device)\n\n    out = scatter(src, index, dim=0, reduce='any')\n\n    for i in range(3):\n        for j in range(4):\n            assert float(out[i, j]) in src[2 * i:2 * i + 2, j].tolist()\n\n\n@withDevice\n@pytest.mark.parametrize('num_groups', [4])\n@pytest.mark.parametrize('descending', [False, True])\ndef test_group_argsort(num_groups, descending, device):\n    src = torch.randn(20, device=device)\n    index = torch.randint(0, num_groups, (20, ), device=device)\n\n    out = group_argsort(src, index, 0, num_groups, descending=descending)\n\n    expected = torch.empty_like(index)\n    for i in range(num_groups):\n        mask = index == i\n        tmp = src[mask].argsort(descending=descending)\n        perm = torch.empty_like(tmp)\n        perm[tmp] = torch.arange(tmp.numel(), device=device)\n        expected[mask] = perm\n\n    assert torch.equal(out, expected)\n\n    empty_tensor = torch.tensor([], device=device)\n    out = group_argsort(empty_tensor, empty_tensor)\n    assert out.numel() == 0\n\n\n@withCUDA\ndef test_scatter_argmax(device):\n    src = torch.arange(5, device=device)\n    index = torch.tensor([2, 2, 0, 0, 3], device=device)\n\n    old_state = torch_geometric.typing.WITH_TORCH_SCATTER\n    torch_geometric.typing.WITH_TORCH_SCATTER = False\n    argmax = scatter_argmax(src, index, dim_size=6)\n    torch_geometric.typing.WITH_TORCH_SCATTER = old_state\n    assert argmax.tolist() == [3, 5, 1, 4, 5, 5]\n\n\n@withDevice\ndef test_group_cat(device):\n    x1 = torch.randn(4, 4, device=device)\n    x2 = torch.randn(2, 4, device=device)\n    index1 = torch.tensor([0, 0, 1, 2], device=device)\n    index2 = torch.tensor([0, 2], device=device)\n\n    expected = torch.cat([x1[:2], x2[:1], x1[2:4], x2[1:]], dim=0)\n\n    out, index = group_cat(\n        [x1, x2],\n        [index1, index2],\n        dim=0,\n        return_index=True,\n    )\n    assert torch.equal(out, expected)\n    assert index.tolist() == [0, 0, 0, 1, 2, 2]\n\n\nif __name__ == '__main__':\n    # Insights on GPU:\n    # ================\n    # * \"sum\": Prefer `scatter_add_` implementation\n    # * \"mean\": Prefer manual implementation via `scatter_add_` + `count`\n    # * \"min\"/\"max\":\n    #   * Prefer `scatter_reduce_` implementation without gradients\n    #   * Prefer `torch_sparse` implementation with gradients\n    # * \"mul\": Prefer `torch_sparse` implementation\n    #\n    # Insights on CPU:\n    # ================\n    # * \"sum\": Prefer `scatter_add_` implementation\n    # * \"mean\": Prefer manual implementation via `scatter_add_` + `count`\n    # * \"min\"/\"max\": Prefer `scatter_reduce_` implementation\n    # * \"mul\" (probably not worth branching for this):\n    #   * Prefer `scatter_reduce_` implementation without gradients\n    #   * Prefer `torch_sparse` implementation with gradients\n    import argparse\n\n    from torch_geometric.typing import WITH_TORCH_SCATTER, torch_scatter\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    parser.add_argument('--aggr', type=str, default='all')\n    args = parser.parse_args()\n\n    num_nodes_list = [4_000, 8_000, 16_000, 32_000, 64_000]\n\n    if args.aggr == 'all':\n        aggrs = ['sum', 'mean', 'min', 'max', 'mul']\n    else:\n        aggrs = args.aggr.split(',')\n\n    def pytorch_scatter(x, index, dim_size, reduce):\n        if reduce == 'min' or reduce == 'max':\n            reduce = f'a{aggr}'  # `amin` or `amax`\n        elif reduce == 'mul':\n            reduce = 'prod'\n        out = x.new_zeros(dim_size, x.size(-1))\n        include_self = reduce in ['sum', 'mean']\n        index = index.view(-1, 1).expand(-1, x.size(-1))\n        out.scatter_reduce_(0, index, x, reduce, include_self=include_self)\n        return out\n\n    def pytorch_index_add(x, index, dim_size, reduce):\n        if reduce != 'sum':\n            raise NotImplementedError\n        out = x.new_zeros(dim_size, x.size(-1))\n        out.index_add_(0, index, x)\n        return out\n\n    def own_scatter(x, index, dim_size, reduce):\n        return torch_scatter.scatter(x, index, dim=0, dim_size=num_nodes,\n                                     reduce=reduce)\n\n    def optimized_scatter(x, index, dim_size, reduce):\n        return scatter(x, index, dim=0, dim_size=dim_size, reduce=reduce)\n\n    for aggr, num_nodes in product(aggrs, num_nodes_list):\n        num_edges = num_nodes * 50\n        print(f'aggr: {aggr}, #nodes: {num_nodes}, #edges: {num_edges}')\n\n        x = torch.randn(num_edges, 64, device=args.device)\n        index = torch.randint(num_nodes, (num_edges, ), device=args.device)\n\n        funcs = [pytorch_scatter]\n        func_names = ['PyTorch scatter_reduce']\n\n        if aggr == 'sum':\n            funcs.append(pytorch_index_add)\n            func_names.append('PyTorch index_add')\n\n        if WITH_TORCH_SCATTER:\n            funcs.append(own_scatter)\n            func_names.append('torch_scatter')\n\n        funcs.append(optimized_scatter)\n        func_names.append('Optimized PyG Scatter')\n\n        benchmark(\n            funcs=funcs,\n            func_names=func_names,\n            args=(x, index, num_nodes, aggr),\n            num_steps=100 if args.device == 'cpu' else 1000,\n            num_warmups=50 if args.device == 'cpu' else 500,\n            backward=args.backward,\n        )\n"
  },
  {
    "path": "test/utils/test_segment.py",
    "content": "from itertools import product\n\nimport pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.index import index2ptr\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import withCUDA, withoutExtensions\nfrom torch_geometric.utils import scatter, segment, segment_logsumexp\n\n\n@withCUDA\n@withoutExtensions\n@pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max'])\ndef test_segment(device, without_extensions, reduce):\n    src = torch.randn(20, 16, device=device)\n    ptr = torch.tensor([0, 0, 5, 10, 15, 20], device=device)\n\n    if (not torch_geometric.typing.WITH_TORCH_SCATTER\n            and not torch_geometric.typing.WITH_PT20):\n        with pytest.raises(ImportError, match=\"requires the 'torch-scatter'\"):\n            segment(src, ptr, reduce=reduce)\n    else:\n        out = segment(src, ptr, reduce=reduce)\n\n        expected = getattr(torch, reduce)(src.view(4, 5, -1), dim=1)\n        expected = expected[0] if isinstance(expected, tuple) else expected\n\n        assert torch.allclose(out[:1], torch.zeros(1, 16, device=device))\n        assert torch.allclose(out[1:], expected)\n\n\n@withCUDA\n@withoutExtensions\ndef test_segment_logsumexp(device, without_extensions) -> None:\n    src = torch.randn(5, 4, device=device)\n\n    expected = src.logsumexp(dim=0)\n    ptr = torch.tensor([0, 0, 5, 5], device=device)\n    out = segment_logsumexp(src, ptr, dim=0)\n    assert out.size() == (3, 4)\n    assert out[0].abs().sum() == 0.0\n    assert torch.allclose(expected, out[1])\n    assert out[2].abs().sum() == 0.0\n\n    expected = src.logsumexp(dim=1)\n    ptr = torch.tensor([0, 0, 4, 4], device=device)\n    out = segment_logsumexp(src, ptr, dim=1)\n    assert out.size() == (5, 3)\n    assert out[:, 0].abs().sum() == 0.0\n    assert torch.allclose(expected, out[:, 1])\n    assert out[:, 2].abs().sum() == 0.0\n\n\nif __name__ == '__main__':\n    # Insights on GPU:\n    # ================\n    # * \"mean\": Prefer `torch._segment_reduce` implementation\n    # * others: Prefer `torch_scatter` implementation\n    #\n    # Insights on CPU:\n    # ================\n    # * \"all\": Prefer `torch_scatter` implementation (but `scatter(...)`\n    #          implementation is far superior due to multi-threading usage.\n    import argparse\n\n    from torch_geometric.typing import WITH_TORCH_SCATTER, torch_scatter\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    parser.add_argument('--aggr', type=str, default='all')\n    args = parser.parse_args()\n\n    num_nodes_list = [4_000, 8_000, 16_000, 32_000, 64_000]\n\n    if args.aggr == 'all':\n        aggrs = ['sum', 'mean', 'min', 'max']\n    else:\n        aggrs = args.aggr.split(',')\n\n    def pytorch_segment(x, ptr, reduce):\n        if reduce == 'min' or reduce == 'max':\n            reduce = f'a{aggr}'  # `amin` or `amax`\n        return torch._segment_reduce(x, reduce, offsets=ptr)\n\n    def own_segment(x, ptr, reduce):\n        return torch_scatter.segment_csr(x, ptr, reduce=reduce)\n\n    def optimized_scatter(x, index, reduce, dim_size):\n        return scatter(x, index, dim=0, dim_size=dim_size, reduce=reduce)\n\n    def optimized_segment(x, index, reduce):\n        return segment(x, ptr, reduce=reduce)\n\n    for aggr, num_nodes in product(aggrs, num_nodes_list):\n        num_edges = num_nodes * 50\n        print(f'aggr: {aggr}, #nodes: {num_nodes}, #edges: {num_edges}')\n\n        x = torch.randn(num_edges, 64, device=args.device)\n        index = torch.randint(num_nodes, (num_edges, ), device=args.device)\n        index, _ = index.sort()\n        ptr = index2ptr(index, size=num_nodes)\n\n        funcs = [pytorch_segment]\n        func_names = ['PyTorch segment_reduce']\n        arg_list = [(x, ptr, aggr)]\n\n        if WITH_TORCH_SCATTER:\n            funcs.append(own_segment)\n            func_names.append('torch_scatter')\n            arg_list.append((x, ptr, aggr))\n\n        funcs.append(optimized_scatter)\n        func_names.append('Optimized PyG Scatter')\n        arg_list.append((x, index, aggr, num_nodes))\n\n        funcs.append(optimized_segment)\n        func_names.append('Optimized PyG Segment')\n        arg_list.append((x, ptr, aggr))\n\n        benchmark(\n            funcs=funcs,\n            func_names=func_names,\n            args=arg_list,\n            num_steps=100 if args.device == 'cpu' else 1000,\n            num_warmups=50 if args.device == 'cpu' else 500,\n            backward=args.backward,\n        )\n"
  },
  {
    "path": "test/utils/test_select.py",
    "content": "import torch\n\nfrom torch_geometric.utils import narrow, select\n\n\ndef test_select():\n    src = torch.randn(5, 3)\n    index = torch.tensor([0, 2, 4])\n    mask = torch.tensor([True, False, True, False, True])\n\n    out = select(src, index, dim=0)\n    assert torch.equal(out, src[index])\n    assert torch.equal(out, select(src, mask, dim=0))\n    assert torch.equal(out, torch.tensor(select(src.tolist(), index, dim=0)))\n    assert torch.equal(out, torch.tensor(select(src.tolist(), mask, dim=0)))\n\n\ndef test_narrow():\n    src = torch.randn(5, 3)\n\n    out = narrow(src, dim=0, start=2, length=2)\n    assert torch.equal(out, src[2:4])\n    assert torch.equal(out, torch.tensor(narrow(src.tolist(), 0, 2, 2)))\n"
  },
  {
    "path": "test/utils/test_smiles.py",
    "content": "import pytest\n\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.utils import from_rdmol, from_smiles, to_rdmol, to_smiles\n\nsmiles = [\n    r'F/C=C/F',\n    r'F/C=C\\F',\n    r'F/C=C\\F',\n    (r'COc1cccc([C@@H]2Oc3ccc(OC)cc3/C(=N/OC[C@@H](C)[C@H](OCc3ccccc3)'\n     r'C(C)C)[C@@H]2O)c1'),\n    r'C/C(=C\\C(=O)c1ccc(C)o1)Nc1ccc2c(c1)OCO2',\n    r'F[B-](F)(F)c1cnc2ccccc2c1',\n    r'COC(=O)[C@@]1(Cc2ccccc2)[C@H]2C(=O)N(C)C(=O)[C@H]2[C@H]2CN=C(SC)N21',\n    (r'O=C(O)c1ccc(NS(=O)(=O)c2ccc3c(c2)C(=O)c2cc(S(=O)(=O)Nc4ccc(C(=O)O)'\n     r'cc4)ccc2-3)cc1'),\n]\n\n\n@withPackage('rdkit')\n@pytest.mark.parametrize('smiles', smiles)\ndef test_from_to_smiles(smiles):\n    data = from_smiles(smiles)\n    assert to_smiles(data) == smiles\n\n\n@withPackage('rdkit')\n@pytest.mark.parametrize('smiles', smiles)\ndef test_from_to_rdmol(smiles):\n    from rdkit import Chem\n    mol1 = Chem.MolFromSmiles(smiles)\n    data = from_rdmol(mol1)\n    mol2 = to_rdmol(data)\n    assert Chem.MolToSmiles(mol1) == Chem.MolToSmiles(mol2)\n"
  },
  {
    "path": "test/utils/test_softmax.py",
    "content": "import pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.utils import softmax\n\nCALCULATION_VIA_PTR_AVAILABLE = (torch_geometric.typing.WITH_SOFTMAX\n                                 or torch_geometric.typing.WITH_TORCH_SCATTER)\n\n\ndef test_softmax():\n    src = torch.tensor([1., 1., 1., 1.])\n    index = torch.tensor([0, 0, 1, 2])\n    ptr = torch.tensor([0, 2, 3, 4])\n\n    out = softmax(src, index)\n    assert out.tolist() == [0.5, 0.5, 1, 1]\n    assert softmax(src, ptr=ptr).tolist() == out.tolist()\n\n    src = src.view(-1, 1)\n    out = softmax(src, index)\n    assert out.tolist() == [[0.5], [0.5], [1], [1]]\n    assert softmax(src, ptr=ptr).tolist() == out.tolist()\n\n    jit = torch.jit.script(softmax)\n    assert torch.allclose(jit(src, index), out)\n\n\ndef test_softmax_backward():\n    src_sparse = torch.rand(4, 8)\n    index = torch.tensor([0, 0, 1, 1])\n    src_dense = src_sparse.clone().view(2, 2, src_sparse.size(-1))\n\n    src_sparse.requires_grad_(True)\n    src_dense.requires_grad_(True)\n\n    out_sparse = softmax(src_sparse, index)\n    out_sparse.mean().backward()\n    out_dense = src_dense.softmax(dim=1)\n    out_dense.mean().backward()\n\n    assert torch.allclose(out_sparse, out_dense.view_as(out_sparse))\n    assert torch.allclose(src_sparse.grad, src_dense.grad.view_as(src_sparse))\n\n\ndef test_softmax_dim():\n    index = torch.tensor([0, 0, 0, 0])\n    ptr = torch.tensor([0, 4])\n\n    src = torch.randn(4)\n    assert torch.allclose(softmax(src, index, dim=0), src.softmax(dim=0))\n    assert torch.allclose(softmax(src, ptr=ptr, dim=0), src.softmax(dim=0))\n\n    src = torch.randn(4, 16)\n    assert torch.allclose(softmax(src, index, dim=0), src.softmax(dim=0))\n    assert torch.allclose(softmax(src, ptr=ptr, dim=0), src.softmax(dim=0))\n\n    src = torch.randn(4, 4)\n    assert torch.allclose(softmax(src, index, dim=-1), src.softmax(dim=-1))\n    if CALCULATION_VIA_PTR_AVAILABLE:\n        assert torch.allclose(softmax(src, ptr=ptr, dim=-1), src.softmax(-1))\n    else:\n        with pytest.raises(ImportError, match=\"requires the 'torch-scatter'\"):\n            softmax(src, ptr=ptr, dim=-1)\n\n    src = torch.randn(4, 4, 16)\n    assert torch.allclose(softmax(src, index, dim=1), src.softmax(dim=1))\n    if CALCULATION_VIA_PTR_AVAILABLE:\n        assert torch.allclose(softmax(src, ptr=ptr, dim=1), src.softmax(dim=1))\n    else:\n        with pytest.raises(ImportError, match=\"requires the 'torch-scatter'\"):\n            softmax(src, ptr=ptr, dim=1)\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    args = parser.parse_args()\n\n    num_nodes, num_edges = 10_000, 200_000\n    x = torch.randn(num_edges, 64, device=args.device)\n    index = torch.randint(num_nodes, (num_edges, ), device=args.device)\n\n    compiled_softmax = torch.compile(softmax)\n\n    def dense_softmax(x, index):\n        x = x.view(num_nodes, -1, x.size(-1))\n        return x.softmax(dim=-1)\n\n    benchmark(\n        funcs=[dense_softmax, softmax, compiled_softmax],\n        func_names=['Dense Softmax', 'Vanilla', 'Compiled'],\n        args=(x, index),\n        num_steps=50 if args.device == 'cpu' else 500,\n        num_warmups=10 if args.device == 'cpu' else 100,\n        backward=args.backward,\n    )\n"
  },
  {
    "path": "test/utils/test_sort_edge_index.py",
    "content": "from typing import List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.utils import sort_edge_index\n\n\ndef test_sort_edge_index():\n    edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]])\n    edge_attr = torch.tensor([[1], [2], [3], [4]])\n\n    out = sort_edge_index(edge_index)\n    assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n\n    torch_geometric.typing.MAX_INT64 = 1\n    out = sort_edge_index(edge_index)\n    torch_geometric.typing.MAX_INT64 = torch.iinfo(torch.int64).max\n    assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n\n    out = sort_edge_index((edge_index[0], edge_index[1]))\n    assert isinstance(out, tuple)\n    assert out[0].tolist() == [0, 1, 1, 2]\n    assert out[1].tolist() == [1, 0, 2, 1]\n\n    out = sort_edge_index(edge_index, None)\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert out[1] is None\n\n    out = sort_edge_index(edge_index, edge_attr)\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert out[1].tolist() == [[4], [3], [2], [1]]\n\n    out = sort_edge_index(edge_index, [edge_attr, edge_attr.view(-1)])\n    assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n    assert out[1][0].tolist() == [[4], [3], [2], [1]]\n    assert out[1][1].tolist() == [4, 3, 2, 1]\n\n\ndef test_sort_edge_index_jit():\n    @torch.jit.script\n    def wrapper1(edge_index: Tensor) -> Tensor:\n        return sort_edge_index(edge_index)\n\n    @torch.jit.script\n    def wrapper2(\n        edge_index: Tensor,\n        edge_attr: Optional[Tensor],\n    ) -> Tuple[Tensor, Optional[Tensor]]:\n        return sort_edge_index(edge_index, edge_attr)\n\n    @torch.jit.script\n    def wrapper3(\n        edge_index: Tensor,\n        edge_attr: List[Tensor],\n    ) -> Tuple[Tensor, List[Tensor]]:\n        return sort_edge_index(edge_index, edge_attr)\n\n    edge_index = torch.tensor([[2, 1, 1, 0], [1, 2, 0, 1]])\n    edge_attr = torch.tensor([[1], [2], [3], [4]])\n\n    out = wrapper1(edge_index)\n    assert out.size() == edge_index.size()\n\n    out = wrapper2(edge_index, None)\n    assert out[0].size() == edge_index.size()\n    assert out[1] is None\n\n    out = wrapper2(edge_index, edge_attr)\n    assert out[0].size() == edge_index.size()\n    assert out[1].size() == edge_attr.size()\n\n    out = wrapper3(edge_index, [edge_attr, edge_attr.view(-1)])\n    assert out[0].size() == edge_index.size()\n    assert len(out[1]) == 2\n    assert out[1][0].size() == edge_attr.size()\n    assert out[1][1].size() == edge_attr.view(-1).size()\n"
  },
  {
    "path": "test/utils/test_sparse.py",
    "content": "import os.path as osp\n\nimport pytest\nimport torch\n\nimport torch_geometric.typing\nfrom torch_geometric.io import fs\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import is_full_test, withCUDA, withPackage\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import (\n    dense_to_sparse,\n    is_sparse,\n    is_torch_sparse_tensor,\n    to_edge_index,\n    to_torch_coo_tensor,\n    to_torch_csc_tensor,\n    to_torch_csr_tensor,\n    to_torch_sparse_tensor,\n)\nfrom torch_geometric.utils.sparse import cat\n\n\ndef test_dense_to_sparse():\n    adj = torch.tensor([\n        [3.0, 1.0],\n        [2.0, 0.0],\n    ])\n    edge_index, edge_attr = dense_to_sparse(adj)\n    assert edge_index.tolist() == [[0, 0, 1], [0, 1, 0]]\n    assert edge_attr.tolist() == [3, 1, 2]\n\n    if is_full_test():\n        jit = torch.jit.script(dense_to_sparse)\n        edge_index, edge_attr = jit(adj)\n        assert edge_index.tolist() == [[0, 0, 1], [0, 1, 0]]\n        assert edge_attr.tolist() == [3, 1, 2]\n\n    adj = torch.tensor([[\n        [3.0, 1.0],\n        [2.0, 0.0],\n    ], [\n        [0.0, 1.0],\n        [0.0, 2.0],\n    ]])\n    edge_index, edge_attr = dense_to_sparse(adj)\n    assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]]\n    assert edge_attr.tolist() == [3, 1, 2, 1, 2]\n\n    if is_full_test():\n        jit = torch.jit.script(dense_to_sparse)\n        edge_index, edge_attr = jit(adj)\n        assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]]\n        assert edge_attr.tolist() == [3, 1, 2, 1, 2]\n\n    adj = torch.tensor([\n        [\n            [3.0, 1.0, 0.0],\n            [2.0, 0.0, 0.0],\n            [0.0, 0.0, 0.0],\n        ],\n        [\n            [0.0, 1.0, 0.0],\n            [0.0, 2.0, 3.0],\n            [0.0, 5.0, 0.0],\n        ],\n    ])\n    mask = torch.tensor([[True, True, False], [True, True, True]])\n\n    edge_index, edge_attr = dense_to_sparse(adj, mask)\n\n    assert edge_index.tolist() == [[0, 0, 1, 2, 3, 3, 4],\n                                   [0, 1, 0, 3, 3, 4, 3]]\n    assert edge_attr.tolist() == [3, 1, 2, 1, 2, 3, 5]\n\n    if is_full_test():\n        jit = torch.jit.script(dense_to_sparse)\n        edge_index, edge_attr = jit(adj, mask)\n        assert edge_index.tolist() == [[0, 0, 1, 2, 3, 3, 4],\n                                       [0, 1, 0, 3, 3, 4, 3]]\n        assert edge_attr.tolist() == [3, 1, 2, 1, 2, 3, 5]\n\n\ndef test_dense_to_sparse_bipartite():\n    edge_index, edge_attr = dense_to_sparse(torch.rand(2, 10, 5))\n    assert edge_index[0].max() == 19\n    assert edge_index[1].max() == 9\n\n\ndef test_is_torch_sparse_tensor():\n    x = torch.randn(5, 5)\n\n    assert not is_torch_sparse_tensor(x)\n    assert is_torch_sparse_tensor(x.to_sparse())\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert not is_torch_sparse_tensor(SparseTensor.from_dense(x))\n\n\ndef test_is_sparse():\n    x = torch.randn(5, 5)\n\n    assert not is_sparse(x)\n    assert is_sparse(x.to_sparse())\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        assert is_sparse(SparseTensor.from_dense(x))\n\n\ndef test_to_torch_coo_tensor():\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3],\n        [1, 0, 2, 1, 3, 2],\n    ])\n    edge_attr = torch.randn(edge_index.size(1), 8)\n\n    adj = to_torch_coo_tensor(edge_index, is_coalesced=False)\n    assert adj.is_coalesced()\n    assert adj.size() == (4, 4)\n    assert adj.layout == torch.sparse_coo\n    assert torch.allclose(adj.indices(), edge_index)\n\n    adj = to_torch_coo_tensor(edge_index, is_coalesced=True)\n    assert adj.is_coalesced()\n    assert adj.size() == (4, 4)\n    assert adj.layout == torch.sparse_coo\n    assert torch.allclose(adj.indices(), edge_index)\n\n    adj = to_torch_coo_tensor(edge_index, size=6)\n    assert adj.size() == (6, 6)\n    assert adj.layout == torch.sparse_coo\n    assert torch.allclose(adj.indices(), edge_index)\n\n    adj = to_torch_coo_tensor(edge_index, edge_attr)\n    assert adj.size() == (4, 4, 8)\n    assert adj.layout == torch.sparse_coo\n    assert torch.allclose(adj.indices(), edge_index)\n    assert torch.allclose(adj.values(), edge_attr)\n\n    if is_full_test():\n        jit = torch.jit.script(to_torch_coo_tensor)\n        adj = jit(edge_index, edge_attr)\n        assert adj.size() == (4, 4, 8)\n        assert adj.layout == torch.sparse_coo\n        assert torch.allclose(adj.indices(), edge_index)\n        assert torch.allclose(adj.values(), edge_attr)\n\n\ndef test_to_torch_csr_tensor():\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3],\n        [1, 0, 2, 1, 3, 2],\n    ])\n\n    adj = to_torch_csr_tensor(edge_index)\n    assert adj.size() == (4, 4)\n    assert adj.layout == torch.sparse_csr\n    assert torch.allclose(adj.to_sparse_coo().coalesce().indices(), edge_index)\n\n    edge_weight = torch.randn(edge_index.size(1))\n    adj = to_torch_csr_tensor(edge_index, edge_weight)\n    assert adj.size() == (4, 4)\n    assert adj.layout == torch.sparse_csr\n    coo = adj.to_sparse_coo().coalesce()\n    assert torch.allclose(coo.indices(), edge_index)\n    assert torch.allclose(coo.values(), edge_weight)\n\n    if torch_geometric.typing.WITH_PT20:\n        edge_attr = torch.randn(edge_index.size(1), 8)\n        adj = to_torch_csr_tensor(edge_index, edge_attr)\n        assert adj.size() == (4, 4, 8)\n        assert adj.layout == torch.sparse_csr\n        coo = adj.to_sparse_coo().coalesce()\n        assert torch.allclose(coo.indices(), edge_index)\n        assert torch.allclose(coo.values(), edge_attr)\n\n\ndef test_to_torch_csc_tensor():\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3],\n        [1, 0, 2, 1, 3, 2],\n    ])\n\n    adj = to_torch_csc_tensor(edge_index)\n    assert adj.size() == (4, 4)\n    assert adj.layout == torch.sparse_csc\n    adj_coo = adj.to_sparse_coo().coalesce()\n    if torch_geometric.typing.WITH_PT20:\n        assert torch.allclose(adj_coo.indices(), edge_index)\n    else:\n        assert torch.allclose(adj_coo.indices().flip([0]), edge_index)\n\n    edge_weight = torch.randn(edge_index.size(1))\n    adj = to_torch_csc_tensor(edge_index, edge_weight)\n    assert adj.size() == (4, 4)\n    assert adj.layout == torch.sparse_csc\n    adj_coo = adj.to_sparse_coo().coalesce()\n    if torch_geometric.typing.WITH_PT20:\n        assert torch.allclose(adj_coo.indices(), edge_index)\n        assert torch.allclose(adj_coo.values(), edge_weight)\n    else:\n        perm = adj_coo.indices()[0].argsort()\n        assert torch.allclose(adj_coo.indices()[:, perm], edge_index)\n        assert torch.allclose(adj_coo.values()[perm], edge_weight)\n\n    if torch_geometric.typing.WITH_PT20:\n        edge_attr = torch.randn(edge_index.size(1), 8)\n        adj = to_torch_csc_tensor(edge_index, edge_attr)\n        assert adj.size() == (4, 4, 8)\n        assert adj.layout == torch.sparse_csc\n        assert torch.allclose(adj.to_sparse_coo().coalesce().indices(),\n                              edge_index)\n        assert torch.allclose(adj.to_sparse_coo().coalesce().values(),\n                              edge_attr)\n\n\n@withPackage('torch>=2.1.0')\ndef test_to_torch_coo_tensor_save_load(tmp_path):\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3],\n        [1, 0, 2, 1, 3, 2],\n    ])\n    adj = to_torch_coo_tensor(edge_index, is_coalesced=False)\n    assert adj.is_coalesced()\n\n    path = osp.join(tmp_path, 'adj.t')\n    torch.save(adj, path)\n    adj = fs.torch_load(path)\n    assert adj.is_coalesced()\n\n\ndef test_to_edge_index():\n    adj = torch.tensor([\n        [0., 1., 0., 0.],\n        [1., 0., 1., 0.],\n        [0., 1., 0., 1.],\n        [0., 0., 1., 0.],\n    ]).to_sparse()\n\n    edge_index, edge_attr = to_edge_index(adj)\n    assert edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]\n    assert edge_attr.tolist() == [1., 1., 1., 1., 1., 1.]\n\n    if is_full_test():\n        jit = torch.jit.script(to_edge_index)\n        edge_index, edge_attr = jit(adj)\n        assert edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]\n        assert edge_attr.tolist() == [1., 1., 1., 1., 1., 1.]\n\n\n@withCUDA\n@pytest.mark.parametrize(\n    'layout',\n    [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc],\n)\n@pytest.mark.parametrize('dim', [0, 1, (0, 1)])\ndef test_cat(layout, dim, device):\n    edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], device=device)\n    if torch_geometric.typing.WITH_PT20:\n        edge_weight = torch.rand(4, 2, device=device)\n    else:\n        edge_weight = torch.rand(4, device=device)\n\n    adj = to_torch_sparse_tensor(edge_index, edge_weight, layout=layout)\n\n    out = cat([adj, adj], dim=dim)\n    edge_index, edge_weight = to_edge_index(out.to_sparse_csr())\n\n    if dim == 0:\n        if torch_geometric.typing.WITH_PT20:\n            assert out.size() == (6, 3, 2)\n        else:\n            assert out.size() == (6, 3)\n        assert edge_index[0].tolist() == [0, 1, 1, 2, 3, 4, 4, 5]\n        assert edge_index[1].tolist() == [1, 0, 2, 1, 1, 0, 2, 1]\n    elif dim == 1:\n        if torch_geometric.typing.WITH_PT20:\n            assert out.size() == (3, 6, 2)\n        else:\n            assert out.size() == (3, 6)\n        assert edge_index[0].tolist() == [0, 0, 1, 1, 1, 1, 2, 2]\n        assert edge_index[1].tolist() == [1, 4, 0, 2, 3, 5, 1, 4]\n    else:\n        if torch_geometric.typing.WITH_PT20:\n            assert out.size() == (6, 6, 2)\n        else:\n            assert out.size() == (6, 6)\n        assert edge_index[0].tolist() == [0, 1, 1, 2, 3, 4, 4, 5]\n        assert edge_index[1].tolist() == [1, 0, 2, 1, 4, 3, 5, 4]\n\n\nif __name__ == '__main__':\n    import argparse\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    args = parser.parse_args()\n\n    num_nodes, num_edges = 10_000, 200_000\n    edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)\n\n    benchmark(\n        funcs=[\n            SparseTensor.from_edge_index, to_torch_coo_tensor,\n            to_torch_csr_tensor, to_torch_csc_tensor\n        ],\n        func_names=['SparseTensor', 'To COO', 'To CSR', 'To CSC'],\n        args=(edge_index, None, (num_nodes, num_nodes)),\n        num_steps=50 if args.device == 'cpu' else 500,\n        num_warmups=10 if args.device == 'cpu' else 100,\n    )\n"
  },
  {
    "path": "test/utils/test_spmm.py",
    "content": "import itertools\nimport warnings\n\nimport pytest\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.profile import benchmark\nfrom torch_geometric.testing import withCUDA, withPackage\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import spmm, to_torch_coo_tensor\n\n\n@withCUDA\n@pytest.mark.parametrize('reduce', ['sum', 'mean'])\ndef test_spmm_basic(device, reduce):\n    src = torch.randn(5, 4, device=device)\n    other = torch.randn(4, 8, device=device)\n\n    out1 = (src @ other) / (src.size(1) if reduce == 'mean' else 1)\n    out2 = spmm(src.to_sparse_csr(), other, reduce=reduce)\n    assert out1.size() == (5, 8)\n    assert torch.allclose(out1, out2, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        out3 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)\n        assert torch.allclose(out2, out3, atol=1e-6)\n\n    # Test `mean` reduction with isolated nodes:\n    src[0] = 0.\n    out1 = (src @ other) / (4. if reduce == 'mean' else 1.)\n    out2 = spmm(src.to_sparse_csr(), other, reduce=reduce)\n    assert out1.size() == (5, 8)\n    assert torch.allclose(out1, out2, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        out3 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)\n        assert torch.allclose(out2, out3, atol=1e-6)\n\n\n@withCUDA\n@withPackage('torch>=2.0.0')\n@pytest.mark.parametrize('reduce', ['min', 'max'])\ndef test_spmm_reduce(device, reduce):\n    src = torch.randn(5, 4, device=device)\n    other = torch.randn(4, 8, device=device)\n\n    if src.is_cuda:\n        with pytest.raises(NotImplementedError, match=\"not yet supported\"):\n            spmm(src.to_sparse_csr(), other, reduce)\n    else:\n        out1 = spmm(src.to_sparse_csr(), other, reduce)\n        assert out1.size() == (5, 8)\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            out2 = spmm(SparseTensor.from_dense(src), other, reduce=reduce)\n            assert torch.allclose(out1, out2)\n\n\n@withCUDA\n@withPackage('torch>=2.0.0')\n@pytest.mark.parametrize(\n    'layout', [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc])\n@pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max'])\ndef test_spmm_layout(device, layout, reduce):\n    src = torch.randn(5, 4, device=device)\n    if layout == torch.sparse_coo:\n        src = src.to_sparse_coo()\n    elif layout == torch.sparse_csr:\n        src = src.to_sparse_csr()\n    else:\n        assert layout == torch.sparse_csc\n        src = src.to_sparse_csc()\n    other = torch.randn(4, 8, device=device)\n\n    if src.is_cuda and reduce in {'min', 'max'}:\n        with pytest.raises(NotImplementedError, match=\"not yet supported\"):\n            spmm(src, other, reduce=reduce)\n    elif layout != torch.sparse_csr:\n        with pytest.warns(UserWarning, match=\"Converting sparse tensor\"):\n            spmm(src, other, reduce=reduce)\n    else:\n        spmm(src, other, reduce=reduce)\n\n\n@pytest.mark.parametrize('reduce', ['sum', 'mean'])\ndef test_spmm_jit(reduce):\n    @torch.jit.script\n    def jit_torch_sparse(src: SparseTensor, other: Tensor,\n                         reduce: str) -> Tensor:\n        return spmm(src, other, reduce=reduce)\n\n    @torch.jit.script\n    def jit_torch(src: Tensor, other: Tensor, reduce: str) -> Tensor:\n        return spmm(src, other, reduce=reduce)\n\n    src = torch.randn(5, 4)\n    other = torch.randn(4, 8)\n\n    out1 = src @ other\n    out2 = jit_torch(src.to_sparse_csr(), other, reduce)\n    assert out1.size() == (5, 8)\n    if reduce == 'sum':\n        assert torch.allclose(out1, out2, atol=1e-6)\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        out3 = jit_torch_sparse(SparseTensor.from_dense(src), other, reduce)\n        assert torch.allclose(out2, out3, atol=1e-6)\n\n\n@withCUDA\n@withPackage('torch>=2.0.0')\n@pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max'])\ndef test_spmm_edge_index(device, reduce):\n    src = EdgeIndex(\n        [[0, 1, 1, 2], [1, 0, 2, 1]],\n        sparse_size=(4, 3),\n        sort_order='row',\n        device=device,\n    )\n    other = torch.rand(3, 4, device=device)\n    out = spmm(src, other, reduce=reduce)\n    assert out.size() == (4, 4)\n\n    if not other.is_cuda or reduce not in ['min', 'max']:\n        out2 = spmm(src.to_sparse_csr(), other, reduce=reduce)\n        assert torch.allclose(out, out2)\n\n\nif __name__ == '__main__':\n    import argparse\n\n    warnings.filterwarnings('ignore', \".*Sparse CSR tensor support.*\")\n    warnings.filterwarnings('ignore', \".*Converting sparse tensor to CSR.*\")\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--device', type=str, default='cuda')\n    parser.add_argument('--backward', action='store_true')\n    args = parser.parse_args()\n\n    num_nodes, num_edges = 10_000, 200_000\n    x = torch.randn(num_nodes, 64, device=args.device)\n    edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)\n\n    reductions = ['sum', 'mean']\n    if not x.is_cuda:\n        reductions.extend(['min', 'max'])\n    layouts = [torch.sparse_coo, torch.sparse_csr, torch.sparse_csc]\n\n    for reduce, layout in itertools.product(reductions, layouts):\n        print(f'Aggregator: {reduce}, Layout: {layout}')\n\n        adj = to_torch_coo_tensor(edge_index, size=num_nodes)\n        adj = adj.to_sparse(layout=layout)\n\n        benchmark(\n            funcs=[spmm],\n            func_names=['spmm'],\n            args=(adj, x, reduce),\n            num_steps=50 if args.device == 'cpu' else 500,\n            num_warmups=10 if args.device == 'cpu' else 100,\n            backward=args.backward,\n        )\n"
  },
  {
    "path": "test/utils/test_subgraph.py",
    "content": "import torch\n\nfrom torch_geometric.nn import GCNConv, Linear\nfrom torch_geometric.testing import withDevice, withPackage\nfrom torch_geometric.utils import (\n    bipartite_subgraph,\n    get_num_hops,\n    index_to_mask,\n    k_hop_subgraph,\n    subgraph,\n)\n\n\ndef test_get_num_hops():\n    class GNN(torch.nn.Module):\n        def __init__(self):\n            super().__init__()\n            self.conv1 = GCNConv(3, 16, normalize=False)\n            self.conv2 = GCNConv(16, 16, normalize=False)\n            self.lin = Linear(16, 2)\n\n        def forward(self, x, edge_index):\n            x = torch.F.relu(self.conv1(x, edge_index))\n            x = self.conv2(x, edge_index)\n            return self.lin(x)\n\n    assert get_num_hops(GNN()) == 2\n\n\ndef test_subgraph():\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6],\n        [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5],\n    ])\n    edge_attr = torch.tensor(\n        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0])\n\n    idx = torch.tensor([3, 4, 5])\n    mask = index_to_mask(idx, 7)\n    indices = idx.tolist()\n\n    for subset in [idx, mask, indices]:\n        out = subgraph(subset, edge_index, edge_attr, return_edge_mask=True)\n        assert out[0].tolist() == [[3, 4, 4, 5], [4, 3, 5, 4]]\n        assert out[1].tolist() == [7.0, 8.0, 9.0, 10.0]\n        assert out[2].tolist() == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0]\n\n        out = subgraph(subset, edge_index, edge_attr, relabel_nodes=True)\n        assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n        assert out[1].tolist() == [7, 8, 9, 10]\n\n\n@withDevice\n@withPackage('pandas')\ndef test_subgraph_large_index(device):\n    subset = torch.tensor([50_000_000], device=device)\n    edge_index = torch.tensor([[50_000_000], [50_000_000]], device=device)\n    edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True)\n    assert edge_index.tolist() == [[0], [0]]\n\n\ndef test_bipartite_subgraph():\n    edge_index = torch.tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6],\n                               [0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]])\n    edge_attr = torch.tensor(\n        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0])\n    idx = (torch.tensor([2, 3, 5]), torch.tensor([2, 3]))\n    mask = (index_to_mask(idx[0], 7), index_to_mask(idx[1], 4))\n    indices = (idx[0].tolist(), idx[1].tolist())\n    mixed = (mask[0], idx[1])\n\n    for subset in [idx, mask, indices, mixed]:\n        out = bipartite_subgraph(subset, edge_index, edge_attr,\n                                 return_edge_mask=True)\n        assert out[0].tolist() == [[2, 3, 5, 5], [3, 2, 2, 3]]\n        assert out[1].tolist() == [3.0, 4.0, 9.0, 10.0]\n        assert out[2].tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0]\n\n        out = bipartite_subgraph(subset, edge_index, edge_attr,\n                                 relabel_nodes=True)\n        assert out[0].tolist() == [[0, 1, 2, 2], [1, 0, 0, 1]]\n        assert out[1].tolist() == [3.0, 4.0, 9.0, 10.0]\n\n\n@withDevice\n@withPackage('pandas')\ndef test_bipartite_subgraph_large_index(device):\n    subset = torch.tensor([50_000_000], device=device)\n    edge_index = torch.tensor([[50_000_000], [50_000_000]], device=device)\n\n    edge_index, _ = bipartite_subgraph(\n        (subset, subset),\n        edge_index,\n        relabel_nodes=True,\n    )\n    assert edge_index.tolist() == [[0], [0]]\n\n\ndef test_k_hop_subgraph():\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 4, 5],\n        [2, 2, 4, 4, 6, 6],\n    ])\n    subset, edge_index, mapping, edge_mask = k_hop_subgraph(\n        node_idx=6,\n        num_hops=2,\n        edge_index=edge_index,\n        relabel_nodes=True,\n    )\n    assert subset.tolist() == [2, 3, 4, 5, 6]\n    assert edge_index.tolist() == [[0, 1, 2, 3], [2, 2, 4, 4]]\n    assert mapping.tolist() == [4]\n    assert edge_mask.tolist() == [False, False, True, True, True, True]\n\n    edge_index = torch.tensor([\n        [1, 2, 4, 5],\n        [0, 1, 5, 6],\n    ])\n    subset, edge_index, mapping, edge_mask = k_hop_subgraph(\n        node_idx=[0, 6],\n        num_hops=2,\n        edge_index=edge_index,\n        relabel_nodes=True,\n    )\n    assert subset.tolist() == [0, 1, 2, 4, 5, 6]\n    assert edge_index.tolist() == [[1, 2, 3, 4], [0, 1, 4, 5]]\n    assert mapping.tolist() == [0, 5]\n    assert edge_mask.tolist() == [True, True, True, True]\n\n    edge_index = torch.tensor([\n        [0, 1, 2, 3, 4, 4, 5],\n        [2, 2, 4, 4, 2, 6, 6],\n    ])\n    subset, edge_index, mapping, edge_mask = k_hop_subgraph(\n        node_idx=6,\n        num_hops=2,\n        edge_index=edge_index,\n        relabel_nodes=False,\n        directed=True,\n    )\n    assert subset.tolist() == [2, 3, 4, 5, 6]\n    assert edge_index.tolist() == [[2, 3, 4, 5], [4, 4, 6, 6]]\n    assert mapping.tolist() == [4]\n    assert edge_mask.tolist() == [False, False, True, True, False, True, True]\n"
  },
  {
    "path": "test/utils/test_to_dense_adj.py",
    "content": "import torch\n\nfrom torch_geometric.testing import is_full_test\nfrom torch_geometric.utils import to_dense_adj\n\n\ndef test_to_dense_adj():\n    edge_index = torch.tensor([\n        [0, 0, 1, 2, 3, 4],\n        [0, 1, 0, 3, 4, 2],\n    ])\n    batch = torch.tensor([0, 0, 1, 1, 1])\n\n    adj = to_dense_adj(edge_index, batch)\n    assert adj.size() == (2, 3, 3)\n    assert adj[0].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]]\n    assert adj[1].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]]\n\n    if is_full_test():\n        jit = torch.jit.script(to_dense_adj)\n        adj = jit(edge_index, batch)\n        assert adj.size() == (2, 3, 3)\n        assert adj[0].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]]\n        assert adj[1].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]]\n\n    adj = to_dense_adj(edge_index, batch, max_num_nodes=2)\n    assert adj.size() == (2, 2, 2)\n    assert adj[0].tolist() == [[1, 1], [1, 0]]\n    assert adj[1].tolist() == [[0, 1], [0, 0]]\n\n    adj = to_dense_adj(edge_index, batch, max_num_nodes=5)\n    assert adj.size() == (2, 5, 5)\n    assert adj[0][:3, :3].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]]\n    assert adj[1][:3, :3].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]]\n\n    edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])\n    adj = to_dense_adj(edge_index, batch, edge_attr)\n    assert adj.size() == (2, 3, 3)\n    assert adj[0].tolist() == [[1, 2, 0], [3, 0, 0], [0, 0, 0]]\n    assert adj[1].tolist() == [[0, 4, 0], [0, 0, 5], [6, 0, 0]]\n\n    adj = to_dense_adj(edge_index, batch, edge_attr, max_num_nodes=5)\n    assert adj.size() == (2, 5, 5)\n    assert adj[0][:3, :3].tolist() == [[1, 2, 0], [3, 0, 0], [0, 0, 0]]\n    assert adj[1][:3, :3].tolist() == [[0, 4, 0], [0, 0, 5], [6, 0, 0]]\n\n    edge_attr = edge_attr.view(-1, 1)\n    adj = to_dense_adj(edge_index, batch, edge_attr)\n    assert adj.size() == (2, 3, 3, 1)\n\n    edge_attr = edge_attr.view(-1, 1)\n    adj = to_dense_adj(edge_index, batch, edge_attr, max_num_nodes=5)\n    assert adj.size() == (2, 5, 5, 1)\n\n    adj = to_dense_adj(edge_index)\n    assert adj.size() == (1, 5, 5)\n    assert adj[0].nonzero(as_tuple=False).t().tolist() == edge_index.tolist()\n\n    adj = to_dense_adj(edge_index, max_num_nodes=10)\n    assert adj.size() == (1, 10, 10)\n    assert adj[0].nonzero(as_tuple=False).t().tolist() == edge_index.tolist()\n\n    adj = to_dense_adj(edge_index, batch, batch_size=4)\n    assert adj.size() == (4, 3, 3)\n\n\ndef test_to_dense_adj_with_empty_edge_index():\n    edge_index = torch.tensor([[], []], dtype=torch.long)\n    batch = torch.tensor([0, 0, 1, 1, 1])\n\n    adj = to_dense_adj(edge_index)\n    assert adj.size() == (1, 0, 0)\n\n    adj = to_dense_adj(edge_index, max_num_nodes=10)\n    assert adj.size() == (1, 10, 10) and adj.sum() == 0\n\n    adj = to_dense_adj(edge_index, batch)\n    assert adj.size() == (2, 3, 3) and adj.sum() == 0\n\n    adj = to_dense_adj(edge_index, batch, max_num_nodes=10)\n    assert adj.size() == (2, 10, 10) and adj.sum() == 0\n\n\ndef test_to_dense_adj_with_duplicate_entries():\n    edge_index = torch.tensor([\n        [0, 0, 0, 1, 2, 3, 3, 4],\n        [0, 0, 1, 0, 3, 4, 4, 2],\n    ])\n    batch = torch.tensor([0, 0, 1, 1, 1])\n\n    adj = to_dense_adj(edge_index, batch)\n    assert adj.size() == (2, 3, 3)\n    assert adj[0].tolist() == [[2, 1, 0], [1, 0, 0], [0, 0, 0]]\n    assert adj[1].tolist() == [[0, 1, 0], [0, 0, 2], [1, 0, 0]]\n\n    edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])\n    adj = to_dense_adj(edge_index, batch, edge_attr)\n    assert adj.size() == (2, 3, 3)\n    assert adj[0].tolist() == [\n        [3.0, 3.0, 0.0],\n        [4.0, 0.0, 0.0],\n        [0.0, 0.0, 0.0],\n    ]\n    assert adj[1].tolist() == [\n        [0.0, 5.0, 0.0],\n        [0.0, 0.0, 13.0],\n        [8.0, 0.0, 0.0],\n    ]\n"
  },
  {
    "path": "test/utils/test_to_dense_batch.py",
    "content": "from typing import Tuple\n\nimport pytest\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.experimental import set_experimental_mode\nfrom torch_geometric.testing import onlyFullTest\nfrom torch_geometric.utils import to_dense_batch\n\n\n@pytest.mark.parametrize('fill', [70.0, torch.tensor(49.0)])\ndef test_to_dense_batch(fill):\n    x = torch.tensor([\n        [1.0, 2.0],\n        [3.0, 4.0],\n        [5.0, 6.0],\n        [7.0, 8.0],\n        [9.0, 10.0],\n        [11.0, 12.0],\n    ])\n    batch = torch.tensor([0, 0, 1, 2, 2, 2])\n\n    item = fill.item() if isinstance(fill, Tensor) else fill\n    expected = torch.tensor([\n        [[1.0, 2.0], [3.0, 4.0], [item, item]],\n        [[5.0, 6.0], [item, item], [item, item]],\n        [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]],\n    ])\n\n    out, mask = to_dense_batch(x, batch, fill_value=fill)\n    assert out.size() == (3, 3, 2)\n    assert torch.equal(out, expected)\n    assert mask.tolist() == [[1, 1, 0], [1, 0, 0], [1, 1, 1]]\n\n    out, mask = to_dense_batch(x, batch, max_num_nodes=2, fill_value=fill)\n    assert out.size() == (3, 2, 2)\n    assert torch.equal(out, expected[:, :2])\n    assert mask.tolist() == [[1, 1], [1, 0], [1, 1]]\n\n    out, mask = to_dense_batch(x, batch, max_num_nodes=5, fill_value=fill)\n    assert out.size() == (3, 5, 2)\n    assert torch.equal(out[:, :3], expected)\n    assert mask.tolist() == [[1, 1, 0, 0, 0], [1, 0, 0, 0, 0], [1, 1, 1, 0, 0]]\n\n    out, mask = to_dense_batch(x, fill_value=fill)\n    assert out.size() == (1, 6, 2)\n    assert torch.equal(out[0], x)\n    assert mask.tolist() == [[1, 1, 1, 1, 1, 1]]\n\n    out, mask = to_dense_batch(x, max_num_nodes=2, fill_value=fill)\n    assert out.size() == (1, 2, 2)\n    assert torch.equal(out[0], x[:2])\n    assert mask.tolist() == [[1, 1]]\n\n    out, mask = to_dense_batch(x, max_num_nodes=10, fill_value=fill)\n    assert out.size() == (1, 10, 2)\n    assert torch.equal(out[0, :6], x)\n    assert mask.tolist() == [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]]\n\n    out, mask = to_dense_batch(x, batch, batch_size=4, fill_value=fill)\n    assert out.size() == (4, 3, 2)\n\n\ndef test_to_dense_batch_disable_dynamic_shapes():\n    x = torch.tensor([\n        [1.0, 2.0],\n        [3.0, 4.0],\n        [5.0, 6.0],\n        [7.0, 8.0],\n        [9.0, 10.0],\n        [11.0, 12.0],\n    ])\n    batch = torch.tensor([0, 0, 1, 2, 2, 2])\n\n    with set_experimental_mode(True, 'disable_dynamic_shapes'):\n        with pytest.raises(ValueError, match=\"'batch_size' needs to be set\"):\n            out, mask = to_dense_batch(x, batch, max_num_nodes=6)\n        with pytest.raises(ValueError, match=\"'max_num_nodes' needs to be\"):\n            out, mask = to_dense_batch(x, batch, batch_size=4)\n        with pytest.raises(ValueError, match=\"'batch_size' needs to be set\"):\n            out, mask = to_dense_batch(x)\n\n        out, mask = to_dense_batch(x, batch_size=1, max_num_nodes=6)\n        assert out.size() == (1, 6, 2)\n        assert mask.size() == (1, 6)\n\n        out, mask = to_dense_batch(x, batch, batch_size=3, max_num_nodes=10)\n        assert out.size() == (3, 10, 2)\n        assert mask.size() == (3, 10)\n\n\n@onlyFullTest\ndef test_to_dense_batch_jit():\n    @torch.jit.script\n    def to_dense_batch_jit(\n        x: Tensor,\n        batch: Tensor,\n        fill_value: Tensor,\n    ) -> Tuple[Tensor, Tensor]:\n        return to_dense_batch(x, batch, fill_value=fill_value)\n\n    x = torch.randn(6, 2)\n    batch = torch.tensor([0, 0, 1, 2, 2, 2])\n\n    out, mask = to_dense_batch_jit(x, batch, fill_value=torch.tensor(0.0))\n    assert out.size() == (3, 3, 2)\n    assert mask.size() == (3, 3)\n"
  },
  {
    "path": "test/utils/test_total_influence.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.nn import GCNConv\nfrom torch_geometric.utils import total_influence\n\n\nclass GNN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = GCNConv(5, 6)\n        self.conv2 = GCNConv(6, 7)\n\n    def forward(self, x0, edge_index):\n        x1 = self.conv1(x0, edge_index)\n        x2 = self.conv2(x1, edge_index)\n        return x2\n\n\ndef test_total_influence_smoke():\n    x = torch.randn(6, 5)\n    edge_index = torch.tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]])\n    max_hops = 2\n    num_samples = 4\n    data = Data(\n        x=x,\n        edge_index=edge_index,\n    )\n    model = GNN()\n    I, R = total_influence(\n        model,\n        data,\n        max_hops=max_hops,\n        num_samples=num_samples,\n    )\n\n    assert I.shape == (max_hops + 1, )\n    assert 0.0 <= R <= max_hops\n\n    I, R = total_influence(\n        model,\n        data,\n        max_hops=max_hops,\n        num_samples=num_samples,\n        average=False,\n    )\n    assert I.shape == torch.Size([num_samples, max_hops + 1])\n"
  },
  {
    "path": "test/utils/test_train_test_split_edges.py",
    "content": "import pytest\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.utils import train_test_split_edges\n\n\ndef test_train_test_split_edges():\n    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n                               [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])\n    edge_attr = torch.arange(edge_index.size(1))\n    data = Data(edge_index=edge_index, edge_attr=edge_attr)\n    data.num_nodes = edge_index.max().item() + 1\n\n    with pytest.warns(UserWarning, match='deprecated'):\n        data = train_test_split_edges(data, val_ratio=0.2, test_ratio=0.3)\n\n    assert len(data) == 10\n    assert data.val_pos_edge_index.size() == (2, 2)\n    assert data.val_neg_edge_index.size() == (2, 2)\n    assert data.test_pos_edge_index.size() == (2, 3)\n    assert data.test_neg_edge_index.size() == (2, 3)\n    assert data.train_pos_edge_index.size() == (2, 10)\n    assert data.train_neg_adj_mask.size() == (11, 11)\n    assert data.train_neg_adj_mask.sum().item() == (11**2 - 11) / 2 - 4 - 6 - 5\n    assert data.train_pos_edge_attr.size() == (10, )\n    assert data.val_pos_edge_attr.size() == (2, )\n    assert data.test_pos_edge_attr.size() == (3, )\n"
  },
  {
    "path": "test/utils/test_tree_decomposition.py",
    "content": "import pytest\n\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.utils import tree_decomposition\n\n\n@withPackage('rdkit')\n@pytest.mark.parametrize('smiles', [\n    r'F/C=C/F',\n    r'C/C(=C\\C(=O)c1ccc(C)o1)Nc1ccc2c(c1)OCO2',\n])\ndef test_tree_decomposition(smiles):\n    from rdkit import Chem\n    mol = Chem.MolFromSmiles(smiles)\n    tree_decomposition(mol)  # TODO Test output\n"
  },
  {
    "path": "test/utils/test_trim_to_layer.py",
    "content": "from typing import List, Optional\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import NeighborLoader\nfrom torch_geometric.nn import GraphConv\nfrom torch_geometric.testing import withPackage\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import trim_to_layer\nfrom torch_geometric.utils._trim_to_layer import trim_sparse_tensor\n\n\n@withPackage('torch_sparse')\ndef test_trim_sparse_tensor():\n    edge_index = torch.tensor([[0, 0, 1, 2], [1, 2, 3, 4]])\n    adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[5, 5])\n\n    adj = trim_sparse_tensor(adj, size=(3, 3), num_seed_nodes=1)\n\n    row, col, _ = adj.coo()\n    assert row.tolist() == [0, 0]\n    assert col.tolist() == [1, 2]\n\n\ndef test_trim_to_layer_basic():\n    x0 = torch.arange(4)\n    edge_index0 = torch.tensor([[1, 2, 3], [0, 1, 2]])\n    edge_weight0 = torch.arange(3)\n\n    num_sampled_nodes_per_hop = [1, 1, 1]\n    num_sampled_edges_per_hop = [1, 1, 1]\n\n    x1, edge_index1, edge_weight1 = trim_to_layer(\n        layer=0,\n        num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,\n        num_sampled_edges_per_hop=num_sampled_edges_per_hop,\n        x=x0,\n        edge_index=edge_index0,\n        edge_attr=edge_weight0,\n    )\n    assert torch.equal(x1, torch.arange(4))\n    assert edge_index1.tolist() == [[1, 2, 3], [0, 1, 2]]\n    assert torch.equal(edge_weight1, torch.arange(3))\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj0 = SparseTensor.from_edge_index(edge_index0, edge_weight0, (4, 4))\n        x1, adj_t1, _ = trim_to_layer(\n            layer=0,\n            num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,\n            num_sampled_edges_per_hop=num_sampled_edges_per_hop,\n            x=x0,\n            edge_index=adj0.t(),\n            edge_attr=edge_weight0,\n        )\n        adj1 = adj_t1.t()\n        assert adj1.sizes() == [4, 4]\n\n        row, col, value = adj1.coo()\n        assert torch.equal(x1, torch.arange(4))\n        assert row.tolist() == [1, 2, 3]\n        assert col.tolist() == [0, 1, 2]\n        assert torch.equal(value, torch.arange(3))\n\n    x2, edge_index2, edge_weight2 = trim_to_layer(\n        layer=1,\n        num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,\n        num_sampled_edges_per_hop=num_sampled_edges_per_hop,\n        x=x1,\n        edge_index=edge_index1,\n        edge_attr=edge_weight1,\n    )\n    assert torch.equal(x2, torch.arange(3))\n    assert edge_index2.tolist() == [[1, 2], [0, 1]]\n    assert torch.equal(edge_weight2, torch.arange(2))\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj1 = SparseTensor.from_edge_index(edge_index1, edge_weight1, (4, 4))\n        x2, adj_t2, _ = trim_to_layer(\n            layer=1,\n            num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,\n            num_sampled_edges_per_hop=num_sampled_edges_per_hop,\n            x=x1,\n            edge_index=adj1.t(),\n        )\n        adj2 = adj_t2.t()\n        assert adj2.sizes() == [3, 3]\n\n        row, col, value = adj2.coo()\n        assert torch.equal(x2, torch.arange(3))\n        assert row.tolist() == [1, 2]\n        assert col.tolist() == [0, 1]\n        assert torch.equal(value, torch.arange(2))\n\n    x3, edge_index3, edge_weight3 = trim_to_layer(\n        layer=2,\n        num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,\n        num_sampled_edges_per_hop=num_sampled_edges_per_hop,\n        x=x2,\n        edge_index=edge_index2,\n        edge_attr=edge_weight2,\n    )\n    assert torch.equal(x3, torch.arange(2))\n    assert edge_index3.tolist() == [[1], [0]]\n    assert torch.equal(edge_weight3, torch.arange(1))\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE:\n        adj2 = SparseTensor.from_edge_index(edge_index2, edge_weight2, (3, 3))\n        x3, adj_t3, _ = trim_to_layer(\n            layer=2,\n            num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,\n            num_sampled_edges_per_hop=num_sampled_edges_per_hop,\n            x=x2,\n            edge_index=adj2.t(),\n        )\n        adj3 = adj_t3.t()\n        assert adj3.sizes() == [2, 2]\n\n        row, col, value = adj3.coo()\n        assert torch.equal(x3, torch.arange(2))\n        assert row.tolist() == [1]\n        assert col.tolist() == [0]\n        assert torch.equal(value, torch.arange(1))\n\n\ndef test_trim_to_layer_hetero():\n    x = {'v': torch.arange(4)}\n    edge_index = {('v', 'to', 'v'): torch.tensor([[1, 2, 3], [0, 1, 2]])}\n    edge_weight = {('v', 'to', 'v'): torch.arange(3)}\n\n    num_sampled_nodes_per_hop = {'v': [1, 1, 1, 1]}\n    num_sampled_edges_per_hop = {('v', 'to', 'v'): [1, 1, 1]}\n\n    x, edge_index, edge_weight = trim_to_layer(\n        layer=1,\n        num_sampled_nodes_per_hop=num_sampled_nodes_per_hop,\n        num_sampled_edges_per_hop=num_sampled_edges_per_hop,\n        x=x,\n        edge_index=edge_index,\n        edge_attr=edge_weight,\n    )\n    assert torch.equal(x['v'], torch.arange(3))\n    assert edge_index['v', 'to', 'v'].tolist() == [[1, 2], [0, 1]]\n    assert torch.equal(edge_weight['v', 'to', 'v'], torch.arange(2))\n\n\nclass GNN(torch.nn.Module):\n    def __init__(self, num_layers: int):\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList(\n            GraphConv(16, 16) for _ in range(num_layers))\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        edge_weight: Tensor,\n        num_sampled_nodes: Optional[List[int]] = None,\n        num_sampled_edges: Optional[List[int]] = None,\n    ) -> Tensor:\n        for i, conv in enumerate(self.convs):\n            if num_sampled_nodes is not None:\n                x, edge_index, edge_weight = trim_to_layer(\n                    i, num_sampled_nodes, num_sampled_edges, x, edge_index,\n                    edge_weight)\n            x = conv(x, edge_index, edge_weight)\n        return x\n\n\n@withPackage('pyg_lib')\ndef test_trim_to_layer_with_neighbor_loader():\n    x = torch.randn(14, 16)\n    edge_index = torch.tensor([\n        [2, 3, 4, 5, 7, 7, 10, 11, 12, 13],\n        [0, 1, 2, 3, 2, 3, 7, 7, 7, 7],\n    ])\n    edge_weight = torch.rand(edge_index.size(1))\n    data = Data(x=x, edge_index=edge_index, edge_weight=edge_weight)\n\n    loader = NeighborLoader(\n        data,\n        num_neighbors=[1, 2, 4],\n        batch_size=2,\n        shuffle=False,\n    )\n    batch = next(iter(loader))\n\n    model = GNN(num_layers=3)\n    out1 = model(batch.x, batch.edge_index, batch.edge_weight)[:2]\n    assert out1.size() == (2, 16)\n\n    out2 = model(batch.x, batch.edge_index, batch.edge_weight,\n                 batch.num_sampled_nodes, batch.num_sampled_edges)[:2]\n    assert out2.size() == (2, 16)\n\n    assert torch.allclose(out1, out2, atol=1e-6)\n"
  },
  {
    "path": "test/utils/test_unbatch.py",
    "content": "import torch\n\nfrom torch_geometric.utils import unbatch, unbatch_edge_index\n\n\ndef test_unbatch():\n    src = torch.arange(10)\n    batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 3, 4, 4])\n\n    out = unbatch(src, batch)\n    assert len(out) == 5\n    for i in range(len(out)):\n        assert torch.equal(out[i], src[batch == i])\n\n\ndef test_unbatch_edge_index():\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 4, 5, 5, 6],\n        [1, 0, 2, 1, 3, 2, 5, 4, 6, 5],\n    ])\n    batch = torch.tensor([0, 0, 0, 0, 1, 1, 1])\n\n    edge_indices = unbatch_edge_index(edge_index, batch)\n    assert edge_indices[0].tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]\n    assert edge_indices[1].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n"
  },
  {
    "path": "test/utils/test_undirected.py",
    "content": "import torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import is_undirected, to_undirected\n\n\ndef test_is_undirected():\n    row = torch.tensor([0, 1, 0])\n    col = torch.tensor([1, 0, 0])\n    sym_weight = torch.tensor([0, 0, 1])\n    asym_weight = torch.tensor([0, 1, 1])\n\n    assert is_undirected(torch.stack([row, col], dim=0))\n    assert is_undirected(torch.stack([row, col], dim=0), sym_weight)\n    assert not is_undirected(torch.stack([row, col], dim=0), asym_weight)\n\n    row = torch.tensor([0, 1, 1])\n    col = torch.tensor([1, 0, 2])\n\n    assert not is_undirected(torch.stack([row, col], dim=0))\n\n    @torch.jit.script\n    def jit(edge_index: Tensor) -> bool:\n        return is_undirected(edge_index)\n\n    assert not jit(torch.stack([row, col], dim=0))\n\n\ndef test_to_undirected():\n    row = torch.tensor([0, 1, 1])\n    col = torch.tensor([1, 0, 2])\n\n    edge_index = to_undirected(torch.stack([row, col], dim=0))\n    assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]]\n\n    @torch.jit.script\n    def jit(edge_index: Tensor) -> Tensor:\n        return to_undirected(edge_index)\n\n    assert torch.equal(jit(torch.stack([row, col], dim=0)), edge_index)\n"
  },
  {
    "path": "test/visualization/test_graph_visualization.py",
    "content": "import os.path as osp\n\nimport pytest\nimport torch\n\nfrom torch_geometric.testing import onlyGraphviz, withPackage\nfrom torch_geometric.visualization import visualize_graph\n\n\n@onlyGraphviz\n@pytest.mark.parametrize('backend', [None, 'graphviz'])\ndef test_visualize_graph_via_graphviz(tmp_path, backend):\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4],\n        [1, 0, 2, 1, 3, 2, 4, 3],\n    ])\n    edge_weight = (torch.rand(edge_index.size(1)) > 0.5).float()\n\n    path = osp.join(tmp_path, 'graph.pdf')\n    visualize_graph(edge_index, edge_weight, path, backend)\n    assert osp.exists(path)\n\n\n@onlyGraphviz\n@pytest.mark.parametrize('backend', [None, 'graphviz'])\ndef test_visualize_graph_via_graphviz_with_node_labels(tmp_path, backend):\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4],\n        [1, 0, 2, 1, 3, 2, 4, 3],\n    ])\n    edge_weight = (torch.rand(edge_index.size(1)) > 0.5).float()\n    node_labels = ['A', 'B', 'C', 'D', 'E']\n\n    path = osp.join(tmp_path, 'graph.pdf')\n    visualize_graph(edge_index, edge_weight, path, backend, node_labels)\n    assert osp.exists(path)\n\n\n@withPackage('networkx', 'matplotlib')\n@pytest.mark.parametrize('backend', [None, 'networkx'])\ndef test_visualize_graph_via_networkx(tmp_path, backend):\n    edge_index = torch.tensor([\n        [0, 1, 1, 2, 2, 3, 3, 4],\n        [1, 0, 2, 1, 3, 2, 4, 3],\n    ])\n    edge_weight = (torch.rand(edge_index.size(1)) > 0.5).float()\n\n    path = osp.join(tmp_path, 'graph.pdf')\n    visualize_graph(edge_index, edge_weight, path, backend)\n    assert osp.exists(path)\n"
  },
  {
    "path": "test/visualization/test_influence.py",
    "content": "import torch\n\nfrom torch_geometric.datasets import KarateClub\nfrom torch_geometric.nn import GCNConv\nfrom torch_geometric.visualization import influence\n\n\nclass Net(torch.nn.Module):\n    def __init__(self, in_channels, out_channels):\n        super().__init__()\n        self.conv1 = GCNConv(in_channels, out_channels)\n        self.conv2 = GCNConv(out_channels, out_channels)\n\n    def forward(self, x, edge_index):\n        x = torch.nn.functional.relu(self.conv1(x, edge_index))\n        x = self.conv2(x, edge_index)\n        return x\n\n\ndef test_influence():\n    data = KarateClub()[0]\n    x = torch.randn(data.num_nodes, 8)\n\n    out = influence(Net(x.size(1), 16), x, data.edge_index)\n    assert out.size() == (data.num_nodes, data.num_nodes)\n    assert torch.allclose(out.sum(dim=-1), torch.ones(data.num_nodes),\n                          atol=1e-04)\n"
  },
  {
    "path": "torch_geometric/__init__.py",
    "content": "from collections import defaultdict\n\nimport torch\nimport torch_geometric.typing\n\nfrom ._compile import compile, is_compiling\nfrom ._onnx import is_in_onnx_export, safe_onnx_export\nfrom .index import Index\nfrom .edge_index import EdgeIndex\nfrom .hash_tensor import HashTensor\nfrom .seed import seed_everything\nfrom .home import get_home_dir, set_home_dir\nfrom .device import is_mps_available, is_xpu_available, device\nfrom .isinstance import is_torch_instance\nfrom .debug import is_debug_enabled, debug, set_debug\n\nimport torch_geometric.utils\nimport torch_geometric.data\nimport torch_geometric.sampler\nimport torch_geometric.loader\nimport torch_geometric.transforms\nimport torch_geometric.datasets\nimport torch_geometric.nn\nimport torch_geometric.explain\nimport torch_geometric.profile\n\nfrom .experimental import (is_experimental_mode_enabled, experimental_mode,\n                           set_experimental_mode)\nfrom .lazy_loader import LazyLoader\n\ncontrib = LazyLoader('contrib', globals(), 'torch_geometric.contrib')\ngraphgym = LazyLoader('graphgym', globals(), 'torch_geometric.graphgym')\n\n__version__ = '2.8.0'\n\n__all__ = [\n    'Index',\n    'EdgeIndex',\n    'HashTensor',\n    'seed_everything',\n    'get_home_dir',\n    'set_home_dir',\n    'compile',\n    'is_compiling',\n    'is_in_onnx_export',\n    'safe_onnx_export',\n    'is_mps_available',\n    'is_xpu_available',\n    'device',\n    'is_torch_instance',\n    'is_debug_enabled',\n    'debug',\n    'set_debug',\n    'is_experimental_mode_enabled',\n    'experimental_mode',\n    'set_experimental_mode',\n    'torch_geometric',\n    '__version__',\n]\n\nif not torch_geometric.typing.WITH_PT113:\n    import warnings as std_warnings\n\n    std_warnings.warn(\n        \"PyG 2.7 removed support for PyTorch < 1.13. Consider \"\n        \"Consider upgrading to PyTorch >= 1.13 or downgrading \"\n        \"to PyG <= 2.6. \", stacklevel=2)\n\n# Serialization ###############################################################\n\nif torch_geometric.typing.WITH_PT24:\n    torch.serialization.add_safe_globals([\n        dict,\n        list,\n        defaultdict,\n        Index,\n        torch_geometric.index.CatMetadata,\n        EdgeIndex,\n        torch_geometric.edge_index.SortOrder,\n        torch_geometric.edge_index.CatMetadata,\n        HashTensor,\n    ])\n"
  },
  {
    "path": "torch_geometric/_compile.py",
    "content": "import warnings\nfrom typing import Any, Callable, Optional, Union\n\nimport torch\n\nimport torch_geometric.typing\n\n\ndef is_compiling() -> bool:\n    r\"\"\"Returns :obj:`True` in case :pytorch:`PyTorch` is compiling via\n    :meth:`torch.compile`.\n    \"\"\"\n    if torch_geometric.typing.WITH_PT23:\n        return torch.compiler.is_compiling()\n    if torch_geometric.typing.WITH_PT21:\n        return torch._dynamo.is_compiling()\n    return False  # pragma: no cover\n\n\ndef compile(\n    model: Optional[torch.nn.Module] = None,\n    *args: Any,\n    **kwargs: Any,\n) -> Union[torch.nn.Module, Callable[[torch.nn.Module], torch.nn.Module]]:\n    r\"\"\"Optimizes the given :pyg:`PyG` model/function via\n    :meth:`torch.compile`.\n    This function has the same signature as :meth:`torch.compile` (see\n    `here <https://pytorch.org/docs/stable/generated/torch.compile.html>`__).\n\n    Args:\n        model: The model to compile.\n        *args: Additional arguments of :meth:`torch.compile`.\n        **kwargs: Additional keyword arguments of :meth:`torch.compile`.\n\n    .. note::\n        :meth:`torch_geometric.compile` is deprecated in favor of\n        :meth:`torch.compile`.\n    \"\"\"\n    warnings.warn(\n        \"'torch_geometric.compile' is deprecated in favor of \"\n        \"'torch.compile'\", stacklevel=2)\n    return torch.compile(model, *args, **kwargs)  # type: ignore\n"
  },
  {
    "path": "torch_geometric/_onnx.py",
    "content": "import warnings\nfrom os import PathLike\nfrom typing import Any, Union\n\nimport torch\n\nfrom torch_geometric import is_compiling\n\n\ndef is_in_onnx_export() -> bool:\n    r\"\"\"Returns :obj:`True` in case :pytorch:`PyTorch` is exporting to ONNX via\n    :meth:`torch.onnx.export`.\n    \"\"\"\n    if is_compiling():\n        return False\n    if torch.jit.is_scripting():\n        return False\n    return torch.onnx.is_in_onnx_export()\n\n\ndef safe_onnx_export(\n    model: torch.nn.Module,\n    args: Union[torch.Tensor, tuple[Any, ...]],\n    f: Union[str, PathLike[Any], None],\n    skip_on_error: bool = False,\n    **kwargs: Any,\n) -> bool:\n    r\"\"\"A safe wrapper around :meth:`torch.onnx.export` that handles known\n    ONNX serialization issues in PyTorch Geometric.\n\n    This function provides workarounds for the ``onnx_ir.serde.SerdeError``\n    with boolean ``allowzero`` attributes that occurs in certain environments.\n\n    Args:\n        model (torch.nn.Module): The model to export.\n        args (torch.Tensor or tuple): The input arguments for the model.\n        f (str or PathLike): The file path to save the model.\n        skip_on_error (bool): If True, return False instead of raising when\n            workarounds fail. Useful for CI environments.\n        **kwargs: Additional arguments passed to :meth:`torch.onnx.export`.\n\n    Returns:\n        bool: True if export succeeded, False if skipped due to known issues\n              (only when skip_on_error=True).\n\n    Example:\n        >>> from torch_geometric.nn import SAGEConv\n        >>> from torch_geometric import safe_onnx_export\n        >>>\n        >>> class MyModel(torch.nn.Module):\n        ...     def __init__(self):\n        ...         super().__init__()\n        ...         self.conv = SAGEConv(8, 16)\n        ...     def forward(self, x, edge_index):\n        ...         return self.conv(x, edge_index)\n        >>>\n        >>> model = MyModel()\n        >>> x = torch.randn(3, 8)\n        >>> edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]])\n        >>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx')\n        >>>\n        >>> # For CI environments:\n        >>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx',\n        ...                             skip_on_error=True)\n        >>> if not success:\n        ...     print(\"ONNX export skipped due to known upstream issue\")\n    \"\"\"\n    # Convert single tensor to tuple for torch.onnx.export compatibility\n    if isinstance(args, torch.Tensor):\n        args = (args, )\n\n    try:\n        # First attempt: standard ONNX export\n        torch.onnx.export(model, args, f, **kwargs)\n        return True\n\n    except Exception as e:\n        error_str = str(e)\n        error_type = type(e).__name__\n\n        # Check for the specific onnx_ir.serde.SerdeError patterns\n        is_allowzero_error = (('onnx_ir.serde.SerdeError' in error_str\n                               and 'allowzero' in error_str) or\n                              'ValueError: Value out of range: 1' in error_str\n                              or 'serialize_model_into' in error_str\n                              or 'serialize_attribute_into' in error_str)\n\n        if is_allowzero_error:\n            warnings.warn(\n                f\"Encountered known ONNX serialization issue ({error_type}). \"\n                \"This is likely the allowzero boolean attribute bug. \"\n                \"Attempting workaround...\", UserWarning, stacklevel=2)\n\n            # Apply workaround strategies\n            return _apply_onnx_allowzero_workaround(model, args, f,\n                                                    skip_on_error, **kwargs)\n\n        else:\n            # Re-raise other errors\n            raise\n\n\ndef _apply_onnx_allowzero_workaround(\n    model: torch.nn.Module,\n    args: tuple[Any, ...],\n    f: Union[str, PathLike[Any], None],\n    skip_on_error: bool = False,\n    **kwargs: Any,\n) -> bool:\n    r\"\"\"Apply workaround strategies for onnx_ir.serde.SerdeError with allowzero\n    attributes.\n\n    Returns:\n        bool: True if export succeeded, False if skipped (when\n              skip_on_error=True).\n    \"\"\"\n    # Strategy 1: Try without dynamo if it was enabled\n    if kwargs.get('dynamo', False):\n        try:\n            kwargs_no_dynamo = kwargs.copy()\n            kwargs_no_dynamo['dynamo'] = False\n\n            warnings.warn(\n                \"Retrying ONNX export with dynamo=False as workaround\",\n                UserWarning, stacklevel=3)\n\n            torch.onnx.export(model, args, f, **kwargs_no_dynamo)\n            return True\n\n        except Exception:\n            pass\n\n    # Strategy 2: Try with different opset versions\n    original_opset = kwargs.get('opset_version', 18)\n    for opset_version in [17, 16, 15, 14, 13, 11]:\n        if opset_version != original_opset:\n            try:\n                kwargs_opset = kwargs.copy()\n                kwargs_opset['opset_version'] = opset_version\n\n                warnings.warn(\n                    f\"Retrying ONNX export with opset_version={opset_version}\",\n                    UserWarning, stacklevel=3)\n\n                torch.onnx.export(model, args, f, **kwargs_opset)\n                return True\n\n            except Exception:\n                continue\n\n    # Strategy 3: Try legacy export (non-dynamo with older opset)\n    try:\n        kwargs_legacy = kwargs.copy()\n        kwargs_legacy['dynamo'] = False\n        kwargs_legacy['opset_version'] = 11\n\n        warnings.warn(\n            \"Retrying ONNX export with legacy settings \"\n            \"(dynamo=False, opset_version=11)\", UserWarning, stacklevel=3)\n\n        torch.onnx.export(model, args, f, **kwargs_legacy)\n        return True\n\n    except Exception:\n        pass\n\n    # Strategy 4: Try with minimal settings\n    try:\n        minimal_kwargs: dict[str, Any] = {\n            'opset_version': 11,\n            'dynamo': False,\n        }\n        # Add optional parameters if they exist\n        if kwargs.get('input_names') is not None:\n            minimal_kwargs['input_names'] = kwargs.get('input_names')\n        if kwargs.get('output_names') is not None:\n            minimal_kwargs['output_names'] = kwargs.get('output_names')\n\n        warnings.warn(\n            \"Retrying ONNX export with minimal settings as last resort\",\n            UserWarning, stacklevel=3)\n\n        torch.onnx.export(model, args, f, **minimal_kwargs)\n        return True\n\n    except Exception:\n        pass\n\n    # If all strategies fail, handle based on skip_on_error flag\n    import os\n    pytest_detected = 'PYTEST_CURRENT_TEST' in os.environ or 'pytest' in str(f)\n\n    if skip_on_error:\n        # For CI environments: skip gracefully instead of failing\n        warnings.warn(\n            \"ONNX export skipped due to known upstream issue \"\n            \"(onnx_ir.serde.SerdeError). \"\n            \"This is caused by a bug in the onnx_ir package where boolean \"\n            \"allowzero attributes cannot be serialized. All workarounds \"\n            \"failed. Consider updating packages: pip install --upgrade onnx \"\n            \"onnxscript \"\n            \"onnx_ir\", UserWarning, stacklevel=3)\n        return False\n\n    # For regular usage: provide detailed error message\n    error_msg = (\n        \"Failed to export model to ONNX due to known serialization issue. \"\n        \"This is caused by a bug in the onnx_ir package where boolean \"\n        \"allowzero attributes cannot be serialized. \"\n        \"Workarounds attempted: dynamo=False, multiple opset versions, \"\n        \"and legacy export. \")\n\n    if pytest_detected:\n        error_msg += (\n            \"\\n\\nThis error commonly occurs in pytest environments. \"\n            \"Try one of these solutions:\\n\"\n            \"1. Run the export outside of pytest (in a regular Python \"\n            \"script)\\n\"\n            \"2. Update packages: pip install --upgrade onnx onnxscript \"\n            \"onnx_ir\\n\"\n            \"3. Use torch.jit.script() instead of ONNX export for testing\\n\"\n            \"4. Use safe_onnx_export(..., skip_on_error=True) to skip \"\n            \"gracefully in CI\")\n    else:\n        error_msg += (\"\\n\\nTry updating packages: pip install --upgrade onnx \"\n                      \"onnxscript onnx_ir\")\n\n    raise RuntimeError(error_msg)\n"
  },
  {
    "path": "torch_geometric/backend.py",
    "content": "from typing import Optional\n\nimport torch\n\n# If set to `True`, PyG is configured to use the `segment_matmul` and\n# `grouped_matmul` kernels from `pyg-lib` to parallelize matrix multiplication\n# across segments/groups of potentially varying size.\n# If set to `None`, will automatically decide whether to utilize\n# `segment_matmul` and `grouped_matmul` based on input sizes.\n# Requires `pyg-lib` to be installed.\nuse_segment_matmul: Optional[bool] = None\n\n# Helper functions ############################################################\n\n\ndef use_segment_matmul_heuristic(\n    num_segments: int,\n    max_segment_size: int,\n    in_channels: int,\n    out_channels: int,\n) -> bool:\n    r\"\"\"A heuristic based on input sizes to determine whether the usage of\n    :meth:`segment_matmul` can speed up computation.\n    \"\"\"\n    # NOTE This heuristic was learned on an A100 via sklearn using a simple\n    # StandardScaler() -> LinearSVC() model.\n    # For now, it is only used in combination with `RGCNConv`.\n    x = torch.tensor([\n        num_segments,\n        max_segment_size,\n        in_channels,\n        out_channels,\n    ])\n    mean = torch.tensor([\n        125.11603189,\n        12133.21523472,\n        163.81222321,\n        32.43755536,\n    ])\n    std = torch.tensor([\n        163.34480422,\n        27572.94543809,\n        177.6426489,\n        56.82103934,\n    ])\n    weight = torch.tensor([\n        2.43877659e+00,\n        1.67583047e+00,\n        -5.20527282e-04,\n        3.43925501e-01,\n    ])\n    bias = 1.20236999\n\n    x = (x - mean) / std\n    return bool(x @ weight >= bias)\n"
  },
  {
    "path": "torch_geometric/config_mixin.py",
    "content": "import inspect\nfrom dataclasses import fields, is_dataclass\nfrom importlib import import_module\nfrom typing import Any, Dict\n\nfrom torch.nn import ModuleDict, ModuleList\n\nfrom torch_geometric.config_store import (\n    class_from_dataclass,\n    dataclass_from_class,\n)\nfrom torch_geometric.isinstance import is_torch_instance\n\n\nclass ConfigMixin:\n    r\"\"\"Enables a class to serialize/deserialize itself to a dataclass.\"\"\"\n    def config(self) -> Any:\n        r\"\"\"Creates a serializable configuration of the class.\"\"\"\n        data_cls = dataclass_from_class(self.__class__)\n        if data_cls is None:\n            raise ValueError(f\"Could not find the configuration class that \"\n                             f\"belongs to '{self.__class__.__name__}'. Please \"\n                             f\"register it in the configuration store.\")\n\n        kwargs: Dict[str, Any] = {}\n        for field in fields(data_cls):\n            if not hasattr(self, field.name):\n                continue\n            kwargs[field.name] = _recursive_config(getattr(self, field.name))\n        return data_cls(**kwargs)\n\n    @classmethod\n    def from_config(cls, cfg: Any, *args: Any, **kwargs: Any) -> Any:\n        r\"\"\"Instantiates the class from a serializable configuration.\"\"\"\n        if getattr(cfg, '_target_', None):\n            cls = _locate_cls(cfg._target_)\n        elif isinstance(cfg, dict) and '_target_' in cfg:\n            cls = _locate_cls(cfg['_target_'])\n\n        data_cls = cfg.__class__\n        if not is_dataclass(data_cls):\n            data_cls = dataclass_from_class(cls)\n            if data_cls is None:\n                raise ValueError(f\"Could not find the configuration class \"\n                                 f\"that belongs to '{cls.__name__}'. Please \"\n                                 f\"register it in the configuration store.\")\n\n        field_names = {field.name for field in fields(data_cls)}\n        if isinstance(cfg, dict):\n            _kwargs = {k: v for k, v in cfg.items() if k in field_names}\n            cfg = data_cls(**_kwargs)\n        assert is_dataclass(cfg)\n\n        if len(args) > 0:  # Convert `*args` to `**kwargs`:\n            param_names = list(inspect.signature(cls).parameters.keys())\n            if 'args' in param_names:\n                param_names.remove('args')\n            if 'kwargs' in param_names:\n                param_names.remove('kwargs')\n\n            for name, arg in zip(param_names, args):\n                kwargs[name] = arg\n\n        for key in field_names:\n            if key not in kwargs and key != '_target_':\n                kwargs[key] = _recursive_from_config(getattr(cfg, key))\n\n        return cls(**kwargs)\n\n\ndef _recursive_config(value: Any) -> Any:\n    if isinstance(value, ConfigMixin):\n        return value.config()\n    if is_torch_instance(value, ConfigMixin):\n        return value.config()\n    if isinstance(value, (tuple, list, ModuleList)):\n        return [_recursive_config(v) for v in value]\n    if isinstance(value, (dict, ModuleDict)):\n        return {k: _recursive_config(v) for k, v in value.items()}\n    return value\n\n\ndef _recursive_from_config(value: Any) -> Any:\n    cls: Any = None\n    if is_dataclass(value):\n        if getattr(value, '_target_', None):\n            try:\n                cls = _locate_cls(value._target_)  # type: ignore\n            except ImportError:\n                pass  # Keep the dataclass as it is.\n        else:\n            cls = class_from_dataclass(value.__class__)\n    elif isinstance(value, dict) and '_target_' in value:\n        cls = _locate_cls(value['_target_'])\n\n    if cls is not None and issubclass(cls, ConfigMixin):\n        return cls.from_config(value)\n    if isinstance(value, (tuple, list)):\n        return [_recursive_from_config(v) for v in value]\n    if isinstance(value, dict):\n        return {k: _recursive_from_config(v) for k, v in value.items()}\n    return value\n\n\ndef _locate_cls(qualname: str) -> Any:\n    parts = qualname.split('.')\n\n    if len(parts) <= 1:\n        raise ValueError(f\"Qualified name is missing a dot (got '{qualname}')\")\n\n    if any([len(part) == 0 for part in parts]):\n        raise ValueError(f\"Relative imports not supported (got '{qualname}')\")\n\n    module_name, cls_name = '.'.join(parts[:-1]), parts[-1]\n    return getattr(import_module(module_name), cls_name)\n"
  },
  {
    "path": "torch_geometric/config_store.py",
    "content": "import copy\nimport inspect\nimport typing\nfrom collections import defaultdict\nfrom dataclasses import dataclass, field, make_dataclass\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\n\nEXCLUDE = {'self', 'args', 'kwargs'}\n\nMAPPING = {\n    torch.nn.Module: Any,\n    torch.Tensor: Any,\n}\n\ntry:\n    from omegaconf import MISSING\nexcept Exception:\n    MISSING = '???'\n\ntry:\n    import hydra  # noqa\n    WITH_HYDRA = True\nexcept Exception:\n    WITH_HYDRA = False\n\nif not typing.TYPE_CHECKING and WITH_HYDRA:\n    from hydra.core.config_store import ConfigStore\n\n    def get_node(cls: Union[str, Any]) -> Optional[Any]:\n        if (not isinstance(cls, str)\n                and cls.__module__ in {'builtins', 'typing'}):\n            return None\n\n        def _get_candidates(repo: Dict[str, Any]) -> List[Any]:\n            outs: List[Any] = []\n            for key, value in repo.items():\n                if isinstance(value, dict):\n                    outs.extend(_get_candidates(value))\n                elif getattr(value.node._metadata, 'object_type', None) == cls:\n                    outs.append(value.node)\n                elif getattr(value.node._metadata, 'orig_type', None) == cls:\n                    outs.append(value.node)\n                elif isinstance(cls, str) and key == f'{cls}.yaml':\n                    outs.append(value.node)\n\n            return outs\n\n        candidates = _get_candidates(get_config_store().repo)\n\n        if len(candidates) > 1:\n            raise ValueError(f\"Found multiple entries in the configuration \"\n                             f\"store for the same node '{candidates[0].name}'\")\n\n        return candidates[0] if len(candidates) == 1 else None\n\n    def dataclass_from_class(cls: Union[str, Any]) -> Optional[Any]:\n        r\"\"\"Returns the :obj:`dataclass` of a class registered in the global\n        configuration store.\n        \"\"\"\n        node = get_node(cls)\n        return node._metadata.object_type if node is not None else None\n\n    def class_from_dataclass(cls: Union[str, Any]) -> Optional[Any]:\n        r\"\"\"Returns the original class of a :obj:`dataclass` registered in the\n        global configuration store.\n        \"\"\"\n        node = get_node(cls)\n        return node._metadata.orig_type if node is not None else None\n\nelse:\n\n    class Singleton(type):\n        _instances: Dict[type, Any] = {}\n\n        def __call__(cls, *args: Any, **kwargs: Any) -> Any:\n            if cls not in cls._instances:\n                instance = super().__call__(*args, **kwargs)\n                cls._instances[cls] = instance\n                return instance\n            return cls._instances[cls]\n\n    @dataclass\n    class Metadata:\n        orig_type: Optional[Any] = None\n\n    @dataclass\n    class ConfigNode:\n        name: str\n        node: Any\n        group: Optional[str] = None\n        _metadata: Metadata = field(default_factory=Metadata)\n\n    class ConfigStore(metaclass=Singleton):\n        def __init__(self) -> None:\n            self.repo: Dict[str, Any] = defaultdict(dict)\n\n        @classmethod\n        def instance(cls, *args: Any, **kwargs: Any) -> 'ConfigStore':\n            return cls(*args, **kwargs)\n\n        def store(\n            self,\n            name: str,\n            node: Any,\n            group: Optional[str] = None,\n            orig_type: Optional[Any] = None,\n        ) -> None:\n            cur = self.repo\n            if group is not None:\n                cur = cur[group]\n            if name in cur:\n                raise KeyError(f\"Configuration '{name}' already registered. \"\n                               f\"Please store it under a different group.\")\n            metadata = Metadata(orig_type=orig_type)\n            cur[name] = ConfigNode(name, node, group, metadata)\n\n    def get_node(cls: Union[str, Any]) -> Optional[ConfigNode]:\n        if (not isinstance(cls, str)\n                and cls.__module__ in {'builtins', 'typing'}):\n            return None\n\n        def _get_candidates(repo: Dict[str, Any]) -> List[ConfigNode]:\n            outs: List[ConfigNode] = []\n            for key, value in repo.items():\n                if isinstance(value, dict):\n                    outs.extend(_get_candidates(value))\n                elif value.node == cls:\n                    outs.append(value)\n                elif value._metadata.orig_type == cls:\n                    outs.append(value)\n                elif isinstance(cls, str) and key == cls:\n                    outs.append(value)\n\n            return outs\n\n        candidates = _get_candidates(get_config_store().repo)\n\n        if len(candidates) > 1:\n            raise ValueError(f\"Found multiple entries in the configuration \"\n                             f\"store for the same node '{candidates[0].name}'\")\n\n        return candidates[0] if len(candidates) == 1 else None\n\n    def dataclass_from_class(cls: Union[str, Any]) -> Optional[Any]:\n        r\"\"\"Returns the :obj:`dataclass` of a class registered in the global\n        configuration store.\n        \"\"\"\n        node = get_node(cls)\n        return node.node if node is not None else None\n\n    def class_from_dataclass(cls: Union[str, Any]) -> Optional[Any]:\n        r\"\"\"Returns the original class of a :obj:`dataclass` registered in the\n        global configuration store.\n        \"\"\"\n        node = get_node(cls)\n        return node._metadata.orig_type if node is not None else None\n\n\ndef map_annotation(\n    annotation: Any,\n    mapping: Optional[Dict[Any, Any]] = None,\n) -> Any:\n    origin = getattr(annotation, '__origin__', None)\n    args: Tuple[Any, ...] = getattr(annotation, '__args__', tuple())\n    if origin in {Union, list, dict, tuple}:\n        assert origin is not None\n        args = tuple(map_annotation(a, mapping) for a in args)\n        if type(annotation).__name__ == 'GenericAlias':\n            # If annotated with `list[...]` or `dict[...]`:\n            annotation = origin[args]\n        else:\n            # If annotated with `typing.List[...]` or `typing.Dict[...]`:\n            annotation = copy.copy(annotation)\n            annotation.__args__ = args\n\n        return annotation\n\n    if mapping is not None and annotation in mapping:\n        return mapping[annotation]\n\n    out = dataclass_from_class(annotation)\n    if out is not None:\n        return out\n\n    return annotation\n\n\ndef to_dataclass(\n    cls: Any,\n    base_cls: Optional[Any] = None,\n    with_target: Optional[bool] = None,\n    map_args: Optional[Dict[str, Tuple]] = None,\n    exclude_args: Optional[List[str]] = None,\n    strict: bool = False,\n) -> Any:\n    r\"\"\"Converts the input arguments of a given class :obj:`cls` to a\n    :obj:`dataclass` schema.\n\n    For example,\n\n    .. code-block:: python\n\n        from torch_geometric.transforms import NormalizeFeatures\n\n        dataclass = to_dataclass(NormalizeFeatures)\n\n    will generate\n\n    .. code-block:: python\n\n        @dataclass\n        class NormalizeFeatures:\n            _target_: str = \"torch_geometric.transforms.NormalizeFeatures\"\n            attrs: List[str] = field(default_factory = lambda: [\"x\"])\n\n    Args:\n        cls (Any): The class to generate a schema for.\n        base_cls (Any, optional): The base class of the schema.\n            (default: :obj:`None`)\n        with_target (bool, optional): If set to :obj:`False`, will not add the\n            :obj:`_target_` attribute to the schema. If set to :obj:`None`,\n            will only add the :obj:`_target_` in case :obj:`base_cls` is given.\n            (default: :obj:`None`)\n        map_args (Dict[str, Tuple], optional): Arguments for which annotation\n            and default values should be overridden. (default: :obj:`None`)\n        exclude_args (List[str or int], optional): Arguments to exclude.\n            (default: :obj:`None`)\n        strict (bool, optional): If set to :obj:`True`, ensures that all\n            arguments in both :obj:`map_args` and :obj:`exclude_args` are\n            present in the input parameters. (default: :obj:`False`)\n    \"\"\"\n    fields = []\n\n    params = inspect.signature(cls.__init__).parameters\n\n    if strict:  # Check that keys in map_args or exclude_args are present.\n        keys = set() if map_args is None else set(map_args.keys())\n        if exclude_args is not None:\n            keys |= {arg for arg in exclude_args if isinstance(arg, str)}\n        diff = keys - set(params.keys())\n        if len(diff) > 0:\n            raise ValueError(f\"Expected input argument(s) {diff} in \"\n                             f\"'{cls.__name__}'\")\n\n    for i, (name, arg) in enumerate(params.items()):\n        if name in EXCLUDE:\n            continue\n        if exclude_args is not None:\n            if name in exclude_args or i in exclude_args:\n                continue\n        if base_cls is not None:\n            if name in base_cls.__dataclass_fields__:\n                continue\n\n        if map_args is not None and name in map_args:\n            fields.append((name, ) + map_args[name])\n            continue\n\n        annotation, default = arg.annotation, arg.default\n        annotation = map_annotation(annotation, mapping=MAPPING)\n\n        if annotation != inspect.Parameter.empty:\n            # `Union` types are not supported (except for `Optional`).\n            # As such, we replace them with either `Any` or `Optional[Any]`.\n            origin = getattr(annotation, '__origin__', None)\n            args = getattr(annotation, '__args__', [])\n            if origin == Union and type(None) in args and len(args) > 2:\n                annotation = Optional[Any]\n            elif origin == Union and type(None) not in args:\n                annotation = Any\n            elif origin == list:\n                if getattr(args[0], '__origin__', None) == Union:\n                    annotation = List[Any]\n            elif origin == dict:\n                if getattr(args[1], '__origin__', None) == Union:\n                    annotation = Dict[args[0], Any]  # type: ignore\n        else:\n            annotation = Any\n\n        if str(default) == \"<required parameter>\":\n            # Fix `torch.optim.SGD.lr = _RequiredParameter()`:\n            # https://github.com/pytorch/hydra-torch/blob/main/\n            # hydra-configs-torch/hydra_configs/torch/optim/sgd.py\n            default = field(default=MISSING)\n        elif default != inspect.Parameter.empty:\n            if isinstance(default, (list, dict)):\n                # Avoid late binding of default values inside a loop:\n                # https://stackoverflow.com/questions/3431676/\n                # creating-functions-in-a-loop\n                def wrapper(default: Any) -> Callable[[], Any]:\n                    return lambda: default\n\n                default = field(default_factory=wrapper(default))\n        else:\n            default = field(default=MISSING)\n\n        fields.append((name, annotation, default))\n\n    with_target = base_cls is not None if with_target is None else with_target\n    if with_target:\n        full_cls_name = f'{cls.__module__}.{cls.__qualname__}'\n        fields.append(('_target_', str, field(default=full_cls_name)))\n\n    return make_dataclass(cls.__qualname__, fields=fields,\n                          bases=() if base_cls is None else (base_cls, ))\n\n\ndef get_config_store() -> ConfigStore:\n    r\"\"\"Returns the global configuration store.\"\"\"\n    return ConfigStore.instance()\n\n\ndef clear_config_store() -> ConfigStore:\n    r\"\"\"Clears the global configuration store.\"\"\"\n    config_store = get_config_store()\n    for key in list(config_store.repo.keys()):\n        if key != 'hydra' and not key.endswith('.yaml'):\n            del config_store.repo[key]\n    return config_store\n\n\ndef register(\n    cls: Optional[Any] = None,\n    data_cls: Optional[Any] = None,\n    group: Optional[str] = None,\n    **kwargs: Any,\n) -> Union[Any, Callable]:\n    r\"\"\"Registers a class in the global configuration store.\n\n    Args:\n        cls (Any, optional): The class to register. If set to :obj:`None`, will\n            return a decorator. (default: :obj:`None`)\n        data_cls (Any, optional): The data class to register. If set to\n            :obj:`None`, will dynamically create the data class according to\n            :class:`~torch_geometric.config_store.to_dataclass`.\n            (default: :obj:`None`)\n        group (str, optional): The group in the global configuration store.\n            (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`~torch_geometric.config_store.to_dataclass`.\n    \"\"\"\n    if cls is not None:\n        name = cls.__name__\n\n        if get_node(cls):\n            raise ValueError(f\"The class '{name}' is already registered in \"\n                             \"the global configuration store\")\n\n        if data_cls is None:\n            data_cls = to_dataclass(cls, **kwargs)\n        elif get_node(data_cls):\n            raise ValueError(\n                f\"The data class '{data_cls.__name__}' is already registered \"\n                f\"in the global configuration store\")\n\n        if not typing.TYPE_CHECKING and WITH_HYDRA:\n            get_config_store().store(name, data_cls, group)\n            get_node(name)._metadata.orig_type = cls\n        else:\n            get_config_store().store(name, data_cls, group, cls)\n\n        return data_cls\n\n    def bounded_register(cls: Any) -> Any:  # Other-wise, return a decorator:\n        register(cls=cls, data_cls=data_cls, group=group, **kwargs)\n        return cls\n\n    return bounded_register\n\n\n###############################################################################\n\n\n@dataclass\nclass Transform:\n    pass\n\n\n@dataclass\nclass Dataset:\n    pass\n\n\n@dataclass\nclass Model:\n    pass\n\n\n@dataclass\nclass Optimizer:\n    pass\n\n\n@dataclass\nclass LRScheduler:\n    pass\n\n\n@dataclass\nclass Config:\n    dataset: Dataset = MISSING\n    model: Model = MISSING\n    optim: Optimizer = MISSING\n    lr_scheduler: Optional[LRScheduler] = None\n\n\ndef fill_config_store() -> None:\n    import torch_geometric\n\n    config_store = get_config_store()\n\n    # Register `torch_geometric.transforms` ###################################\n    transforms = torch_geometric.transforms\n    for cls_name in set(transforms.__all__) - {\n            'BaseTransform',\n            'Compose',\n            'ComposeFilters',\n            'LinearTransformation',\n            'AddMetaPaths',  # TODO\n    }:\n        cls = to_dataclass(getattr(transforms, cls_name), base_cls=Transform)\n        # We use an explicit additional nesting level inside each config to\n        # allow for applying multiple transformations.\n        # See: hydra.cc/docs/patterns/select_multiple_configs_from_config_group\n        config_store.store(cls_name, group='transform', node={cls_name: cls})\n\n    # Register `torch_geometric.datasets` #####################################\n    datasets = torch_geometric.datasets\n    map_dataset_args: Dict[str, Any] = {\n        'transform': (Dict[str, Transform], field(default_factory=dict)),\n        'pre_transform': (Dict[str, Transform], field(default_factory=dict)),\n    }\n\n    for cls_name in set(datasets.__all__) - set():\n        cls = to_dataclass(getattr(datasets, cls_name), base_cls=Dataset,\n                           map_args=map_dataset_args,\n                           exclude_args=['pre_filter'])\n        config_store.store(cls_name, group='dataset', node=cls)\n\n    # Register `torch_geometric.models` #######################################\n    models = torch_geometric.nn.models.basic_gnn\n    for cls_name in set(models.__all__) - set():\n        cls = to_dataclass(getattr(models, cls_name), base_cls=Model)\n        config_store.store(cls_name, group='model', node=cls)\n\n    # Register `torch.optim.Optimizer` ########################################\n    for cls_name in {\n            key\n            for key, cls in torch.optim.__dict__.items()\n            if inspect.isclass(cls) and issubclass(cls, torch.optim.Optimizer)\n    } - {\n            'Optimizer',\n    }:\n        cls = to_dataclass(getattr(torch.optim, cls_name), base_cls=Optimizer,\n                           exclude_args=['params'])\n        config_store.store(cls_name, group='optimizer', node=cls)\n\n    # Register `torch.optim.lr_scheduler` #####################################\n    for cls_name in {\n            key\n            for key, cls in torch.optim.lr_scheduler.__dict__.items()\n            if inspect.isclass(cls)\n    } - {\n            'Optimizer',\n            '_LRScheduler',\n            'Counter',\n            'SequentialLR',\n            'ChainedScheduler',\n    }:\n        cls = to_dataclass(getattr(torch.optim.lr_scheduler, cls_name),\n                           base_cls=LRScheduler, exclude_args=['optimizer'])\n        config_store.store(cls_name, group='lr_scheduler', node=cls)\n\n    # Register global schema ##################################################\n    config_store.store('config', node=Config)\n"
  },
  {
    "path": "torch_geometric/contrib/__init__.py",
    "content": "import warnings\n\nimport torch_geometric.contrib.transforms  # noqa\nimport torch_geometric.contrib.datasets  # noqa\nimport torch_geometric.contrib.nn  # noqa\nimport torch_geometric.contrib.explain  # noqa\n\nwarnings.warn(\n    \"'torch_geometric.contrib' contains experimental code and is subject to \"\n    \"change. Please use with caution.\", stacklevel=2)\n\n__all__ = []\n"
  },
  {
    "path": "torch_geometric/contrib/datasets/__init__.py",
    "content": "__all__ = classes = []\n"
  },
  {
    "path": "torch_geometric/contrib/explain/__init__.py",
    "content": "from torch_geometric.deprecation import deprecated\n\nfrom .pgm_explainer import PGMExplainer\nfrom torch_geometric.explain.algorithm.graphmask_explainer import (\n    GraphMaskExplainer as NewGraphMaskExplainer)\n\nGraphMaskExplainer = deprecated(\n    \"use 'torch_geometric.explain.algorithm.GraphMaskExplainer' instead\", )(\n        NewGraphMaskExplainer)\n\n__all__ = classes = [\n    'PGMExplainer',\n]\n"
  },
  {
    "path": "torch_geometric/contrib/explain/pgm_explainer.py",
    "content": "import logging\nfrom typing import List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain import ExplainerAlgorithm\nfrom torch_geometric.explain.config import ModelMode, ModelTaskLevel\nfrom torch_geometric.explain.explanation import Explanation\nfrom torch_geometric.utils import k_hop_subgraph\nfrom torch_geometric.utils._subgraph import get_num_hops\n\n\nclass PGMExplainer(ExplainerAlgorithm):\n    r\"\"\"The PGMExplainer model from the `\"PGMExplainer: Probabilistic\n    Graphical Model Explanations  for Graph Neural Networks\"\n    <https://arxiv.org/abs/1903.03894>`_ paper.\n\n    The generated :class:`~torch_geometric.explain.Explanation` provides a\n    :obj:`node_mask` and a :obj:`pgm_stats` tensor, which stores the\n    :math:`p`-values of each node as calculated by the Chi-squared test.\n\n    Args:\n        feature_index (List): The indices of the perturbed features. If set\n            to :obj:`None`, all features are perturbed. (default: :obj:`None`)\n        perturb_mode (str, optional): The method to generate the variations in\n            features. One of :obj:`\"randint\"`, :obj:`\"mean\"`, :obj:`\"zero\"`,\n            :obj:`\"max\"` or :obj:`\"uniform\"`. (default: :obj:`\"randint\"`)\n        perturbations_is_positive_only (bool, optional): If set to :obj:`True`,\n            restrict perturbed values to be positive. (default: :obj:`False`)\n        is_perturbation_scaled (bool, optional): If set to :obj:`True`, will\n            normalize the range of the perturbed features.\n            (default: :obj:`False`)\n        num_samples (int, optional): The number of samples of perturbations\n            used to test the significance of nodes to the prediction.\n            (default: :obj:`100`)\n        max_subgraph_size (int, optional): The maximum number of neighbors to\n            consider for the explanation. (default: :obj:`None`)\n        significance_threshold (float, optional): The statistical threshold\n            (:math:`p`-value) for which a node is considered to have an effect\n            on the prediction. (default: :obj:`0.05`)\n        pred_threshold (float, optional): The buffer value (in range\n            :obj:`[0, 1]`) to consider the output from a perturbed data to be\n            different from the original. (default: :obj:`0.1`)\n    \"\"\"\n    def __init__(\n        self,\n        feature_index: Optional[List] = None,\n        perturbation_mode: str = \"randint\",\n        perturbations_is_positive_only: bool = False,\n        is_perturbation_scaled: bool = False,\n        num_samples: int = 100,\n        max_subgraph_size: Optional[int] = None,\n        significance_threshold: float = 0.05,\n        pred_threshold: float = 0.1,\n    ):\n        super().__init__()\n        self.feature_index = feature_index\n        self.perturbation_mode = perturbation_mode\n        self.perturbations_is_positive_only = perturbations_is_positive_only\n        self.is_perturbation_scaled = is_perturbation_scaled\n        self.num_samples = num_samples\n        self.max_subgraph_size = max_subgraph_size\n        self.significance_threshold = significance_threshold\n        self.pred_threshold = pred_threshold\n\n    def _perturb_features_on_nodes(\n        self,\n        x: Tensor,\n        index: Tensor,\n    ) -> Tensor:\n        r\"\"\"Perturbs feature matrix :obj:`x`.\n\n        Args:\n            x (torch.Tensor): The feature matrix.\n            index (torch.Tensor): The indices of nodes to perturb.\n        \"\"\"\n        x_perturb = x.detach().clone()\n        perturb_array = x_perturb[index]\n        epsilon = 0.05 * torch.max(x, dim=0).values\n\n        if self.perturbation_mode == \"randint\":\n            perturb_array = torch.randint(high=2, size=perturb_array.size(),\n                                          device=x.device)\n        elif self.perturbation_mode == \"mean\":\n            perturb_array[:, self.feature_index] = torch.mean(\n                x[:, self.feature_index])\n        elif self.perturbation_mode == \"zero\":\n            perturb_array[:, self.feature_index] = 0\n        elif self.perturbation_mode == \"max\":\n            perturb_array[:, self.feature_index] = torch.max(\n                x[:, self.feature_index])\n        elif self.perturbation_mode == \"uniform\":\n            random_perturbations = torch.rand(\n                perturb_array.shape) * 2 * epsilon - epsilon\n            perturb_array[:, self.feature_index] = perturb_array[\n                self.feature_index] + random_perturbations\n            perturb_array.clamp(min=0, max=torch.max(x, dim=0))\n\n        if self.is_perturbation_scaled:\n            perturb_array = torch.multiply(\n                perturb_array, torch.rand(size=perturb_array.size())) * 2\n\n        x_perturb[index] = perturb_array.type(x_perturb.dtype)\n        return x_perturb\n\n    def _batch_perturb_features_on_node(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        indices_to_perturb: np.array,\n        percentage: float = 50.,  # % time node gets perturbed\n        **kwargs,\n    ) -> Tensor:\n        r\"\"\"Perturbs the node features of a batch of graphs for graph\n        classification tasks.\n\n        Args:\n            model (torch.nn.Module): The GNN model.\n            x (torch.Tensor): The node feature matrix\n            edge_index (torch.Tensor): The edge indices.\n            indices_to_perturb (np.array): The indices of nodes to perturb.\n            percentage (float, optional): The percentage of times a node gets\n                perturbed. (default: :obj:`50.`)\n            **kwargs (optional): Additional arguments passed to\n                :meth:`model.forward`.\n        \"\"\"\n        pred_torch = model(x, edge_index, **kwargs)\n        soft_pred = torch.softmax(pred_torch, dim=1)\n        pred_label = torch.argmax(soft_pred, dim=1)\n        num_nodes = x.shape[0]\n\n        samples = []\n        for _ in range(self.num_samples):\n            x_perturb = x.detach().clone()\n\n            seeds = np.random.randint(0, 100, size=len(indices_to_perturb))\n            perturbed_node_indexes = indices_to_perturb[(seeds < percentage)]\n            x_perturb = self._perturb_features_on_nodes(\n                x=x_perturb,\n                index=perturbed_node_indexes,\n            )\n            sample = np.zeros(num_nodes + 1)\n            sample[perturbed_node_indexes] = 1\n\n            pred_perturb_torch = model(x_perturb, edge_index, **kwargs)\n            soft_pred_perturb = torch.softmax(pred_perturb_torch,\n                                              dim=1).squeeze()\n\n            pred_change = torch.max(soft_pred) - soft_pred_perturb[pred_label]\n\n            sample[num_nodes] = pred_change.detach()\n            samples.append(sample)\n\n        samples = torch.tensor(np.array(samples))\n        if self.perturbations_is_positive_only:\n            samples = torch.abs(samples)\n\n        top = int(self.num_samples / 8)\n        top_idx = torch.argsort(samples[:, num_nodes])[-top:]\n        for i in range(self.num_samples):\n            if i in top_idx:\n                samples[i, num_nodes] = 1\n            else:\n                samples[i, num_nodes] = 0\n\n        return samples\n\n    def _explain_graph(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        target=None,\n        **kwargs,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Generates explanations for graph classification tasks.\n\n        Args:\n            model (torch.nn.Module): The model to explain.\n            x (torch.Tensor): The node features.\n            edge_index (torch.Tensor): The edge indices of the input graph.\n            target (torch.Tensor, optional): The predicted label from the\n                model. (default: :obj:`None`)\n            **kwargs (optional): Additional arguments passed to\n                :meth:`model.forward`.\n\n        Returns:\n            pgm_nodes (List): The neighbor nodes that are significant in the\n                selected node's prediction.\n            pgm_stats (torch.Tensor): The :math:`p`-values of all the nodes in\n                the graph, ordered by node index.\n        \"\"\"\n        import pandas as pd\n        from pgmpy.estimators.CITests import chi_square\n\n        num_nodes = x.shape[0]\n        if not self.max_subgraph_size:\n            self.max_subgraph_size = int(num_nodes / 20)\n\n        samples = self._batch_perturb_features_on_node(\n            indices_to_perturb=np.array(range(num_nodes)),\n            x=x,\n            model=model,\n            edge_index=edge_index,\n        )\n\n        # note: the PC estimator is in the original code, ie. est= PC(data)\n        # but as it does nothing it is not included here\n        data = pd.DataFrame(np.array(samples.detach().cpu()))\n\n        p_values = []\n        for node in range(num_nodes):\n            chi2, p, _ = chi_square(\n                node, int(target.detach().cpu()), [], data, boolean=False,\n                significance_level=self.significance_threshold)\n            p_values.append(p)\n\n        # the original code uses number_candidates_nodes = int(top_nodes * 4)\n        # if we consider 'top nodes' to equate to max number of nodes\n        # it seems more correct to limit number_candidates_nodes to this\n        candidate_nodes = np.argpartition(\n            p_values, self.max_subgraph_size)[0:self.max_subgraph_size]\n\n        # Round 2\n        samples = self._batch_perturb_features_on_node(\n            indices_to_perturb=candidate_nodes, x=x, edge_index=edge_index,\n            model=model, **kwargs)\n\n        # note: the PC estimator is in the original code, ie. est= PC(data)\n        # but as it does nothing it is not included here\n        data = pd.DataFrame(np.array(samples.detach().cpu()))\n\n        p_values = []\n        dependent_nodes = []\n\n        target = num_nodes\n        for node in range(num_nodes):\n            _, p, _ = chi_square(\n                node, target, [], data, boolean=False,\n                significance_level=self.significance_threshold)\n            p_values.append(p)\n            if p < self.significance_threshold:\n                dependent_nodes.append(node)\n\n        top_p = np.min((self.max_subgraph_size, num_nodes - 1))\n        ind_top_p = np.argpartition(p_values, top_p)[0:top_p]\n        pgm_nodes = list(ind_top_p)\n\n        node_mask = torch.zeros(x.size(), dtype=torch.int)\n        node_mask[pgm_nodes] = 1\n        pgm_stats = torch.tensor(p_values)\n\n        return node_mask, pgm_stats\n\n    def _explain_node(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        target: Tensor,\n        index: int,\n        **kwargs,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Generates explanations for node classification tasks.\n\n        Args:\n            model (torch.nn.Module): The model to explain.\n            x (torch.Tensor): The node features.\n            edge_index (torch.Tensor): The edge indices of the input graph.\n            target (torch.Tensor): The predicted label from the model.\n            index (int): The index of the node for which the explanations is\n                generated.\n            **kwargs (optional): Additional arguments passed to\n                :meth:`model.forward`.\n\n        Returns:\n            node_mask (torch.Tensor): A hard node mask corresponding to whether\n                a node is significant in the selected node's prediction.\n            pgm_stats (torch.Tensor): The :math:`p`-values of all the nodes in\n                the graph, ordered by node index.\n        \"\"\"\n        import pandas as pd\n        from pgmpy.estimators.CITests import chi_square\n\n        neighbors, _, _, _ = k_hop_subgraph(\n            node_idx=index,\n            num_hops=get_num_hops(model),\n            edge_index=edge_index,\n            relabel_nodes=False,\n            num_nodes=x.size(0),\n        )\n\n        if index not in neighbors:\n            neighbors = torch.cat([neighbors, index], dim=1)\n\n        pred_model = model(x, edge_index, **kwargs)\n\n        softmax_pred = torch.softmax(pred_model, dim=1)\n\n        samples = []\n        pred_samples = []\n\n        for _ in range(self.num_samples):\n            # A subset of neighbors are selected randomly for perturbing:\n            seeds = np.random.choice([1, 0], size=(len(neighbors), ))\n            x_perturb = self._perturb_features_on_nodes(\n                x=x,\n                index=neighbors[seeds == 1],\n            )\n\n            # prediction after perturbation\n            pred_perturb = model(x_perturb, edge_index, **kwargs)\n            softmax_pred_perturb = torch.softmax(pred_perturb, dim=1)\n            sample_bool = np.ones(shape=(len(neighbors), ))\n            sample_bool[((softmax_pred_perturb[neighbors, target] +\n                          self.pred_threshold)\n                         >= softmax_pred[neighbors, target]).cpu()] = 0\n\n            samples.append(seeds)\n            pred_samples.append(sample_bool)\n\n        samples = np.asarray(samples)\n        pred_samples = np.asarray(pred_samples)\n        combine_samples = (samples * 10 + pred_samples) + 1\n\n        neighbors = np.array(neighbors.detach().cpu())\n        data_pgm = pd.DataFrame(combine_samples)\n        data_pgm = data_pgm.rename(columns={\n            0: \"A\",\n            1: \"B\"\n        })  # Trick to use chi_square test on first two data columns\n        index_original_to_subgraph = dict(\n            zip(neighbors, list(data_pgm.columns)))\n        index_subgraph_to_original = dict(\n            zip(list(data_pgm.columns), neighbors))\n        p_values = []\n\n        dependent_neighbors = []\n        dependent_neighbors_p_values = []\n        for node in neighbors:\n            if node == index:\n                # null hypothesis is perturbing a particular\n                # node has no effect on result\n                p = 0\n            else:\n                _, p, _ = chi_square(\n                    index_original_to_subgraph[node],\n                    index_original_to_subgraph[index], [], data_pgm,\n                    boolean=False,\n                    significance_level=self.significance_threshold)\n            p_values.append(p)\n            if p < self.significance_threshold:\n                dependent_neighbors.append(node)\n                dependent_neighbors_p_values.append(p)\n\n        pgm_stats = torch.ones(x.size(0), dtype=torch.float)\n        node_mask = torch.zeros(x.size(), dtype=torch.int)\n\n        pgm_stats[neighbors] = torch.tensor(p_values, dtype=torch.float)\n\n        if self.max_subgraph_size is None:\n            pgm_nodes = dependent_neighbors\n        else:\n            top_p = np.min((self.max_subgraph_size, len(neighbors) - 1))\n            ind_top_p = np.argpartition(p_values, top_p)[0:top_p]\n            pgm_nodes = [\n                index_subgraph_to_original[node] for node in ind_top_p\n            ]\n        node_mask[pgm_nodes] = 1\n        return node_mask, pgm_stats\n\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,  # node index\n        **kwargs,\n    ) -> Explanation:\n\n        if self.feature_index is None:\n            self.feature_index = list(range(x.shape[-1]))\n\n        if isinstance(index, Tensor):\n            if index.numel() > 1:\n                raise NotImplementedError(\n                    f\"'{self.__class__.__name}' only supports a single \"\n                    f\"`index` for now\")\n            index = index.item()\n\n        if self.model_config.task_level == ModelTaskLevel.node:\n            node_mask, pgm_stats = self._explain_node(\n                model=model,\n                x=x,\n                edge_index=edge_index,\n                target=target[index],\n                index=index,\n                **kwargs,\n            )\n            return Explanation(\n                x=x,\n                edge_index=edge_index,\n                node_mask=node_mask,\n                pgm_stats=pgm_stats,\n            )\n\n        elif self.model_config.task_level == ModelTaskLevel.graph:\n            node_mask, pgm_stats = self._explain_graph(\n                model=model,\n                x=x,\n                target=target,\n                edge_index=edge_index,\n                **kwargs,\n            )\n            return Explanation(\n                node_mask=node_mask,\n                pgm_stats=pgm_stats,\n            )\n\n    def supports(self) -> bool:\n        task_level = self.model_config.task_level\n        if task_level not in [ModelTaskLevel.node, ModelTaskLevel.graph]:\n            logging.error(f\"Task level '{task_level.value}' not supported\")\n            return False\n        if self.explainer_config.edge_mask_type is not None:\n            logging.error(\"Generation of edge masks is not supported\")\n            return False\n        if self.model_config.mode == ModelMode.regression:\n            logging.error(\"'PGMExplainer' only supports classification tasks\")\n            return False\n        return True\n"
  },
  {
    "path": "torch_geometric/contrib/nn/__init__.py",
    "content": "from .conv import *  # noqa\nfrom .models import *  # noqa\n\n__all__ = []\n"
  },
  {
    "path": "torch_geometric/contrib/nn/conv/__init__.py",
    "content": "__all__ = classes = []\n"
  },
  {
    "path": "torch_geometric/contrib/nn/models/__init__.py",
    "content": "from .rbcd_attack import PRBCDAttack, GRBCDAttack\n\n__all__ = classes = [\n    'PRBCDAttack',\n    'GRBCDAttack',\n]\n"
  },
  {
    "path": "torch_geometric/contrib/nn/models/rbcd_attack.py",
    "content": "from collections import defaultdict\nfrom functools import partial\nfrom typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom tqdm import tqdm\n\nfrom torch_geometric.utils import coalesce, to_undirected\n\n# (predictions, labels, ids/mask) -> Tensor with one element\nLOSS_TYPE = Callable[[Tensor, Tensor, Optional[Tensor]], Tensor]\n\n\nclass PRBCDAttack(torch.nn.Module):\n    r\"\"\"The Projected Randomized Block Coordinate Descent (PRBCD) adversarial\n    attack from the `Robustness of Graph Neural Networks at Scale\n    <https://www.cs.cit.tum.de/daml/robustness-of-gnns-at-scale>`_ paper.\n\n    This attack uses an efficient gradient based approach that (during the\n    attack) relaxes the discrete entries in the adjacency matrix\n    :math:`\\{0, 1\\}` to :math:`[0, 1]` and solely perturbs the adjacency matrix\n    (no feature perturbations). Thus, this attack supports all models that can\n    handle weighted graphs that are differentiable w.r.t. these edge weights,\n    *e.g.*, :class:`~torch_geometric.nn.conv.GCNConv` or\n    :class:`~torch_geometric.nn.conv.GraphConv`. For non-differentiable models\n    you might need modifications, e.g., see example for\n    :class:`~torch_geometric.nn.conv.GATConv`.\n\n    The memory overhead is driven by the additional edges (at most\n    :attr:`block_size`). For scalability reasons, the block is drawn with\n    replacement and then the index is made unique. Thus, the actual block size\n    is typically slightly smaller than specified.\n\n    This attack can be used for both global and local attacks as well as\n    test-time attacks (evasion) and training-time attacks (poisoning). Please\n    see the provided examples.\n\n    This attack is designed with a focus on node- or graph-classification,\n    however, to adapt to other tasks you most likely only need to provide an\n    appropriate loss and model. However, we currently do not support batching\n    out of the box (sampling needs to be adapted).\n\n    .. note::\n        For examples of using the PRBCD Attack, see\n        `examples/contrib/rbcd_attack.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        contrib/rbcd_attack.py>`_\n        for a test time attack (evasion) or\n        `examples/contrib/rbcd_attack_poisoning.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        contrib/rbcd_attack_poisoning.py>`_\n        for a training time (poisoning) attack.\n\n    Args:\n        model (torch.nn.Module): The GNN module to assess.\n        block_size (int): Number of randomly selected elements in the\n            adjacency matrix to consider.\n        epochs (int, optional): Number of epochs (aborts early if\n            :obj:`mode='greedy'` and budget is satisfied) (default: :obj:`125`)\n        epochs_resampling (int, optional): Number of epochs to resample the\n            random block. (default: obj:`100`)\n        loss (str or callable, optional): A loss to quantify the \"strength\" of\n            an attack. Note that this function must match the output format of\n            :attr:`model`. By default, it is assumed that the task is\n            classification and that the model returns raw predictions (*i.e.*,\n            no output activation) or uses :obj:`logsoftmax`. Moreover, and the\n            number of predictions should match the number of labels passed to\n            :attr:`attack`. Either pass a callable or one of: :obj:`'masked'`,\n            :obj:`'margin'`, :obj:`'prob_margin'`, :obj:`'tanh_margin'`.\n            (default: :obj:`'prob_margin'`)\n        metric (callable, optional): Second (potentially\n            non-differentiable) loss for monitoring or early stopping (if\n            :obj:`mode='greedy'`). (default: same as :attr:`loss`)\n        lr (float, optional): Learning rate for updating edge weights.\n            Additionally, it is heuristically corrected for :attr:`block_size`,\n            budget (see :attr:`attack`) and graph size. (default: :obj:`1_000`)\n        is_undirected (bool, optional): If :obj:`True` the graph is\n            assumed to be undirected. (default: :obj:`True`)\n        log (bool, optional): If set to :obj:`False`, will not log any learning\n            progress. (default: :obj:`True`)\n    \"\"\"\n    coeffs = {\n        'max_final_samples': 20,\n        'max_trials_sampling': 20,\n        'with_early_stopping': True,\n        'eps': 1e-7\n    }\n\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        block_size: int,\n        epochs: int = 125,\n        epochs_resampling: int = 100,\n        loss: Optional[Union[str, LOSS_TYPE]] = 'prob_margin',\n        metric: Optional[Union[str, LOSS_TYPE]] = None,\n        lr: float = 1_000,\n        is_undirected: bool = True,\n        log: bool = True,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.model = model\n        self.block_size = block_size\n        self.epochs = epochs\n\n        if isinstance(loss, str):\n            if loss == 'masked':\n                self.loss = self._masked_cross_entropy\n            elif loss == 'margin':\n                self.loss = partial(self._margin_loss, reduce='mean')\n            elif loss == 'prob_margin':\n                self.loss = self._probability_margin_loss\n            elif loss == 'tanh_margin':\n                self.loss = self._tanh_margin_loss\n            else:\n                raise ValueError(f'Unknown loss `{loss}`')\n        else:\n            self.loss = loss\n\n        self.is_undirected = is_undirected\n        self.log = log\n        self.metric = metric or self.loss\n\n        self.epochs_resampling = epochs_resampling\n        self.lr = lr\n\n        self.coeffs.update(kwargs)\n\n    def attack(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        labels: Tensor,\n        budget: int,\n        idx_attack: Optional[Tensor] = None,\n        **kwargs,\n    ) -> Tuple[Tensor, Tensor]:\n        \"\"\"Attack the predictions for the provided model and graph.\n\n        A subset of predictions may be specified with :attr:`idx_attack`. The\n        attack is allowed to flip (i.e. add or delete) :attr:`budget` edges and\n        will return the strongest perturbation it can find. It returns both the\n        resulting perturbed :attr:`edge_index` as well as the perturbations.\n\n        Args:\n            x (torch.Tensor): The node feature matrix.\n            edge_index (torch.Tensor): The edge indices.\n            labels (torch.Tensor): The labels.\n            budget (int): The number of allowed perturbations (i.e.\n                number of edges that are flipped at most).\n            idx_attack (torch.Tensor, optional): Filter for predictions/labels.\n                Shape and type must match that it can index :attr:`labels`\n                and the model's predictions.\n            **kwargs (optional): Additional arguments passed to the GNN module.\n\n        :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`)\n        \"\"\"\n        self.model.eval()\n\n        self.device = x.device\n        assert kwargs.get('edge_weight') is None\n        edge_weight = torch.ones(edge_index.size(1), device=self.device)\n        self.edge_index = edge_index.cpu().clone()\n        self.edge_weight = edge_weight.cpu().clone()\n        self.num_nodes = x.size(0)\n\n        # For collecting attack statistics\n        self.attack_statistics = defaultdict(list)\n\n        # Prepare attack and define `self.iterable` to iterate over\n        step_sequence = self._prepare(budget)\n\n        # Loop over the epochs (Algorithm 1, line 5)\n        for step in tqdm(step_sequence, disable=not self.log, desc='Attack'):\n            loss, gradient = self._forward_and_gradient(\n                x, labels, idx_attack, **kwargs)\n\n            scalars = self._update(step, gradient, x, labels, budget,\n                                   idx_attack, **kwargs)\n\n            scalars['loss'] = loss.item()\n            self._append_statistics(scalars)\n\n        perturbed_edge_index, flipped_edges = self._close(\n            x, labels, budget, idx_attack, **kwargs)\n\n        assert flipped_edges.size(1) <= budget, (\n            f'# perturbed edges {flipped_edges.size(1)} '\n            f'exceeds budget {budget}')\n\n        return perturbed_edge_index, flipped_edges\n\n    def _prepare(self, budget: int) -> Iterable[int]:\n        \"\"\"Prepare attack.\"\"\"\n        if self.block_size <= budget:\n            raise ValueError(\n                f'The search space size ({self.block_size}) must be '\n                f'greater than the number of permutations ({budget})')\n\n        # For early stopping (not explicitly covered by pseudo code)\n        self.best_metric = float('-Inf')\n\n        # Sample initial search space (Algorithm 1, line 3-4)\n        self._sample_random_block(budget)\n\n        steps = range(self.epochs)\n        return steps\n\n    @torch.no_grad()\n    def _update(self, epoch: int, gradient: Tensor, x: Tensor, labels: Tensor,\n                budget: int, idx_attack: Optional[Tensor] = None,\n                **kwargs) -> Dict[str, float]:\n        \"\"\"Update edge weights given gradient.\"\"\"\n        # Gradient update step (Algorithm 1, line 7)\n        self.block_edge_weight = self._update_edge_weights(\n            budget, self.block_edge_weight, epoch, gradient)\n\n        # For monitoring\n        pmass_update = torch.clamp(self.block_edge_weight, 0, 1)\n        # Projection to stay within relaxed `L_0` budget\n        # (Algorithm 1, line 8)\n        self.block_edge_weight = self._project(budget, self.block_edge_weight,\n                                               self.coeffs['eps'])\n\n        # For monitoring\n        scalars = dict(\n            prob_mass_after_update=pmass_update.sum().item(),\n            prob_mass_after_update_max=pmass_update.max().item(),\n            prob_mass_after_projection=self.block_edge_weight.sum().item(),\n            prob_mass_after_projection_nonzero_weights=(\n                self.block_edge_weight > self.coeffs['eps']).sum().item(),\n            prob_mass_after_projection_max=self.block_edge_weight.max().item())\n        if not self.coeffs['with_early_stopping']:\n            return scalars\n\n        # Calculate metric after the current epoch (overhead\n        # for monitoring and early stopping)\n        topk_block_edge_weight = torch.zeros_like(self.block_edge_weight)\n        topk_block_edge_weight[torch.topk(self.block_edge_weight,\n                                          budget).indices] = 1\n        edge_index, edge_weight = self._get_modified_adj(\n            self.edge_index, self.edge_weight, self.block_edge_index,\n            topk_block_edge_weight)\n        prediction = self._forward(x, edge_index, edge_weight, **kwargs)\n        metric = self.metric(prediction, labels, idx_attack)\n\n        # Save best epoch for early stopping\n        # (not explicitly covered by pseudo code)\n        if metric > self.best_metric:\n            self.best_metric = metric\n            self.best_block = self.current_block.cpu().clone()\n            self.best_edge_index = self.block_edge_index.cpu().clone()\n            self.best_pert_edge_weight = self.block_edge_weight.cpu().clone()\n\n        # Resampling of search space (Algorithm 1, line 9-14)\n        if epoch < self.epochs_resampling - 1:\n            self._resample_random_block(budget)\n        elif epoch == self.epochs_resampling - 1:\n            # Retrieve best epoch if early stopping is active\n            # (not explicitly covered by pseudo code)\n            self.current_block = self.best_block.to(self.device)\n            self.block_edge_index = self.best_edge_index.to(self.device)\n            block_edge_weight = self.best_pert_edge_weight.clone()\n            self.block_edge_weight = block_edge_weight.to(self.device)\n\n        scalars['metric'] = metric.item()\n        return scalars\n\n    @torch.no_grad()\n    def _close(self, x: Tensor, labels: Tensor, budget: int,\n               idx_attack: Optional[Tensor] = None,\n               **kwargs) -> Tuple[Tensor, Tensor]:\n        \"\"\"Clean up and prepare return argument.\"\"\"\n        # Retrieve best epoch if early stopping is active\n        # (not explicitly covered by pseudo code)\n        if self.coeffs['with_early_stopping']:\n            self.current_block = self.best_block.to(self.device)\n            self.block_edge_index = self.best_edge_index.to(self.device)\n            self.block_edge_weight = self.best_pert_edge_weight.to(self.device)\n\n        # Sample final discrete graph (Algorithm 1, line 16)\n        edge_index, flipped_edges = self._sample_final_edges(\n            x, labels, budget, idx_attack=idx_attack, **kwargs)\n\n        return edge_index, flipped_edges\n\n    def _forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor,\n                 **kwargs) -> Tensor:\n        \"\"\"Forward model.\"\"\"\n        return self.model(x, edge_index, edge_weight, **kwargs)\n\n    def _forward_and_gradient(self, x: Tensor, labels: Tensor,\n                              idx_attack: Optional[Tensor] = None,\n                              **kwargs) -> Tuple[Tensor, Tensor]:\n        \"\"\"Forward and update edge weights.\"\"\"\n        self.block_edge_weight.requires_grad = True\n\n        # Retrieve sparse perturbed adjacency matrix `A \\oplus p_{t-1}`\n        # (Algorithm 1, line 6 / Algorithm 2, line 7)\n        edge_index, edge_weight = self._get_modified_adj(\n            self.edge_index, self.edge_weight, self.block_edge_index,\n            self.block_edge_weight)\n\n        # Get prediction (Algorithm 1, line 6 / Algorithm 2, line 7)\n        prediction = self._forward(x, edge_index, edge_weight, **kwargs)\n        # Calculate loss combining all each node\n        # (Algorithm 1, line 7 / Algorithm 2, line 8)\n        loss = self.loss(prediction, labels, idx_attack)\n        # Retrieve gradient towards the current block\n        # (Algorithm 1, line 7 / Algorithm 2, line 8)\n        gradient = torch.autograd.grad(loss, self.block_edge_weight)[0]\n\n        return loss, gradient\n\n    def _get_modified_adj(self, edge_index: Tensor, edge_weight: Tensor,\n                          block_edge_index: Tensor,\n                          block_edge_weight: Tensor) -> Tuple[Tensor, Tensor]:\n        \"\"\"Merges adjacency matrix with current block (incl. weights).\"\"\"\n        if self.is_undirected:\n            block_edge_index, block_edge_weight = to_undirected(\n                block_edge_index, block_edge_weight, num_nodes=self.num_nodes,\n                reduce='mean')\n\n        modified_edge_index = torch.cat(\n            (edge_index.to(self.device), block_edge_index), dim=-1)\n        modified_edge_weight = torch.cat(\n            (edge_weight.to(self.device), block_edge_weight))\n\n        modified_edge_index, modified_edge_weight = coalesce(\n            modified_edge_index, modified_edge_weight,\n            num_nodes=self.num_nodes, reduce='sum')\n\n        # Allow (soft) removal of edges\n        is_edge_in_clean_adj = modified_edge_weight > 1\n        modified_edge_weight[is_edge_in_clean_adj] = (\n            2 - modified_edge_weight[is_edge_in_clean_adj])\n\n        return modified_edge_index, modified_edge_weight\n\n    def _filter_self_loops_in_block(self, with_weight: bool):\n        is_not_sl = self.block_edge_index[0] != self.block_edge_index[1]\n        self.current_block = self.current_block[is_not_sl]\n        self.block_edge_index = self.block_edge_index[:, is_not_sl]\n        if with_weight:\n            self.block_edge_weight = self.block_edge_weight[is_not_sl]\n\n    def _sample_random_block(self, budget: int = 0):\n        for _ in range(self.coeffs['max_trials_sampling']):\n            num_possible_edges = self._num_possible_edges(\n                self.num_nodes, self.is_undirected)\n            self.current_block = torch.randint(num_possible_edges,\n                                               (self.block_size, ),\n                                               device=self.device)\n            self.current_block = torch.unique(self.current_block, sorted=True)\n            if self.is_undirected:\n                self.block_edge_index = self._linear_to_triu_idx(\n                    self.num_nodes, self.current_block)\n            else:\n                self.block_edge_index = self._linear_to_full_idx(\n                    self.num_nodes, self.current_block)\n                self._filter_self_loops_in_block(with_weight=False)\n\n            self.block_edge_weight = torch.full(self.current_block.shape,\n                                                self.coeffs['eps'],\n                                                device=self.device)\n            if self.current_block.size(0) >= budget:\n                return\n        raise RuntimeError('Sampling random block was not successful. '\n                           'Please decrease `budget`.')\n\n    def _resample_random_block(self, budget: int):\n        # Keep at most half of the block (i.e. resample low weights)\n        sorted_idx = torch.argsort(self.block_edge_weight)\n        keep_above = (self.block_edge_weight\n                      <= self.coeffs['eps']).sum().long()\n        if keep_above < sorted_idx.size(0) // 2:\n            keep_above = sorted_idx.size(0) // 2\n        sorted_idx = sorted_idx[keep_above:]\n\n        self.current_block = self.current_block[sorted_idx]\n\n        # Sample until enough edges were drawn\n        for _ in range(self.coeffs['max_trials_sampling']):\n            n_edges_resample = self.block_size - self.current_block.size(0)\n            num_possible_edges = self._num_possible_edges(\n                self.num_nodes, self.is_undirected)\n            lin_index = torch.randint(num_possible_edges, (n_edges_resample, ),\n                                      device=self.device)\n\n            current_block = torch.cat((self.current_block, lin_index))\n            self.current_block, unique_idx = torch.unique(\n                current_block, sorted=True, return_inverse=True)\n\n            if self.is_undirected:\n                self.block_edge_index = self._linear_to_triu_idx(\n                    self.num_nodes, self.current_block)\n            else:\n                self.block_edge_index = self._linear_to_full_idx(\n                    self.num_nodes, self.current_block)\n\n            # Merge existing weights with new edge weights\n            block_edge_weight_prev = self.block_edge_weight[sorted_idx]\n            self.block_edge_weight = torch.full(self.current_block.shape,\n                                                self.coeffs['eps'],\n                                                device=self.device)\n            self.block_edge_weight[\n                unique_idx[:sorted_idx.size(0)]] = block_edge_weight_prev\n\n            if not self.is_undirected:\n                self._filter_self_loops_in_block(with_weight=True)\n\n            if self.current_block.size(0) > budget:\n                return\n        raise RuntimeError('Sampling random block was not successful.'\n                           'Please decrease `budget`.')\n\n    def _sample_final_edges(self, x: Tensor, labels: Tensor, budget: int,\n                            idx_attack: Optional[Tensor] = None,\n                            **kwargs) -> Tuple[Tensor, Tensor]:\n        best_metric = float('-Inf')\n        block_edge_weight = self.block_edge_weight\n        block_edge_weight[block_edge_weight <= self.coeffs['eps']] = 0\n\n        for i in range(self.coeffs['max_final_samples']):\n            if i == 0:\n                # In first iteration employ top k heuristic instead of sampling\n                sampled_edges = torch.zeros_like(block_edge_weight)\n                sampled_edges[torch.topk(block_edge_weight,\n                                         budget).indices] = 1\n            else:\n                sampled_edges = torch.bernoulli(block_edge_weight).float()\n\n            if sampled_edges.sum() > budget:\n                # Allowed budget is exceeded\n                continue\n\n            edge_index, edge_weight = self._get_modified_adj(\n                self.edge_index, self.edge_weight, self.block_edge_index,\n                sampled_edges)\n            prediction = self._forward(x, edge_index, edge_weight, **kwargs)\n            metric = self.metric(prediction, labels, idx_attack)\n\n            # Save best sample\n            if metric > best_metric:\n                best_metric = metric\n                self.block_edge_weight = sampled_edges.clone().cpu()\n\n        # Recover best sample\n        self.block_edge_weight = self.block_edge_weight.to(self.device)\n        flipped_edges = self.block_edge_index[:, self.block_edge_weight > 0]\n\n        edge_index, edge_weight = self._get_modified_adj(\n            self.edge_index, self.edge_weight, self.block_edge_index,\n            self.block_edge_weight)\n        edge_mask = edge_weight == 1\n        edge_index = edge_index[:, edge_mask]\n\n        return edge_index, flipped_edges\n\n    def _update_edge_weights(self, budget: int, block_edge_weight: Tensor,\n                             epoch: int, gradient: Tensor) -> Tensor:\n        # The learning rate is refined heuristically, s.t. (1) it is\n        # independent of the number of perturbations (assuming an undirected\n        # adjacency matrix) and (2) to decay learning rate during fine-tuning\n        # (i.e. fixed search space).\n        lr = (budget / self.num_nodes * self.lr /\n              np.sqrt(max(0, epoch - self.epochs_resampling) + 1))\n        return block_edge_weight + lr * gradient\n\n    @staticmethod\n    def _project(budget: int, values: Tensor, eps: float = 1e-7) -> Tensor:\n        r\"\"\"Project :obj:`values`:\n        :math:`budget \\ge \\sum \\Pi_{[0, 1]}(\\text{values})`.\n        \"\"\"\n        if torch.clamp(values, 0, 1).sum() > budget:\n            left = (values - 1).min()\n            right = values.max()\n            miu = PRBCDAttack._bisection(values, left, right, budget)\n            values = values - miu\n        return torch.clamp(values, min=eps, max=1 - eps)\n\n    @staticmethod\n    def _bisection(edge_weights: Tensor, a: float, b: float, n_pert: int,\n                   eps=1e-5, max_iter=1e3) -> Tensor:\n        \"\"\"Bisection search for projection.\"\"\"\n        def shift(offset: float):\n            return (torch.clamp(edge_weights - offset, 0, 1).sum() - n_pert)\n\n        miu = a\n        for _ in range(int(max_iter)):\n            miu = (a + b) / 2\n            # Check if middle point is root\n            if (shift(miu) == 0.0):\n                break\n            # Decide the side to repeat the steps\n            if (shift(miu) * shift(a) < 0):\n                b = miu\n            else:\n                a = miu\n            if ((b - a) <= eps):\n                break\n        return miu\n\n    @staticmethod\n    def _num_possible_edges(n: int, is_undirected: bool) -> int:\n        \"\"\"Determine number of possible edges for graph.\"\"\"\n        if is_undirected:\n            return n * (n - 1) // 2\n        else:\n            return int(n**2)  # We filter self-loops later\n\n    @staticmethod\n    def _linear_to_triu_idx(n: int, lin_idx: Tensor) -> Tensor:\n        \"\"\"Linear index to upper triangular matrix without diagonal. This is\n        similar to\n        https://stackoverflow.com/questions/242711/algorithm-for-index-numbers-of-triangular-matrix-coefficients/28116498#28116498\n        with number nodes decremented and col index incremented by one.\n        \"\"\"\n        nn = n * (n - 1)\n        row_idx = n - 2 - torch.floor(\n            torch.sqrt(-8 * lin_idx.double() + 4 * nn - 7) / 2.0 - 0.5).long()\n        col_idx = 1 + lin_idx + row_idx - nn // 2 + torch.div(\n            (n - row_idx) * (n - row_idx - 1), 2, rounding_mode='floor')\n        return torch.stack((row_idx, col_idx))\n\n    @staticmethod\n    def _linear_to_full_idx(n: int, lin_idx: Tensor) -> Tensor:\n        \"\"\"Linear index to dense matrix including diagonal.\"\"\"\n        row_idx = torch.div(lin_idx, n, rounding_mode='floor')\n        col_idx = lin_idx % n\n        return torch.stack((row_idx, col_idx))\n\n    @staticmethod\n    def _margin_loss(score: Tensor, labels: Tensor,\n                     idx_mask: Optional[Tensor] = None,\n                     reduce: Optional[str] = None) -> Tensor:\n        r\"\"\"Margin loss between true score and highest non-target score.\n\n        .. math::\n            m = - s_{y} + max_{y' \\ne y} s_{y'}\n\n        where :math:`m` is the margin :math:`s` the score and :math:`y` the\n        labels.\n\n        Args:\n            score (Tensor): Some score (*e.g.*, logits) of shape\n                :obj:`[n_elem, dim]`.\n            labels (LongTensor): The labels of shape :obj:`[n_elem]`.\n            idx_mask (Tensor, optional): To select subset of `score` and\n                `labels` of shape :obj:`[n_select]`. Defaults to None.\n            reduce (str, optional): if :obj:`mean` the result is aggregated.\n                Otherwise, return element wise margin.\n\n        :rtype: (Tensor)\n        \"\"\"\n        if idx_mask is not None:\n            score = score[idx_mask]\n            labels = labels[idx_mask]\n\n        linear_idx = torch.arange(score.size(0), device=score.device)\n        true_score = score[linear_idx, labels]\n\n        score = score.clone()\n        score[linear_idx, labels] = float('-Inf')\n        best_non_target_score = score.amax(dim=-1)\n\n        margin_ = best_non_target_score - true_score\n        if reduce is None:\n            return margin_\n        return margin_.mean()\n\n    @staticmethod\n    def _tanh_margin_loss(prediction: Tensor, labels: Tensor,\n                          idx_mask: Optional[Tensor] = None) -> Tensor:\n        \"\"\"Calculate tanh margin loss, a node-classification loss that focuses\n        on nodes next to decision boundary.\n\n        Args:\n            prediction (Tensor): Prediction of shape :obj:`[n_elem, dim]`.\n            labels (LongTensor): The labels of shape :obj:`[n_elem]`.\n            idx_mask (Tensor, optional): To select subset of `score` and\n                `labels` of shape :obj:`[n_select]`. Defaults to None.\n\n        :rtype: (Tensor)\n        \"\"\"\n        log_prob = F.log_softmax(prediction, dim=-1)\n        margin_ = GRBCDAttack._margin_loss(log_prob, labels, idx_mask)\n        loss = torch.tanh(margin_).mean()\n        return loss\n\n    @staticmethod\n    def _probability_margin_loss(prediction: Tensor, labels: Tensor,\n                                 idx_mask: Optional[Tensor] = None) -> Tensor:\n        \"\"\"Calculate probability margin loss, a node-classification loss that\n        focuses  on nodes next to decision boundary. See `Are Defenses for\n        Graph Neural Networks Robust?\n        <https://www.cs.cit.tum.de/daml/are-gnn-defenses-robust>`_ for details.\n\n        Args:\n            prediction (Tensor): Prediction of shape :obj:`[n_elem, dim]`.\n            labels (LongTensor): The labels of shape :obj:`[n_elem]`.\n            idx_mask (Tensor, optional): To select subset of `score` and\n                `labels` of shape :obj:`[n_select]`. Defaults to None.\n\n        :rtype: (Tensor)\n        \"\"\"\n        prob = F.softmax(prediction, dim=-1)\n        margin_ = GRBCDAttack._margin_loss(prob, labels, idx_mask)\n        return margin_.mean()\n\n    @staticmethod\n    def _masked_cross_entropy(log_prob: Tensor, labels: Tensor,\n                              idx_mask: Optional[Tensor] = None) -> Tensor:\n        \"\"\"Calculate masked cross entropy loss, a node-classification loss that\n        focuses on nodes next to decision boundary.\n\n        Args:\n            log_prob (Tensor): Log probabilities of shape :obj:`[n_elem, dim]`.\n            labels (LongTensor): The labels of shape :obj:`[n_elem]`.\n            idx_mask (Tensor, optional): To select subset of `score` and\n                `labels` of shape :obj:`[n_select]`. Defaults to None.\n\n        :rtype: (Tensor)\n        \"\"\"\n        if idx_mask is not None:\n            log_prob = log_prob[idx_mask]\n            labels = labels[idx_mask]\n\n        is_correct = log_prob.argmax(-1) == labels\n        if is_correct.any():\n            log_prob = log_prob[is_correct]\n            labels = labels[is_correct]\n\n        return F.nll_loss(log_prob, labels)\n\n    def _append_statistics(self, mapping: Dict[str, Any]):\n        for key, value in mapping.items():\n            self.attack_statistics[key].append(value)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n\n\nclass GRBCDAttack(PRBCDAttack):\n    r\"\"\"The Greedy Randomized Block Coordinate Descent (GRBCD) adversarial\n    attack from the `Robustness of Graph Neural Networks at Scale\n    <https://www.cs.cit.tum.de/daml/robustness-of-gnns-at-scale>`_ paper.\n\n    GRBCD shares most of the properties and requirements with\n    :class:`PRBCDAttack`. It also uses an efficient gradient based approach.\n    However, it greedily flips edges based on the gradient towards the\n    adjacency matrix.\n\n    .. note::\n        For examples of using the GRBCD Attack, see\n        `examples/contrib/rbcd_attack.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        contrib/rbcd_attack.py>`_\n        for a test time attack (evasion).\n\n    Args:\n        model (torch.nn.Module): The GNN module to assess.\n        block_size (int): Number of randomly selected elements in the\n            adjacency matrix to consider.\n        epochs (int, optional): Number of epochs (aborts early if\n            :obj:`mode='greedy'` and budget is satisfied) (default: :obj:`125`)\n        loss (str or callable, optional): A loss to quantify the \"strength\" of\n            an attack. Note that this function must match the output format of\n            :attr:`model`. By default, it is assumed that the task is\n            classification and that the model returns raw predictions (*i.e.*,\n            no output activation) or uses :obj:`logsoftmax`. Moreover, and the\n            number of predictions should match the number of labels passed to\n            :attr:`attack`. Either pass Callable or one of: :obj:`'masked'`,\n            :obj:`'margin'`, :obj:`'prob_margin'`, :obj:`'tanh_margin'`.\n            (default: :obj:`'masked'`)\n        is_undirected (bool, optional): If :obj:`True` the graph is\n            assumed to be undirected. (default: :obj:`True`)\n        log (bool, optional): If set to :obj:`False`, will not log any learning\n            progress. (default: :obj:`True`)\n    \"\"\"\n    coeffs = {'max_trials_sampling': 20, 'eps': 1e-7}\n\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        block_size: int,\n        epochs: int = 125,\n        loss: Optional[Union[str, LOSS_TYPE]] = 'masked',\n        is_undirected: bool = True,\n        log: bool = True,\n        **kwargs,\n    ):\n        super().__init__(model, block_size, epochs, loss=loss,\n                         is_undirected=is_undirected, log=log, **kwargs)\n\n    @torch.no_grad()\n    def _prepare(self, budget: int) -> List[int]:\n        \"\"\"Prepare attack.\"\"\"\n        self.flipped_edges = self.edge_index.new_empty(2, 0).to(self.device)\n\n        # Determine the number of edges to be flipped in each attach step/epoch\n        step_size = budget // self.epochs\n        if step_size > 0:\n            steps = self.epochs * [step_size]\n            for i in range(budget % self.epochs):\n                steps[i] += 1\n        else:\n            steps = [1] * budget\n\n        # Sample initial search space (Algorithm 2, line 3-4)\n        self._sample_random_block(step_size)\n\n        return steps\n\n    @torch.no_grad()\n    def _update(self, step_size: int, gradient: Tensor, *args,\n                **kwargs) -> Dict[str, Any]:\n        \"\"\"Update edge weights given gradient.\"\"\"\n        _, topk_edge_index = torch.topk(gradient, step_size)\n\n        flip_edge_index = self.block_edge_index[:, topk_edge_index]\n        flip_edge_weight = torch.ones_like(flip_edge_index[0],\n                                           dtype=torch.float32)\n\n        self.flipped_edges = torch.cat((self.flipped_edges, flip_edge_index),\n                                       axis=-1)\n\n        if self.is_undirected:\n            flip_edge_index, flip_edge_weight = to_undirected(\n                flip_edge_index, flip_edge_weight, num_nodes=self.num_nodes,\n                reduce='mean')\n        edge_index = torch.cat(\n            (self.edge_index.to(self.device), flip_edge_index.to(self.device)),\n            dim=-1)\n        edge_weight = torch.cat((self.edge_weight.to(self.device),\n                                 flip_edge_weight.to(self.device)))\n        edge_index, edge_weight = coalesce(edge_index, edge_weight,\n                                           num_nodes=self.num_nodes,\n                                           reduce='sum')\n\n        is_one_mask = torch.isclose(edge_weight, torch.tensor(1.))\n        self.edge_index = edge_index[:, is_one_mask]\n        self.edge_weight = edge_weight[is_one_mask]\n        # self.edge_weight = torch.ones_like(self.edge_weight)\n        assert self.edge_index.size(1) == self.edge_weight.size(0)\n\n        # Sample initial search space (Algorithm 2, line 3-4)\n        self._sample_random_block(step_size)\n\n        # Return debug information\n        scalars = {\n            'number_positive_entries_in_gradient': (gradient > 0).sum().item()\n        }\n        return scalars\n\n    def _close(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:\n        \"\"\"Clean up and prepare return argument.\"\"\"\n        return self.edge_index, self.flipped_edges\n"
  },
  {
    "path": "torch_geometric/contrib/transforms/__init__.py",
    "content": "__all__ = classes = []\n"
  },
  {
    "path": "torch_geometric/data/__init__.py",
    "content": "# flake8: noqa\n\nimport torch\nimport torch_geometric.typing\n\nfrom .feature_store import FeatureStore, TensorAttr\nfrom .graph_store import GraphStore, EdgeAttr, EdgeLayout\nfrom .data import Data\nfrom .hetero_data import HeteroData\nfrom .batch import Batch\nfrom .temporal import TemporalData\nfrom .database import Database, SQLiteDatabase, RocksDatabase\nfrom .dataset import Dataset\nfrom .in_memory_dataset import InMemoryDataset\nfrom .on_disk_dataset import OnDiskDataset\nfrom .makedirs import makedirs\nfrom .download import download_url, download_google_url\nfrom .extract import extract_tar, extract_zip, extract_bz2, extract_gz\n\nfrom torch_geometric.lazy_loader import LazyLoader\n\ndata_classes = [\n    'Data',\n    'HeteroData',\n    'Batch',\n    'TemporalData',\n    'Dataset',\n    'InMemoryDataset',\n    'OnDiskDataset',\n]\n\nremote_backend_classes = [\n    'FeatureStore',\n    'GraphStore',\n    'TensorAttr',\n    'EdgeAttr',\n]\n\ndatabase_classes = [\n    'Database',\n    'SQLiteDatabase',\n    'RocksDatabase',\n]\n\nhelper_functions = [\n    'makedirs',\n    'download_url',\n    'download_google_url',\n    'extract_tar',\n    'extract_zip',\n    'extract_bz2',\n    'extract_gz',\n]\n\n__all__ = data_classes + remote_backend_classes + helper_functions\n\nlightning = LazyLoader('lightning', globals(),\n                       'torch_geometric.data.lightning')\n\nfrom torch_geometric.deprecation import deprecated\nfrom torch_geometric.loader import NeighborSampler\nfrom torch_geometric.loader import ClusterData\nfrom torch_geometric.loader import ClusterLoader\nfrom torch_geometric.loader import GraphSAINTSampler\nfrom torch_geometric.loader import GraphSAINTNodeSampler\nfrom torch_geometric.loader import GraphSAINTEdgeSampler\nfrom torch_geometric.loader import GraphSAINTRandomWalkSampler\nfrom torch_geometric.loader import ShaDowKHopSampler\nfrom torch_geometric.loader import RandomNodeLoader\nfrom torch_geometric.loader import DataLoader\nfrom torch_geometric.loader import DataListLoader\nfrom torch_geometric.loader import DenseDataLoader\n\n# Serialization ###############################################################\n\nif torch_geometric.typing.WITH_PT24:\n    torch.serialization.add_safe_globals([\n        Data,\n        HeteroData,\n        TemporalData,\n        ClusterData,\n        TensorAttr,\n        EdgeAttr,\n        EdgeLayout,\n    ])\n\n# Deprecations ################################################################\n\nNeighborSampler = deprecated(  # type: ignore\n    details=\"use 'loader.NeighborSampler' instead\",\n    func_name='data.NeighborSampler',\n)(NeighborSampler)\nClusterData = deprecated(  # type: ignore\n    details=\"use 'loader.ClusterData' instead\",\n    func_name='data.ClusterData',\n)(ClusterData)\nClusterLoader = deprecated(  # type: ignore\n    details=\"use 'loader.ClusterLoader' instead\",\n    func_name='data.ClusterLoader',\n)(ClusterLoader)\nGraphSAINTSampler = deprecated(  # type: ignore\n    details=\"use 'loader.GraphSAINTSampler' instead\",\n    func_name='data.GraphSAINTSampler',\n)(GraphSAINTSampler)\nGraphSAINTNodeSampler = deprecated(  # type: ignore\n    details=\"use 'loader.GraphSAINTNodeSampler' instead\",\n    func_name='data.GraphSAINTNodeSampler',\n)(GraphSAINTNodeSampler)\nGraphSAINTEdgeSampler = deprecated(  # type: ignore\n    details=\"use 'loader.GraphSAINTEdgeSampler' instead\",\n    func_name='data.GraphSAINTEdgeSampler',\n)(GraphSAINTEdgeSampler)\nGraphSAINTRandomWalkSampler = deprecated(  # type: ignore\n    details=\"use 'loader.GraphSAINTRandomWalkSampler' instead\",\n    func_name='data.GraphSAINTRandomWalkSampler',\n)(GraphSAINTRandomWalkSampler)\nShaDowKHopSampler = deprecated(  # type: ignore\n    details=\"use 'loader.ShaDowKHopSampler' instead\",\n    func_name='data.ShaDowKHopSampler',\n)(ShaDowKHopSampler)\nRandomNodeSampler = deprecated(\n    details=\"use 'loader.RandomNodeLoader' instead\",\n    func_name='data.RandomNodeSampler',\n)(RandomNodeLoader)\nDataLoader = deprecated(  # type: ignore\n    details=\"use 'loader.DataLoader' instead\",\n    func_name='data.DataLoader',\n)(DataLoader)\nDataListLoader = deprecated(  # type: ignore\n    details=\"use 'loader.DataListLoader' instead\",\n    func_name='data.DataListLoader',\n)(DataListLoader)\nDenseDataLoader = deprecated(  # type: ignore\n    details=\"use 'loader.DenseDataLoader' instead\",\n    func_name='data.DenseDataLoader',\n)(DenseDataLoader)\n"
  },
  {
    "path": "torch_geometric/data/batch.py",
    "content": "import inspect\nfrom collections.abc import Sequence\nfrom typing import Any, List, Optional, Type, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\nfrom typing_extensions import Self\n\nfrom torch_geometric.data.collate import collate\nfrom torch_geometric.data.data import BaseData, Data\nfrom torch_geometric.data.dataset import IndexType\nfrom torch_geometric.data.separate import separate\n\n\nclass DynamicInheritance(type):\n    # A meta class that sets the base class of a `Batch` object, e.g.:\n    # * `Batch(Data)` in case `Data` objects are batched together\n    # * `Batch(HeteroData)` in case `HeteroData` objects are batched together\n    def __call__(cls, *args: Any, **kwargs: Any) -> Any:\n        base_cls = kwargs.pop('_base_cls', Data)\n\n        if issubclass(base_cls, Batch):\n            new_cls = base_cls\n        else:\n            name = f'{base_cls.__name__}{cls.__name__}'\n\n            # NOTE `MetaResolver` is necessary to resolve metaclass conflict\n            # problems between `DynamicInheritance` and the metaclass of\n            # `base_cls`. In particular, it creates a new common metaclass\n            # from the defined metaclasses.\n            class MetaResolver(type(cls), type(base_cls)):  # type: ignore\n                pass\n\n            if name not in globals():\n                globals()[name] = MetaResolver(name, (cls, base_cls), {})\n            new_cls = globals()[name]\n\n        params = list(inspect.signature(base_cls.__init__).parameters.items())\n        for i, (k, v) in enumerate(params[1:]):\n            if k == 'args' or k == 'kwargs':\n                continue\n            if i < len(args) or k in kwargs:\n                continue\n            if v.default is not inspect.Parameter.empty:\n                continue\n            kwargs[k] = None\n\n        return super(DynamicInheritance, new_cls).__call__(*args, **kwargs)\n\n\nclass DynamicInheritanceGetter:\n    def __call__(self, cls: Type, base_cls: Type) -> Self:\n        return cls(_base_cls=base_cls)\n\n\nclass Batch(metaclass=DynamicInheritance):\n    r\"\"\"A data object describing a batch of graphs as one big (disconnected)\n    graph.\n    Inherits from :class:`torch_geometric.data.Data` or\n    :class:`torch_geometric.data.HeteroData`.\n    In addition, single graphs can be identified via the assignment vector\n    :obj:`batch`, which maps each node to its respective graph identifier.\n\n    :pyg:`PyG` allows modification to the underlying batching procedure by\n    overwriting the :meth:`~Data.__inc__` and :meth:`~Data.__cat_dim__`\n    functionalities.\n    The :meth:`~Data.__inc__` method defines the incremental count between two\n    consecutive graph attributes.\n    By default, :pyg:`PyG` increments attributes by the number of nodes\n    whenever their attribute names contain the substring :obj:`index`\n    (for historical reasons), which comes in handy for attributes such as\n    :obj:`edge_index` or :obj:`node_index`.\n    However, note that this may lead to unexpected behavior for attributes\n    whose names contain the substring :obj:`index` but should not be\n    incremented.\n    To make sure, it is best practice to always double-check the output of\n    batching.\n    Furthermore, :meth:`~Data.__cat_dim__` defines in which dimension graph\n    tensors of the same attribute should be concatenated together.\n    \"\"\"\n    @classmethod\n    def from_data_list(\n        cls,\n        data_list: List[BaseData],\n        follow_batch: Optional[List[str]] = None,\n        exclude_keys: Optional[List[str]] = None,\n    ) -> Self:\n        r\"\"\"Constructs a :class:`~torch_geometric.data.Batch` object from a\n        list of :class:`~torch_geometric.data.Data` or\n        :class:`~torch_geometric.data.HeteroData` objects.\n        The assignment vector :obj:`batch` is created on the fly.\n        In addition, creates assignment vectors for each key in\n        :obj:`follow_batch`.\n        Will exclude any keys given in :obj:`exclude_keys`.\n        \"\"\"\n        batch, slice_dict, inc_dict = collate(\n            cls,\n            data_list=data_list,\n            increment=True,\n            add_batch=not isinstance(data_list[0], Batch),\n            follow_batch=follow_batch,\n            exclude_keys=exclude_keys,\n        )\n\n        batch._num_graphs = len(data_list)  # type: ignore\n        batch._slice_dict = slice_dict  # type: ignore\n        batch._inc_dict = inc_dict  # type: ignore\n\n        return batch\n\n    def get_example(self, idx: int) -> BaseData:\n        r\"\"\"Gets the :class:`~torch_geometric.data.Data` or\n        :class:`~torch_geometric.data.HeteroData` object at index :obj:`idx`.\n        The :class:`~torch_geometric.data.Batch` object must have been created\n        via :meth:`from_data_list` in order to be able to reconstruct the\n        initial object.\n        \"\"\"\n        if not hasattr(self, '_slice_dict'):\n            raise RuntimeError(\n                \"Cannot reconstruct 'Data' object from 'Batch' because \"\n                \"'Batch' was not created via 'Batch.from_data_list()'\")\n\n        data = separate(\n            cls=self.__class__.__bases__[-1],\n            batch=self,\n            idx=idx,\n            slice_dict=self._slice_dict,\n            inc_dict=self._inc_dict,\n            decrement=True,\n        )\n\n        return data\n\n    def index_select(self, idx: IndexType) -> List[BaseData]:\n        r\"\"\"Creates a subset of :class:`~torch_geometric.data.Data` or\n        :class:`~torch_geometric.data.HeteroData` objects from specified\n        indices :obj:`idx`.\n        Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a\n        list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type\n        long or bool.\n        The :class:`~torch_geometric.data.Batch` object must have been created\n        via :meth:`from_data_list` in order to be able to reconstruct the\n        initial objects.\n        \"\"\"\n        index: Sequence[int]\n        if isinstance(idx, slice):\n            index = list(range(self.num_graphs)[idx])\n\n        elif isinstance(idx, Tensor) and idx.dtype == torch.long:\n            index = idx.flatten().tolist()\n\n        elif isinstance(idx, Tensor) and idx.dtype == torch.bool:\n            index = idx.flatten().nonzero(as_tuple=False).flatten().tolist()\n\n        elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:\n            index = idx.flatten().tolist()\n\n        elif isinstance(idx, np.ndarray) and idx.dtype == bool:\n            index = idx.flatten().nonzero()[0].flatten().tolist()\n\n        elif isinstance(idx, Sequence) and not isinstance(idx, str):\n            index = idx\n\n        else:\n            raise IndexError(\n                f\"Only slices (':'), list, tuples, torch.tensor and \"\n                f\"np.ndarray of dtype long or bool are valid indices (got \"\n                f\"'{type(idx).__name__}')\")\n\n        return [self.get_example(i) for i in index]\n\n    def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any:\n        if (isinstance(idx, (int, np.integer))\n                or (isinstance(idx, Tensor) and idx.dim() == 0)\n                or (isinstance(idx, np.ndarray) and np.isscalar(idx))):\n            return self.get_example(idx)  # type: ignore\n        elif isinstance(idx, str) or (isinstance(idx, tuple)\n                                      and isinstance(idx[0], str)):\n            # Accessing attributes or node/edge types:\n            return super().__getitem__(idx)  # type: ignore\n        else:\n            return self.index_select(idx)\n\n    def to_data_list(self) -> List[BaseData]:\n        r\"\"\"Reconstructs the list of :class:`~torch_geometric.data.Data` or\n        :class:`~torch_geometric.data.HeteroData` objects from the\n        :class:`~torch_geometric.data.Batch` object.\n        The :class:`~torch_geometric.data.Batch` object must have been created\n        via :meth:`from_data_list` in order to be able to reconstruct the\n        initial objects.\n        \"\"\"\n        return [self.get_example(i) for i in range(self.num_graphs)]\n\n    @property\n    def num_graphs(self) -> int:\n        \"\"\"Returns the number of graphs in the batch.\"\"\"\n        if hasattr(self, '_num_graphs'):\n            return self._num_graphs\n        elif hasattr(self, 'ptr'):\n            return self.ptr.numel() - 1\n        elif hasattr(self, 'batch'):\n            return int(self.batch.max()) + 1\n        else:\n            raise ValueError(\"Can not infer the number of graphs\")\n\n    @property\n    def batch_size(self) -> int:\n        r\"\"\"Alias for :obj:`num_graphs`.\"\"\"\n        return self.num_graphs\n\n    def __len__(self) -> int:\n        return self.num_graphs\n\n    def __reduce__(self) -> Any:\n        state = self.__dict__.copy()\n        return DynamicInheritanceGetter(), self.__class__.__bases__, state\n"
  },
  {
    "path": "torch_geometric/data/collate.py",
    "content": "from collections import defaultdict\nfrom collections.abc import Mapping, Sequence\nfrom typing import (\n    Any,\n    Dict,\n    Iterable,\n    List,\n    Optional,\n    Tuple,\n    Type,\n    TypeVar,\n    Union,\n)\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import EdgeIndex, Index\nfrom torch_geometric.data.data import BaseData\nfrom torch_geometric.data.storage import BaseStorage, NodeStorage\nfrom torch_geometric.edge_index import SortOrder\nfrom torch_geometric.typing import (\n    SparseTensor,\n    TensorFrame,\n    torch_frame,\n    torch_sparse,\n)\nfrom torch_geometric.utils import cumsum, is_sparse, is_torch_sparse_tensor\nfrom torch_geometric.utils.sparse import cat\n\nT = TypeVar('T')\nSliceDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]]\nIncDictType = Dict[str, Union[Tensor, Dict[str, Tensor]]]\n\n\ndef collate(\n    cls: Type[T],\n    data_list: List[BaseData],\n    increment: bool = True,\n    add_batch: bool = True,\n    follow_batch: Optional[Iterable[str]] = None,\n    exclude_keys: Optional[Iterable[str]] = None,\n) -> Tuple[T, SliceDictType, IncDictType]:\n    # Collates a list of `data` objects into a single object of type `cls`.\n    # `collate` can handle both homogeneous and heterogeneous data objects by\n    # individually collating all their stores.\n    # In addition, `collate` can handle nested data structures such as\n    # dictionaries and lists.\n\n    if not isinstance(data_list, (list, tuple)):\n        # Materialize `data_list` to keep the `_parent` weakref alive.\n        data_list = list(data_list)\n\n    if cls != data_list[0].__class__:  # Dynamic inheritance.\n        out = cls(_base_cls=data_list[0].__class__)  # type: ignore\n    else:\n        out = cls()\n\n    # Create empty stores:\n    out.stores_as(data_list[0])  # type: ignore\n\n    follow_batch = set(follow_batch or [])\n    exclude_keys = set(exclude_keys or [])\n\n    # Group all storage objects of every data object in the `data_list` by key,\n    # i.e. `key_to_stores = { key: [store_1, store_2, ...], ... }`:\n    key_to_stores = defaultdict(list)\n    for data in data_list:\n        for store in data.stores:\n            key_to_stores[store._key].append(store)\n\n    # With this, we iterate over each list of storage objects and recursively\n    # collate all its attributes into a unified representation:\n\n    # We maintain two additional dictionaries:\n    # * `slice_dict` stores a compressed index representation of each attribute\n    #    and is needed to re-construct individual elements from mini-batches.\n    # * `inc_dict` stores how individual elements need to be incremented, e.g.,\n    #   `edge_index` is incremented by the cumulated sum of previous elements.\n    #   We also need to make use of `inc_dict` when re-constructuing individual\n    #   elements as attributes that got incremented need to be decremented\n    #   while separating to obtain original values.\n    device: Optional[torch.device] = None\n    slice_dict: SliceDictType = {}\n    inc_dict: IncDictType = {}\n    for out_store in out.stores:  # type: ignore\n        key = out_store._key\n        stores = key_to_stores[key]\n        for attr in stores[0].keys():\n\n            if attr in exclude_keys:  # Do not include top-level attribute.\n                continue\n\n            values = [store[attr] for store in stores]\n\n            # The `num_nodes` attribute needs special treatment, as we need to\n            # sum their values up instead of merging them to a list:\n            if attr == 'num_nodes':\n                out_store._num_nodes = values\n                out_store.num_nodes = sum(values)\n                continue\n\n            # Skip batching of `ptr` vectors for now:\n            if attr == 'ptr':\n                continue\n\n            # Collate attributes into a unified representation:\n            value, slices, incs = _collate(attr, values, data_list, stores,\n                                           increment)\n\n            # If parts of the data are already on GPU, make sure that auxiliary\n            # data like `batch` or `ptr` are also created on GPU:\n            if isinstance(value, Tensor) and value.is_cuda:\n                device = value.device\n\n            out_store[attr] = value\n\n            if key is not None:  # Heterogeneous:\n                store_slice_dict = slice_dict.get(key, {})\n                assert isinstance(store_slice_dict, dict)\n                store_slice_dict[attr] = slices\n                slice_dict[key] = store_slice_dict\n\n                store_inc_dict = inc_dict.get(key, {})\n                assert isinstance(store_inc_dict, dict)\n                store_inc_dict[attr] = incs\n                inc_dict[key] = store_inc_dict\n            else:  # Homogeneous:\n                slice_dict[attr] = slices\n                inc_dict[attr] = incs\n\n            # Add an additional batch vector for the given attribute:\n            if attr in follow_batch:\n                batch, ptr = _batch_and_ptr(slices, device)\n                out_store[f'{attr}_batch'] = batch\n                out_store[f'{attr}_ptr'] = ptr\n\n        # In case of node-level storages, we add a top-level batch vector it:\n        if (add_batch and isinstance(stores[0], NodeStorage)\n                and stores[0].can_infer_num_nodes):\n            repeats = [store.num_nodes or 0 for store in stores]\n            out_store.batch = repeat_interleave(repeats, device=device)\n            out_store.ptr = cumsum(torch.tensor(repeats, device=device))\n\n    return out, slice_dict, inc_dict\n\n\ndef _collate(\n    key: str,\n    values: List[Any],\n    data_list: List[BaseData],\n    stores: List[BaseStorage],\n    increment: bool,\n) -> Tuple[Any, Any, Any]:\n\n    elem = values[0]\n\n    if isinstance(elem, Tensor) and not is_sparse(elem):\n        # Concatenate a list of `torch.Tensor` along the `cat_dim`.\n        # NOTE: We need to take care of incrementing elements appropriately.\n        key = str(key)\n        cat_dim = data_list[0].__cat_dim__(key, elem, stores[0])\n        if cat_dim is None or elem.dim() == 0:\n            values = [value.unsqueeze(0) for value in values]\n        sizes = torch.tensor([value.size(cat_dim or 0) for value in values])\n        slices = cumsum(sizes)\n        if increment:\n            incs = get_incs(key, values, data_list, stores)\n            if incs.dim() > 1 or int(incs[-1]) != 0:\n                values = [\n                    value + inc.to(value.device)\n                    for value, inc in zip(values, incs)\n                ]\n        else:\n            incs = None\n\n        if getattr(elem, 'is_nested', False):\n            tensors = []\n            for nested_tensor in values:\n                tensors.extend(nested_tensor.unbind())\n            value = torch.nested.nested_tensor(tensors)\n\n            return value, slices, incs\n\n        out = None\n        if (torch.utils.data.get_worker_info() is not None\n                and not isinstance(elem, (Index, EdgeIndex))):\n            # Write directly into shared memory to avoid an extra copy:\n            numel = sum(value.numel() for value in values)\n            if torch_geometric.typing.WITH_PT20:\n                storage = elem.untyped_storage()._new_shared(\n                    numel * elem.element_size(), device=elem.device)\n            else:\n                storage = elem.storage()._new_shared(numel, device=elem.device)\n            shape = list(elem.size())\n            if cat_dim is None or elem.dim() == 0:\n                shape = [len(values)] + shape\n            else:\n                shape[cat_dim] = int(slices[-1])\n            out = elem.new(storage).resize_(*shape)\n\n        value = torch.cat(values, dim=cat_dim or 0, out=out)\n\n        if increment and isinstance(value, Index) and values[0].is_sorted:\n            # Check whether the whole `Index` is sorted:\n            if (value.diff() >= 0).all():\n                value._is_sorted = True\n\n        if increment and isinstance(value, EdgeIndex) and values[0].is_sorted:\n            # Check whether the whole `EdgeIndex` is sorted by row:\n            if values[0].is_sorted_by_row and (value[0].diff() >= 0).all():\n                value._sort_order = SortOrder.ROW\n            # Check whether the whole `EdgeIndex` is sorted by column:\n            elif values[0].is_sorted_by_col and (value[1].diff() >= 0).all():\n                value._sort_order = SortOrder.COL\n\n        return value, slices, incs\n\n    elif isinstance(elem, TensorFrame):\n        key = str(key)\n        sizes = torch.tensor([value.num_rows for value in values])\n        slices = cumsum(sizes)\n        value = torch_frame.cat(values, dim=0)\n        return value, slices, None\n\n    elif is_sparse(elem) and increment:\n        # Concatenate a list of `SparseTensor` along the `cat_dim`.\n        # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking.\n        key = str(key)\n        cat_dim = data_list[0].__cat_dim__(key, elem, stores[0])\n        cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim\n        repeats = [[value.size(dim) for dim in cat_dims] for value in values]\n        slices = cumsum(torch.tensor(repeats))\n        if is_torch_sparse_tensor(elem):\n            value = cat(values, dim=cat_dim)\n        else:\n            value = torch_sparse.cat(values, dim=cat_dim)\n        return value, slices, None\n\n    elif isinstance(elem, (int, float)):\n        # Convert a list of numerical values to a `torch.Tensor`.\n        value = torch.tensor(values)\n        if increment:\n            incs = get_incs(key, values, data_list, stores)\n            if int(incs[-1]) != 0:\n                value.add_(incs)\n        else:\n            incs = None\n        slices = torch.arange(len(values) + 1)\n        return value, slices, incs\n\n    elif isinstance(elem, Mapping):\n        # Recursively collate elements of dictionaries.\n        value_dict, slice_dict, inc_dict = {}, {}, {}\n        for key in elem.keys():\n            value_dict[key], slice_dict[key], inc_dict[key] = _collate(\n                key, [v[key] for v in values], data_list, stores, increment)\n        return value_dict, slice_dict, inc_dict\n\n    elif (isinstance(elem, Sequence) and not isinstance(elem, str)\n          and len(elem) > 0 and isinstance(elem[0], (Tensor, SparseTensor))):\n        # Recursively collate elements of lists.\n        value_list, slice_list, inc_list = [], [], []\n        for i in range(len(elem)):\n            value, slices, incs = _collate(key, [v[i] for v in values],\n                                           data_list, stores, increment)\n            value_list.append(value)\n            slice_list.append(slices)\n            inc_list.append(incs)\n        return value_list, slice_list, inc_list\n\n    else:\n        # Other-wise, just return the list of values as it is.\n        slices = torch.arange(len(values) + 1)\n        return values, slices, None\n\n\ndef _batch_and_ptr(\n    slices: Any,\n    device: Optional[torch.device] = None,\n) -> Tuple[Any, Any]:\n    if (isinstance(slices, Tensor) and slices.dim() == 1):\n        # Default case, turn slices tensor into batch.\n        repeats = slices[1:] - slices[:-1]\n        batch = repeat_interleave(repeats.tolist(), device=device)\n        ptr = cumsum(repeats.to(device))\n        return batch, ptr\n\n    elif isinstance(slices, Mapping):\n        # Recursively batch elements of dictionaries.\n        batch, ptr = {}, {}\n        for k, v in slices.items():\n            batch[k], ptr[k] = _batch_and_ptr(v, device)\n        return batch, ptr\n\n    elif (isinstance(slices, Sequence) and not isinstance(slices, str)\n          and isinstance(slices[0], Tensor)):\n        # Recursively batch elements of lists.\n        batch, ptr = [], []\n        for s in slices:\n            sub_batch, sub_ptr = _batch_and_ptr(s, device)\n            batch.append(sub_batch)\n            ptr.append(sub_ptr)\n        return batch, ptr\n\n    else:\n        # Failure of batching, usually due to slices.dim() != 1\n        return None, None\n\n\n###############################################################################\n\n\ndef repeat_interleave(\n    repeats: List[int],\n    device: Optional[torch.device] = None,\n) -> Tensor:\n    outs = [torch.full((n, ), i, device=device) for i, n in enumerate(repeats)]\n    return torch.cat(outs, dim=0)\n\n\ndef get_incs(key, values: List[Any], data_list: List[BaseData],\n             stores: List[BaseStorage]) -> Tensor:\n    repeats = [\n        data.__inc__(key, value, store)\n        for value, data, store in zip(values, data_list, stores)\n    ]\n    if isinstance(repeats[0], Tensor):\n        repeats = torch.stack(repeats, dim=0)\n    else:\n        repeats = torch.tensor(repeats)\n    return cumsum(repeats[:-1])\n"
  },
  {
    "path": "torch_geometric/data/data.py",
    "content": "import copy\nimport warnings\nfrom collections import defaultdict\nfrom collections.abc import Mapping, Sequence\nfrom dataclasses import dataclass\nfrom itertools import chain\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    List,\n    NamedTuple,\n    Optional,\n    Tuple,\n    Union,\n    overload,\n)\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\nfrom typing_extensions import Self\n\nfrom torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr\nfrom torch_geometric.data.feature_store import _FieldStatus\nfrom torch_geometric.data.graph_store import EdgeLayout\nfrom torch_geometric.data.storage import (\n    BaseStorage,\n    EdgeStorage,\n    GlobalStorage,\n    NodeStorage,\n)\nfrom torch_geometric.deprecation import deprecated\nfrom torch_geometric.index import Index\nfrom torch_geometric.typing import (\n    EdgeTensorType,\n    EdgeType,\n    FeatureTensorType,\n    NodeType,\n    OptTensor,\n    SparseTensor,\n    TensorFrame,\n)\nfrom torch_geometric.utils import is_sparse, select, subgraph\n\n\nclass BaseData:\n    def __getattr__(self, key: str) -> Any:\n        raise NotImplementedError\n\n    def __setattr__(self, key: str, value: Any):\n        raise NotImplementedError\n\n    def __delattr__(self, key: str):\n        raise NotImplementedError\n\n    def __getitem__(self, key: str) -> Any:\n        raise NotImplementedError\n\n    def __setitem__(self, key: str, value: Any):\n        raise NotImplementedError\n\n    def __delitem__(self, key: str):\n        raise NotImplementedError\n\n    def __copy__(self):\n        raise NotImplementedError\n\n    def __deepcopy__(self, memo):\n        raise NotImplementedError\n\n    def __repr__(self) -> str:\n        raise NotImplementedError\n\n    def stores_as(self, data: Self):\n        raise NotImplementedError\n\n    @property\n    def stores(self) -> List[BaseStorage]:\n        raise NotImplementedError\n\n    @property\n    def node_stores(self) -> List[NodeStorage]:\n        raise NotImplementedError\n\n    @property\n    def edge_stores(self) -> List[EdgeStorage]:\n        raise NotImplementedError\n\n    def to_dict(self) -> Dict[str, Any]:\n        r\"\"\"Returns a dictionary of stored key/value pairs.\"\"\"\n        raise NotImplementedError\n\n    def to_namedtuple(self) -> NamedTuple:\n        r\"\"\"Returns a :obj:`NamedTuple` of stored key/value pairs.\"\"\"\n        raise NotImplementedError\n\n    def update(self, data: Self) -> Self:\n        r\"\"\"Updates the data object with the elements from another data object.\n        Added elements will override existing ones (in case of duplicates).\n        \"\"\"\n        raise NotImplementedError\n\n    def concat(self, data: Self) -> Self:\n        r\"\"\"Concatenates :obj:`self` with another :obj:`data` object.\n        All values needs to have matching shapes at non-concat dimensions.\n        \"\"\"\n        out = copy.copy(self)\n        for store, other_store in zip(out.stores, data.stores):\n            store.concat(other_store)\n        return out\n\n    def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:\n        r\"\"\"Returns the dimension for which the value :obj:`value` of the\n        attribute :obj:`key` will get concatenated when creating mini-batches\n        using :class:`torch_geometric.loader.DataLoader`.\n\n        .. note::\n\n            This method is for internal use only, and should only be overridden\n            in case the mini-batch creation process is corrupted for a specific\n            attribute.\n        \"\"\"\n        raise NotImplementedError\n\n    def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:\n        r\"\"\"Returns the incremental count to cumulatively increase the value\n        :obj:`value` of the attribute :obj:`key` when creating mini-batches\n        using :class:`torch_geometric.loader.DataLoader`.\n\n        .. note::\n\n            This method is for internal use only, and should only be overridden\n            in case the mini-batch creation process is corrupted for a specific\n            attribute.\n        \"\"\"\n        raise NotImplementedError\n\n    def debug(self):\n        raise NotImplementedError\n\n    ###########################################################################\n\n    def keys(self) -> List[str]:\n        r\"\"\"Returns a list of all graph attribute names.\"\"\"\n        out = []\n        for store in self.stores:\n            out += list(store.keys())\n        return list(set(out))\n\n    def __len__(self) -> int:\n        r\"\"\"Returns the number of graph attributes.\"\"\"\n        return len(self.keys())\n\n    def __contains__(self, key: str) -> bool:\n        r\"\"\"Returns :obj:`True` if the attribute :obj:`key` is present in the\n        data.\n        \"\"\"\n        return key in self.keys()\n\n    def __getstate__(self) -> Dict[str, Any]:\n        return self.__dict__\n\n    def __setstate__(self, mapping: Dict[str, Any]):\n        for key, value in mapping.items():\n            self.__dict__[key] = value\n\n    @property\n    def num_nodes(self) -> Optional[int]:\n        r\"\"\"Returns the number of nodes in the graph.\n\n        .. note::\n            The number of nodes in the data object is automatically inferred\n            in case node-level attributes are present, *e.g.*, :obj:`data.x`.\n            In some cases, however, a graph may only be given without any\n            node-level attributes.\n            :pyg:`PyG` then *guesses* the number of nodes according to\n            :obj:`edge_index.max().item() + 1`.\n            However, in case there exists isolated nodes, this number does not\n            have to be correct which can result in unexpected behavior.\n            Thus, we recommend to set the number of nodes in your data object\n            explicitly via :obj:`data.num_nodes = ...`.\n            You will be given a warning that requests you to do so.\n        \"\"\"\n        try:\n            return sum([v.num_nodes for v in self.node_stores])\n        except TypeError:\n            return None\n\n    @overload\n    def size(self) -> Tuple[Optional[int], Optional[int]]:\n        pass\n\n    @overload\n    def size(self, dim: int) -> Optional[int]:\n        pass\n\n    def size(\n        self, dim: Optional[int] = None\n    ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]:\n        r\"\"\"Returns the size of the adjacency matrix induced by the graph.\"\"\"\n        size = (self.num_nodes, self.num_nodes)\n        return size if dim is None else size[dim]\n\n    @property\n    def num_edges(self) -> int:\n        r\"\"\"Returns the number of edges in the graph.\n        For undirected graphs, this will return the number of bi-directional\n        edges, which is double the amount of unique edges.\n        \"\"\"\n        return sum([v.num_edges for v in self.edge_stores])\n\n    def node_attrs(self) -> List[str]:\n        r\"\"\"Returns all node-level tensor attribute names.\"\"\"\n        return list(set(chain(*[s.node_attrs() for s in self.node_stores])))\n\n    def edge_attrs(self) -> List[str]:\n        r\"\"\"Returns all edge-level tensor attribute names.\"\"\"\n        return list(set(chain(*[s.edge_attrs() for s in self.edge_stores])))\n\n    @property\n    def node_offsets(self) -> Dict[NodeType, int]:\n        out: Dict[NodeType, int] = {}\n        offset: int = 0\n        for store in self.node_stores:\n            out[store._key] = offset\n            offset = offset + store.num_nodes\n        return out\n\n    def generate_ids(self):\n        r\"\"\"Generates and sets :obj:`n_id` and :obj:`e_id` attributes to assign\n        each node and edge to a continuously ascending and unique ID.\n        \"\"\"\n        for store in self.node_stores:\n            store.n_id = torch.arange(store.num_nodes)\n        for store in self.edge_stores:\n            store.e_id = torch.arange(store.num_edges)\n\n    def is_sorted(self, sort_by_row: bool = True) -> bool:\n        r\"\"\"Returns :obj:`True` if edge indices :obj:`edge_index` are sorted.\n\n        Args:\n            sort_by_row (bool, optional): If set to :obj:`False`, will require\n                column-wise order/by destination node order of\n                :obj:`edge_index`. (default: :obj:`True`)\n        \"\"\"\n        return all(\n            [store.is_sorted(sort_by_row) for store in self.edge_stores])\n\n    def sort(self, sort_by_row: bool = True) -> Self:\n        r\"\"\"Sorts edge indices :obj:`edge_index` and their corresponding edge\n        features.\n\n        Args:\n            sort_by_row (bool, optional): If set to :obj:`False`, will sort\n                :obj:`edge_index` in column-wise order/by destination node.\n                (default: :obj:`True`)\n        \"\"\"\n        out = copy.copy(self)\n        for store in out.edge_stores:\n            store.sort(sort_by_row)\n        return out\n\n    def is_coalesced(self) -> bool:\n        r\"\"\"Returns :obj:`True` if edge indices :obj:`edge_index` are sorted\n        and do not contain duplicate entries.\n        \"\"\"\n        return all([store.is_coalesced() for store in self.edge_stores])\n\n    def coalesce(self) -> Self:\n        r\"\"\"Sorts and removes duplicated entries from edge indices\n        :obj:`edge_index`.\n        \"\"\"\n        out = copy.copy(self)\n        for store in out.edge_stores:\n            store.coalesce()\n        return out\n\n    def is_sorted_by_time(self) -> bool:\n        r\"\"\"Returns :obj:`True` if :obj:`time` is sorted.\"\"\"\n        return all([store.is_sorted_by_time() for store in self.stores])\n\n    def sort_by_time(self) -> Self:\n        r\"\"\"Sorts data associated with :obj:`time` according to :obj:`time`.\"\"\"\n        out = copy.copy(self)\n        for store in out.stores:\n            store.sort_by_time()\n        return out\n\n    def snapshot(\n        self,\n        start_time: Union[float, int],\n        end_time: Union[float, int],\n        attr: str = 'time',\n    ) -> Self:\n        r\"\"\"Returns a snapshot of :obj:`data` to only hold events that occurred\n        in period :obj:`[start_time, end_time]`.\n        \"\"\"\n        out = copy.copy(self)\n        for store in out.stores:\n            store.snapshot(start_time, end_time, attr)\n        return out\n\n    def up_to(self, end_time: Union[float, int]) -> Self:\n        r\"\"\"Returns a snapshot of :obj:`data` to only hold events that occurred\n        up to :obj:`end_time` (inclusive of :obj:`edge_time`).\n        \"\"\"\n        out = copy.copy(self)\n        for store in out.stores:\n            store.up_to(end_time)\n        return out\n\n    def has_isolated_nodes(self) -> bool:\n        r\"\"\"Returns :obj:`True` if the graph contains isolated nodes.\"\"\"\n        return any([store.has_isolated_nodes() for store in self.edge_stores])\n\n    def has_self_loops(self) -> bool:\n        \"\"\"Returns :obj:`True` if the graph contains self-loops.\"\"\"\n        return any([store.has_self_loops() for store in self.edge_stores])\n\n    def is_undirected(self) -> bool:\n        r\"\"\"Returns :obj:`True` if graph edges are undirected.\"\"\"\n        return all([store.is_undirected() for store in self.edge_stores])\n\n    def is_directed(self) -> bool:\n        r\"\"\"Returns :obj:`True` if graph edges are directed.\"\"\"\n        return not self.is_undirected()\n\n    def apply_(self, func: Callable, *args: str):\n        r\"\"\"Applies the in-place function :obj:`func`, either to all attributes\n        or only the ones given in :obj:`*args`.\n        \"\"\"\n        for store in self.stores:\n            store.apply_(func, *args)\n        return self\n\n    def apply(self, func: Callable, *args: str):\n        r\"\"\"Applies the function :obj:`func`, either to all attributes or only\n        the ones given in :obj:`*args`.\n        \"\"\"\n        for store in self.stores:\n            store.apply(func, *args)\n        return self\n\n    def clone(self, *args: str):\n        r\"\"\"Performs cloning of tensors, either for all attributes or only the\n        ones given in :obj:`*args`.\n        \"\"\"\n        return copy.copy(self).apply(lambda x: x.clone(), *args)\n\n    def contiguous(self, *args: str):\n        r\"\"\"Ensures a contiguous memory layout, either for all attributes or\n        only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.contiguous(), *args)\n\n    def to(self, device: Union[int, str, torch.device], *args: str,\n           non_blocking: bool = False):\n        r\"\"\"Performs tensor device conversion, either for all attributes or\n        only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(\n            lambda x: x.to(device=device, non_blocking=non_blocking), *args)\n\n    def cpu(self, *args: str):\n        r\"\"\"Copies attributes to CPU memory, either for all attributes or only\n        the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.cpu(), *args)\n\n    def cuda(self, device: Optional[Union[int, str]] = None, *args: str,\n             non_blocking: bool = False):\n        r\"\"\"Copies attributes to CUDA memory, either for all attributes or only\n        the ones given in :obj:`*args`.\n        \"\"\"\n        # Some PyTorch tensor like objects require a default value for `cuda`:\n        device = 'cuda' if device is None else device\n        return self.apply(lambda x: x.cuda(device, non_blocking=non_blocking),\n                          *args)\n\n    def pin_memory(self, *args: str):\n        r\"\"\"Copies attributes to pinned memory, either for all attributes or\n        only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.pin_memory(), *args)\n\n    def share_memory_(self, *args: str):\n        r\"\"\"Moves attributes to shared memory, either for all attributes or\n        only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply_(lambda x: x.share_memory_(), *args)\n\n    def detach_(self, *args: str):\n        r\"\"\"Detaches attributes from the computation graph, either for all\n        attributes or only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply_(lambda x: x.detach_(), *args)\n\n    def detach(self, *args: str):\n        r\"\"\"Detaches attributes from the computation graph by creating a new\n        tensor, either for all attributes or only the ones given in\n        :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.detach(), *args)\n\n    def requires_grad_(self, *args: str, requires_grad: bool = True):\n        r\"\"\"Tracks gradient computation, either for all attributes or only the\n        ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply_(\n            lambda x: x.requires_grad_(requires_grad=requires_grad), *args)\n\n    def record_stream(self, stream: torch.cuda.Stream, *args: str):\n        r\"\"\"Ensures that the tensor memory is not reused for another tensor\n        until all current work queued on :obj:`stream` has been completed,\n        either for all attributes or only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply_(lambda x: x.record_stream(stream), *args)\n\n    @property\n    def is_cuda(self) -> bool:\n        r\"\"\"Returns :obj:`True` if any :class:`torch.Tensor` attribute is\n        stored on the GPU, :obj:`False` otherwise.\n        \"\"\"\n        for store in self.stores:\n            for value in store.values():\n                if isinstance(value, Tensor) and value.is_cuda:\n                    return True\n        return False\n\n    # Deprecated functions ####################################################\n\n    @deprecated(details=\"use 'has_isolated_nodes' instead\")\n    def contains_isolated_nodes(self) -> bool:\n        return self.has_isolated_nodes()\n\n    @deprecated(details=\"use 'has_self_loops' instead\")\n    def contains_self_loops(self) -> bool:\n        return self.has_self_loops()\n\n\n###############################################################################\n\n\n@dataclass\nclass DataTensorAttr(TensorAttr):\n    r\"\"\"Tensor attribute for `Data` without group name.\"\"\"\n    def __init__(\n        self,\n        attr_name=_FieldStatus.UNSET,\n        index=None,\n    ):\n        super().__init__(None, attr_name, index)\n\n\n@dataclass\nclass DataEdgeAttr(EdgeAttr):\n    r\"\"\"Edge attribute class for `Data` without edge type.\"\"\"\n    def __init__(\n        self,\n        layout: Optional[EdgeLayout] = None,\n        is_sorted: bool = False,\n        size: Optional[Tuple[int, int]] = None,\n    ):\n        super().__init__(None, layout, is_sorted, size)\n\n\n###############################################################################\n\n\nclass Data(BaseData, FeatureStore, GraphStore):\n    r\"\"\"A data object describing a homogeneous graph.\n    The data object can hold node-level, link-level and graph-level attributes.\n    In general, :class:`~torch_geometric.data.Data` tries to mimic the\n    behavior of a regular :python:`Python` dictionary.\n    In addition, it provides useful functionality for analyzing graph\n    structures, and provides basic PyTorch tensor functionalities.\n    See `here <https://pytorch-geometric.readthedocs.io/en/latest/get_started/\n    introduction.html#data-handling-of-graphs>`__ for the accompanying\n    tutorial.\n\n    .. code-block:: python\n\n        from torch_geometric.data import Data\n\n        data = Data(x=x, edge_index=edge_index, ...)\n\n        # Add additional arguments to `data`:\n        data.train_idx = torch.tensor([...], dtype=torch.long)\n        data.test_mask = torch.tensor([...], dtype=torch.bool)\n\n        # Analyzing the graph structure:\n        data.num_nodes\n        >>> 23\n\n        data.is_directed()\n        >>> False\n\n        # PyTorch tensor functionality:\n        data = data.pin_memory()\n        data = data.to('cuda:0', non_blocking=True)\n\n    Args:\n        x (torch.Tensor, optional): Node feature matrix with shape\n            :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`)\n        edge_index (LongTensor, optional): Graph connectivity in COO format\n            with shape :obj:`[2, num_edges]`. (default: :obj:`None`)\n        edge_attr (torch.Tensor, optional): Edge feature matrix with shape\n            :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`)\n        y (torch.Tensor, optional): Graph-level or node-level ground-truth\n            labels with arbitrary shape. (default: :obj:`None`)\n        pos (torch.Tensor, optional): Node position matrix with shape\n            :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)\n        time (torch.Tensor, optional): The timestamps for each event with shape\n            :obj:`[num_edges]` or :obj:`[num_nodes]`. (default: :obj:`None`)\n        **kwargs (optional): Additional attributes.\n    \"\"\"\n    def __init__(\n        self,\n        x: Optional[Tensor] = None,\n        edge_index: OptTensor = None,\n        edge_attr: OptTensor = None,\n        y: Optional[Union[Tensor, int, float]] = None,\n        pos: OptTensor = None,\n        time: OptTensor = None,\n        **kwargs,\n    ):\n        # `Data` doesn't support group_name, so we need to adjust `TensorAttr`\n        # accordingly here to avoid requiring `group_name` to be set:\n        super().__init__(tensor_attr_cls=DataTensorAttr)\n\n        # `Data` doesn't support edge_type, so we need to adjust `EdgeAttr`\n        # accordingly here to avoid requiring `edge_type` to be set:\n        GraphStore.__init__(self, edge_attr_cls=DataEdgeAttr)\n\n        self.__dict__['_store'] = GlobalStorage(_parent=self)\n\n        if x is not None:\n            self.x = x\n        if edge_index is not None:\n            self.edge_index = edge_index\n        if edge_attr is not None:\n            self.edge_attr = edge_attr\n        if y is not None:\n            self.y = y\n        if pos is not None:\n            self.pos = pos\n        if time is not None:\n            self.time = time\n\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n\n    def __getattr__(self, key: str) -> Any:\n        if '_store' not in self.__dict__:\n            raise RuntimeError(\n                \"The 'data' object was created by an older version of PyG. \"\n                \"If this error occurred while loading an already existing \"\n                \"dataset, remove the 'processed/' directory in the dataset's \"\n                \"root folder and try again.\")\n        return getattr(self._store, key)\n\n    def __setattr__(self, key: str, value: Any):\n        propobj = getattr(self.__class__, key, None)\n        if propobj is not None and getattr(propobj, 'fset', None) is not None:\n            propobj.fset(self, value)\n        else:\n            setattr(self._store, key, value)\n\n    def __delattr__(self, key: str):\n        delattr(self._store, key)\n\n    # TODO consider supporting the feature store interface for\n    # __getitem__, __setitem__, and __delitem__ so, for example, we\n    # can accept key: Union[str, TensorAttr] in __getitem__.\n    def __getitem__(self, key: str) -> Any:\n        return self._store[key]\n\n    def __setitem__(self, key: str, value: Any):\n        self._store[key] = value\n\n    def __delitem__(self, key: str):\n        if key in self._store:\n            del self._store[key]\n\n    def __copy__(self):\n        out = self.__class__.__new__(self.__class__)\n        for key, value in self.__dict__.items():\n            out.__dict__[key] = value\n        out.__dict__['_store'] = copy.copy(self._store)\n        out._store._parent = out\n        return out\n\n    def __deepcopy__(self, memo):\n        out = self.__class__.__new__(self.__class__)\n        for key, value in self.__dict__.items():\n            out.__dict__[key] = copy.deepcopy(value, memo)\n        out._store._parent = out\n        return out\n\n    def __repr__(self) -> str:\n        cls = self.__class__.__name__\n        has_dict = any([isinstance(v, Mapping) for v in self._store.values()])\n\n        if not has_dict:\n            info = [size_repr(k, v) for k, v in self._store.items()]\n            info = ', '.join(info)\n            return f'{cls}({info})'\n        else:\n            info = [size_repr(k, v, indent=2) for k, v in self._store.items()]\n            info = ',\\n'.join(info)\n            return f'{cls}(\\n{info}\\n)'\n\n    @property\n    def num_nodes(self) -> Optional[int]:\n        return super().num_nodes\n\n    @num_nodes.setter\n    def num_nodes(self, num_nodes: Optional[int]):\n        self._store.num_nodes = num_nodes\n\n    def stores_as(self, data: Self):\n        return self\n\n    @property\n    def stores(self) -> List[BaseStorage]:\n        return [self._store]\n\n    @property\n    def node_stores(self) -> List[NodeStorage]:\n        return [self._store]\n\n    @property\n    def edge_stores(self) -> List[EdgeStorage]:\n        return [self._store]\n\n    def to_dict(self) -> Dict[str, Any]:\n        return self._store.to_dict()\n\n    def to_namedtuple(self) -> NamedTuple:\n        return self._store.to_namedtuple()\n\n    def update(self, data: Union[Self, Dict[str, Any]]) -> Self:\n        for key, value in data.items():\n            self[key] = value\n        return self\n\n    def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:\n        if is_sparse(value) and ('adj' in key or 'edge_index' in key):\n            return (0, 1)\n        elif 'index' in key or key == 'face':\n            return -1\n        else:\n            return 0\n\n    def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:\n        if 'batch' in key and isinstance(value, Tensor):\n            if isinstance(value, Index):\n                return value.get_dim_size()\n            return int(value.max()) + 1\n        elif 'index' in key or key == 'face':\n            num_nodes = self.num_nodes\n            if num_nodes is None:\n                raise RuntimeError(f\"Unable to infer 'num_nodes' from the \"\n                                   f\"attribute '{key}'. Please explicitly set \"\n                                   f\"'num_nodes' as an attribute of 'data' to \"\n                                   f\"prevent this error\")\n            return num_nodes\n        else:\n            return 0\n\n    def validate(self, raise_on_error: bool = True) -> bool:\n        r\"\"\"Validates the correctness of the data.\"\"\"\n        cls_name = self.__class__.__name__\n        status = True\n\n        num_nodes = self.num_nodes\n        if num_nodes is None:\n            status = False\n            warn_or_raise(f\"'num_nodes' is undefined in '{cls_name}'\",\n                          raise_on_error)\n\n        if 'edge_index' in self:\n            if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2:\n                status = False\n                warn_or_raise(\n                    f\"'edge_index' needs to be of shape [2, num_edges] in \"\n                    f\"'{cls_name}' (found {self.edge_index.size()})\",\n                    raise_on_error)\n\n        if 'edge_index' in self and self.edge_index.numel() > 0:\n            if self.edge_index.min() < 0:\n                status = False\n                warn_or_raise(\n                    f\"'edge_index' contains negative indices in \"\n                    f\"'{cls_name}' (found {int(self.edge_index.min())})\",\n                    raise_on_error)\n\n            if num_nodes is not None and self.edge_index.max() >= num_nodes:\n                status = False\n                warn_or_raise(\n                    f\"'edge_index' contains larger indices than the number \"\n                    f\"of nodes ({num_nodes}) in '{cls_name}' \"\n                    f\"(found {int(self.edge_index.max())})\", raise_on_error)\n\n        return status\n\n    def debug(self):\n        pass  # TODO\n\n    def is_node_attr(self, key: str) -> bool:\n        r\"\"\"Returns :obj:`True` if the object at key :obj:`key` denotes a\n        node-level tensor attribute.\n        \"\"\"\n        return self._store.is_node_attr(key)\n\n    def is_edge_attr(self, key: str) -> bool:\n        r\"\"\"Returns :obj:`True` if the object at key :obj:`key` denotes an\n        edge-level tensor attribute.\n        \"\"\"\n        return self._store.is_edge_attr(key)\n\n    def subgraph(self, subset: Tensor) -> Self:\n        r\"\"\"Returns the induced subgraph given by the node indices\n        :obj:`subset`.\n\n        Args:\n            subset (LongTensor or BoolTensor): The nodes to keep.\n        \"\"\"\n        if 'edge_index' in self:\n            edge_index, _, edge_mask = subgraph(\n                subset,\n                self.edge_index,\n                relabel_nodes=True,\n                num_nodes=self.num_nodes,\n                return_edge_mask=True,\n            )\n        else:\n            edge_index = None\n            edge_mask = torch.ones(\n                self.num_edges,\n                dtype=torch.bool,\n                device=subset.device,\n            )\n\n        data = copy.copy(self)\n\n        for key, value in self:\n            if key == 'edge_index':\n                data.edge_index = edge_index\n            elif key == 'num_nodes':\n                if subset.dtype == torch.bool:\n                    data.num_nodes = int(subset.sum())\n                else:\n                    data.num_nodes = subset.size(0)\n            elif self.is_node_attr(key):\n                cat_dim = self.__cat_dim__(key, value)\n                data[key] = select(value, subset, dim=cat_dim)\n            elif self.is_edge_attr(key):\n                cat_dim = self.__cat_dim__(key, value)\n                data[key] = select(value, edge_mask, dim=cat_dim)\n\n        return data\n\n    def edge_subgraph(self, subset: Tensor) -> Self:\n        r\"\"\"Returns the induced subgraph given by the edge indices\n        :obj:`subset`.\n        Will currently preserve all the nodes in the graph, even if they are\n        isolated after subgraph computation.\n\n        Args:\n            subset (LongTensor or BoolTensor): The edges to keep.\n        \"\"\"\n        data = copy.copy(self)\n\n        for key, value in self:\n            if self.is_edge_attr(key):\n                cat_dim = self.__cat_dim__(key, value)\n                data[key] = select(value, subset, dim=cat_dim)\n\n        return data\n\n    def to_heterogeneous(\n        self,\n        node_type: Optional[Tensor] = None,\n        edge_type: Optional[Tensor] = None,\n        node_type_names: Optional[List[NodeType]] = None,\n        edge_type_names: Optional[List[EdgeType]] = None,\n    ):\n        r\"\"\"Converts a :class:`~torch_geometric.data.Data` object to a\n        heterogeneous :class:`~torch_geometric.data.HeteroData` object.\n        For this, node and edge attributes are splitted according to the\n        node-level and edge-level vectors :obj:`node_type` and\n        :obj:`edge_type`, respectively.\n        :obj:`node_type_names` and :obj:`edge_type_names` can be used to give\n        meaningful node and edge type names, respectively.\n        That is, the node_type :obj:`0` is given by :obj:`node_type_names[0]`.\n        If the :class:`~torch_geometric.data.Data` object was constructed via\n        :meth:`~torch_geometric.data.HeteroData.to_homogeneous`, the object can\n        be reconstructed without any need to pass in additional arguments.\n\n        Args:\n            node_type (torch.Tensor, optional): A node-level vector denoting\n                the type of each node. (default: :obj:`None`)\n            edge_type (torch.Tensor, optional): An edge-level vector denoting\n                the type of each edge. (default: :obj:`None`)\n            node_type_names (List[str], optional): The names of node types.\n                (default: :obj:`None`)\n            edge_type_names (List[Tuple[str, str, str]], optional): The names\n                of edge types. (default: :obj:`None`)\n        \"\"\"\n        from torch_geometric.data import HeteroData\n\n        if node_type is None:\n            node_type = self._store.get('node_type', None)\n        if node_type is None:\n            node_type = torch.zeros(self.num_nodes, dtype=torch.long)\n\n        if node_type_names is None:\n            store = self._store\n            node_type_names = store.__dict__.get('_node_type_names', None)\n        if node_type_names is None:\n            node_type_names = [str(i) for i in node_type.unique().tolist()]\n\n        if edge_type is None:\n            edge_type = self._store.get('edge_type', None)\n        if edge_type is None:\n            edge_type = torch.zeros(self.num_edges, dtype=torch.long)\n\n        if edge_type_names is None:\n            store = self._store\n            edge_type_names = store.__dict__.get('_edge_type_names', None)\n        if edge_type_names is None:\n            edge_type_names = []\n            edge_index = self.edge_index\n            for i in edge_type.unique().tolist():\n                src, dst = edge_index[:, edge_type == i]\n                src_types = node_type[src].unique().tolist()\n                dst_types = node_type[dst].unique().tolist()\n                if len(src_types) != 1 and len(dst_types) != 1:\n                    raise ValueError(\n                        \"Could not construct a 'HeteroData' object from the \"\n                        \"'Data' object because single edge types span over \"\n                        \"multiple node types\")\n                edge_type_names.append((node_type_names[src_types[0]], str(i),\n                                        node_type_names[dst_types[0]]))\n\n        # We iterate over node types to find the local node indices belonging\n        # to each node type. Furthermore, we create a global `index_map` vector\n        # that maps global node indices to local ones in the final\n        # heterogeneous graph:\n        node_ids, index_map = {}, torch.empty_like(node_type)\n        for i in range(len(node_type_names)):\n            node_ids[i] = (node_type == i).nonzero(as_tuple=False).view(-1)\n            index_map[node_ids[i]] = torch.arange(len(node_ids[i]),\n                                                  device=index_map.device)\n\n        # We iterate over edge types to find the local edge indices:\n        edge_ids = {}\n        for i in range(len(edge_type_names)):\n            edge_ids[i] = (edge_type == i).nonzero(as_tuple=False).view(-1)\n\n        data = HeteroData()\n\n        for i, key in enumerate(node_type_names):\n            for attr, value in self.items():\n                if attr in {'node_type', 'edge_type', 'ptr'}:\n                    continue\n                elif isinstance(value, Tensor) and self.is_node_attr(attr):\n                    cat_dim = self.__cat_dim__(attr, value)\n                    data[key][attr] = value.index_select(cat_dim, node_ids[i])\n                elif (isinstance(value, TensorFrame)\n                      and self.is_node_attr(attr)):\n                    data[key][attr] = value[node_ids[i]]\n\n            if len(data[key]) == 0:\n                data[key].num_nodes = node_ids[i].size(0)\n\n        for i, key in enumerate(edge_type_names):\n            src, _, dst = key\n            for attr, value in self.items():\n                if attr in {'node_type', 'edge_type', 'ptr'}:\n                    continue\n                elif attr == 'edge_index':\n                    edge_index = value[:, edge_ids[i]]\n                    edge_index[0] = index_map[edge_index[0]]\n                    edge_index[1] = index_map[edge_index[1]]\n                    data[key].edge_index = edge_index\n                elif isinstance(value, Tensor) and self.is_edge_attr(attr):\n                    cat_dim = self.__cat_dim__(attr, value)\n                    data[key][attr] = value.index_select(cat_dim, edge_ids[i])\n                elif (isinstance(value, TensorFrame)\n                      and self.is_edge_attr(attr)):\n                    data[key][attr] = value[edge_ids[i]]\n\n        # Add global attributes.\n        exclude_keys = set(data.keys()) | {\n            'node_type', 'edge_type', 'edge_index', 'num_nodes', 'ptr'\n        }\n        for attr, value in self.items():\n            if attr in exclude_keys:\n                continue\n            data[attr] = value\n\n        return data\n\n    def connected_components(self) -> List[Self]:\n        r\"\"\"Extracts connected components of the graph using a union-find\n        algorithm. The components are returned as a list of\n        :class:`~torch_geometric.data.Data` objects, where each object\n        represents a connected component of the graph.\n\n        .. code-block::\n\n            data = Data()\n            data.x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])\n            data.y = torch.tensor([[1.1], [2.1], [3.1], [4.1]])\n            data.edge_index = torch.tensor(\n                [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long\n            )\n\n            components = data.connected_components()\n            print(len(components))\n            >>> 2\n\n            print(components[0].x)\n            >>> Data(x=[2, 1], y=[2, 1], edge_index=[2, 2])\n\n        Returns:\n            List[Data]: A list of disconnected components.\n        \"\"\"\n        # Union-Find algorithm to find connected components\n        self._parents: Dict[int, int] = {}\n        self._ranks: Dict[int, int] = {}\n        for edge in self.edge_index.t().tolist():\n            self._union(edge[0], edge[1])\n\n        # Rerun _find_parent to ensure all nodes are covered correctly\n        for node in range(self.num_nodes):\n            self._find_parent(node)\n\n        # Group parents\n        grouped_parents = defaultdict(list)\n        for node, parent in self._parents.items():\n            grouped_parents[parent].append(node)\n        del self._ranks\n        del self._parents\n\n        # Create components based on the found parents (roots)\n        components: List[Self] = []\n        for nodes in grouped_parents.values():\n            # Convert the list of node IDs to a tensor\n            subset = torch.tensor(nodes, dtype=torch.long)\n\n            # Use the existing subgraph function\n            component_data = self.subgraph(subset)\n            components.append(component_data)\n\n        return components\n\n    ###########################################################################\n\n    @classmethod\n    def from_dict(cls, mapping: Dict[str, Any]) -> Self:\n        r\"\"\"Creates a :class:`~torch_geometric.data.Data` object from a\n        dictionary.\n        \"\"\"\n        return cls(**mapping)\n\n    @property\n    def num_node_features(self) -> int:\n        r\"\"\"Returns the number of features per node in the graph.\"\"\"\n        return self._store.num_node_features\n\n    @property\n    def num_features(self) -> int:\n        r\"\"\"Returns the number of features per node in the graph.\n        Alias for :py:attr:`~num_node_features`.\n        \"\"\"\n        return self.num_node_features\n\n    @property\n    def num_edge_features(self) -> int:\n        r\"\"\"Returns the number of features per edge in the graph.\"\"\"\n        return self._store.num_edge_features\n\n    @property\n    def num_node_types(self) -> int:\n        r\"\"\"Returns the number of node types in the graph.\"\"\"\n        return int(self.node_type.max()) + 1 if 'node_type' in self else 1\n\n    @property\n    def num_edge_types(self) -> int:\n        r\"\"\"Returns the number of edge types in the graph.\"\"\"\n        return int(self.edge_type.max()) + 1 if 'edge_type' in self else 1\n\n    def __iter__(self) -> Iterable:\n        r\"\"\"Iterates over all attributes in the data, yielding their attribute\n        names and values.\n        \"\"\"\n        yield from self._store.items()\n\n    def __call__(self, *args: str) -> Iterable:\n        r\"\"\"Iterates over all attributes :obj:`*args` in the data, yielding\n        their attribute names and values.\n        If :obj:`*args` is not given, will iterate over all attributes.\n        \"\"\"\n        yield from self._store.items(*args)\n\n    @property\n    def x(self) -> Optional[Tensor]:\n        return self['x'] if 'x' in self._store else None\n\n    @x.setter\n    def x(self, x: Optional[Tensor]):\n        self._store.x = x\n\n    @property\n    def edge_index(self) -> Optional[Tensor]:\n        return self['edge_index'] if 'edge_index' in self._store else None\n\n    @edge_index.setter\n    def edge_index(self, edge_index: Optional[Tensor]):\n        self._store.edge_index = edge_index\n\n    @property\n    def edge_weight(self) -> Optional[Tensor]:\n        return self['edge_weight'] if 'edge_weight' in self._store else None\n\n    @edge_weight.setter\n    def edge_weight(self, edge_weight: Optional[Tensor]):\n        self._store.edge_weight = edge_weight\n\n    @property\n    def edge_attr(self) -> Optional[Tensor]:\n        return self['edge_attr'] if 'edge_attr' in self._store else None\n\n    @edge_attr.setter\n    def edge_attr(self, edge_attr: Optional[Tensor]):\n        self._store.edge_attr = edge_attr\n\n    @property\n    def y(self) -> Optional[Union[Tensor, int, float]]:\n        return self['y'] if 'y' in self._store else None\n\n    @y.setter\n    def y(self, y: Optional[Tensor]):\n        self._store.y = y\n\n    @property\n    def pos(self) -> Optional[Tensor]:\n        return self['pos'] if 'pos' in self._store else None\n\n    @pos.setter\n    def pos(self, pos: Optional[Tensor]):\n        self._store.pos = pos\n\n    @property\n    def batch(self) -> Optional[Tensor]:\n        return self['batch'] if 'batch' in self._store else None\n\n    @batch.setter\n    def batch(self, batch: Optional[Tensor]):\n        self._store.batch = batch\n\n    @property\n    def time(self) -> Optional[Tensor]:\n        return self['time'] if 'time' in self._store else None\n\n    @time.setter\n    def time(self, time: Optional[Tensor]):\n        self._store.time = time\n\n    @property\n    def face(self) -> Optional[Tensor]:\n        return self['face'] if 'face' in self._store else None\n\n    @face.setter\n    def face(self, face: Optional[Tensor]):\n        self._store.face = face\n\n    # Deprecated functions ####################################################\n\n    @property\n    @deprecated(details=\"use 'data.face.size(-1)' instead\")\n    def num_faces(self) -> Optional[int]:\n        r\"\"\"Returns the number of faces in the mesh.\"\"\"\n        if 'face' in self._store and isinstance(self.face, Tensor):\n            return self.face.size(self.__cat_dim__('face', self.face))\n        return None\n\n    # FeatureStore interface ##################################################\n\n    def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:\n        out = self.get(attr.attr_name)\n        if out is not None and attr.index is not None:\n            out[attr.index] = tensor\n        else:\n            assert attr.index is None\n            setattr(self, attr.attr_name, tensor)\n        return True\n\n    def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:\n        tensor = getattr(self, attr.attr_name, None)\n        if tensor is not None:\n            # TODO this behavior is a bit odd, since TensorAttr requires that\n            # we set `index`. So, we assume here that indexing by `None` is\n            # equivalent to not indexing at all, which is not in line with\n            # Python semantics.\n            return tensor[attr.index] if attr.index is not None else tensor\n        return None\n\n    def _remove_tensor(self, attr: TensorAttr) -> bool:\n        if hasattr(self, attr.attr_name):\n            delattr(self, attr.attr_name)\n            return True\n        return False\n\n    def _get_tensor_size(self, attr: TensorAttr) -> Tuple:\n        return self._get_tensor(attr).size()\n\n    def get_all_tensor_attrs(self) -> List[TensorAttr]:\n        r\"\"\"Obtains all feature attributes stored in `Data`.\"\"\"\n        return [\n            TensorAttr(attr_name=name) for name in self._store.keys()\n            if self._store.is_node_attr(name)\n        ]\n\n    # GraphStore interface ####################################################\n\n    def _put_edge_index(self, edge_index: EdgeTensorType,\n                        edge_attr: EdgeAttr) -> bool:\n        if not hasattr(self, '_edge_attrs'):\n            self._edge_attrs = {}\n        self._edge_attrs[edge_attr.layout] = edge_attr\n\n        row, col = edge_index\n\n        if edge_attr.layout == EdgeLayout.COO:\n            self.edge_index = torch.stack([row, col], dim=0)\n        elif edge_attr.layout == EdgeLayout.CSR:\n            self.adj = SparseTensor(\n                rowptr=row,\n                col=col,\n                sparse_sizes=edge_attr.size,\n                is_sorted=True,\n                trust_data=True,\n            )\n        else:  # edge_attr.layout == EdgeLayout.CSC:\n            size = edge_attr.size[::-1] if edge_attr.size is not None else None\n            self.adj_t = SparseTensor(\n                rowptr=col,\n                col=row,\n                sparse_sizes=size,\n                is_sorted=True,\n                trust_data=True,\n            )\n        return True\n\n    def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:\n        if edge_attr.size is None:\n            edge_attr.size = self.size()  # Modify in-place.\n\n        if edge_attr.layout == EdgeLayout.COO and 'edge_index' in self:\n            row, col = self.edge_index\n            return row, col\n        elif edge_attr.layout == EdgeLayout.CSR and 'adj' in self:\n            rowptr, col, _ = self.adj.csr()\n            return rowptr, col\n        elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in self:\n            colptr, row, _ = self.adj_t.csr()\n            return row, colptr\n        return None\n\n    def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool:\n        if edge_attr.layout == EdgeLayout.COO and 'edge_index' in self:\n            del self.edge_index\n            if hasattr(self, '_edge_attrs'):\n                self._edge_attrs.pop(EdgeLayout.COO, None)\n            return True\n        elif edge_attr.layout == EdgeLayout.CSR and 'adj' in self:\n            del self.adj\n            if hasattr(self, '_edge_attrs'):\n                self._edge_attrs.pop(EdgeLayout.CSR, None)\n            return True\n        elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in self:\n            del self.adj_t\n            if hasattr(self, '_edge_attrs'):\n                self._edge_attrs.pop(EdgeLayout.CSC, None)\n            return True\n        return False\n\n    def get_all_edge_attrs(self) -> List[EdgeAttr]:\n        edge_attrs = getattr(self, '_edge_attrs', {})\n\n        if 'edge_index' in self and EdgeLayout.COO not in edge_attrs:\n            edge_attrs[EdgeLayout.COO] = DataEdgeAttr('coo', is_sorted=False)\n        if 'adj' in self and EdgeLayout.CSR not in edge_attrs:\n            size = self.adj.sparse_sizes()\n            edge_attrs[EdgeLayout.CSR] = DataEdgeAttr('csr', size=size)\n        if 'adj_t' in self and EdgeLayout.CSC not in edge_attrs:\n            size = self.adj_t.sparse_sizes()[::-1]\n            edge_attrs[EdgeLayout.CSC] = DataEdgeAttr('csc', size=size)\n\n        return list(edge_attrs.values())\n\n    # Connected Components Helper Functions ###################################\n\n    def _find_parent(self, node: int) -> int:\n        r\"\"\"Finds and returns the representative parent of the given node in a\n        disjoint-set (union-find) data structure. Implements path compression\n        to optimize future queries.\n\n        Args:\n            node (int): The node for which to find the representative parent.\n\n        Returns:\n            int: The representative parent of the node.\n        \"\"\"\n        if node not in self._parents:\n            self._parents[node] = node\n            self._ranks[node] = 0\n        if self._parents[node] != node:\n            self._parents[node] = self._find_parent(self._parents[node])\n        return self._parents[node]\n\n    def _union(self, node1: int, node2: int):\n        r\"\"\"Merges the sets containing node1 and node2 in the disjoint-set\n        data structure.\n\n        Finds the root parents of node1 and node2 using the _find_parent\n        method. If they belong to different sets, updates the parent of\n        root2 to be root1, effectively merging the two sets.\n\n        Args:\n            node1 (int): The index of the first node to union.\n            node2 (int): The index of the second node to union.\n        \"\"\"\n        root1 = self._find_parent(node1)\n        root2 = self._find_parent(node2)\n        if root1 != root2:\n            if self._ranks[root1] < self._ranks[root2]:\n                self._parents[root1] = root2\n            elif self._ranks[root1] > self._ranks[root2]:\n                self._parents[root2] = root1\n            else:\n                self._parents[root2] = root1\n                self._ranks[root1] += 1\n\n\n###############################################################################\n\n\ndef size_repr(key: Any, value: Any, indent: int = 0) -> str:\n    pad = ' ' * indent\n    if isinstance(value, Tensor) and value.dim() == 0:\n        out = value.item()\n    elif isinstance(value, Tensor) and getattr(value, 'is_nested', False):\n        out = str(list(value.to_padded_tensor(padding=0.0).size()))\n    elif isinstance(value, Tensor):\n        out = str(list(value.size()))\n    elif isinstance(value, np.ndarray):\n        out = str(list(value.shape))\n    elif isinstance(value, SparseTensor):\n        out = str(value.sizes())[:-1] + f', nnz={value.nnz()}]'\n    elif isinstance(value, TensorFrame):\n        out = (f'{value.__class__.__name__}('\n               f'[{value.num_rows}, {value.num_cols}])')\n    elif isinstance(value, str):\n        out = f\"'{value}'\"\n    elif isinstance(value, (Sequence, set)):\n        out = str([len(value)])\n    elif isinstance(value, Mapping) and len(value) == 0:\n        out = '{}'\n    elif (isinstance(value, Mapping) and len(value) == 1\n          and not isinstance(list(value.values())[0], Mapping)):\n        lines = [size_repr(k, v, 0) for k, v in value.items()]\n        out = '{ ' + ', '.join(lines) + ' }'\n    elif isinstance(value, Mapping):\n        lines = [size_repr(k, v, indent + 2) for k, v in value.items()]\n        out = '{\\n' + ',\\n'.join(lines) + ',\\n' + pad + '}'\n    else:\n        out = str(value)\n\n    key = str(key).replace(\"'\", '')\n    return f'{pad}{key}={out}'\n\n\ndef warn_or_raise(msg: str, raise_on_error: bool = True):\n    if raise_on_error:\n        raise ValueError(msg)\n    else:\n        warnings.warn(msg, stacklevel=2)\n"
  },
  {
    "path": "torch_geometric/data/database.py",
    "content": "import io\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom functools import cached_property\nfrom typing import Any, Dict, List, Optional, Sequence, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom tqdm import tqdm\n\nfrom torch_geometric import EdgeIndex, Index\nfrom torch_geometric.edge_index import SortOrder\nfrom torch_geometric.utils.mixin import CastMixin\n\n\n@dataclass\nclass TensorInfo(CastMixin):\n    dtype: torch.dtype\n    size: Tuple[int, ...] = (-1, )\n    is_index: bool = False\n    is_edge_index: bool = False\n\n    def __post_init__(self) -> None:\n        if self.is_index and self.is_edge_index:\n            raise ValueError(\"Tensor cannot be a 'Index' and 'EdgeIndex' \"\n                             \"tensor at the same time\")\n\n        if self.is_index:\n            self.size = (-1, )\n\n        if self.is_edge_index:\n            self.size = (2, -1)\n\n\ndef maybe_cast_to_tensor_info(value: Any) -> Union[Any, TensorInfo]:\n    if not isinstance(value, dict):\n        return value\n    if len(value) < 1 or len(value) > 3:\n        return value\n    if 'dtype' not in value:\n        return value\n    valid_keys = {'dtype', 'size', 'is_index', 'is_edge_index'}\n    if len(set(value.keys()) | valid_keys) != len(valid_keys):\n        return value\n    return TensorInfo.cast(value)\n\n\nSchema = Union[Any, Dict[str, Any], Tuple[Any], List[Any]]\n\nSORT_ORDER_TO_INDEX: Dict[Optional[SortOrder], int] = {\n    None: -1,\n    SortOrder.ROW: 0,\n    SortOrder.COL: 1,\n}\nINDEX_TO_SORT_ORDER = {v: k for k, v in SORT_ORDER_TO_INDEX.items()}\n\n\nclass Database(ABC):\n    r\"\"\"Base class for inserting and retrieving data from a database.\n\n    A database acts as a persisted, out-of-memory and index-based key/value\n    store for tensor and custom data:\n\n    .. code-block:: python\n\n        db = Database()\n        db[0] = Data(x=torch.randn(5, 16), y=0, z='id_0')\n        print(db[0])\n        >>> Data(x=[5, 16], y=0, z='id_0')\n\n    To improve efficiency, it is recommended to specify the underlying\n    :obj:`schema` of the data:\n\n    .. code-block:: python\n\n        db = Database(schema={  # Custom schema:\n            # Tensor information can be specified through a dictionary:\n            'x': dict(dtype=torch.float, size=(-1, 16)),\n            'y': int,\n            'z': str,\n        })\n        db[0] = dict(x=torch.randn(5, 16), y=0, z='id_0')\n        print(db[0])\n        >>> {'x': torch.tensor(...), 'y': 0, 'z': 'id_0'}\n\n    In addition, databases support batch-wise insert and get, and support\n    syntactic sugar known from indexing :python:`Python` lists, *e.g.*:\n\n    .. code-block:: python\n\n        db = Database()\n        db[2:5] = torch.randn(3, 16)\n        print(db[torch.tensor([2, 3])])\n        >>> [torch.tensor(...), torch.tensor(...)]\n\n    Args:\n        schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of\n            the input data.\n            Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a\n            dictionary with :obj:`dtype` and :obj:`size` keys (for specifying\n            tensor data) as input, and can be nested as a tuple or dictionary.\n            Specifying the schema will improve efficiency, since by default the\n            database will use python pickling for serializing and\n            deserializing. (default: :obj:`object`)\n    \"\"\"\n    def __init__(self, schema: Schema = object) -> None:\n        schema_dict = self._to_dict(maybe_cast_to_tensor_info(schema))\n        self.schema: Dict[Union[str, int], Any] = {\n            key: maybe_cast_to_tensor_info(value)\n            for key, value in schema_dict.items()\n        }\n\n    @abstractmethod\n    def connect(self) -> None:\n        r\"\"\"Connects to the database.\n        Databases will automatically connect on instantiation.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def close(self) -> None:\n        r\"\"\"Closes the connection to the database.\"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def insert(self, index: int, data: Any) -> None:\n        r\"\"\"Inserts data at the specified index.\n\n        Args:\n            index (int): The index at which to insert.\n            data (Any): The object to insert.\n        \"\"\"\n        raise NotImplementedError\n\n    def multi_insert(\n        self,\n        indices: Union[Sequence[int], Tensor, slice, range],\n        data_list: Sequence[Any],\n        batch_size: Optional[int] = None,\n        log: bool = False,\n    ) -> None:\n        r\"\"\"Inserts a chunk of data at the specified indices.\n\n        Args:\n            indices (List[int] or torch.Tensor or range): The indices at which\n                to insert.\n            data_list (List[Any]): The objects to insert.\n            batch_size (int, optional): If specified, will insert the data to\n                the database in batches of size :obj:`batch_size`.\n                (default: :obj:`None`)\n            log (bool, optional): If set to :obj:`True`, will log progress to\n                the console. (default: :obj:`False`)\n        \"\"\"\n        if isinstance(indices, slice):\n            indices = self.slice_to_range(indices)\n\n        length = min(len(indices), len(data_list))\n        batch_size = length if batch_size is None else batch_size\n\n        if log and length > batch_size:\n            desc = f'Insert {length} entries'\n            offsets = tqdm(range(0, length, batch_size), desc=desc)\n        else:\n            offsets = range(0, length, batch_size)\n\n        for start in offsets:\n            self._multi_insert(\n                indices[start:start + batch_size],\n                data_list[start:start + batch_size],\n            )\n\n    def _multi_insert(\n        self,\n        indices: Union[Sequence[int], Tensor, range],\n        data_list: Sequence[Any],\n    ) -> None:\n        if isinstance(indices, Tensor):\n            indices = indices.tolist()\n        for index, data in zip(indices, data_list):\n            self.insert(index, data)\n\n    @abstractmethod\n    def get(self, index: int) -> Any:\n        r\"\"\"Gets data from the specified index.\n\n        Args:\n            index (int): The index to query.\n        \"\"\"\n        raise NotImplementedError\n\n    def multi_get(\n        self,\n        indices: Union[Sequence[int], Tensor, slice, range],\n        batch_size: Optional[int] = None,\n    ) -> List[Any]:\n        r\"\"\"Gets a chunk of data from the specified indices.\n\n        Args:\n            indices (List[int] or torch.Tensor or range): The indices to query.\n            batch_size (int, optional): If specified, will request the data\n                from the database in batches of size :obj:`batch_size`.\n                (default: :obj:`None`)\n        \"\"\"\n        if isinstance(indices, slice):\n            indices = self.slice_to_range(indices)\n\n        length = len(indices)\n        batch_size = length if batch_size is None else batch_size\n\n        data_list: List[Any] = []\n        for start in range(0, length, batch_size):\n            chunk_indices = indices[start:start + batch_size]\n            data_list.extend(self._multi_get(chunk_indices))\n        return data_list\n\n    def _multi_get(self, indices: Union[Sequence[int], Tensor]) -> List[Any]:\n        if isinstance(indices, Tensor):\n            indices = indices.tolist()\n        return [self.get(index) for index in indices]\n\n    # Helper functions ########################################################\n\n    @staticmethod\n    def _to_dict(\n        value: Union[Dict[Union[int, str], Any], Sequence[Any], Any],\n    ) -> Dict[Union[str, int], Any]:\n        if isinstance(value, dict):\n            return value\n        if isinstance(value, (tuple, list)):\n            return {i: v for i, v in enumerate(value)}\n        else:\n            return {0: value}\n\n    def slice_to_range(self, indices: slice) -> range:\n        start = 0 if indices.start is None else indices.start\n        stop = len(self) if indices.stop is None else indices.stop\n        step = 1 if indices.step is None else indices.step\n\n        return range(start, stop, step)\n\n    # Python built-ins ########################################################\n\n    def __len__(self) -> int:\n        raise NotImplementedError\n\n    def __getitem__(\n        self,\n        key: Union[int, Sequence[int], Tensor, slice, range],\n    ) -> Union[Any, List[Any]]:\n\n        if isinstance(key, int):\n            return self.get(key)\n        else:\n            return self.multi_get(key)\n\n    def __setitem__(\n        self,\n        key: Union[int, Sequence[int], Tensor, slice, range],\n        value: Union[Any, Sequence[Any]],\n    ) -> None:\n        if isinstance(key, int):\n            self.insert(key, value)\n        else:\n            self.multi_insert(key, value)\n\n    def __repr__(self) -> str:\n        try:\n            return f'{self.__class__.__name__}({len(self)})'\n        except NotImplementedError:\n            return f'{self.__class__.__name__}()'\n\n\nclass SQLiteDatabase(Database):\n    r\"\"\"An index-based key/value database based on :obj:`sqlite3`.\n\n    .. note::\n        This database implementation requires the :obj:`sqlite3` package.\n\n    Args:\n        path (str): The path to where the database should be saved.\n        name (str): The name of the table to save the data to.\n        schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of\n            the input data.\n            Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a\n            dictionary with :obj:`dtype` and :obj:`size` keys (for specifying\n            tensor data) as input, and can be nested as a tuple or dictionary.\n            Specifying the schema will improve efficiency, since by default the\n            database will use python pickling for serializing and\n            deserializing. (default: :obj:`object`)\n    \"\"\"\n    def __init__(self, path: str, name: str, schema: Schema = object) -> None:\n        super().__init__(schema)\n\n        warnings.filterwarnings('ignore', '.*given buffer is not writable.*')\n\n        import sqlite3\n\n        self.path = path\n        self.name = name\n\n        self._connection: Optional[sqlite3.Connection] = None\n        self._cursor: Optional[sqlite3.Cursor] = None\n\n        self.connect()\n\n        # Create the table (if it does not exist) by mapping the Python schema\n        # to the corresponding SQL schema:\n        sql_schema = ',\\n'.join([\n            f'  {col_name} {self._to_sql_type(type_info)}' for col_name,\n            type_info in zip(self._col_names, self.schema.values())\n        ])\n        query = (f'CREATE TABLE IF NOT EXISTS {self.name} (\\n'\n                 f'  id INTEGER PRIMARY KEY,\\n'\n                 f'{sql_schema}\\n'\n                 f')')\n        self.cursor.execute(query)\n\n    def connect(self) -> None:\n        import sqlite3\n        self._connection = sqlite3.connect(self.path)\n        self._cursor = self._connection.cursor()\n\n    def close(self) -> None:\n        if self._connection is not None:\n            self._connection.commit()\n            self._connection.close()\n            self._connection = None\n            self._cursor = None\n\n    @property\n    def connection(self) -> Any:\n        if self._connection is None:\n            raise RuntimeError(\"No open database connection\")\n        return self._connection\n\n    @property\n    def cursor(self) -> Any:\n        if self._cursor is None:\n            raise RuntimeError(\"No open database connection\")\n        return self._cursor\n\n    def insert(self, index: int, data: Any) -> None:\n        query = (f'INSERT INTO {self.name} '\n                 f'(id, {self._joined_col_names}) '\n                 f'VALUES (?, {self._dummies})')\n        self.cursor.execute(query, (index, *self._serialize(data)))\n        self.connection.commit()\n\n    def _multi_insert(\n        self,\n        indices: Union[Sequence[int], Tensor, range],\n        data_list: Sequence[Any],\n    ) -> None:\n        if isinstance(indices, Tensor):\n            indices = indices.tolist()\n\n        data_list = [(index, *self._serialize(data))\n                     for index, data in zip(indices, data_list)]\n\n        query = (f'INSERT INTO {self.name} '\n                 f'(id, {self._joined_col_names}) '\n                 f'VALUES (?, {self._dummies})')\n        self.cursor.executemany(query, data_list)\n        self.connection.commit()\n\n    def get(self, index: int) -> Any:\n        query = (f'SELECT {self._joined_col_names} FROM {self.name} '\n                 f'WHERE id = ?')\n        self.cursor.execute(query, (index, ))\n        return self._deserialize(self.cursor.fetchone())\n\n    def multi_get(\n        self,\n        indices: Union[Sequence[int], Tensor, slice, range],\n        batch_size: Optional[int] = None,\n    ) -> List[Any]:\n\n        if isinstance(indices, slice):\n            indices = self.slice_to_range(indices)\n        elif isinstance(indices, Tensor):\n            indices = indices.tolist()\n\n        # We create a temporary ID table to then perform an INNER JOIN.\n        # This avoids having a long IN clause and guarantees sorted outputs:\n        join_table_name = f'{self.name}__join'\n        # Temporary tables do not lock the database.\n        query = (f'CREATE TEMP TABLE {join_table_name} (\\n'\n                 f'  id INTEGER,\\n'\n                 f'  row_id INTEGER\\n'\n                 f')')\n        self.cursor.execute(query)\n\n        query = f'INSERT INTO {join_table_name} (id, row_id) VALUES (?, ?)'\n        self.cursor.executemany(query, zip(indices, range(len(indices))))\n        self.connection.commit()\n\n        query = f'SELECT * FROM {join_table_name}'\n        self.cursor.execute(query)\n\n        query = (f'SELECT {self._joined_col_names} '\n                 f'FROM {self.name} INNER JOIN {join_table_name} '\n                 f'ON {self.name}.id = {join_table_name}.id '\n                 f'ORDER BY {join_table_name}.row_id')\n        self.cursor.execute(query)\n\n        if batch_size is None:\n            data_list = self.cursor.fetchall()\n        else:\n            data_list = []\n            while True:\n                chunk_list = self.cursor.fetchmany(size=batch_size)\n                if len(chunk_list) == 0:\n                    break\n                data_list.extend(chunk_list)\n\n        query = f'DROP TABLE {join_table_name}'\n        self.cursor.execute(query)\n\n        return [self._deserialize(data) for data in data_list]\n\n    def __len__(self) -> int:\n        query = f'SELECT COUNT(*) FROM {self.name}'\n        self.cursor.execute(query)\n        return self.cursor.fetchone()[0]\n\n    # Helper functions ########################################################\n\n    @cached_property\n    def _col_names(self) -> List[str]:\n        return [f'COL_{key}' for key in self.schema.keys()]\n\n    @cached_property\n    def _joined_col_names(self) -> str:\n        return ', '.join(self._col_names)\n\n    @cached_property\n    def _dummies(self) -> str:\n        return ', '.join(['?'] * len(self.schema.keys()))\n\n    def _to_sql_type(self, type_info: Any) -> str:\n        if type_info == int:\n            return 'INTEGER NOT NULL'\n        if type_info == float:\n            return 'FLOAT'\n        if type_info == str:\n            return 'TEXT NOT NULL'\n        else:\n            return 'BLOB NOT NULL'\n\n    def _serialize(self, row: Any) -> List[Any]:\n        # Serializes the given input data according to `schema`:\n        # * {int, float, str}: Use as they are.\n        # * torch.Tensor: Convert into the raw byte string\n        # * object: Dump via pickle\n        # If we find a `torch.Tensor` that is not registered as such in\n        # `schema`, we modify the schema in-place for improved efficiency.\n        out: List[Any] = []\n        row_dict = self._to_dict(row)\n        for key, schema in self.schema.items():\n            col = row_dict[key]\n\n            if isinstance(col, Tensor) and not isinstance(schema, TensorInfo):\n                self.schema[key] = schema = TensorInfo(\n                    col.dtype,\n                    is_index=isinstance(col, Index),\n                    is_edge_index=isinstance(col, EdgeIndex),\n                )\n\n            if isinstance(schema, TensorInfo) and schema.is_index:\n                assert isinstance(col, Index)\n\n                meta = torch.tensor([\n                    col.dim_size if col.dim_size is not None else -1,\n                    col.is_sorted,\n                ], dtype=torch.long)\n\n                out.append(meta.numpy().tobytes() +\n                           col.as_tensor().numpy().tobytes())\n\n            elif isinstance(schema, TensorInfo) and schema.is_edge_index:\n                assert isinstance(col, EdgeIndex)\n\n                num_rows, num_cols = col.sparse_size()\n                meta = torch.tensor([\n                    num_rows if num_rows is not None else -1,\n                    num_cols if num_cols is not None else -1,\n                    SORT_ORDER_TO_INDEX[col._sort_order],\n                    col.is_undirected,\n                ], dtype=torch.long)\n\n                out.append(meta.numpy().tobytes() +\n                           col.as_tensor().numpy().tobytes())\n\n            elif isinstance(schema, TensorInfo):\n                assert isinstance(col, Tensor)\n                out.append(col.numpy().tobytes())\n\n            elif schema in {int, float, str}:\n                out.append(col)\n\n            else:\n                buffer = io.BytesIO()\n                torch.save(col, buffer)\n                out.append(buffer.getvalue())\n\n        return out\n\n    def _deserialize(self, row: Tuple[Any]) -> Any:\n        # Deserializes the DB data according to `schema`:\n        # * {int, float, str}: Use as they are.\n        # * torch.Tensor: Load raw byte string with `dtype` and `size`\n        #   information from `schema`\n        # * object: Load via pickle\n        out_dict = {}\n        for i, (key, schema) in enumerate(self.schema.items()):\n            value = row[i]\n\n            if isinstance(schema, TensorInfo) and schema.is_index:\n                meta = torch.frombuffer(value[:16], dtype=torch.long).tolist()\n                dim_size = meta[0] if meta[0] >= 0 else None\n                is_sorted = meta[1] > 0\n\n                if len(value) > 16:\n                    tensor = torch.frombuffer(value[16:], dtype=schema.dtype)\n                else:\n                    tensor = torch.empty(0, dtype=schema.dtype)\n\n                out_dict[key] = Index(\n                    tensor.view(*schema.size),\n                    dim_size=dim_size,\n                    is_sorted=is_sorted,\n                )\n\n            elif isinstance(schema, TensorInfo) and schema.is_edge_index:\n                meta = torch.frombuffer(value[:32], dtype=torch.long).tolist()\n                num_rows = meta[0] if meta[0] >= 0 else None\n                num_cols = meta[1] if meta[1] >= 0 else None\n                sort_order = INDEX_TO_SORT_ORDER[meta[2]]\n                is_undirected = meta[3] > 0\n\n                if len(value) > 32:\n                    tensor = torch.frombuffer(value[32:], dtype=schema.dtype)\n                else:\n                    tensor = torch.empty(0, dtype=schema.dtype)\n\n                out_dict[key] = EdgeIndex(\n                    tensor.view(*schema.size),\n                    sparse_size=(num_rows, num_cols),\n                    sort_order=sort_order,\n                    is_undirected=is_undirected,\n                )\n\n            elif isinstance(schema, TensorInfo):\n                if len(value) > 0:\n                    tensor = torch.frombuffer(value, dtype=schema.dtype)\n                else:\n                    tensor = torch.empty(0, dtype=schema.dtype)\n                out_dict[key] = tensor.view(*schema.size)\n\n            elif schema == float:\n                out_dict[key] = value if value is not None else float('NaN')\n\n            elif schema in {int, str}:\n                out_dict[key] = value\n\n            else:\n                out_dict[key] = torch.load(\n                    io.BytesIO(value),\n                    weights_only=False,\n                )\n\n        # In case `0` exists as integer in the schema, this means that the\n        # schema was passed as either a single entry or a tuple:\n        if 0 in self.schema:\n            if len(self.schema) == 1:\n                return out_dict[0]\n            else:\n                return tuple(out_dict.values())\n        else:  # Otherwise, return the dictionary as it is:\n            return out_dict\n\n\nclass RocksDatabase(Database):\n    r\"\"\"An index-based key/value database based on :obj:`RocksDB`.\n\n    .. note::\n        This database implementation requires the :obj:`rocksdict` package.\n\n    .. warning::\n        :class:`RocksDatabase` is currently less optimized than\n        :class:`SQLiteDatabase`.\n\n    Args:\n        path (str): The path to where the database should be saved.\n        schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of\n            the input data.\n            Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a\n            dictionary with :obj:`dtype` and :obj:`size` keys (for specifying\n            tensor data) as input, and can be nested as a tuple or dictionary.\n            Specifying the schema will improve efficiency, since by default the\n            database will use python pickling for serializing and\n            deserializing. (default: :obj:`object`)\n    \"\"\"\n    def __init__(self, path: str, schema: Schema = object) -> None:\n        super().__init__(schema)\n\n        import rocksdict\n\n        self.path = path\n\n        self._db: Optional[rocksdict.Rdict] = None\n\n        self.connect()\n\n    def connect(self) -> None:\n        import rocksdict\n        self._db = rocksdict.Rdict(\n            self.path,\n            options=rocksdict.Options(raw_mode=True),\n        )\n\n    def close(self) -> None:\n        if self._db is not None:\n            self._db.close()\n            self._db = None\n\n    @property\n    def db(self) -> Any:\n        if self._db is None:\n            raise RuntimeError(\"No open database connection\")\n        return self._db\n\n    @staticmethod\n    def to_key(index: int) -> bytes:\n        return index.to_bytes(8, byteorder='big', signed=True)\n\n    def insert(self, index: int, data: Any) -> None:\n        self.db[self.to_key(index)] = self._serialize(data)\n\n    def get(self, index: int) -> Any:\n        return self._deserialize(self.db[self.to_key(index)])\n\n    def _multi_get(self, indices: Union[Sequence[int], Tensor]) -> List[Any]:\n        if isinstance(indices, Tensor):\n            indices = indices.tolist()\n        data_list = self.db[[self.to_key(index) for index in indices]]\n        return [self._deserialize(data) for data in data_list]\n\n    # Helper functions ########################################################\n\n    def _serialize(self, row: Any) -> bytes:\n        # Ensure that data is not a view of a larger tensor:\n        if isinstance(row, Tensor):\n            row = row.clone()\n        buffer = io.BytesIO()\n        torch.save(row, buffer)\n        return buffer.getvalue()\n\n    def _deserialize(self, row: bytes) -> Any:\n        return torch.load(\n            io.BytesIO(row),\n            weights_only=False,\n        )\n"
  },
  {
    "path": "torch_geometric/data/datapipes.py",
    "content": "import copy\nfrom typing import Any, Callable, Iterator, Optional, Sequence\n\nimport torch\n\nfrom torch_geometric.data import Batch\nfrom torch_geometric.utils import from_smiles\n\ntry:\n    from torch.utils.data import IterDataPipe, functional_datapipe\n    from torch.utils.data.datapipes.iter import Batcher as IterBatcher\nexcept ImportError:\n    IterDataPipe = IterBatcher = object  # type: ignore\n\n    def functional_datapipe(name: str) -> Callable:  # type: ignore\n        return lambda cls: cls\n\n\n@functional_datapipe('batch_graphs')\nclass Batcher(IterBatcher):\n    def __init__(\n        self,\n        dp: IterDataPipe,\n        batch_size: int,\n        drop_last: bool = False,\n    ) -> None:\n        super().__init__(\n            dp,\n            batch_size=batch_size,\n            drop_last=drop_last,\n            wrapper_class=Batch.from_data_list,\n        )\n\n\n@functional_datapipe('parse_smiles')\nclass SMILESParser(IterDataPipe):\n    def __init__(\n        self,\n        dp: IterDataPipe,\n        smiles_key: str = 'smiles',\n        target_key: Optional[str] = None,\n    ) -> None:\n        super().__init__()\n        self.dp = dp\n        self.smiles_key = smiles_key\n        self.target_key = target_key\n\n    def __iter__(self) -> Iterator:\n        for d in self.dp:\n            if isinstance(d, str):\n                data = from_smiles(d)\n            elif isinstance(d, dict):\n                data = from_smiles(d[self.smiles_key])\n                if self.target_key is not None:\n                    y = d.get(self.target_key, None)\n                    if y is not None:\n                        y = float(y) if len(y) > 0 else float('NaN')\n                        data.y = torch.tensor([y], dtype=torch.float)\n            else:\n                raise ValueError(\n                    f\"'{self.__class__.__name__}' expected either a string or \"\n                    f\"a dict as input (got '{type(d)}')\")\n\n            yield data\n\n\nclass DatasetAdapter(IterDataPipe):\n    def __init__(self, dataset: Sequence[Any]) -> None:\n        super().__init__()\n        self.dataset = dataset\n        self.range = range(len(self))\n\n    def is_shardable(self) -> bool:\n        return True\n\n    def apply_sharding(self, num_shards: int, shard_idx: int) -> None:\n        self.range = range(shard_idx, len(self), num_shards)\n\n    def __iter__(self) -> Iterator:\n        for i in self.range:\n            yield self.dataset[i]\n\n    def __len__(self) -> int:\n        return len(self.dataset)\n\n\ndef functional_transform(name: str) -> Callable:\n    def wrapper(cls: Any) -> Any:\n        @functional_datapipe(name)\n        class DynamicMapper(IterDataPipe):\n            def __init__(\n                self,\n                dp: IterDataPipe,\n                *args: Any,\n                **kwargs: Any,\n            ) -> None:\n                super().__init__()\n                self.dp = dp\n                self.fn = cls(*args, **kwargs)\n\n            def __iter__(self) -> Iterator:\n                for data in self.dp:\n                    yield self.fn(copy.copy(data))\n\n        return cls\n\n    return wrapper\n"
  },
  {
    "path": "torch_geometric/data/dataset.py",
    "content": "import copy\nimport os\nimport os.path as osp\nimport re\nimport sys\nimport warnings\nfrom collections.abc import Sequence\nfrom typing import (\n    Any,\n    Callable,\n    Iterable,\n    Iterator,\n    List,\n    Optional,\n    Tuple,\n    Union,\n)\n\nimport numpy as np\nimport torch.utils.data\nfrom torch import Tensor\n\nfrom torch_geometric.data.data import BaseData\nfrom torch_geometric.io import fs\n\nIndexType = Union[slice, Tensor, np.ndarray, Sequence]\nMISSING = '???'\n\n\nclass Dataset(torch.utils.data.Dataset):\n    r\"\"\"Dataset base class for creating graph datasets.\n    See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/\n    create_dataset.html>`__ for the accompanying tutorial.\n\n    Args:\n        root (str, optional): Root directory where the dataset should be saved.\n            (optional: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in a\n            :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` object and returns a\n            transformed version.\n            The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            a :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` object and returns a\n            transformed version.\n            The data object will be transformed before being saved to disk.\n            (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in a\n            :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` object and returns a\n            boolean value, indicating whether the data object should be\n            included in the final dataset. (default: :obj:`None`)\n        log (bool, optional): Whether to print any console output while\n            downloading and processing the dataset. (default: :obj:`True`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    @property\n    def raw_file_names(self) -> Union[str, List[str], Tuple[str, ...]]:\n        r\"\"\"The name of the files in the :obj:`self.raw_dir` folder that must\n        be present in order to skip downloading.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]:\n        r\"\"\"The name of the files in the :obj:`self.processed_dir` folder that\n        must be present in order to skip processing.\n        \"\"\"\n        raise NotImplementedError\n\n    def download(self) -> None:\n        r\"\"\"Downloads the dataset to the :obj:`self.raw_dir` folder.\"\"\"\n        raise NotImplementedError\n\n    def process(self) -> None:\n        r\"\"\"Processes the dataset to the :obj:`self.processed_dir` folder.\"\"\"\n        raise NotImplementedError\n\n    def len(self) -> int:\n        r\"\"\"Returns the number of data objects stored in the dataset.\"\"\"\n        raise NotImplementedError\n\n    def get(self, idx: int) -> BaseData:\n        r\"\"\"Gets the data object at index :obj:`idx`.\"\"\"\n        raise NotImplementedError\n\n    def __init__(\n        self,\n        root: Optional[str] = None,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        log: bool = True,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__()\n\n        if isinstance(root, str):\n            root = osp.expanduser(fs.normpath(root))\n\n        self.root = root or MISSING\n        self.transform = transform\n        self.pre_transform = pre_transform\n        self.pre_filter = pre_filter\n        self.log = log\n        self._indices: Optional[Sequence] = None\n        self.force_reload = force_reload\n\n        if self.has_download:\n            self._download()\n\n        if self.has_process:\n            self._process()\n\n    def indices(self) -> Sequence:\n        return range(self.len()) if self._indices is None else self._indices\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, 'processed')\n\n    @property\n    def num_node_features(self) -> int:\n        r\"\"\"Returns the number of features per node in the dataset.\"\"\"\n        data = self[0]\n        # Do not fill cache for `InMemoryDataset`:\n        if hasattr(self, '_data_list') and self._data_list is not None:\n            self._data_list[0] = None\n        data = data[0] if isinstance(data, tuple) else data\n        if hasattr(data, 'num_node_features'):\n            return data.num_node_features\n        raise AttributeError(f\"'{data.__class__.__name__}' object has no \"\n                             f\"attribute 'num_node_features'\")\n\n    @property\n    def num_features(self) -> int:\n        r\"\"\"Returns the number of features per node in the dataset.\n        Alias for :py:attr:`~num_node_features`.\n        \"\"\"\n        return self.num_node_features\n\n    @property\n    def num_edge_features(self) -> int:\n        r\"\"\"Returns the number of features per edge in the dataset.\"\"\"\n        data = self[0]\n        # Do not fill cache for `InMemoryDataset`:\n        if hasattr(self, '_data_list') and self._data_list is not None:\n            self._data_list[0] = None\n        data = data[0] if isinstance(data, tuple) else data\n        if hasattr(data, 'num_edge_features'):\n            return data.num_edge_features\n        raise AttributeError(f\"'{data.__class__.__name__}' object has no \"\n                             f\"attribute 'num_edge_features'\")\n\n    def _infer_num_classes(self, y: Optional[Tensor]) -> int:\n        if y is None:\n            return 0\n        elif y.numel() == y.size(0) and not torch.is_floating_point(y):\n            return int(y.max()) + 1\n        elif y.numel() == y.size(0) and torch.is_floating_point(y):\n            num_classes = torch.unique(y).numel()\n            if num_classes > 2:\n                warnings.warn(\n                    \"Found floating-point labels while calling \"\n                    \"`dataset.num_classes`. Returning the number of \"\n                    \"unique elements. Please make sure that this \"\n                    \"is expected before proceeding.\", stacklevel=2)\n            return num_classes\n        else:\n            return y.size(-1)\n\n    @property\n    def num_classes(self) -> int:\n        r\"\"\"Returns the number of classes in the dataset.\"\"\"\n        # We iterate over the dataset and collect all labels to determine the\n        # maximum number of classes. Importantly, in rare cases, `__getitem__`\n        # may produce a tuple of data objects (e.g., when used in combination\n        # with `RandomLinkSplit`, so we take care of this case here as well:\n        data_list = _get_flattened_data_list([data for data in self])\n        if 'y' in data_list[0] and isinstance(data_list[0].y, Tensor):\n            y = torch.cat([data.y for data in data_list if 'y' in data], dim=0)\n        else:\n            y = torch.as_tensor([data.y for data in data_list if 'y' in data])\n\n        # Do not fill cache for `InMemoryDataset`:\n        if hasattr(self, '_data_list') and self._data_list is not None:\n            self._data_list = self.len() * [None]\n        return self._infer_num_classes(y)\n\n    @property\n    def raw_paths(self) -> List[str]:\n        r\"\"\"The absolute filepaths that must be present in order to skip\n        downloading.\n        \"\"\"\n        files = self.raw_file_names\n        # Prevent a common source of error in which `file_names` are not\n        # defined as a property.\n        if isinstance(files, Callable):\n            files = files()\n        return [osp.join(self.raw_dir, f) for f in to_list(files)]\n\n    @property\n    def processed_paths(self) -> List[str]:\n        r\"\"\"The absolute filepaths that must be present in order to skip\n        processing.\n        \"\"\"\n        files = self.processed_file_names\n        # Prevent a common source of error in which `file_names` are not\n        # defined as a property.\n        if isinstance(files, Callable):\n            files = files()\n        return [osp.join(self.processed_dir, f) for f in to_list(files)]\n\n    @property\n    def has_download(self) -> bool:\n        r\"\"\"Checks whether the dataset defines a :meth:`download` method.\"\"\"\n        return overrides_method(self.__class__, 'download')\n\n    def _download(self):\n        if files_exist(self.raw_paths):  # pragma: no cover\n            return\n\n        fs.makedirs(self.raw_dir, exist_ok=True)\n        self.download()\n\n    @property\n    def has_process(self) -> bool:\n        r\"\"\"Checks whether the dataset defines a :meth:`process` method.\"\"\"\n        return overrides_method(self.__class__, 'process')\n\n    def _process(self):\n        f = osp.join(self.processed_dir, 'pre_transform.pt')\n        if not self.force_reload and osp.exists(f) and torch.load(\n                f, weights_only=False) != _repr(self.pre_transform):\n            warnings.warn(\n                \"The `pre_transform` argument differs from the one used in \"\n                \"the pre-processed version of this dataset. If you want to \"\n                \"make use of another pre-processing technique, pass \"\n                \"`force_reload=True` explicitly to reload the dataset.\",\n                stacklevel=2)\n\n        f = osp.join(self.processed_dir, 'pre_filter.pt')\n        if not self.force_reload and osp.exists(f) and torch.load(\n                f, weights_only=False) != _repr(self.pre_filter):\n            warnings.warn(\n                \"The `pre_filter` argument differs from the one used in \"\n                \"the pre-processed version of this dataset. If you want to \"\n                \"make use of another pre-fitering technique, pass \"\n                \"`force_reload=True` explicitly to reload the dataset.\",\n                stacklevel=2)\n\n        if not self.force_reload and files_exist(self.processed_paths):\n            return\n\n        if self.log and 'PYTEST_CURRENT_TEST' not in os.environ:\n            print('Processing...', file=sys.stderr)\n\n        fs.makedirs(self.processed_dir, exist_ok=True)\n        self.process()\n\n        path = osp.join(self.processed_dir, 'pre_transform.pt')\n        fs.torch_save(_repr(self.pre_transform), path)\n        path = osp.join(self.processed_dir, 'pre_filter.pt')\n        fs.torch_save(_repr(self.pre_filter), path)\n\n        if self.log and 'PYTEST_CURRENT_TEST' not in os.environ:\n            print('Done!', file=sys.stderr)\n\n    def __len__(self) -> int:\n        r\"\"\"The number of examples in the dataset.\"\"\"\n        return len(self.indices())\n\n    def __getitem__(\n        self,\n        idx: Union[int, np.integer, IndexType],\n    ) -> Union['Dataset', BaseData]:\n        r\"\"\"In case :obj:`idx` is of type integer, will return the data object\n        at index :obj:`idx` (and transforms it in case :obj:`transform` is\n        present).\n        In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a\n        tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or\n        bool, will return a subset of the dataset at the specified indices.\n        \"\"\"\n        if (isinstance(idx, (int, np.integer))\n                or (isinstance(idx, Tensor) and idx.dim() == 0)\n                or (isinstance(idx, np.ndarray) and np.isscalar(idx))):\n\n            data = self.get(self.indices()[idx])\n            data = data if self.transform is None else self.transform(data)\n            return data\n\n        else:\n            return self.index_select(idx)\n\n    def __iter__(self) -> Iterator[BaseData]:\n        for i in range(len(self)):\n            yield self[i]\n\n    def index_select(self, idx: IndexType) -> 'Dataset':\n        r\"\"\"Creates a subset of the dataset from specified indices :obj:`idx`.\n        Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a\n        list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type\n        long or bool.\n        \"\"\"\n        indices = self.indices()\n\n        if isinstance(idx, slice):\n            start, stop, step = idx.start, idx.stop, idx.step\n            # Allow floating-point slicing, e.g., dataset[:0.9]\n            if isinstance(start, float):\n                start = round(start * len(self))\n            if isinstance(stop, float):\n                stop = round(stop * len(self))\n            idx = slice(start, stop, step)\n\n            indices = indices[idx]\n\n        elif isinstance(idx, Tensor) and idx.dtype == torch.long:\n            return self.index_select(idx.flatten().tolist())\n\n        elif isinstance(idx, Tensor) and idx.dtype == torch.bool:\n            idx = idx.flatten().nonzero(as_tuple=False)\n            return self.index_select(idx.flatten().tolist())\n\n        elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:\n            return self.index_select(idx.flatten().tolist())\n\n        elif isinstance(idx, np.ndarray) and idx.dtype == bool:\n            idx = idx.flatten().nonzero()[0]\n            return self.index_select(idx.flatten().tolist())\n\n        elif isinstance(idx, Sequence) and not isinstance(idx, str):\n            indices = [indices[i] for i in idx]\n\n        else:\n            raise IndexError(\n                f\"Only slices (':'), list, tuples, torch.tensor and \"\n                f\"np.ndarray of dtype long or bool are valid indices (got \"\n                f\"'{type(idx).__name__}')\")\n\n        dataset = copy.copy(self)\n        dataset._indices = indices\n        return dataset\n\n    def shuffle(\n        self,\n        return_perm: bool = False,\n    ) -> Union['Dataset', Tuple['Dataset', Tensor]]:\n        r\"\"\"Randomly shuffles the examples in the dataset.\n\n        Args:\n            return_perm (bool, optional): If set to :obj:`True`, will also\n                return the random permutation used to shuffle the dataset.\n                (default: :obj:`False`)\n        \"\"\"\n        perm = torch.randperm(len(self))\n        dataset = self.index_select(perm)\n        return (dataset, perm) if return_perm is True else dataset\n\n    def __repr__(self) -> str:\n        arg_repr = str(len(self)) if len(self) > 1 else ''\n        return f'{self.__class__.__name__}({arg_repr})'\n\n    def get_summary(self) -> Any:\n        r\"\"\"Collects summary statistics for the dataset.\"\"\"\n        from torch_geometric.data.summary import Summary\n        return Summary.from_dataset(self)\n\n    def print_summary(self, fmt: str = \"psql\") -> None:\n        r\"\"\"Prints summary statistics of the dataset to the console.\n\n        Args:\n            fmt (str, optional): Summary tables format. Available table formats\n                can be found `here <https://github.com/astanin/python-tabulate?\n                tab=readme-ov-file#table-format>`__. (default: :obj:`\"psql\"`)\n        \"\"\"\n        print(self.get_summary().format(fmt=fmt))\n\n    def to_datapipe(self) -> Any:\n        r\"\"\"Converts the dataset into a :class:`torch.utils.data.DataPipe`.\n\n        The returned instance can then be used with :pyg:`PyG's` built-in\n        :class:`DataPipes` for batching graphs as follows:\n\n        .. code-block:: python\n\n            from torch_geometric.datasets import QM9\n\n            dp = QM9(root='./data/QM9/').to_datapipe()\n            dp = dp.batch_graphs(batch_size=2, drop_last=True)\n\n            for batch in dp:\n                pass\n\n        See the `PyTorch tutorial\n        <https://pytorch.org/data/main/tutorial.html>`_ for further background\n        on DataPipes.\n        \"\"\"\n        from torch_geometric.data.datapipes import DatasetAdapter\n\n        return DatasetAdapter(self)\n\n\ndef overrides_method(cls, method_name: str) -> bool:\n    from torch_geometric.data import InMemoryDataset\n\n    if method_name in cls.__dict__:\n        return True\n\n    out = False\n    for base in cls.__bases__:\n        if base != Dataset and base != InMemoryDataset:\n            out |= overrides_method(base, method_name)\n    return out\n\n\ndef to_list(value: Any) -> Sequence:\n    if isinstance(value, Sequence) and not isinstance(value, str):\n        return value\n    else:\n        return [value]\n\n\ndef files_exist(files: List[str]) -> bool:\n    # NOTE: We return `False` in case `files` is empty, leading to a\n    # re-processing of files on every instantiation.\n    return len(files) != 0 and all([fs.exists(f) for f in files])\n\n\ndef _repr(obj: Any) -> str:\n    if obj is None:\n        return 'None'\n    return re.sub('(<.*?)\\\\s.*(>)', r'\\1\\2', str(obj))\n\n\ndef _get_flattened_data_list(data_list: Iterable[Any]) -> List[BaseData]:\n    outs: List[BaseData] = []\n    for data in data_list:\n        if isinstance(data, BaseData):\n            outs.append(data)\n        elif isinstance(data, (tuple, list)):\n            outs.extend(_get_flattened_data_list(data))\n        elif isinstance(data, dict):\n            outs.extend(_get_flattened_data_list(data.values()))\n    return outs\n"
  },
  {
    "path": "torch_geometric/data/download.py",
    "content": "import os\nimport os.path as osp\nimport ssl\nimport sys\nimport urllib\nfrom typing import Optional\n\nimport fsspec\n\nfrom torch_geometric.io import fs\n\n\ndef download_url(\n    url: str,\n    folder: str,\n    log: bool = True,\n    filename: Optional[str] = None,\n):\n    r\"\"\"Downloads the content of an URL to a specific folder.\n\n    Args:\n        url (str): The URL.\n        folder (str): The folder.\n        log (bool, optional): If :obj:`False`, will not print anything to the\n            console. (default: :obj:`True`)\n        filename (str, optional): The filename of the downloaded file. If set\n            to :obj:`None`, will correspond to the filename given by the URL.\n            (default: :obj:`None`)\n    \"\"\"\n    if filename is None:\n        filename = url.rpartition('/')[2]\n        filename = filename if filename[0] == '?' else filename.split('?')[0]\n\n    path = osp.join(folder, filename)\n\n    if fs.exists(path):  # pragma: no cover\n        if log and 'PYTEST_CURRENT_TEST' not in os.environ:\n            print(f'Using existing file {filename}', file=sys.stderr)\n        return path\n\n    if log and 'PYTEST_CURRENT_TEST' not in os.environ:\n        print(f'Downloading {url}', file=sys.stderr)\n\n    os.makedirs(folder, exist_ok=True)\n\n    context = ssl._create_unverified_context()\n    data = urllib.request.urlopen(url, context=context)\n\n    with fsspec.open(path, 'wb') as f:\n        # workaround for https://bugs.python.org/issue42853\n        while True:\n            chunk = data.read(10 * 1024 * 1024)\n            if not chunk:\n                break\n            f.write(chunk)\n\n    return path\n\n\ndef download_google_url(\n    id: str,\n    folder: str,\n    filename: str,\n    log: bool = True,\n):\n    r\"\"\"Downloads the content of a Google Drive ID to a specific folder.\"\"\"\n    url = f'https://drive.usercontent.google.com/download?id={id}&confirm=t'\n    return download_url(url, folder, log, filename)\n"
  },
  {
    "path": "torch_geometric/data/extract.py",
    "content": "import bz2\nimport gzip\nimport os\nimport os.path as osp\nimport sys\nimport tarfile\nimport zipfile\n\n\ndef maybe_log(path: str, log: bool = True) -> None:\n    if log and 'PYTEST_CURRENT_TEST' not in os.environ:\n        print(f'Extracting {path}', file=sys.stderr)\n\n\ndef extract_tar(\n    path: str,\n    folder: str,\n    mode: str = 'r:gz',\n    log: bool = True,\n) -> None:\n    r\"\"\"Extracts a tar archive to a specific folder.\n\n    Args:\n        path (str): The path to the tar archive.\n        folder (str): The folder.\n        mode (str, optional): The compression mode. (default: :obj:`\"r:gz\"`)\n        log (bool, optional): If :obj:`False`, will not print anything to the\n            console. (default: :obj:`True`)\n    \"\"\"\n    maybe_log(path, log)\n    with tarfile.open(path, mode) as f:\n        f.extractall(folder, filter='data')\n\n\ndef extract_zip(path: str, folder: str, log: bool = True) -> None:\n    r\"\"\"Extracts a zip archive to a specific folder.\n\n    Args:\n        path (str): The path to the tar archive.\n        folder (str): The folder.\n        log (bool, optional): If :obj:`False`, will not print anything to the\n            console. (default: :obj:`True`)\n    \"\"\"\n    maybe_log(path, log)\n    with zipfile.ZipFile(path, 'r') as f:\n        f.extractall(folder)\n\n\ndef extract_bz2(path: str, folder: str, log: bool = True) -> None:\n    r\"\"\"Extracts a bz2 archive to a specific folder.\n\n    Args:\n        path (str): The path to the tar archive.\n        folder (str): The folder.\n        log (bool, optional): If :obj:`False`, will not print anything to the\n            console. (default: :obj:`True`)\n    \"\"\"\n    maybe_log(path, log)\n    path = osp.abspath(path)\n    with bz2.open(path, 'r') as r:\n        with open(osp.join(folder, '.'.join(path.split('.')[:-1])), 'wb') as w:\n            w.write(r.read())\n\n\ndef extract_gz(path: str, folder: str, log: bool = True) -> None:\n    r\"\"\"Extracts a gz archive to a specific folder.\n\n    Args:\n        path (str): The path to the tar archive.\n        folder (str): The folder.\n        log (bool, optional): If :obj:`False`, will not print anything to the\n            console. (default: :obj:`True`)\n    \"\"\"\n    maybe_log(path, log)\n    path = osp.abspath(path)\n    with gzip.open(path, 'r') as r:\n        with open(osp.join(folder, '.'.join(path.split('.')[:-1])), 'wb') as w:\n            w.write(r.read())\n"
  },
  {
    "path": "torch_geometric/data/feature_store.py",
    "content": "r\"\"\"This class defines the abstraction for a backend-agnostic feature store.\nThe goal of the feature store is to abstract away all node and edge feature\nmemory management so that varying implementations can allow for independent\nscale-out.\n\nThis particular feature store abstraction makes a few key assumptions:\n* The features we care about storing are node and edge features of a graph.\n  To this end, the attributes that the feature store supports include a\n  `group_name` (e.g. a heterogeneous node name or a heterogeneous edge type),\n  an `attr_name` (e.g. `x` or `edge_attr`), and an index.\n* A feature can be uniquely identified from any associated attributes specified\n  in `TensorAttr`.\n\nIt is the job of a feature store implementer class to handle these assumptions\nproperly. For example, a simple in-memory feature store implementation may\nconcatenate all metadata values with a feature index and use this as a unique\nindex in a KV store. More complicated implementations may choose to partition\nfeatures in interesting manners based on the provided metadata.\n\nMajor TODOs for future implementation:\n* Async `put` and `get` functionality\n\"\"\"\nimport copy\nfrom abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom enum import Enum\nfrom typing import Any, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import FeatureTensorType, NodeType\nfrom torch_geometric.utils.mixin import CastMixin\n\n# We allow indexing with a tensor, numpy array, Python slicing, or a single\n# integer index.\nIndexType = Union[torch.Tensor, np.ndarray, slice, int]\n\n\nclass _FieldStatus(Enum):\n    UNSET = None\n\n\n@dataclass\nclass TensorAttr(CastMixin):\n    r\"\"\"Defines the attributes of a :class:`FeatureStore` tensor.\n    It holds all the parameters necessary to uniquely identify a tensor from\n    the :class:`FeatureStore`.\n\n    Note that the order of the attributes is important; this is the order in\n    which attributes must be provided for indexing calls. :class:`FeatureStore`\n    implementations can define a different ordering by overriding\n    :meth:`TensorAttr.__init__`.\n    \"\"\"\n\n    # The group name that the tensor corresponds to. Defaults to UNSET.\n    group_name: Optional[NodeType] = _FieldStatus.UNSET\n\n    # The name of the tensor within its group. Defaults to UNSET.\n    attr_name: Optional[str] = _FieldStatus.UNSET\n\n    # The node indices the rows of the tensor correspond to. Defaults to UNSET.\n    index: Optional[IndexType] = _FieldStatus.UNSET\n\n    # Convenience methods #####################################################\n\n    def is_set(self, key: str) -> bool:\n        r\"\"\"Whether an attribute is set in :obj:`TensorAttr`.\"\"\"\n        assert key in self.__dataclass_fields__\n        return getattr(self, key) != _FieldStatus.UNSET\n\n    def is_fully_specified(self) -> bool:\n        r\"\"\"Whether the :obj:`TensorAttr` has no unset fields.\"\"\"\n        return all([self.is_set(key) for key in self.__dataclass_fields__])\n\n    def update(self, attr: 'TensorAttr') -> 'TensorAttr':\n        r\"\"\"Updates an :class:`TensorAttr` with set attributes from another\n        :class:`TensorAttr`.\n        \"\"\"\n        for key in self.__dataclass_fields__:\n            if attr.is_set(key):\n                setattr(self, key, getattr(attr, key))\n        return self\n\n\nclass AttrView(CastMixin):\n    r\"\"\"Defines a view of a :class:`FeatureStore` that is obtained from a\n    specification of attributes on the feature store. The view stores a\n    reference to the backing feature store as well as a :class:`TensorAttr`\n    object that represents the view's state.\n\n    Users can create views either using the :class:`AttrView` constructor,\n    :meth:`FeatureStore.view`, or by incompletely indexing a feature store.\n    For example, the following calls all create views:\n\n    .. code-block:: python\n\n        store[group_name]\n        store[group_name].feat\n        store[group_name, feat]\n\n    While the following calls all materialize those views and produce tensors\n    by either calling the view or fully-specifying the view:\n\n    .. code-block:: python\n\n        store[group_name]()\n        store[group_name].feat[index]\n        store[group_name, feat][index]\n    \"\"\"\n    def __init__(self, store: 'FeatureStore', attr: TensorAttr):\n        self.__dict__['_store'] = store\n        self.__dict__['_attr'] = attr\n\n    # Advanced indexing #######################################################\n\n    def __getattr__(self, key: Any) -> Union['AttrView', FeatureTensorType]:\n        r\"\"\"Sets the first unset field of the backing :class:`TensorAttr`\n        object to the attribute.\n\n        This allows for :class:`AttrView` to be indexed by different values of\n        attributes, in order.\n        In particular, for a feature store that we want to index by\n        :obj:`group_name` and :obj:`attr_name`, the following code will do so:\n\n        .. code-block:: python\n\n            store[group, attr]\n            store[group].attr\n            store.group.attr\n        \"\"\"\n        out = copy.copy(self)\n\n        # Find the first attribute name that is UNSET:\n        attr_name: Optional[str] = None\n        for field in out._attr.__dataclass_fields__:\n            if getattr(out._attr, field) == _FieldStatus.UNSET:\n                attr_name = field\n                break\n\n        if attr_name is None:\n            raise AttributeError(f\"Cannot access attribute '{key}' on view \"\n                                 f\"'{out}' as all attributes have already \"\n                                 f\"been set in this view\")\n\n        setattr(out._attr, attr_name, key)\n\n        if out._attr.is_fully_specified():\n            return out._store.get_tensor(out._attr)\n\n        return out\n\n    def __getitem__(self, key: Any) -> Union['AttrView', FeatureTensorType]:\n        r\"\"\"Sets the first unset field of the backing :class:`TensorAttr`\n        object to the attribute via indexing.\n\n        This allows for :class:`AttrView` to be indexed by different values of\n        attributes, in order.\n        In particular, for a feature store that we want to index by\n        :obj:`group_name` and :obj:`attr_name`, the following code will do so:\n\n        .. code-block:: python\n\n            store[group, attr]\n            store[group][attr]\n\n        \"\"\"\n        return self.__getattr__(key)\n\n    # Setting attributes ######################################################\n\n    def __setattr__(self, key: str, value: Any):\n        r\"\"\"Supports attribute assignment to the backing :class:`TensorAttr` of\n        an :class:`AttrView`.\n\n        This allows for :class:`AttrView` objects to set their backing\n        attribute values.\n        In particular, the following operation sets the :obj:`index` of an\n        :class:`AttrView`:\n\n        .. code-block:: python\n\n            view = store.view(group_name)\n            view.index = torch.tensor([1, 2, 3])\n        \"\"\"\n        if key not in self._attr.__dataclass_fields__:\n            raise ValueError(f\"Attempted to set nonexistent attribute '{key}' \"\n                             f\"(acceptable attributes are \"\n                             f\"{self._attr.__dataclass_fields__})\")\n\n        setattr(self._attr, key, value)\n\n    def __setitem__(self, key: str, value: Any):\n        r\"\"\"Supports attribute assignment to the backing :class:`TensorAttr` of\n        an :class:`AttrView` via indexing.\n\n        This allows for :class:`AttrView` objects to set their backing\n        attribute values.\n        In particular, the following operation sets the `index` of an\n        :class:`AttrView`:\n\n        .. code-block:: python\n\n            view = store.view(TensorAttr(group_name))\n            view['index'] = torch.tensor([1, 2, 3])\n        \"\"\"\n        self.__setattr__(key, value)\n\n    # Miscellaneous built-ins #################################################\n\n    def __call__(self) -> FeatureTensorType:\n        r\"\"\"Supports :class:`AttrView` as a callable to force retrieval from\n        the currently specified attributes.\n\n        In particular, this passes the current :class:`TensorAttr` object to a\n        GET call, regardless of whether all attributes have been specified.\n        It returns the result of this call.\n        In particular, the following operation returns a tensor by performing a\n        GET operation on the backing feature store:\n\n        .. code-block:: python\n\n            store[group_name, attr_name]()\n        \"\"\"\n        attr = copy.copy(self._attr)\n        for key in attr.__dataclass_fields__:  # Set all UNSET values to None.\n            if not attr.is_set(key):\n                setattr(attr, key, None)\n        return self._store.get_tensor(attr)\n\n    def __copy__(self) -> 'AttrView':\n        out = self.__class__.__new__(self.__class__)\n        for key, value in self.__dict__.items():\n            out.__dict__[key] = value\n        out.__dict__['_attr'] = copy.copy(out.__dict__['_attr'])\n        return out\n\n    def __eq__(self, obj: Any) -> bool:\n        r\"\"\"Compares two :class:`AttrView` objects by checking equality of\n        their :class:`FeatureStore` references and :class:`TensorAttr`\n        attributes.\n        \"\"\"\n        if not isinstance(obj, AttrView):\n            return False\n        return self._store == obj._store and self._attr == obj._attr\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(store={self._store}, '\n                f'attr={self._attr})')\n\n\n# TODO (manan, matthias) Ideally, we want to let `FeatureStore` inherit from\n# `MutableMapping` to clearly indicate its behavior and usage to the user.\n# However, having `MutableMapping` as a base class leads to strange behavior\n# in combination with PyTorch and PyTorch Lightning, in particular since these\n# libraries use customized logic during mini-batch for `Mapping` base classes.\n\n\nclass FeatureStore(ABC):\n    r\"\"\"An abstract base class to access features from a remote feature store.\n\n    Args:\n        tensor_attr_cls (TensorAttr, optional): A user-defined\n            :class:`TensorAttr` class to customize the required attributes and\n            their ordering to unique identify tensor values.\n            (default: :obj:`None`)\n    \"\"\"\n    _tensor_attr_cls: TensorAttr\n\n    def __init__(self, tensor_attr_cls: Optional[Any] = None):\n        super().__init__()\n        self.__dict__['_tensor_attr_cls'] = tensor_attr_cls or TensorAttr\n\n    # Core (CRUD) #############################################################\n\n    @abstractmethod\n    def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:\n        r\"\"\"To be implemented by :class:`FeatureStore` subclasses.\"\"\"\n\n    def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool:\n        r\"\"\"Synchronously adds a :obj:`tensor` to the :class:`FeatureStore`.\n        Returns whether insertion was successful.\n\n        Args:\n            tensor (torch.Tensor or np.ndarray): The feature tensor to be\n                added.\n            *args: Arguments passed to :class:`TensorAttr`.\n            **kwargs: Keyword arguments passed to :class:`TensorAttr`.\n\n        Raises:\n            ValueError: If the input :class:`TensorAttr` is not fully\n                specified.\n        \"\"\"\n        attr = self._tensor_attr_cls.cast(*args, **kwargs)\n        if not attr.is_fully_specified():\n            raise ValueError(f\"The input TensorAttr '{attr}' is not fully \"\n                             f\"specified. Please fully-specify the input by \"\n                             f\"specifying all 'UNSET' fields\")\n        return self._put_tensor(tensor, attr)\n\n    @abstractmethod\n    def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:\n        r\"\"\"To be implemented by :class:`FeatureStore` subclasses.\"\"\"\n\n    def get_tensor(\n        self,\n        *args,\n        convert_type: bool = False,\n        **kwargs,\n    ) -> FeatureTensorType:\n        r\"\"\"Synchronously obtains a :class:`tensor` from the\n        :class:`FeatureStore`.\n\n        Args:\n            *args: Arguments passed to :class:`TensorAttr`.\n            convert_type (bool, optional): Whether to convert the type of the\n                output tensor to the type of the attribute index.\n                (default: :obj:`False`)\n            **kwargs: Keyword arguments passed to :class:`TensorAttr`.\n\n        Raises:\n            ValueError: If the input :class:`TensorAttr` is not fully\n                specified.\n        \"\"\"\n        attr = self._tensor_attr_cls.cast(*args, **kwargs)\n        if not attr.is_fully_specified():\n            raise ValueError(f\"The input TensorAttr '{attr}' is not fully \"\n                             f\"specified. Please fully-specify the input by \"\n                             f\"specifying all 'UNSET' fields.\")\n\n        tensor = self._get_tensor(attr)\n        if convert_type:\n            tensor = self._to_type(attr, tensor)\n        return tensor\n\n    def _multi_get_tensor(\n        self,\n        attrs: List[TensorAttr],\n    ) -> List[Optional[FeatureTensorType]]:\n        r\"\"\"To be implemented by :class:`FeatureStore` subclasses.\"\"\"\n        return [self._get_tensor(attr) for attr in attrs]\n\n    def multi_get_tensor(\n        self,\n        attrs: List[TensorAttr],\n        convert_type: bool = False,\n    ) -> List[FeatureTensorType]:\n        r\"\"\"Synchronously obtains a list of tensors from the\n        :class:`FeatureStore` for each tensor associated with the attributes in\n        :obj:`attrs`.\n\n        .. note::\n            The default implementation simply iterates over all calls to\n            :meth:`get_tensor`. Implementer classes that can provide\n            additional, more performant functionality are recommended to\n            to override this method.\n\n        Args:\n            attrs (List[TensorAttr]): A list of input :class:`TensorAttr`\n                objects that identify the tensors to obtain.\n            convert_type (bool, optional): Whether to convert the type of the\n                output tensor to the type of the attribute index.\n                (default: :obj:`False`)\n\n        Raises:\n            ValueError: If any input :class:`TensorAttr` is not fully\n                specified.\n        \"\"\"\n        attrs = [self._tensor_attr_cls.cast(attr) for attr in attrs]\n        bad_attrs = [attr for attr in attrs if not attr.is_fully_specified()]\n        if len(bad_attrs) > 0:\n            raise ValueError(\n                f\"The input TensorAttr(s) '{bad_attrs}' are not fully \"\n                f\"specified. Please fully-specify them by specifying all \"\n                f\"'UNSET' fields\")\n\n        tensors = self._multi_get_tensor(attrs)\n        if convert_type:\n            tensors = [\n                self._to_type(attr, tensor)\n                for attr, tensor in zip(attrs, tensors)\n            ]\n        return tensors\n\n    @abstractmethod\n    def _remove_tensor(self, attr: TensorAttr) -> bool:\n        r\"\"\"To be implemented by :obj:`FeatureStore` subclasses.\"\"\"\n\n    def remove_tensor(self, *args, **kwargs) -> bool:\n        r\"\"\"Removes a tensor from the :class:`FeatureStore`.\n        Returns whether deletion was successful.\n\n        Args:\n            *args: Arguments passed to :class:`TensorAttr`.\n            **kwargs: Keyword arguments passed to :class:`TensorAttr`.\n\n        Raises:\n            ValueError: If the input :class:`TensorAttr` is not fully\n                specified.\n        \"\"\"\n        attr = self._tensor_attr_cls.cast(*args, **kwargs)\n        if not attr.is_fully_specified():\n            raise ValueError(f\"The input TensorAttr '{attr}' is not fully \"\n                             f\"specified. Please fully-specify the input by \"\n                             f\"specifying all 'UNSET' fields.\")\n        return self._remove_tensor(attr)\n\n    def update_tensor(self, tensor: FeatureTensorType, *args,\n                      **kwargs) -> bool:\n        r\"\"\"Updates a :obj:`tensor` in the :class:`FeatureStore` with a new\n        value. Returns whether the update was successful.\n\n        .. note::\n            Implementer classes can choose to define more efficient update\n            methods; the default performs a removal and insertion.\n\n        Args:\n            tensor (torch.Tensor or np.ndarray): The feature tensor to be\n                updated.\n            *args: Arguments passed to :class:`TensorAttr`.\n            **kwargs: Keyword arguments passed to :class:`TensorAttr`.\n        \"\"\"\n        attr = self._tensor_attr_cls.cast(*args, **kwargs)\n        self.remove_tensor(attr)\n        return self.put_tensor(tensor, attr)\n\n    # Additional methods ######################################################\n\n    @abstractmethod\n    def _get_tensor_size(self, attr: TensorAttr) -> Optional[Tuple[int, ...]]:\n        pass\n\n    def get_tensor_size(self, *args, **kwargs) -> Optional[Tuple[int, ...]]:\n        r\"\"\"Obtains the size of a tensor given its :class:`TensorAttr`, or\n        :obj:`None` if the tensor does not exist.\n        \"\"\"\n        attr = self._tensor_attr_cls.cast(*args, **kwargs)\n        if not attr.is_set('index'):\n            attr.index = None\n        return self._get_tensor_size(attr)\n\n    @abstractmethod\n    def get_all_tensor_attrs(self) -> List[TensorAttr]:\n        r\"\"\"Returns all registered tensor attributes.\"\"\"\n\n    # `AttrView` methods ######################################################\n\n    def view(self, *args, **kwargs) -> AttrView:\n        r\"\"\"Returns a view of the :class:`FeatureStore` given a not yet\n        fully-specified :class:`TensorAttr`.\n        \"\"\"\n        attr = self._tensor_attr_cls.cast(*args, **kwargs)\n        return AttrView(self, attr)\n\n    # Helper functions ########################################################\n\n    @staticmethod\n    def _to_type(\n        attr: TensorAttr,\n        tensor: FeatureTensorType,\n    ) -> FeatureTensorType:\n        if isinstance(attr.index, Tensor) and isinstance(tensor, np.ndarray):\n            return torch.from_numpy(tensor)\n        if isinstance(attr.index, np.ndarray) and isinstance(tensor, Tensor):\n            return tensor.detach().cpu().numpy()\n        return tensor\n\n    # Python built-ins ########################################################\n\n    def __setitem__(self, key: TensorAttr, value: FeatureTensorType):\n        r\"\"\"Supports :obj:`store[tensor_attr] = tensor`.\"\"\"\n        # CastMixin will handle the case of key being a tuple or TensorAttr\n        # object:\n        key = self._tensor_attr_cls.cast(key)\n        assert key.is_fully_specified()\n        self.put_tensor(value, key)\n\n    def __getitem__(self, key: TensorAttr) -> Any:\n        r\"\"\"Supports pythonic indexing into the :class:`FeatureStore`.\n\n        In particular, the following rules are followed for indexing:\n\n        * A fully-specified :obj:`key` will produce a tensor output.\n\n        * A partially-specified :obj:`key` will produce an :class:`AttrView`\n          output, which is a view on the :class:`FeatureStore`. If a view is\n          called, it will produce a tensor output from the corresponding\n          (partially specified) attributes.\n        \"\"\"\n        # CastMixin will handle the case of key being a tuple or TensorAttr:\n        attr = self._tensor_attr_cls.cast(key)\n        if attr.is_fully_specified():\n            return self.get_tensor(attr)\n        # If the view is not fully-specified, return a :class:`AttrView`:\n        return self.view(attr)\n\n    def __delitem__(self, attr: TensorAttr):\n        r\"\"\"Supports :obj:`del store[tensor_attr]`.\"\"\"\n        # CastMixin will handle the case of key being a tuple or TensorAttr\n        # object:\n        attr = self._tensor_attr_cls.cast(attr)\n        attr = copy.copy(attr)\n        for key in attr.__dataclass_fields__:  # Set all UNSET values to None.\n            if not attr.is_set(key):\n                setattr(attr, key, None)\n        self.remove_tensor(attr)\n\n    def __iter__(self):\n        raise NotImplementedError\n\n    def __eq__(self, obj: object) -> bool:\n        return id(self) == id(obj)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/data/graph_store.py",
    "content": "r\"\"\"This class defines the abstraction for a backend-agnostic graph store. The\ngoal of the graph store is to abstract away all graph edge index memory\nmanagement so that varying implementations can allow for independent scale-out.\n\nThis particular graph store abstraction makes a few key assumptions:\n* The edge indices we care about storing are represented either in COO, CSC,\n  or CSR format. They can be uniquely identified by an edge type (in PyG,\n  this is a tuple of the source node, relation type, and destination node).\n* Edge indices are static once they are stored in the graph. That is, we do not\n  support dynamic modification of edge indices once they have been inserted\n  into the graph store.\n\nIt is the job of a graph store implementer class to handle these assumptions\nproperly. For example, a simple in-memory graph store implementation may\nconcatenate all metadata values with an edge index and use this as a unique\nindex in a KV store. More complicated implementations may choose to partition\nthe graph in interesting manners based on the provided metadata.\n\"\"\"\nimport copy\nfrom abc import ABC, abstractmethod\nfrom collections import defaultdict\nfrom dataclasses import dataclass\nfrom enum import Enum\nfrom typing import Any, Dict, List, Optional, Tuple\n\nfrom torch import Tensor\n\nfrom torch_geometric.index import index2ptr, ptr2index\nfrom torch_geometric.typing import EdgeTensorType, EdgeType, OptTensor\nfrom torch_geometric.utils import index_sort\nfrom torch_geometric.utils.mixin import CastMixin\n\n# The output of converting between two types in the GraphStore is a Tuple of\n# dictionaries: row, col, and perm. The dictionaries are keyed by the edge\n# type of the input edge attribute.\n#   * The row dictionary contains the row tensor for COO, the row pointer for\n#     CSR, or the row tensor for CSC\n#   * The col dictionary contains the col tensor for COO, the col tensor for\n#     CSR, or the col pointer for CSC\n#   * The perm dictionary contains the permutation of edges that was applied\n#     in converting between formats, if applicable.\nConversionOutputType = Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor],\n                             Dict[EdgeType, OptTensor]]\n\n\nclass EdgeLayout(Enum):\n    COO = 'coo'\n    CSC = 'csc'\n    CSR = 'csr'\n\n\n@dataclass\nclass EdgeAttr(CastMixin):\n    r\"\"\"Defines the attributes of a :obj:`GraphStore` edge.\n    It holds all the parameters necessary to uniquely identify an edge from\n    the :class:`GraphStore`.\n\n    Note that the order of the attributes is important; this is the order in\n    which attributes must be provided for indexing calls. :class:`GraphStore`\n    implementations can define a different ordering by overriding\n    :meth:`EdgeAttr.__init__`.\n    \"\"\"\n\n    # The type of the edge:\n    edge_type: EdgeType\n\n    # The layout of the edge representation:\n    layout: EdgeLayout\n\n    # Whether the edge index is sorted by destination node. Useful for\n    # avoiding sorting costs when performing neighbor sampling, and only\n    # meaningful for COO (CSC is sorted and CSR is not sorted by definition):\n    is_sorted: bool = False\n\n    # The number of source and destination nodes in this edge type:\n    size: Optional[Tuple[int, int]] = None\n\n    # NOTE we define __init__ to force-cast layout\n    def __init__(\n        self,\n        edge_type: EdgeType,\n        layout: EdgeLayout,\n        is_sorted: bool = False,\n        size: Optional[Tuple[int, int]] = None,\n    ):\n        layout = EdgeLayout(layout)\n\n        if layout == EdgeLayout.CSR and is_sorted:\n            raise ValueError(\"Cannot create a 'CSR' edge attribute with \"\n                             \"option 'is_sorted=True'\")\n\n        if layout == EdgeLayout.CSC:\n            is_sorted = True\n\n        self.edge_type = edge_type\n        self.layout = layout\n        self.is_sorted = is_sorted\n        self.size = size\n\n\nclass GraphStore(ABC):\n    r\"\"\"An abstract base class to access edges from a remote graph store.\n\n    Args:\n        edge_attr_cls (EdgeAttr, optional): A user-defined\n            :class:`EdgeAttr` class to customize the required attributes and\n            their ordering to uniquely identify edges. (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, edge_attr_cls: Optional[Any] = None):\n        super().__init__()\n        self.__dict__['_edge_attr_cls'] = edge_attr_cls or EdgeAttr\n\n    # Core (CRUD) #############################################################\n\n    @abstractmethod\n    def _put_edge_index(self, edge_index: EdgeTensorType,\n                        edge_attr: EdgeAttr) -> bool:\n        r\"\"\"To be implemented by :class:`GraphStore` subclasses.\"\"\"\n\n    def put_edge_index(self, edge_index: EdgeTensorType, *args,\n                       **kwargs) -> bool:\n        r\"\"\"Synchronously adds an :obj:`edge_index` tuple to the\n        :class:`GraphStore`.\n        Returns whether insertion was successful.\n\n        Args:\n            edge_index (Tuple[torch.Tensor, torch.Tensor]): The\n                :obj:`edge_index` tuple in a format specified in\n                :class:`EdgeAttr`.\n            *args: Arguments passed to :class:`EdgeAttr`.\n            **kwargs: Keyword arguments passed to :class:`EdgeAttr`.\n        \"\"\"\n        edge_attr = self._edge_attr_cls.cast(*args, **kwargs)\n        return self._put_edge_index(edge_index, edge_attr)\n\n    @abstractmethod\n    def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:\n        r\"\"\"To be implemented by :class:`GraphStore` subclasses.\"\"\"\n\n    def get_edge_index(self, *args, **kwargs) -> EdgeTensorType:\n        r\"\"\"Synchronously obtains an :obj:`edge_index` tuple from the\n        :class:`GraphStore`.\n\n        Args:\n            *args: Arguments passed to :class:`EdgeAttr`.\n            **kwargs: Keyword arguments passed to :class:`EdgeAttr`.\n\n        Raises:\n            KeyError: If the :obj:`edge_index` corresponding to the input\n                :class:`EdgeAttr` was not found.\n        \"\"\"\n        edge_attr = self._edge_attr_cls.cast(*args, **kwargs)\n        edge_index = self._get_edge_index(edge_attr)\n        if edge_index is None:\n            raise KeyError(f\"'edge_index' for '{edge_attr}' not found\")\n        return edge_index\n\n    @abstractmethod\n    def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool:\n        r\"\"\"To be implemented by :class:`GraphStore` subclasses.\"\"\"\n\n    def remove_edge_index(self, *args, **kwargs) -> bool:\n        r\"\"\"Synchronously deletes an :obj:`edge_index` tuple from the\n        :class:`GraphStore`.\n        Returns whether deletion was successful.\n\n        Args:\n            *args: Arguments passed to :class:`EdgeAttr`.\n            **kwargs: Keyword arguments passed to :class:`EdgeAttr`.\n        \"\"\"\n        edge_attr = self._edge_attr_cls.cast(*args, **kwargs)\n        return self._remove_edge_index(edge_attr)\n\n    @abstractmethod\n    def get_all_edge_attrs(self) -> List[EdgeAttr]:\n        r\"\"\"Returns all registered edge attributes.\"\"\"\n\n    # Layout Conversion #######################################################\n\n    def coo(\n        self,\n        edge_types: Optional[List[Any]] = None,\n        store: bool = False,\n    ) -> ConversionOutputType:\n        r\"\"\"Returns the edge indices in the :class:`GraphStore` in COO format.\n\n        Args:\n            edge_types (List[Any], optional): The edge types of edge indices\n                to obtain. If set to :obj:`None`, will return the edge indices\n                of all existing edge types. (default: :obj:`None`)\n            store (bool, optional): Whether to store converted edge indices in\n                the :class:`GraphStore`. (default: :obj:`False`)\n        \"\"\"\n        return self._edges_to_layout(EdgeLayout.COO, edge_types, store)\n\n    def csr(\n        self,\n        edge_types: Optional[List[Any]] = None,\n        store: bool = False,\n    ) -> ConversionOutputType:\n        r\"\"\"Returns the edge indices in the :class:`GraphStore` in CSR format.\n\n        Args:\n            edge_types (List[Any], optional): The edge types of edge indices\n                to obtain. If set to :obj:`None`, will return the edge indices\n                of all existing edge types. (default: :obj:`None`)\n            store (bool, optional): Whether to store converted edge indices in\n                the :class:`GraphStore`. (default: :obj:`False`)\n        \"\"\"\n        return self._edges_to_layout(EdgeLayout.CSR, edge_types, store)\n\n    def csc(\n        self,\n        edge_types: Optional[List[Any]] = None,\n        store: bool = False,\n    ) -> ConversionOutputType:\n        r\"\"\"Returns the edge indices in the :class:`GraphStore` in CSC format.\n\n        Args:\n            edge_types (List[Any], optional): The edge types of edge indices\n                to obtain. If set to :obj:`None`, will return the edge indices\n                of all existing edge types. (default: :obj:`None`)\n            store (bool, optional): Whether to store converted edge indices in\n                the :class:`GraphStore`. (default: :obj:`False`)\n        \"\"\"\n        return self._edges_to_layout(EdgeLayout.CSC, edge_types, store)\n\n    # Python built-ins ########################################################\n\n    def __setitem__(self, key: EdgeAttr, value: EdgeTensorType):\n        self.put_edge_index(value, key)\n\n    def __getitem__(self, key: EdgeAttr) -> Optional[EdgeTensorType]:\n        return self.get_edge_index(key)\n\n    def __delitem__(self, key: EdgeAttr):\n        return self.remove_edge_index(key)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n\n    # Helper methods ##########################################################\n\n    def _edge_to_layout(\n        self,\n        attr: EdgeAttr,\n        layout: EdgeLayout,\n        store: bool = False,\n    ) -> Tuple[Tensor, Tensor, OptTensor]:\n\n        (row, col), perm = self.get_edge_index(attr), None\n\n        if layout == EdgeLayout.COO:  # COO output requested:\n            if attr.layout == EdgeLayout.CSR:  # CSR->COO\n                row = ptr2index(row)\n            elif attr.layout == EdgeLayout.CSC:  # CSC->COO\n                col = ptr2index(col)\n\n        elif layout == EdgeLayout.CSR:  # CSR output requested:\n            if attr.layout == EdgeLayout.CSC:  # CSC->COO\n                col = ptr2index(col)\n\n            if attr.layout != EdgeLayout.CSR:  # COO->CSR\n                num_rows = attr.size[0] if attr.size is not None else int(\n                    row.max()) + 1\n                row, perm = index_sort(row, max_value=num_rows)\n                col = col[perm]\n                row = index2ptr(row, num_rows)\n\n        else:  # CSC output requested:\n            if attr.layout == EdgeLayout.CSR:  # CSR->COO\n                row = ptr2index(row)\n\n            if attr.layout != EdgeLayout.CSC:  # COO->CSC\n                if hasattr(self, 'meta') and self.meta.get('is_hetero', False):\n                    # Hotfix for `LocalGraphStore`, where in heterogeneous\n                    # graphs, edge indices for different edge types have\n                    # continuous indices not starting at 0.\n                    num_cols = int(col.max()) + 1\n                elif attr.size is not None:\n                    num_cols = attr.size[1]\n                else:\n                    num_cols = int(col.max()) + 1\n\n                if not attr.is_sorted:  # Not sorted by destination.\n                    col, perm = index_sort(col, max_value=num_cols)\n                    row = row[perm]\n                col = index2ptr(col, num_cols)\n\n        if attr.layout != layout and store:\n            attr = copy.copy(attr)\n            attr.layout = layout\n            if perm is not None:\n                attr.is_sorted = False\n            self.put_edge_index((row, col), attr)\n\n        return row, col, perm\n\n    def _edges_to_layout(\n        self,\n        layout: EdgeLayout,\n        edge_types: Optional[List[Any]] = None,\n        store: bool = False,\n    ) -> ConversionOutputType:\n\n        edge_attrs: List[EdgeAttr] = self.get_all_edge_attrs()\n\n        if hasattr(self, 'meta'):  # `LocalGraphStore` hack.\n            is_hetero = self.meta.get('is_hetero', False)\n        else:\n            is_hetero = all(attr.edge_type is not None for attr in edge_attrs)\n\n        if not is_hetero:\n            return self._edge_to_layout(edge_attrs[0], layout, store)\n\n        # Obtain all edge attributes, grouped by type:\n        edge_type_attrs: Dict[EdgeType, List[EdgeAttr]] = defaultdict(list)\n        for attr in self.get_all_edge_attrs():\n            edge_type_attrs[attr.edge_type].append(attr)\n\n        # Check that requested edge types exist and filter:\n        if edge_types is not None:\n            for edge_type in edge_types:\n                if edge_type not in edge_type_attrs:\n                    raise ValueError(f\"The 'edge_index' of type '{edge_type}' \"\n                                     f\"was not found in the graph store.\")\n\n            edge_type_attrs = {\n                key: attr\n                for key, attr in edge_type_attrs.items() if key in edge_types\n            }\n\n        # Convert layout from its most favorable original layout:\n        row_dict, col_dict, perm_dict = {}, {}, {}\n        for edge_type, attrs in edge_type_attrs.items():\n            layouts = [attr.layout for attr in attrs]\n\n            if layout in layouts:  # No conversion needed.\n                attr = attrs[layouts.index(layout)]\n            elif EdgeLayout.COO in layouts:  # Prefer COO for conversion.\n                attr = attrs[layouts.index(EdgeLayout.COO)]\n            elif EdgeLayout.CSC in layouts:\n                attr = attrs[layouts.index(EdgeLayout.CSC)]\n            elif EdgeLayout.CSR in layouts:\n                attr = attrs[layouts.index(EdgeLayout.CSR)]\n\n            row_dict[edge_type], col_dict[edge_type], perm_dict[edge_type] = (\n                self._edge_to_layout(attr, layout, store))\n\n        return row_dict, col_dict, perm_dict\n"
  },
  {
    "path": "torch_geometric/data/hetero_data.py",
    "content": "import copy\nimport re\nimport warnings\nfrom collections import defaultdict, namedtuple\nfrom collections.abc import Mapping\nfrom itertools import chain\nfrom typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom typing_extensions import Self\n\nfrom torch_geometric import Index\nfrom torch_geometric.data import EdgeAttr, FeatureStore, GraphStore, TensorAttr\nfrom torch_geometric.data.data import BaseData, Data, size_repr, warn_or_raise\nfrom torch_geometric.data.graph_store import EdgeLayout\nfrom torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage\nfrom torch_geometric.typing import (\n    DEFAULT_REL,\n    EdgeTensorType,\n    EdgeType,\n    FeatureTensorType,\n    NodeOrEdgeType,\n    NodeType,\n    QueryType,\n    SparseTensor,\n    TensorFrame,\n    torch_frame,\n)\nfrom torch_geometric.utils import (\n    bipartite_subgraph,\n    contains_isolated_nodes,\n    is_sparse,\n    is_undirected,\n    mask_select,\n)\n\nNodeOrEdgeStorage = Union[NodeStorage, EdgeStorage]\n\n_DISPLAYED_TYPE_NAME_WARNING: bool = False\n\n\nclass HeteroData(BaseData, FeatureStore, GraphStore):\n    r\"\"\"A data object describing a heterogeneous graph, holding multiple node\n    and/or edge types in disjunct storage objects.\n    Storage objects can hold either node-level, link-level or graph-level\n    attributes.\n    In general, :class:`~torch_geometric.data.HeteroData` tries to mimic the\n    behavior of a regular **nested** :python:`Python` dictionary.\n    In addition, it provides useful functionality for analyzing graph\n    structures, and provides basic PyTorch tensor functionalities.\n\n    .. code-block::\n\n        from torch_geometric.data import HeteroData\n\n        data = HeteroData()\n\n        # Create two node types \"paper\" and \"author\" holding a feature matrix:\n        data['paper'].x = torch.randn(num_papers, num_paper_features)\n        data['author'].x = torch.randn(num_authors, num_authors_features)\n\n        # Create an edge type \"(author, writes, paper)\" and building the\n        # graph connectivity:\n        data['author', 'writes', 'paper'].edge_index = ...  # [2, num_edges]\n\n        data['paper'].num_nodes\n        >>> 23\n\n        data['author', 'writes', 'paper'].num_edges\n        >>> 52\n\n        # PyTorch tensor functionality:\n        data = data.pin_memory()\n        data = data.to('cuda:0', non_blocking=True)\n\n    Note that there exists multiple ways to create a heterogeneous graph data,\n    *e.g.*:\n\n    * To initialize a node of type :obj:`\"paper\"` holding a node feature\n      matrix :obj:`x_paper` named :obj:`x`:\n\n      .. code-block:: python\n\n        from torch_geometric.data import HeteroData\n\n        # (1) Assign attributes after initialization,\n        data = HeteroData()\n        data['paper'].x = x_paper\n\n        # or (2) pass them as keyword arguments during initialization,\n        data = HeteroData(paper={ 'x': x_paper })\n\n        # or (3) pass them as dictionaries during initialization,\n        data = HeteroData({'paper': { 'x': x_paper }})\n\n    * To initialize an edge from source node type :obj:`\"author\"` to\n      destination node type :obj:`\"paper\"` with relation type :obj:`\"writes\"`\n      holding a graph connectivity matrix :obj:`edge_index_author_paper` named\n      :obj:`edge_index`:\n\n      .. code-block:: python\n\n        # (1) Assign attributes after initialization,\n        data = HeteroData()\n        data['author', 'writes', 'paper'].edge_index = edge_index_author_paper\n\n        # or (2) pass them as keyword arguments during initialization,\n        data = HeteroData(author__writes__paper={\n            'edge_index': edge_index_author_paper\n        })\n\n        # or (3) pass them as dictionaries during initialization,\n        data = HeteroData({\n            ('author', 'writes', 'paper'):\n            { 'edge_index': edge_index_author_paper }\n        })\n    \"\"\"\n    def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs):\n        super().__init__()\n\n        self.__dict__['_global_store'] = BaseStorage(_parent=self)\n        self.__dict__['_node_store_dict'] = {}\n        self.__dict__['_edge_store_dict'] = {}\n\n        for key, value in chain((_mapping or {}).items(), kwargs.items()):\n            if '__' in key and isinstance(value, Mapping):\n                key = tuple(key.split('__'))\n\n            if isinstance(value, Mapping):\n                self[key].update(value)\n            else:\n                setattr(self, key, value)\n\n    @classmethod\n    def from_dict(cls, mapping: Dict[str, Any]) -> Self:\n        r\"\"\"Creates a :class:`~torch_geometric.data.HeteroData` object from a\n        dictionary.\n        \"\"\"\n        out = cls()\n        for key, value in mapping.items():\n            if key == '_global_store':\n                out.__dict__['_global_store'] = BaseStorage(\n                    _parent=out, **value)\n            elif isinstance(key, str):\n                out._node_store_dict[key] = NodeStorage(\n                    _parent=out, _key=key, **value)\n            else:\n                out._edge_store_dict[key] = EdgeStorage(\n                    _parent=out, _key=key, **value)\n        return out\n\n    def __getattr__(self, key: str) -> Any:\n        # `data.*_dict` => Link to node and edge stores.\n        # `data.*` => Link to the `_global_store`.\n        # Using `data.*_dict` is the same as using `collect()` for collecting\n        # nodes and edges features.\n        if hasattr(self._global_store, key):\n            return getattr(self._global_store, key)\n        elif bool(re.search('_dict$', key)):\n            return self.collect(key[:-5])\n        raise AttributeError(f\"'{self.__class__.__name__}' has no \"\n                             f\"attribute '{key}'\")\n\n    def __setattr__(self, key: str, value: Any):\n        # NOTE: We aim to prevent duplicates in node or edge types.\n        if key in self.node_types:\n            raise AttributeError(f\"'{key}' is already present as a node type\")\n        elif key in self.edge_types:\n            raise AttributeError(f\"'{key}' is already present as an edge type\")\n        setattr(self._global_store, key, value)\n\n    def __delattr__(self, key: str):\n        delattr(self._global_store, key)\n\n    def __getitem__(self, *args: QueryType) -> Any:\n        # `data[*]` => Link to either `_global_store`, _node_store_dict` or\n        # `_edge_store_dict`.\n        # If neither is present, we create a new `Storage` object for the given\n        # node/edge-type.\n        key = self._to_canonical(*args)\n\n        out = self._global_store.get(key, None)\n        if out is not None:\n            return out\n\n        if isinstance(key, tuple):\n            return self.get_edge_store(*key)\n        else:\n            return self.get_node_store(key)\n\n    def __setitem__(self, key: str, value: Any):\n        if key in self.node_types:\n            raise AttributeError(f\"'{key}' is already present as a node type\")\n        elif key in self.edge_types:\n            raise AttributeError(f\"'{key}' is already present as an edge type\")\n        self._global_store[key] = value\n\n    def __delitem__(self, *args: QueryType):\n        # `del data[*]` => Link to `_node_store_dict` or `_edge_store_dict`.\n        key = self._to_canonical(*args)\n        if key in self.edge_types:\n            del self._edge_store_dict[key]\n        elif key in self.node_types:\n            del self._node_store_dict[key]\n        else:\n            del self._global_store[key]\n\n    def __copy__(self):\n        out = self.__class__.__new__(self.__class__)\n        for key, value in self.__dict__.items():\n            out.__dict__[key] = value\n        out.__dict__['_global_store'] = copy.copy(self._global_store)\n        out._global_store._parent = out\n        out.__dict__['_node_store_dict'] = {}\n        for key, store in self._node_store_dict.items():\n            out._node_store_dict[key] = copy.copy(store)\n            out._node_store_dict[key]._parent = out\n        out.__dict__['_edge_store_dict'] = {}\n        for key, store in self._edge_store_dict.items():\n            out._edge_store_dict[key] = copy.copy(store)\n            out._edge_store_dict[key]._parent = out\n        return out\n\n    def __deepcopy__(self, memo):\n        out = self.__class__.__new__(self.__class__)\n        for key, value in self.__dict__.items():\n            out.__dict__[key] = copy.deepcopy(value, memo)\n        out._global_store._parent = out\n        for key in self._node_store_dict.keys():\n            out._node_store_dict[key]._parent = out\n        for key in out._edge_store_dict.keys():\n            out._edge_store_dict[key]._parent = out\n        return out\n\n    def __repr__(self) -> str:\n        info1 = [size_repr(k, v, 2) for k, v in self._global_store.items()]\n        info2 = [size_repr(k, v, 2) for k, v in self._node_store_dict.items()]\n        info3 = [size_repr(k, v, 2) for k, v in self._edge_store_dict.items()]\n        info = ',\\n'.join(info1 + info2 + info3)\n        info = f'\\n{info}\\n' if len(info) > 0 else info\n        return f'{self.__class__.__name__}({info})'\n\n    def stores_as(self, data: Self):\n        for node_type in data.node_types:\n            self.get_node_store(node_type)\n        for edge_type in data.edge_types:\n            self.get_edge_store(*edge_type)\n        return self\n\n    @property\n    def stores(self) -> List[BaseStorage]:\n        r\"\"\"Returns a list of all storages of the graph.\"\"\"\n        return ([self._global_store] + list(self.node_stores) +\n                list(self.edge_stores))\n\n    @property\n    def node_types(self) -> List[NodeType]:\n        r\"\"\"Returns a list of all node types of the graph.\"\"\"\n        return list(self._node_store_dict.keys())\n\n    @property\n    def node_stores(self) -> List[NodeStorage]:\n        r\"\"\"Returns a list of all node storages of the graph.\"\"\"\n        return list(self._node_store_dict.values())\n\n    @property\n    def edge_types(self) -> List[EdgeType]:\n        r\"\"\"Returns a list of all edge types of the graph.\"\"\"\n        return list(self._edge_store_dict.keys())\n\n    @property\n    def edge_stores(self) -> List[EdgeStorage]:\n        r\"\"\"Returns a list of all edge storages of the graph.\"\"\"\n        return list(self._edge_store_dict.values())\n\n    def node_items(self) -> List[Tuple[NodeType, NodeStorage]]:\n        r\"\"\"Returns a list of node type and node storage pairs.\"\"\"\n        return list(self._node_store_dict.items())\n\n    def edge_items(self) -> List[Tuple[EdgeType, EdgeStorage]]:\n        r\"\"\"Returns a list of edge type and edge storage pairs.\"\"\"\n        return list(self._edge_store_dict.items())\n\n    @property\n    def input_type(self) -> Optional[Union[NodeType, EdgeType]]:\n        r\"\"\"Returns the seed/input node/edge type of the graph in case it\n        refers to a sampled subgraph, *e.g.*, obtained via\n        :class:`~torch_geometric.loader.NeighborLoader` or\n        :class:`~torch_geometric.loader.LinkNeighborLoader`.\n        \"\"\"\n        for node_type, store in self.node_items():\n            if hasattr(store, 'input_id'):\n                return node_type\n        for edge_type, store in self.edge_items():\n            if hasattr(store, 'input_id'):\n                return edge_type\n        return None\n\n    def to_dict(self) -> Dict[str, Any]:\n        out_dict: Dict[str, Any] = {}\n        out_dict['_global_store'] = self._global_store.to_dict()\n        for key, store in chain(self._node_store_dict.items(),\n                                self._edge_store_dict.items()):\n            out_dict[key] = store.to_dict()\n        return out_dict\n\n    def to_namedtuple(self) -> NamedTuple:\n        field_names = list(self._global_store.keys())\n        field_values = list(self._global_store.values())\n        field_names += [\n            '__'.join(key) if isinstance(key, tuple) else key\n            for key in self.node_types + self.edge_types\n        ]\n        field_values += [\n            store.to_namedtuple()\n            for store in self.node_stores + self.edge_stores\n        ]\n        DataTuple = namedtuple('DataTuple', field_names)\n        return DataTuple(*field_values)\n\n    def set_value_dict(\n        self,\n        key: str,\n        value_dict: Dict[str, Any],\n    ) -> Self:\n        r\"\"\"Sets the values in the dictionary :obj:`value_dict` to the\n        attribute with name :obj:`key` to all node/edge types present in the\n        dictionary.\n\n        .. code-block:: python\n\n           data = HeteroData()\n\n           data.set_value_dict('x', {\n               'paper': torch.randn(4, 16),\n               'author': torch.randn(8, 32),\n           })\n\n           print(data['paper'].x)\n        \"\"\"\n        for k, v in (value_dict or {}).items():\n            self[k][key] = v\n        return self\n\n    def update(self, data: Self) -> Self:\n        for store in data.stores:\n            for key, value in store.items():\n                self[store._key][key] = value\n        return self\n\n    def __cat_dim__(self, key: str, value: Any,\n                    store: Optional[NodeOrEdgeStorage] = None, *args,\n                    **kwargs) -> Any:\n        if is_sparse(value) and ('adj' in key or 'edge_index' in key):\n            return (0, 1)\n        elif isinstance(store, EdgeStorage) and 'index' in key:\n            return -1\n        return 0\n\n    def __inc__(self, key: str, value: Any,\n                store: Optional[NodeOrEdgeStorage] = None, *args,\n                **kwargs) -> Any:\n        if 'batch' in key and isinstance(value, Tensor):\n            if isinstance(value, Index):\n                return value.get_dim_size()\n            return int(value.max()) + 1\n        elif isinstance(store, EdgeStorage) and 'index' in key:\n            return torch.tensor(store.size()).view(2, 1)\n        else:\n            return 0\n\n    @property\n    def num_nodes(self) -> Optional[int]:\n        r\"\"\"Returns the number of nodes in the graph.\"\"\"\n        return super().num_nodes\n\n    @property\n    def num_node_features(self) -> Dict[NodeType, int]:\n        r\"\"\"Returns the number of features per node type in the graph.\"\"\"\n        return {\n            key: store.num_node_features\n            for key, store in self._node_store_dict.items()\n        }\n\n    @property\n    def num_features(self) -> Dict[NodeType, int]:\n        r\"\"\"Returns the number of features per node type in the graph.\n        Alias for :py:attr:`~num_node_features`.\n        \"\"\"\n        return self.num_node_features\n\n    @property\n    def num_edge_features(self) -> Dict[EdgeType, int]:\n        r\"\"\"Returns the number of features per edge type in the graph.\"\"\"\n        return {\n            key: store.num_edge_features\n            for key, store in self._edge_store_dict.items()\n        }\n\n    def has_isolated_nodes(self) -> bool:\n        r\"\"\"Returns :obj:`True` if the graph contains isolated nodes.\"\"\"\n        edge_index, _, _ = to_homogeneous_edge_index(self)\n        return contains_isolated_nodes(edge_index, num_nodes=self.num_nodes)\n\n    def is_undirected(self) -> bool:\n        r\"\"\"Returns :obj:`True` if graph edges are undirected.\"\"\"\n        edge_index, _, _ = to_homogeneous_edge_index(self)\n        return is_undirected(edge_index, num_nodes=self.num_nodes)\n\n    def validate(self, raise_on_error: bool = True) -> bool:\n        r\"\"\"Validates the correctness of the data.\"\"\"\n        cls_name = self.__class__.__name__\n        status = True\n\n        node_types = set(self.node_types)\n        num_src_node_types = {src for src, _, _ in self.edge_types}\n        num_dst_node_types = {dst for _, _, dst in self.edge_types}\n\n        dangling_types = (num_src_node_types | num_dst_node_types) - node_types\n        if len(dangling_types) > 0:\n            status = False\n            warn_or_raise(\n                f\"The node types {dangling_types} are referenced in edge \"\n                f\"types but do not exist as node types\", raise_on_error)\n\n        dangling_types = node_types - (num_src_node_types | num_dst_node_types)\n        if len(dangling_types) > 0:\n            warn_or_raise(  # May be intended.\n                f\"The node types {dangling_types} are isolated and are not \"\n                f\"referenced by any edge type \", raise_on_error=False)\n\n        for edge_type, store in self._edge_store_dict.items():\n            src, _, dst = edge_type\n\n            num_src_nodes = self[src].num_nodes\n            num_dst_nodes = self[dst].num_nodes\n            if num_src_nodes is None:\n                status = False\n                warn_or_raise(\n                    f\"'num_nodes' is undefined in node type '{src}' of \"\n                    f\"'{cls_name}'\", raise_on_error)\n\n            if num_dst_nodes is None:\n                status = False\n                warn_or_raise(\n                    f\"'num_nodes' is undefined in node type '{dst}' of \"\n                    f\"'{cls_name}'\", raise_on_error)\n\n            if 'edge_index' in store:\n                if (store.edge_index.dim() != 2\n                        or store.edge_index.size(0) != 2):\n                    status = False\n                    warn_or_raise(\n                        f\"'edge_index' of edge type {edge_type} needs to be \"\n                        f\"of shape [2, num_edges] in '{cls_name}' (found \"\n                        f\"{store.edge_index.size()})\", raise_on_error)\n\n            if 'edge_index' in store and store.edge_index.numel() > 0:\n                if store.edge_index.min() < 0:\n                    status = False\n                    warn_or_raise(\n                        f\"'edge_index' of edge type {edge_type} contains \"\n                        f\"negative indices in '{cls_name}' \"\n                        f\"(found {int(store.edge_index.min())})\",\n                        raise_on_error)\n\n                if (num_src_nodes is not None\n                        and store.edge_index[0].max() >= num_src_nodes):\n                    status = False\n                    warn_or_raise(\n                        f\"'edge_index' of edge type {edge_type} contains \"\n                        f\"larger source indices than the number of nodes \"\n                        f\"({num_src_nodes}) of this node type in '{cls_name}' \"\n                        f\"(found {int(store.edge_index[0].max())})\",\n                        raise_on_error)\n\n                if (num_dst_nodes is not None\n                        and store.edge_index[1].max() >= num_dst_nodes):\n                    status = False\n                    warn_or_raise(\n                        f\"'edge_index' of edge type {edge_type} contains \"\n                        f\"larger destination indices than the number of nodes \"\n                        f\"({num_dst_nodes}) of this node type in '{cls_name}' \"\n                        f\"(found {int(store.edge_index[1].max())})\",\n                        raise_on_error)\n\n        return status\n\n    def connected_components(self) -> List[Self]:\n        r\"\"\"Extracts connected components of the heterogeneous graph using\n        a union-find algorithm. The components are returned as a list of\n        :class:`~torch_geometric.data.HeteroData` objects.\n\n        .. code-block::\n\n            data = HeteroData()\n            data[\"red\"].x = torch.tensor([[1.0], [2.0], [3.0], [4.0]])\n            data[\"blue\"].x = torch.tensor([[5.0], [6.0]])\n            data[\"red\", \"to\", \"red\"].edge_index = torch.tensor(\n                [[0, 1, 2, 3], [1, 0, 3, 2]], dtype=torch.long\n            )\n\n            components = data.connected_components()\n            print(len(components))\n            >>> 4\n\n            print(components[0])\n            >>> HeteroData(\n                red={x: tensor([[1.], [2.]])},\n                blue={x: tensor([[]])},\n                red, to, red={edge_index: tensor([[0, 1], [1, 0]])}\n            )\n\n        Returns:\n            List[HeteroData]: A list of connected components.\n        \"\"\"\n        # Initialize union-find structures\n        self._parents: Dict[Tuple[str, int], Tuple[str, int]] = {}\n        self._ranks: Dict[Tuple[str, int], int] = {}\n\n        # Union-Find algorithm to find connected components\n        for edge_type in self.edge_types:\n            src, _, dst = edge_type\n            edge_index = self[edge_type].edge_index\n            for src_node, dst_node in edge_index.t().tolist():\n                self._union((src, src_node), (dst, dst_node))\n\n        # Rerun _find_parent to ensure all nodes are covered correctly\n        for node_type in self.node_types:\n            for node_index in range(self[node_type].num_nodes):\n                self._find_parent((node_type, node_index))\n\n        # Group nodes by their representative parent\n        components_map = defaultdict(list)\n        for node, parent in self._parents.items():\n            components_map[parent].append(node)\n        del self._parents\n        del self._ranks\n\n        components: List[Self] = []\n        for nodes in components_map.values():\n            # Prefill subset_dict with all node types to ensure all are present\n            subset_dict = {node_type: [] for node_type in self.node_types}\n\n            # Convert the list of (node_type, node_id) tuples to a subset_dict\n            for node_type, node_id in nodes:\n                subset_dict[node_type].append(node_id)\n\n            # Convert lists to tensors\n            for node_type, node_ids in subset_dict.items():\n                subset_dict[node_type] = torch.tensor(node_ids,\n                                                      dtype=torch.long)\n\n            # Use the existing subgraph function to do all the heavy lifting\n            component_data = self.subgraph(subset_dict)\n            components.append(component_data)\n\n        return components\n\n    def debug(self):\n        pass  # TODO\n\n    ###########################################################################\n\n    def _to_canonical(self, *args: QueryType) -> NodeOrEdgeType:\n        # Converts a given `QueryType` to its \"canonical type\":\n        # 1. `relation_type` will get mapped to the unique\n        #    `(src_node_type, relation_type, dst_node_type)` tuple.\n        # 2. `(src_node_type, dst_node_type)` will get mapped to the unique\n        #    `(src_node_type, *, dst_node_type)` tuple, and\n        #    `(src_node_type, 'to', dst_node_type)` otherwise.\n        if len(args) == 1:\n            args = args[0]\n\n        if isinstance(args, str):\n            node_types = [key for key in self.node_types if key == args]\n            if len(node_types) == 1:\n                args = node_types[0]\n                return args\n\n            # Try to map to edge type based on unique relation type:\n            edge_types = [key for key in self.edge_types if key[1] == args]\n            if len(edge_types) == 1:\n                args = edge_types[0]\n                return args\n\n        elif len(args) == 2:\n            # Try to find the unique source/destination node tuple:\n            edge_types = [\n                key for key in self.edge_types\n                if key[0] == args[0] and key[-1] == args[-1]\n            ]\n            if len(edge_types) == 1:\n                args = edge_types[0]\n                return args\n            elif len(edge_types) == 0:\n                args = (args[0], DEFAULT_REL, args[1])\n                return args\n\n        return args\n\n    def metadata(self) -> Tuple[List[NodeType], List[EdgeType]]:\n        r\"\"\"Returns the heterogeneous meta-data, *i.e.* its node and edge\n        types.\n\n        .. code-block:: python\n\n            data = HeteroData()\n            data['paper'].x = ...\n            data['author'].x = ...\n            data['author', 'writes', 'paper'].edge_index = ...\n\n            print(data.metadata())\n            >>> (['paper', 'author'], [('author', 'writes', 'paper')])\n        \"\"\"\n        return self.node_types, self.edge_types\n\n    def collect(\n        self,\n        key: str,\n        allow_empty: bool = False,\n    ) -> Dict[NodeOrEdgeType, Any]:\n        r\"\"\"Collects the attribute :attr:`key` from all node and edge types.\n\n        .. code-block:: python\n\n            data = HeteroData()\n            data['paper'].x = ...\n            data['author'].x = ...\n\n            print(data.collect('x'))\n            >>> { 'paper': ..., 'author': ...}\n\n        .. note::\n\n            This is equivalent to writing :obj:`data.x_dict`.\n\n        Args:\n            key (str): The attribute to collect from all node and edge types.\n            allow_empty (bool, optional): If set to :obj:`True`, will not raise\n                an error in case the attribute does not exit in any node or\n                edge type. (default: :obj:`False`)\n        \"\"\"\n        mapping = {}\n        for subtype, store in chain(self._node_store_dict.items(),\n                                    self._edge_store_dict.items()):\n            if hasattr(store, key):\n                mapping[subtype] = getattr(store, key)\n        if not allow_empty and len(mapping) == 0:\n            raise KeyError(f\"Tried to collect '{key}' but did not find any \"\n                           f\"occurrences of it in any node and/or edge type\")\n        return mapping\n\n    def _check_type_name(self, name: str):\n        global _DISPLAYED_TYPE_NAME_WARNING\n        if not _DISPLAYED_TYPE_NAME_WARNING and '__' in name:\n            _DISPLAYED_TYPE_NAME_WARNING = True\n            warnings.warn(\n                f\"There exist type names in the \"\n                f\"'{self.__class__.__name__}' object that contain \"\n                f\"double underscores '__' (e.g., '{name}'). This \"\n                f\"may lead to unexpected behavior. To avoid any \"\n                f\"issues, ensure that your type names only contain \"\n                f\"single underscores.\", stacklevel=2)\n\n    def get_node_store(self, key: NodeType) -> NodeStorage:\n        r\"\"\"Gets the :class:`~torch_geometric.data.storage.NodeStorage` object\n        of a particular node type :attr:`key`.\n        If the storage is not present yet, will create a new\n        :class:`torch_geometric.data.storage.NodeStorage` object for the given\n        node type.\n\n        .. code-block:: python\n\n            data = HeteroData()\n            node_storage = data.get_node_store('paper')\n        \"\"\"\n        out = self._node_store_dict.get(key, None)\n        if out is None:\n            self._check_type_name(key)\n            out = NodeStorage(_parent=self, _key=key)\n            self._node_store_dict[key] = out\n        return out\n\n    def get_edge_store(self, src: str, rel: str, dst: str) -> EdgeStorage:\n        r\"\"\"Gets the :class:`~torch_geometric.data.storage.EdgeStorage` object\n        of a particular edge type given by the tuple :obj:`(src, rel, dst)`.\n        If the storage is not present yet, will create a new\n        :class:`torch_geometric.data.storage.EdgeStorage` object for the given\n        edge type.\n\n        .. code-block:: python\n\n            data = HeteroData()\n            edge_storage = data.get_edge_store('author', 'writes', 'paper')\n        \"\"\"\n        key = (src, rel, dst)\n        out = self._edge_store_dict.get(key, None)\n        if out is None:\n            self._check_type_name(rel)\n            out = EdgeStorage(_parent=self, _key=key)\n            self._edge_store_dict[key] = out\n        return out\n\n    def rename(self, name: NodeType, new_name: NodeType) -> Self:\n        r\"\"\"Renames the node type :obj:`name` to :obj:`new_name` in-place.\"\"\"\n        node_store = self._node_store_dict.pop(name)\n        node_store._key = new_name\n        self._node_store_dict[new_name] = node_store\n\n        for edge_type in self.edge_types:\n            src, rel, dst = edge_type\n            if src == name or dst == name:\n                edge_store = self._edge_store_dict.pop(edge_type)\n                src = new_name if src == name else src\n                dst = new_name if dst == name else dst\n                edge_type = (src, rel, dst)\n                edge_store._key = edge_type\n                self._edge_store_dict[edge_type] = edge_store\n\n        return self\n\n    def subgraph(self, subset_dict: Dict[NodeType, Tensor]) -> Self:\n        r\"\"\"Returns the induced subgraph containing the node types and\n        corresponding nodes in :obj:`subset_dict`.\n\n        If a node type is not a key in :obj:`subset_dict` then all nodes of\n        that type remain in the graph.\n\n        .. code-block:: python\n\n            data = HeteroData()\n            data['paper'].x = ...\n            data['author'].x = ...\n            data['conference'].x = ...\n            data['paper', 'cites', 'paper'].edge_index = ...\n            data['author', 'paper'].edge_index = ...\n            data['paper', 'conference'].edge_index = ...\n            print(data)\n            >>> HeteroData(\n                paper={ x=[10, 16] },\n                author={ x=[5, 32] },\n                conference={ x=[5, 8] },\n                (paper, cites, paper)={ edge_index=[2, 50] },\n                (author, to, paper)={ edge_index=[2, 30] },\n                (paper, to, conference)={ edge_index=[2, 25] }\n            )\n\n            subset_dict = {\n                'paper': torch.tensor([3, 4, 5, 6]),\n                'author': torch.tensor([0, 2]),\n            }\n\n            print(data.subgraph(subset_dict))\n            >>> HeteroData(\n                paper={ x=[4, 16] },\n                author={ x=[2, 32] },\n                conference={ x=[5, 8] },\n                (paper, cites, paper)={ edge_index=[2, 24] },\n                (author, to, paper)={ edge_index=[2, 5] },\n                (paper, to, conference)={ edge_index=[2, 10] }\n            )\n\n        Args:\n            subset_dict (Dict[str, LongTensor or BoolTensor]): A dictionary\n                holding the nodes to keep for each node type.\n        \"\"\"\n        data = copy.copy(self)\n        subset_dict = copy.copy(subset_dict)\n\n        for node_type, subset in subset_dict.items():\n            for key, value in self[node_type].items():\n                if key == 'num_nodes':\n                    if subset.dtype == torch.bool:\n                        data[node_type].num_nodes = int(subset.sum())\n                    else:\n                        data[node_type].num_nodes = subset.size(0)\n                elif self[node_type].is_node_attr(key):\n                    data[node_type][key] = value[subset]\n                else:\n                    data[node_type][key] = value\n\n        for edge_type in self.edge_types:\n            if 'edge_index' not in self[edge_type]:\n                continue\n\n            src, _, dst = edge_type\n\n            src_subset = subset_dict.get(src)\n            if src_subset is None:\n                src_subset = torch.arange(data[src].num_nodes)\n            dst_subset = subset_dict.get(dst)\n            if dst_subset is None:\n                dst_subset = torch.arange(data[dst].num_nodes)\n\n            edge_index, _, edge_mask = bipartite_subgraph(\n                (src_subset, dst_subset),\n                self[edge_type].edge_index,\n                relabel_nodes=True,\n                size=(self[src].num_nodes, self[dst].num_nodes),\n                return_edge_mask=True,\n            )\n\n            for key, value in self[edge_type].items():\n                if key == 'edge_index':\n                    data[edge_type].edge_index = edge_index\n                elif self[edge_type].is_edge_attr(key):\n                    data[edge_type][key] = value[edge_mask]\n                else:\n                    data[edge_type][key] = value\n\n        return data\n\n    def edge_subgraph(\n        self,\n        subset_dict: Dict[EdgeType, Tensor],\n    ) -> Self:\n        r\"\"\"Returns the induced subgraph given by the edge indices in\n        :obj:`subset_dict` for certain edge types.\n        Will currently preserve all the nodes in the graph, even if they are\n        isolated after subgraph computation.\n\n        Args:\n            subset_dict (Dict[Tuple[str, str, str], LongTensor or BoolTensor]):\n                A dictionary holding the edges to keep for each edge type.\n        \"\"\"\n        data = copy.copy(self)\n\n        for edge_type, subset in subset_dict.items():\n            edge_store, new_edge_store = self[edge_type], data[edge_type]\n            for key, value in edge_store.items():\n                if edge_store.is_edge_attr(key):\n                    dim = self.__cat_dim__(key, value, edge_store)\n                    if subset.dtype == torch.bool:\n                        new_edge_store[key] = mask_select(value, dim, subset)\n                    else:\n                        new_edge_store[key] = value.index_select(dim, subset)\n\n        return data\n\n    def node_type_subgraph(self, node_types: List[NodeType]) -> Self:\n        r\"\"\"Returns the subgraph induced by the given :obj:`node_types`, *i.e.*\n        the returned :class:`HeteroData` object only contains the node types\n        which are included in :obj:`node_types`, and only contains the edge\n        types where both end points are included in :obj:`node_types`.\n        \"\"\"\n        data = copy.copy(self)\n        for edge_type in self.edge_types:\n            src, _, dst = edge_type\n            if src not in node_types or dst not in node_types:\n                del data[edge_type]\n        for node_type in self.node_types:\n            if node_type not in node_types:\n                del data[node_type]\n        return data\n\n    def edge_type_subgraph(self, edge_types: List[EdgeType]) -> Self:\n        r\"\"\"Returns the subgraph induced by the given :obj:`edge_types`, *i.e.*\n        the returned :class:`HeteroData` object only contains the edge types\n        which are included in :obj:`edge_types`, and only contains the node\n        types of the end points which are included in :obj:`node_types`.\n        \"\"\"\n        edge_types = [self._to_canonical(e) for e in edge_types]\n\n        data = copy.copy(self)\n        for edge_type in self.edge_types:\n            if edge_type not in edge_types:\n                del data[edge_type]\n        node_types = {e[0] for e in edge_types}\n        node_types |= {e[-1] for e in edge_types}\n        for node_type in self.node_types:\n            if node_type not in node_types:\n                del data[node_type]\n        return data\n\n    def to_homogeneous(\n        self,\n        node_attrs: Optional[List[str]] = None,\n        edge_attrs: Optional[List[str]] = None,\n        add_node_type: bool = True,\n        add_edge_type: bool = True,\n        dummy_values: bool = True,\n    ) -> Data:\n        \"\"\"Converts a :class:`~torch_geometric.data.HeteroData` object to a\n        homogeneous :class:`~torch_geometric.data.Data` object.\n        By default, all features with same feature dimensionality across\n        different types will be merged into a single representation, unless\n        otherwise specified via the :obj:`node_attrs` and :obj:`edge_attrs`\n        arguments.\n        Furthermore, attributes named :obj:`node_type` and :obj:`edge_type`\n        will be added to the returned :class:`~torch_geometric.data.Data`\n        object, denoting node-level and edge-level vectors holding the\n        node and edge type as integers, respectively.\n\n        Args:\n            node_attrs (List[str], optional): The node features to combine\n                across all node types. These node features need to be of the\n                same feature dimensionality. If set to :obj:`None`, will\n                automatically determine which node features to combine.\n                (default: :obj:`None`)\n            edge_attrs (List[str], optional): The edge features to combine\n                across all edge types. These edge features need to be of the\n                same feature dimensionality. If set to :obj:`None`, will\n                automatically determine which edge features to combine.\n                (default: :obj:`None`)\n            add_node_type (bool, optional): If set to :obj:`False`, will not\n                add the node-level vector :obj:`node_type` to the returned\n                :class:`~torch_geometric.data.Data` object.\n                (default: :obj:`True`)\n            add_edge_type (bool, optional): If set to :obj:`False`, will not\n                add the edge-level vector :obj:`edge_type` to the returned\n                :class:`~torch_geometric.data.Data` object.\n                (default: :obj:`True`)\n            dummy_values (bool, optional): If set to :obj:`True`, will fill\n                attributes of remaining types with dummy values.\n                Dummy values are :obj:`NaN` for floating point attributes,\n                :obj:`False` for booleans, and :obj:`-1` for integers.\n                (default: :obj:`True`)\n        \"\"\"\n        def get_sizes(stores: List[BaseStorage]) -> Dict[str, List[Tuple]]:\n            sizes_dict = defaultdict(list)\n            for store in stores:\n                for key, value in store.items():\n                    if key in [\n                            'edge_index', 'edge_label_index', 'adj', 'adj_t'\n                    ]:\n                        continue\n                    if isinstance(value, Tensor):\n                        dim = self.__cat_dim__(key, value, store)\n                        size = value.size()[:dim] + value.size()[dim + 1:]\n                        sizes_dict[key].append(tuple(size))\n            return sizes_dict\n\n        def fill_dummy_(stores: List[BaseStorage],\n                        keys: Optional[List[str]] = None):\n            sizes_dict = get_sizes(stores)\n\n            if keys is not None:\n                sizes_dict = {\n                    key: sizes\n                    for key, sizes in sizes_dict.items() if key in keys\n                }\n\n            sizes_dict = {\n                key: sizes\n                for key, sizes in sizes_dict.items() if len(set(sizes)) == 1\n            }\n\n            for store in stores:  # Fill stores with dummy features:\n                for key, sizes in sizes_dict.items():\n                    if key not in store:\n                        ref = list(self.collect(key).values())[0]\n                        dim = self.__cat_dim__(key, ref, store)\n                        if ref.is_floating_point():\n                            dummy = float('NaN')\n                        elif ref.dtype == torch.bool:\n                            dummy = False\n                        else:\n                            dummy = -1\n                        if isinstance(store, NodeStorage):\n                            dim_size = store.num_nodes\n                        else:\n                            dim_size = store.num_edges\n                        shape = sizes[0][:dim] + (dim_size, ) + sizes[0][dim:]\n                        store[key] = torch.full(shape, dummy, dtype=ref.dtype,\n                                                device=ref.device)\n\n        def _consistent_size(stores: List[BaseStorage]) -> List[str]:\n            sizes_dict = get_sizes(stores)\n            keys = []\n            for key, sizes in sizes_dict.items():\n                # The attribute needs to exist in all types:\n                if len(sizes) != len(stores):\n                    continue\n                # The attributes needs to have the same number of dimensions:\n                lengths = {len(size) for size in sizes}\n                if len(lengths) != 1:\n                    continue\n                # The attributes needs to have the same size in all dimensions:\n                if len(sizes[0]) != 1 and len(set(sizes)) != 1:\n                    continue\n                keys.append(key)\n\n            # Check for consistent column names in `TensorFrame`:\n            tf_cols = defaultdict(list)\n            for store in stores:\n                for key, value in store.items():\n                    if isinstance(value, TensorFrame):\n                        cols = tuple(chain(*value.col_names_dict.values()))\n                        tf_cols[key].append(cols)\n\n            for key, cols in tf_cols.items():\n                # The attribute needs to exist in all types:\n                if len(cols) != len(stores):\n                    continue\n                # The attributes needs to have the same column names:\n                lengths = set(cols)\n                if len(lengths) != 1:\n                    continue\n                keys.append(key)\n\n            return keys\n\n        if dummy_values:\n            self = copy.copy(self)\n            fill_dummy_(self.node_stores, node_attrs)\n            fill_dummy_(self.edge_stores, edge_attrs)\n\n        edge_index, node_slices, edge_slices = to_homogeneous_edge_index(self)\n        device = edge_index.device if edge_index is not None else None\n\n        data = Data(**self._global_store.to_dict())\n        if edge_index is not None:\n            data.edge_index = edge_index\n        data._node_type_names = list(node_slices.keys())\n        data._edge_type_names = list(edge_slices.keys())\n\n        # Combine node attributes into a single tensor:\n        if node_attrs is None:\n            node_attrs = _consistent_size(self.node_stores)\n        for key in node_attrs:\n            if key in {'ptr'}:\n                continue\n            values = [store[key] for store in self.node_stores]\n            if isinstance(values[0], TensorFrame):\n                value = torch_frame.cat(values, dim=0)\n            else:\n                dim = self.__cat_dim__(key, values[0], self.node_stores[0])\n                dim = values[0].dim() + dim if dim < 0 else dim\n                # For two-dimensional features, we allow arbitrary shapes and\n                # pad them with zeros if necessary in case their size doesn't\n                # match:\n                if values[0].dim() == 2 and dim == 0:\n                    _max = max([value.size(-1) for value in values])\n                    for i, v in enumerate(values):\n                        if v.size(-1) < _max:\n                            pad = v.new_zeros(v.size(0), _max - v.size(-1))\n                            values[i] = torch.cat([v, pad], dim=-1)\n                value = torch.cat(values, dim)\n            data[key] = value\n\n        if not data.can_infer_num_nodes:\n            data.num_nodes = list(node_slices.values())[-1][1]\n\n        # Combine edge attributes into a single tensor:\n        if edge_attrs is None:\n            edge_attrs = _consistent_size(self.edge_stores)\n        for key in edge_attrs:\n            values = [store[key] for store in self.edge_stores]\n            dim = self.__cat_dim__(key, values[0], self.edge_stores[0])\n            value = torch.cat(values, dim) if len(values) > 1 else values[0]\n            data[key] = value\n\n        if 'edge_label_index' in self:\n            edge_label_index_dict = self.edge_label_index_dict\n            for edge_type, edge_label_index in edge_label_index_dict.items():\n                edge_label_index = edge_label_index.clone()\n                edge_label_index[0] += node_slices[edge_type[0]][0]\n                edge_label_index[1] += node_slices[edge_type[-1]][0]\n                edge_label_index_dict[edge_type] = edge_label_index\n            data.edge_label_index = torch.cat(\n                list(edge_label_index_dict.values()), dim=-1)\n\n        if add_node_type:\n            sizes = [offset[1] - offset[0] for offset in node_slices.values()]\n            sizes = torch.tensor(sizes, dtype=torch.long, device=device)\n            node_type = torch.arange(len(sizes), device=device)\n            data.node_type = node_type.repeat_interleave(sizes)\n\n        if add_edge_type and edge_index is not None:\n            sizes = [offset[1] - offset[0] for offset in edge_slices.values()]\n            sizes = torch.tensor(sizes, dtype=torch.long, device=device)\n            edge_type = torch.arange(len(sizes), device=device)\n            data.edge_type = edge_type.repeat_interleave(sizes)\n\n        return data\n\n    # FeatureStore interface ##################################################\n\n    def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:\n        if not attr.is_set('index'):\n            attr.index = None\n\n        out = self._node_store_dict.get(attr.group_name, None)\n        if out:\n            # Group name exists, handle index or create new attribute name:\n            val = getattr(out, attr.attr_name, None)\n            if val is not None:\n                val[attr.index] = tensor\n            else:\n                assert attr.index is None\n                setattr(self[attr.group_name], attr.attr_name, tensor)\n        else:\n            # No node storage found, just store tensor in new one:\n            setattr(self[attr.group_name], attr.attr_name, tensor)\n        return True\n\n    def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:\n        # Retrieve tensor and index accordingly:\n        tensor = getattr(self[attr.group_name], attr.attr_name, None)\n        if tensor is not None:\n            # TODO this behavior is a bit odd, since TensorAttr requires that\n            # we set `index`. So, we assume here that indexing by `None` is\n            # equivalent to not indexing at all, which is not in line with\n            # Python semantics.\n            return tensor[attr.index] if attr.index is not None else tensor\n        return None\n\n    def _remove_tensor(self, attr: TensorAttr) -> bool:\n        # Remove tensor entirely:\n        if hasattr(self[attr.group_name], attr.attr_name):\n            delattr(self[attr.group_name], attr.attr_name)\n            return True\n        return False\n\n    def _get_tensor_size(self, attr: TensorAttr) -> Tuple:\n        return self._get_tensor(attr).size()\n\n    def get_all_tensor_attrs(self) -> List[TensorAttr]:\n        out = []\n        for group_name, group in self.node_items():\n            for attr_name in group:\n                if group.is_node_attr(attr_name):\n                    out.append(TensorAttr(group_name, attr_name))\n        return out\n\n    # GraphStore interface ####################################################\n\n    def _put_edge_index(self, edge_index: EdgeTensorType,\n                        edge_attr: EdgeAttr) -> bool:\n        if not hasattr(self, '_edge_attrs'):\n            self._edge_attrs = {}\n        self._edge_attrs[(edge_attr.edge_type, edge_attr.layout)] = edge_attr\n\n        row, col = edge_index\n        store = self[edge_attr.edge_type]\n\n        if edge_attr.layout == EdgeLayout.COO:\n            store.edge_index = torch.stack([row, col], dim=0)\n        elif edge_attr.layout == EdgeLayout.CSR:\n            store.adj = SparseTensor(\n                rowptr=row,\n                col=col,\n                sparse_sizes=edge_attr.size,\n                is_sorted=True,\n                trust_data=True,\n            )\n        else:  # edge_attr.layout == EdgeLayout.CSC:\n            size = edge_attr.size[::-1] if edge_attr.size is not None else None\n            store.adj_t = SparseTensor(\n                rowptr=col,\n                col=row,\n                sparse_sizes=size,\n                is_sorted=True,\n                trust_data=True,\n            )\n        return True\n\n    def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:\n        r\"\"\"Gets an edge index from edge storage, in the specified layout.\"\"\"\n        store = self[edge_attr.edge_type]\n\n        edge_attrs = getattr(self, '_edge_attrs', {})\n        if (edge_attr.edge_type, edge_attr.layout) in edge_attrs:\n            edge_attr = edge_attrs[(edge_attr.edge_type, edge_attr.layout)]\n        if edge_attr.size is None:\n            edge_attr.size = store.size()  # Modify in-place.\n\n        if edge_attr.layout == EdgeLayout.COO and 'edge_index' in store:\n            row, col = store.edge_index\n            return row, col\n        elif edge_attr.layout == EdgeLayout.CSR and 'adj' in store:\n            rowptr, col, _ = store.adj.csr()\n            return rowptr, col\n        elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in store:\n            colptr, row, _ = store.adj_t.csr()\n            return row, colptr\n        return None\n\n    def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool:\n        edge_type = edge_attr.edge_type\n        store = self[edge_type]\n        if edge_attr.layout == EdgeLayout.COO and 'edge_index' in store:\n            del store.edge_index\n            if hasattr(self, '_edge_attrs'):\n                self._edge_attrs.pop((edge_type, EdgeLayout.COO), None)\n            return True\n        elif edge_attr.layout == EdgeLayout.CSR and 'adj' in store:\n            del store.adj\n            if hasattr(self, '_edge_attrs'):\n                self._edge_attrs.pop((edge_type, EdgeLayout.CSR), None)\n            return True\n        elif edge_attr.layout == EdgeLayout.CSC and 'adj_t' in store:\n            del store.adj_t\n            if hasattr(self, '_edge_attrs'):\n                self._edge_attrs.pop((edge_type, EdgeLayout.CSC), None)\n            return True\n        return False\n\n    def get_all_edge_attrs(self) -> List[EdgeAttr]:\n        edge_attrs = getattr(self, '_edge_attrs', {})\n\n        for store in self.edge_stores:\n            if ('edge_index' in store\n                    and (store._key, EdgeLayout.COO) not in edge_attrs):\n                edge_attrs[(store._key, EdgeLayout.COO)] = EdgeAttr(\n                    store._key, 'coo', is_sorted=False)\n            if ('adj' in store\n                    and (store._key, EdgeLayout.CSR) not in edge_attrs):\n                size = store.adj.sparse_sizes()\n                edge_attrs[(store._key, EdgeLayout.CSR)] = EdgeAttr(\n                    store._key, 'csr', size=size)\n            if ('adj_t' in store\n                    and (store._key, EdgeLayout.CSC) not in edge_attrs):\n                size = store.adj_t.sparse_sizes()[::-1]\n                edge_attrs[(store._key, EdgeLayout.CSC)] = EdgeAttr(\n                    store._key, 'csc', size=size)\n\n        return list(edge_attrs.values())\n\n    # Connected Components Helper Functions ###################################\n\n    def _find_parent(self, node: Tuple[str, int]) -> Tuple[str, int]:\n        r\"\"\"Finds and returns the representative parent of the given node in a\n        disjoint-set (union-find) data structure. Implements path compression\n        to optimize future queries.\n\n        Args:\n            node (tuple[str, int]): The node for which to find the parent.\n            First element is the node type, second is the node index.\n\n        Returns:\n            tuple[str, int]: The representative parent of the node.\n        \"\"\"\n        if node not in self._parents:\n            self._parents[node] = node\n            self._ranks[node] = 0\n        if self._parents[node] != node:\n            self._parents[node] = self._find_parent(self._parents[node])\n        return self._parents[node]\n\n    def _union(self, node1: Tuple[str, int], node2: Tuple[str, int]):\n        r\"\"\"Merges the node1 and node2 in the disjoint-set data structure.\n\n        Finds the root parents of node1 and node2 using the _find_parent\n        method. If they belong to different sets, updates the parent of\n        root2 to be root1, effectively merging the two sets.\n\n        Args:\n            node1 (Tuple[str, int]): The first node to union. First element is\n                the node type, second is the node index.\n            node2 (Tuple[str, int]): The second node to union. First element is\n                the node type, second is the node index.\n        \"\"\"\n        root1 = self._find_parent(node1)\n        root2 = self._find_parent(node2)\n        if root1 != root2:\n            if self._ranks[root1] < self._ranks[root2]:\n                self._parents[root1] = root2\n            elif self._ranks[root1] > self._ranks[root2]:\n                self._parents[root2] = root1\n            else:\n                self._parents[root2] = root1\n                self._ranks[root1] += 1\n\n\n# Helper functions ############################################################\n\n\ndef get_node_slices(num_nodes: Dict[str, int]) -> Dict[str, Tuple[int, int]]:\n    r\"\"\"Returns the boundaries of each node type in a graph.\"\"\"\n    node_slices: Dict[NodeType, Tuple[int, int]] = {}\n    cumsum = 0\n    for node_type, N in num_nodes.items():\n        node_slices[node_type] = (cumsum, cumsum + N)\n        cumsum += N\n    return node_slices\n\n\ndef offset_edge_index(\n    node_slices: Dict[NodeType, Tuple[int, int]],\n    edge_type: EdgeType,\n    edge_index: Tensor,\n) -> Tensor:\n    r\"\"\"Increases the edge indices by the offsets of source and destination\n    node types.\n    \"\"\"\n    src, _, dst = edge_type\n    offset = [[node_slices[src][0]], [node_slices[dst][0]]]\n    offset = torch.tensor(offset, device=edge_index.device)\n    return edge_index + offset\n\n\ndef to_homogeneous_edge_index(\n    data: HeteroData,\n) -> Tuple[Optional[Tensor], Dict[NodeType, Any], Dict[EdgeType, Any]]:\n    r\"\"\"Converts a heterogeneous graph into a homogeneous typed graph.\"\"\"\n    # Record slice information per node type:\n    node_slices = get_node_slices(data.num_nodes_dict)\n\n    # Record edge indices and slice information per edge type:\n    cumsum = 0\n    edge_indices: List[Tensor] = []\n    edge_slices: Dict[EdgeType, Tuple[int, int]] = {}\n    for edge_type, edge_index in data.collect('edge_index', True).items():\n        edge_index = offset_edge_index(node_slices, edge_type, edge_index)\n        edge_indices.append(edge_index)\n        edge_slices[edge_type] = (cumsum, cumsum + edge_index.size(1))\n        cumsum += edge_index.size(1)\n\n    edge_index: Optional[Tensor] = None\n    if len(edge_indices) == 1:  # Memory-efficient `torch.cat`:\n        edge_index = edge_indices[0]\n    elif len(edge_indices) > 1:\n        edge_index = torch.cat(edge_indices, dim=-1)\n\n    return edge_index, node_slices, edge_slices\n"
  },
  {
    "path": "torch_geometric/data/hypergraph_data.py",
    "content": "import copy\nimport warnings\nfrom typing import Any, List, Optional\n\nimport torch\nfrom torch import Tensor\nfrom typing_extensions import Self\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.typing import EdgeType, NodeType, OptTensor\nfrom torch_geometric.utils import select\nfrom torch_geometric.utils._subgraph import hyper_subgraph\n\n\nclass HyperGraphData(Data):\n    r\"\"\"A data object describing a hypergraph.\n\n    The data object can hold node-level, link-level and graph-level attributes.\n    This object differs from a standard :obj:`~torch_geometric.data.Data`\n    object by having hyperedges, i.e. edges that connect more\n    than two nodes. For example, in the hypergraph scenario\n    :math:`\\mathcal{G} = (\\mathcal{V}, \\mathcal{E})` with\n    :math:`\\mathcal{V} = \\{ 0, 1, 2, 3, 4 \\}` and\n    :math:`\\mathcal{E} = \\{ \\{ 0, 1, 2 \\}, \\{ 1, 2, 3, 4 \\} \\}`, the\n    hyperedge index :obj:`edge_index` is represented as:\n\n    .. code-block:: python\n\n        # hyper graph with two hyperedges\n        # connecting 3 and 4 nodes, respectively\n        edge_index = torch.tensor([\n            [0, 1, 2, 1, 2, 3, 4],\n            [0, 0, 0, 1, 1, 1, 1],\n        ])\n\n    Args:\n        x (torch.Tensor, optional): Node feature matrix with shape\n            :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`)\n        edge_index (LongTensor, optional): Hyperedge tensor\n            with shape :obj:`[2, num_edges*num_nodes_per_edge]`.\n            Where `edge_index[1]` denotes the hyperedge index and\n            `edge_index[0]` denotes the node indices that are connected\n            by the hyperedge. (default: :obj:`None`)\n            (default: :obj:`None`)\n        edge_attr (torch.Tensor, optional): Edge feature matrix with shape\n            :obj:`[num_edges, num_edge_features]`.\n            (default: :obj:`None`)\n        y (torch.Tensor, optional): Graph-level or node-level ground-truth\n            labels with arbitrary shape. (default: :obj:`None`)\n        pos (torch.Tensor, optional): Node position matrix with shape\n            :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`)\n        **kwargs (optional): Additional attributes.\n    \"\"\"\n    def __init__(\n        self,\n        x: OptTensor = None,\n        edge_index: OptTensor = None,\n        edge_attr: OptTensor = None,\n        y: OptTensor = None,\n        pos: OptTensor = None,\n        **kwargs: Any,\n    ) -> None:\n        super().__init__(\n            x=x,\n            edge_index=edge_index,\n            edge_attr=edge_attr,\n            y=y,\n            pos=pos,\n            **kwargs,\n        )\n\n    @property\n    def num_edges(self) -> int:\n        r\"\"\"Returns the number of hyperedges in the hypergraph.\"\"\"\n        if self.edge_index is None:\n            return 0\n        return max(self.edge_index[1]) + 1\n\n    @property\n    def num_nodes(self) -> Optional[int]:\n        num_nodes = super().num_nodes\n\n        # For hypergraphs, `edge_index[1]` does not contain node indices.\n        # Therefore, the below code is used to prevent `num_nodes` being\n        # estimated as the number of hyperedges.\n        if (self.edge_index is not None and num_nodes == self.num_edges):\n            return max(self.edge_index[0]) + 1\n\n        return num_nodes\n\n    @num_nodes.setter\n    def num_nodes(self, num_nodes: Optional[int]) -> None:\n        self._store.num_nodes = num_nodes\n\n    def is_edge_attr(self, key: str) -> bool:\n        val = super().is_edge_attr(key)\n        if not val and self.edge_index is not None:\n            return key in self and self[key].size(0) == self.num_edges\n        return val\n\n    def __inc__(self, key: str, value: Any, *args: Any, **kwargs: Any) -> Any:\n        if key == 'edge_index':\n            return torch.tensor([[self.num_nodes], [self.num_edges]])\n        else:\n            return super().__inc__(key, value, *args, **kwargs)\n\n    def subgraph(self, subset: Tensor) -> 'HyperGraphData':\n        r\"\"\"Returns the induced subgraph given by the node indices\n        :obj:`subset`.\n\n        .. note::\n\n            If only a subset of a hyperedge's nodes are to be\n            selected in the subgraph, the hyperedge will remain in the\n            subgraph, but only the selected nodes will be connected by\n            the hyperedge. Hyperedges that only connects one node in the\n            subgraph will be removed.\n\n        Examples:\n            >>> x = torch.randn(4, 16)\n            >>> edge_index = torch.tensor([\n            ...     [0, 1, 0, 2, 1, 1, 2, 4],\n            ...     [0, 0, 1, 1, 1, 2, 2, 2]\n            >>> ])\n            >>> data = HyperGraphData(x = x, edge_index = edge_index)\n            >>> subset = torch.tensor([1, 2, 4])\n            >>> subgraph = data.subgraph(subset)\n            >>> subgraph.edge_index\n            tensor([[2, 1, 1, 2, 4],\n            [0, 0, 1, 1, 1]])\n\n        Args:\n            subset (LongTensor or BoolTensor): The nodes to keep.\n        \"\"\"\n        assert self.edge_index is not None\n        out = hyper_subgraph(subset, self.edge_index, relabel_nodes=True,\n                             num_nodes=self.num_nodes, return_edge_mask=True)\n        edge_index, _, edge_mask = out\n\n        data = copy.copy(self)\n\n        for key, value in self.items():\n            if key == 'edge_index':\n                data.edge_index = edge_index\n            elif key == 'num_nodes':\n                if subset.dtype == torch.bool:\n                    data.num_nodes = int(subset.sum())\n                else:\n                    data.num_nodes = subset.size(0)\n            elif self.is_node_attr(key):\n                cat_dim = self.__cat_dim__(key, value)\n                data[key] = select(value, subset, dim=cat_dim)\n            elif self.is_edge_attr(key):\n                cat_dim = self.__cat_dim__(key, value)\n                data[key] = select(value, edge_mask, dim=cat_dim)\n\n        return data\n\n    def edge_subgraph(self, subset: Tensor) -> Self:\n        raise NotImplementedError\n\n    def to_heterogeneous(\n        self,\n        node_type: Optional[Tensor] = None,\n        edge_type: Optional[Tensor] = None,\n        node_type_names: Optional[List[NodeType]] = None,\n        edge_type_names: Optional[List[EdgeType]] = None,\n    ) -> HeteroData:\n        raise NotImplementedError\n\n    def has_isolated_nodes(self) -> bool:\n        if self.edge_index is None:\n            return False\n        return torch.unique(self.edge_index[0]).size(0) < self.num_nodes\n\n    def is_directed(self) -> bool:\n        raise NotImplementedError\n\n    def is_undirected(self) -> bool:\n        raise NotImplementedError\n\n    def has_self_loops(self) -> bool:\n        raise NotImplementedError\n\n    def validate(self, raise_on_error: bool = True) -> bool:\n        r\"\"\"Validates the correctness of the data.\"\"\"\n        cls_name = self.__class__.__name__\n        status = True\n\n        num_nodes = self.num_nodes\n        if num_nodes is None:\n            status = False\n            warn_or_raise(f\"'num_nodes' is undefined in '{cls_name}'\",\n                          raise_on_error)\n\n        if self.edge_index is not None:\n            if self.edge_index.dim() != 2 or self.edge_index.size(0) != 2:\n                status = False\n                warn_or_raise(\n                    f\"'edge_index' needs to be of shape [2, num_edges] in \"\n                    f\"'{cls_name}' (found {self.edge_index.size()})\",\n                    raise_on_error)\n\n        if self.edge_index is not None and self.edge_index.numel() > 0:\n            if self.edge_index.min() < 0:\n                status = False\n                warn_or_raise(\n                    f\"'edge_index' contains negative indices in \"\n                    f\"'{cls_name}' (found {int(self.edge_index.min())})\",\n                    raise_on_error)\n\n            if num_nodes is not None and self.edge_index[0].max() >= num_nodes:\n                status = False\n                warn_or_raise(\n                    f\"'edge_index' contains larger indices than the number \"\n                    f\"of nodes ({num_nodes}) in '{cls_name}' \"\n                    f\"(found {int(self.edge_index.max())})\", raise_on_error)\n\n        return status\n\n\ndef warn_or_raise(msg: str, raise_on_error: bool = True) -> None:\n    if raise_on_error:\n        raise ValueError(msg)\n    else:\n        warnings.warn(msg, stacklevel=2)\n"
  },
  {
    "path": "torch_geometric/data/in_memory_dataset.py",
    "content": "import copy\nimport os.path as osp\nimport warnings\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    List,\n    Mapping,\n    MutableSequence,\n    Optional,\n    Sequence,\n    Tuple,\n    Type,\n    Union,\n)\n\nimport torch\nfrom torch import Tensor\nfrom tqdm import tqdm\n\nimport torch_geometric\nfrom torch_geometric.data import Batch, Data\nfrom torch_geometric.data.collate import collate\nfrom torch_geometric.data.data import BaseData\nfrom torch_geometric.data.dataset import Dataset, IndexType\nfrom torch_geometric.data.separate import separate\nfrom torch_geometric.io import fs\n\n\nclass InMemoryDataset(Dataset):\n    r\"\"\"Dataset base class for creating graph datasets which easily fit\n    into CPU memory.\n    See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/\n    create_dataset.html#creating-in-memory-datasets>`__ for the accompanying\n    tutorial.\n\n    Args:\n        root (str, optional): Root directory where the dataset should be saved.\n            (optional: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in a\n            :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` object and returns a\n            transformed version.\n            The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            a :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` object and returns a\n            transformed version.\n            The data object will be transformed before being saved to disk.\n            (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in a\n            :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` object and returns a\n            boolean value, indicating whether the data object should be\n            included in the final dataset. (default: :obj:`None`)\n        log (bool, optional): Whether to print any console output while\n            downloading and processing the dataset. (default: :obj:`True`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    @property\n    def raw_file_names(self) -> Union[str, List[str], Tuple[str, ...]]:\n        raise NotImplementedError\n\n    @property\n    def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]:\n        raise NotImplementedError\n\n    def __init__(\n        self,\n        root: Optional[str] = None,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        log: bool = True,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter, log,\n                         force_reload)\n\n        self._data: Optional[BaseData] = None\n        self.slices: Optional[Dict[str, Tensor]] = None\n        self._data_list: Optional[MutableSequence[Optional[BaseData]]] = None\n\n    @property\n    def num_classes(self) -> int:\n        if self.transform is None:\n            return self._infer_num_classes(self._data.y)\n        return super().num_classes\n\n    def len(self) -> int:\n        if self.slices is None:\n            return 1\n        for _, value in nested_iter(self.slices):\n            return len(value) - 1\n        return 0\n\n    def get(self, idx: int) -> BaseData:\n        # TODO (matthias) Avoid unnecessary copy here.\n        if self.len() == 1:\n            return copy.copy(self._data)\n\n        if not hasattr(self, '_data_list') or self._data_list is None:\n            self._data_list = self.len() * [None]\n        elif self._data_list[idx] is not None:\n            return copy.copy(self._data_list[idx])\n\n        data = separate(\n            cls=self._data.__class__,\n            batch=self._data,\n            idx=idx,\n            slice_dict=self.slices,\n            decrement=False,\n        )\n\n        self._data_list[idx] = copy.copy(data)\n\n        return data\n\n    @classmethod\n    def save(cls, data_list: Sequence[BaseData], path: str) -> None:\n        r\"\"\"Saves a list of data objects to the file path :obj:`path`.\"\"\"\n        data, slices = cls.collate(data_list)\n        fs.torch_save((data.to_dict(), slices, data.__class__), path)\n\n    def load(self, path: str, data_cls: Type[BaseData] = Data) -> None:\n        r\"\"\"Loads the dataset from the file path :obj:`path`.\"\"\"\n        out = fs.torch_load(path)\n        assert isinstance(out, tuple)\n        assert len(out) == 2 or len(out) == 3\n        if len(out) == 2:  # Backward compatibility.\n            data, self.slices = out\n        else:\n            data, self.slices, data_cls = out\n\n        if not isinstance(data, dict):  # Backward compatibility.\n            self.data = data\n        else:\n            self.data = data_cls.from_dict(data)\n\n    @staticmethod\n    def collate(\n        data_list: Sequence[BaseData],\n    ) -> Tuple[BaseData, Optional[Dict[str, Tensor]]]:\n        r\"\"\"Collates a list of :class:`~torch_geometric.data.Data` or\n        :class:`~torch_geometric.data.HeteroData` objects to the internal\n        storage format of :class:`~torch_geometric.data.InMemoryDataset`.\n        \"\"\"\n        if len(data_list) == 1:\n            return data_list[0], None\n\n        data, slices, _ = collate(\n            data_list[0].__class__,\n            data_list=data_list,\n            increment=False,\n            add_batch=False,\n        )\n\n        return data, slices\n\n    def copy(self, idx: Optional[IndexType] = None) -> 'InMemoryDataset':\n        r\"\"\"Performs a deep-copy of the dataset. If :obj:`idx` is not given,\n        will clone the full dataset. Otherwise, will only clone a subset of the\n        dataset from indices :obj:`idx`.\n        Indices can be slices, lists, tuples, and a :obj:`torch.Tensor` or\n        :obj:`np.ndarray` of type long or bool.\n        \"\"\"\n        if idx is None:\n            data_list = [self.get(i) for i in self.indices()]\n        else:\n            data_list = [self.get(i) for i in self.index_select(idx).indices()]\n\n        dataset = copy.copy(self)\n        dataset._indices = None\n        dataset._data_list = None\n        dataset.data, dataset.slices = self.collate(data_list)\n        return dataset\n\n    def to_on_disk_dataset(\n        self,\n        root: Optional[str] = None,\n        backend: str = 'sqlite',\n        log: bool = True,\n    ) -> 'torch_geometric.data.OnDiskDataset':\n        r\"\"\"Converts the :class:`InMemoryDataset` to a :class:`OnDiskDataset`\n        variant. Useful for distributed training and hardware instances with\n        limited amount of shared memory.\n\n        root (str, optional): Root directory where the dataset should be saved.\n            If set to :obj:`None`, will save the dataset in\n            :obj:`root/on_disk`.\n            Note that it is important to specify :obj:`root` to account for\n            different dataset splits. (optional: :obj:`None`)\n        backend (str): The :class:`Database` backend to use.\n            (default: :obj:`\"sqlite\"`)\n        log (bool, optional): Whether to print any console output while\n            processing the dataset. (default: :obj:`True`)\n        \"\"\"\n        if root is None and (self.root is None or not osp.exists(self.root)):\n            raise ValueError(f\"The root directory of \"\n                             f\"'{self.__class__.__name__}' is not specified. \"\n                             f\"Please pass in 'root' when creating on-disk \"\n                             f\"datasets from it.\")\n\n        root = root or osp.join(self.root, 'on_disk')\n\n        in_memory_dataset = self\n        ref_data = in_memory_dataset.get(0)\n        if not isinstance(ref_data, Data):\n            raise NotImplementedError(\n                f\"`{self.__class__.__name__}.to_on_disk_dataset()` is \"\n                f\"currently only supported on homogeneous graphs\")\n\n        # Parse the schema ====================================================\n\n        schema: Dict[str, Any] = {}\n        for key, value in ref_data.to_dict().items():\n            if isinstance(value, (int, float, str)):\n                schema[key] = value.__class__\n            elif isinstance(value, Tensor) and value.dim() == 0:\n                schema[key] = dict(dtype=value.dtype, size=(-1, ))\n            elif isinstance(value, Tensor):\n                size = list(value.size())\n                size[ref_data.__cat_dim__(key, value)] = -1\n                schema[key] = dict(dtype=value.dtype, size=tuple(size))\n            else:\n                schema[key] = object\n\n        # Create the on-disk dataset ==========================================\n\n        class OnDiskDataset(torch_geometric.data.OnDiskDataset):\n            def __init__(\n                self,\n                root: str,\n                transform: Optional[Callable] = None,\n            ):\n                super().__init__(\n                    root=root,\n                    transform=transform,\n                    backend=backend,\n                    schema=schema,\n                )\n\n            def process(self):\n                _iter = [\n                    in_memory_dataset.get(i)\n                    for i in in_memory_dataset.indices()\n                ]\n                if log:  # pragma: no cover\n                    _iter = tqdm(_iter, desc='Converting to OnDiskDataset')\n\n                data_list: List[Data] = []\n                for i, data in enumerate(_iter):\n                    data_list.append(data)\n                    if i + 1 == len(in_memory_dataset) or (i + 1) % 1000 == 0:\n                        self.extend(data_list)\n                        data_list = []\n\n            def serialize(self, data: Data) -> Dict[str, Any]:\n                return data.to_dict()\n\n            def deserialize(self, data: Dict[str, Any]) -> Data:\n                return Data.from_dict(data)\n\n            def __repr__(self) -> str:\n                arg_repr = str(len(self)) if len(self) > 1 else ''\n                return (f'OnDisk{in_memory_dataset.__class__.__name__}('\n                        f'{arg_repr})')\n\n        return OnDiskDataset(root, transform=in_memory_dataset.transform)\n\n    @property\n    def data(self) -> Any:\n        msg1 = (\"It is not recommended to directly access the internal \"\n                \"storage format `data` of an 'InMemoryDataset'.\")\n        msg2 = (\"The given 'InMemoryDataset' only references a subset of \"\n                \"examples of the full dataset, but 'data' will contain \"\n                \"information of the full dataset.\")\n        msg3 = (\"The data of the dataset is already cached, so any \"\n                \"modifications to `data` will not be reflected when accessing \"\n                \"its elements. Clearing the cache now by removing all \"\n                \"elements in `dataset._data_list`.\")\n        msg4 = (\"If you are absolutely certain what you are doing, access the \"\n                \"internal storage via `InMemoryDataset._data` instead to \"\n                \"suppress this warning. Alternatively, you can access stacked \"\n                \"individual attributes of every graph via \"\n                \"`dataset.{attr_name}`.\")\n\n        msg = msg1\n        if self._indices is not None:\n            msg += f' {msg2}'\n        if self._data_list is not None:\n            msg += f' {msg3}'\n            self._data_list = None\n        msg += f' {msg4}'\n\n        warnings.warn(msg, stacklevel=2)\n\n        return self._data\n\n    @data.setter\n    def data(self, value: Any):\n        self._data = value\n        self._data_list = None\n\n    def __getattr__(self, key: str) -> Any:\n        data = self.__dict__.get('_data')\n        if isinstance(data, Data) and key in data:\n            if self._indices is None and data.__inc__(key, data[key]) == 0:\n                return data[key]\n            else:\n                data_list = [self.get(i) for i in self.indices()]\n                return Batch.from_data_list(data_list)[key]\n\n        raise AttributeError(f\"'{self.__class__.__name__}' object has no \"\n                             f\"attribute '{key}'\")\n\n    def to(self, device: Union[int, str]) -> 'InMemoryDataset':\n        r\"\"\"Performs device conversion of the whole dataset.\"\"\"\n        if self._indices is not None:\n            raise ValueError(\"The given 'InMemoryDataset' only references a \"\n                             \"subset of examples of the full dataset\")\n        if self._data_list is not None:\n            raise ValueError(\"The data of the dataset is already cached\")\n        self._data.to(device)\n        return self\n\n    def cpu(self, *args: str) -> 'InMemoryDataset':\n        r\"\"\"Moves the dataset to CPU memory.\"\"\"\n        return self.to(torch.device('cpu'))\n\n    def cuda(\n        self,\n        device: Optional[Union[int, str]] = None,\n    ) -> 'InMemoryDataset':\n        r\"\"\"Moves the dataset toto CUDA memory.\"\"\"\n        if isinstance(device, int):\n            device = f'cuda:{int}'\n        elif device is None:\n            device = 'cuda'\n        return self.to(device)\n\n\ndef nested_iter(node: Union[Mapping, Sequence]) -> Iterable:\n    if isinstance(node, Mapping):\n        for value in node.values():\n            yield from nested_iter(value)\n    elif isinstance(node, Sequence):\n        yield from enumerate(node)\n    else:\n        yield None, node\n"
  },
  {
    "path": "torch_geometric/data/lightning/__init__.py",
    "content": "from .datamodule import LightningDataset, LightningNodeData, LightningLinkData\n\n__all__ = classes = [\n    'LightningDataset',\n    'LightningNodeData',\n    'LightningLinkData',\n]\n"
  },
  {
    "path": "torch_geometric/data/lightning/datamodule.py",
    "content": "import copy\nimport inspect\nimport warnings\nfrom typing import Any, Dict, Optional, Tuple, Type, Union\n\nimport torch\n\nfrom torch_geometric.data import Data, Dataset, HeteroData\nfrom torch_geometric.loader import DataLoader, LinkLoader, NodeLoader\nfrom torch_geometric.sampler import BaseSampler, NeighborSampler\nfrom torch_geometric.typing import InputEdges, InputNodes, OptTensor\n\ntry:\n    from lightning.pytorch import LightningDataModule as _LightningDataModule\n    _pl_is_available = True\nexcept ImportError:\n    try:\n        from pytorch_lightning import \\\n            LightningDataModule as _LightningDataModule\n        _pl_is_available = True\n    except ImportError:\n        _pl_is_available = False\n        _LightningDataModule = object\n\n\nclass LightningDataModule(_LightningDataModule):\n    def __init__(self, has_val: bool, has_test: bool, **kwargs: Any) -> None:\n        super().__init__()\n\n        if not _pl_is_available:\n            raise ModuleNotFoundError(\n                \"No module named 'pytorch_lightning' (or 'lightning') found \"\n                \"in your Python environment. Run 'pip install \"\n                \"pytorch_lightning' or 'pip install lightning'\")\n\n        if not has_val:\n            self.val_dataloader = None  # type: ignore\n\n        if not has_test:\n            self.test_dataloader = None  # type: ignore\n\n        kwargs.setdefault('batch_size', 1)\n        kwargs.setdefault('num_workers', 0)\n        kwargs.setdefault('pin_memory', True)\n        kwargs.setdefault('persistent_workers',\n                          kwargs.get('num_workers', 0) > 0)\n\n        if 'shuffle' in kwargs:\n            warnings.warn(\n                f\"The 'shuffle={kwargs['shuffle']}' option is \"\n                f\"ignored in '{self.__class__.__name__}'. Remove it \"\n                f\"from the argument list to disable this warning\",\n                stacklevel=2)\n            del kwargs['shuffle']\n\n        self.kwargs = kwargs\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({kwargs_repr(**self.kwargs)})'\n\n\nclass LightningData(LightningDataModule):\n    def __init__(\n        self,\n        data: Union[Data, HeteroData],\n        has_val: bool,\n        has_test: bool,\n        loader: str = 'neighbor',\n        graph_sampler: Optional[BaseSampler] = None,\n        eval_loader_kwargs: Optional[Dict[str, Any]] = None,\n        **kwargs: Any,\n    ) -> None:\n        kwargs.setdefault('batch_size', 1)\n        kwargs.setdefault('num_workers', 0)\n\n        if graph_sampler is not None:\n            loader = 'custom'\n\n        # For full-batch training, we use reasonable defaults for a lot of\n        # data-loading options:\n        if loader not in ['full', 'neighbor', 'link_neighbor', 'custom']:\n            raise ValueError(f\"Undefined 'loader' option (got '{loader}')\")\n\n        if loader == 'full' and kwargs['batch_size'] != 1:\n            warnings.warn(\n                f\"Re-setting 'batch_size' to 1 in \"\n                f\"'{self.__class__.__name__}' for loader='full' \"\n                f\"(got '{kwargs['batch_size']}')\", stacklevel=2)\n            kwargs['batch_size'] = 1\n\n        if loader == 'full' and kwargs['num_workers'] != 0:\n            warnings.warn(\n                f\"Re-setting 'num_workers' to 0 in \"\n                f\"'{self.__class__.__name__}' for loader='full' \"\n                f\"(got '{kwargs['num_workers']}')\", stacklevel=2)\n            kwargs['num_workers'] = 0\n\n        if loader == 'full' and kwargs.get('sampler') is not None:\n            warnings.warn(\n                \"'sampler' option is not supported for \"\n                \"loader='full'\", stacklevel=2)\n            kwargs.pop('sampler', None)\n\n        if loader == 'full' and kwargs.get('batch_sampler') is not None:\n            warnings.warn(\n                \"'batch_sampler' option is not supported for \"\n                \"loader='full'\", stacklevel=2)\n            kwargs.pop('batch_sampler', None)\n\n        super().__init__(has_val, has_test, **kwargs)\n\n        if loader == 'full':\n            if kwargs.get('pin_memory', False):\n                warnings.warn(\n                    f\"Re-setting 'pin_memory' to 'False' in \"\n                    f\"'{self.__class__.__name__}' for loader='full' \"\n                    f\"(got 'True')\", stacklevel=2)\n            self.kwargs['pin_memory'] = False\n\n        self.data = data\n        self.loader = loader\n\n        # Determine sampler and loader arguments ##############################\n\n        if loader in ['neighbor', 'link_neighbor']:\n\n            # Define a new `NeighborSampler` to be re-used across data loaders:\n            sampler_kwargs, self.loader_kwargs = split_kwargs(\n                self.kwargs,\n                NeighborSampler,\n            )\n            sampler_kwargs.setdefault('share_memory',\n                                      self.kwargs['num_workers'] > 0)\n            self.graph_sampler: BaseSampler = NeighborSampler(\n                data, **sampler_kwargs)\n\n        elif graph_sampler is not None:\n            sampler_kwargs, self.loader_kwargs = split_kwargs(\n                self.kwargs,\n                graph_sampler.__class__,\n            )\n            if len(sampler_kwargs) > 0:\n                warnings.warn(\n                    f\"Ignoring the arguments \"\n                    f\"{list(sampler_kwargs.keys())} in \"\n                    f\"'{self.__class__.__name__}' since a custom \"\n                    f\"'graph_sampler' was passed\", stacklevel=2)\n            self.graph_sampler = graph_sampler\n\n        else:\n            assert loader == 'full'\n            self.loader_kwargs = self.kwargs\n\n        # Determine validation sampler and loader arguments ###################\n\n        self.eval_loader_kwargs = copy.copy(self.loader_kwargs)\n        if eval_loader_kwargs is not None:\n            # If the user wants to override certain values during evaluation,\n            # we shallow-copy the graph sampler and update its attributes.\n            if hasattr(self, 'graph_sampler'):\n                self.eval_graph_sampler = copy.copy(self.graph_sampler)\n\n                eval_sampler_kwargs, eval_loader_kwargs = split_kwargs(\n                    eval_loader_kwargs,\n                    self.graph_sampler.__class__,\n                )\n                for key, value in eval_sampler_kwargs.items():\n                    setattr(self.eval_graph_sampler, key, value)\n\n            self.eval_loader_kwargs.update(eval_loader_kwargs)\n\n        elif hasattr(self, 'graph_sampler'):\n            self.eval_graph_sampler = self.graph_sampler\n\n        self.eval_loader_kwargs.pop('sampler', None)\n        self.eval_loader_kwargs.pop('batch_sampler', None)\n\n        if 'batch_sampler' in self.loader_kwargs:\n            self.loader_kwargs.pop('batch_size', None)\n\n    @property\n    def train_shuffle(self) -> bool:\n        shuffle = self.loader_kwargs.get('sampler', None) is None\n        shuffle &= self.loader_kwargs.get('batch_sampler', None) is None\n        return shuffle\n\n    def prepare_data(self) -> None:\n        if self.loader == 'full':\n            assert self.trainer is not None\n            try:\n                num_devices = self.trainer.num_devices\n            except AttributeError:\n                # PyTorch Lightning < 1.6 backward compatibility:\n                num_devices = self.trainer.num_processes  # type: ignore\n                num_gpus = self.trainer.num_gpus  # type: ignore\n                num_devices = max(num_devices, num_gpus)\n\n            if num_devices > 1:\n                raise ValueError(\n                    f\"'{self.__class__.__name__}' with loader='full' requires \"\n                    f\"training on a single device\")\n        super().prepare_data()\n\n    def full_dataloader(self, **kwargs: Any) -> torch.utils.data.DataLoader:\n        warnings.filterwarnings('ignore', '.*does not have many workers.*')\n        warnings.filterwarnings('ignore', '.*data loading bottlenecks.*')\n\n        return torch.utils.data.DataLoader(\n            [self.data],  # type: ignore\n            collate_fn=lambda xs: xs[0],\n            **kwargs,\n        )\n\n    def __repr__(self) -> str:\n        kwargs = kwargs_repr(data=self.data, loader=self.loader, **self.kwargs)\n        return f'{self.__class__.__name__}({kwargs})'\n\n\nclass LightningDataset(LightningDataModule):\n    r\"\"\"Converts a set of :class:`~torch_geometric.data.Dataset` objects into a\n    :class:`pytorch_lightning.LightningDataModule` variant. It can then be\n    automatically used as a :obj:`datamodule` for multi-GPU graph-level\n    training via :lightning:`null`\n    `PyTorch Lightning <https://www.pytorchlightning.ai>`__.\n    :class:`LightningDataset` will take care of providing mini-batches via\n    :class:`~torch_geometric.loader.DataLoader`.\n\n    .. note::\n\n        Currently only the\n        :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and\n        :class:`pytorch_lightning.strategies.DDPStrategy` training\n        strategies of :lightning:`null` `PyTorch Lightning\n        <https://pytorch-lightning.readthedocs.io/en/latest/guides/\n        speed.html>`__ are supported in order to correctly share data across\n        all devices/processes:\n\n        .. code-block:: python\n\n            import pytorch_lightning as pl\n            trainer = pl.Trainer(strategy=\"ddp_spawn\", accelerator=\"gpu\",\n                                 devices=4)\n            trainer.fit(model, datamodule)\n\n    Args:\n        train_dataset (Dataset): The training dataset.\n        val_dataset (Dataset, optional): The validation dataset.\n            (default: :obj:`None`)\n        test_dataset (Dataset, optional): The test dataset.\n            (default: :obj:`None`)\n        pred_dataset (Dataset, optional): The prediction dataset.\n            (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.loader.DataLoader`.\n    \"\"\"\n    def __init__(\n        self,\n        train_dataset: Dataset,\n        val_dataset: Optional[Dataset] = None,\n        test_dataset: Optional[Dataset] = None,\n        pred_dataset: Optional[Dataset] = None,\n        **kwargs: Any,\n    ) -> None:\n        super().__init__(\n            has_val=val_dataset is not None,\n            has_test=test_dataset is not None,\n            **kwargs,\n        )\n\n        self.train_dataset = train_dataset\n        self.val_dataset = val_dataset\n        self.test_dataset = test_dataset\n        self.pred_dataset = pred_dataset\n\n    def dataloader(self, dataset: Dataset, **kwargs: Any) -> DataLoader:\n        return DataLoader(dataset, **kwargs)\n\n    def train_dataloader(self) -> DataLoader:\n        from torch.utils.data import IterableDataset\n\n        shuffle = not isinstance(self.train_dataset, IterableDataset)\n        shuffle &= self.kwargs.get('sampler', None) is None\n        shuffle &= self.kwargs.get('batch_sampler', None) is None\n\n        return self.dataloader(\n            self.train_dataset,\n            shuffle=shuffle,\n            **self.kwargs,\n        )\n\n    def val_dataloader(self) -> DataLoader:\n        assert self.val_dataset is not None\n\n        kwargs = copy.copy(self.kwargs)\n        kwargs.pop('sampler', None)\n        kwargs.pop('batch_sampler', None)\n\n        return self.dataloader(self.val_dataset, shuffle=False, **kwargs)\n\n    def test_dataloader(self) -> DataLoader:\n        assert self.test_dataset is not None\n\n        kwargs = copy.copy(self.kwargs)\n        kwargs.pop('sampler', None)\n        kwargs.pop('batch_sampler', None)\n\n        return self.dataloader(self.test_dataset, shuffle=False, **kwargs)\n\n    def predict_dataloader(self) -> DataLoader:\n        assert self.pred_dataset is not None\n\n        kwargs = copy.copy(self.kwargs)\n        kwargs.pop('sampler', None)\n        kwargs.pop('batch_sampler', None)\n\n        return self.dataloader(self.pred_dataset, shuffle=False, **kwargs)\n\n    def __repr__(self) -> str:\n        kwargs = kwargs_repr(\n            train_dataset=self.train_dataset,\n            val_dataset=self.val_dataset,\n            test_dataset=self.test_dataset,\n            pred_dataset=self.pred_dataset,\n            **self.kwargs,\n        )\n        return f'{self.__class__.__name__}({kwargs})'\n\n\nclass LightningNodeData(LightningData):\n    r\"\"\"Converts a :class:`~torch_geometric.data.Data` or\n    :class:`~torch_geometric.data.HeteroData` object into a\n    :class:`pytorch_lightning.LightningDataModule` variant. It can then be\n    automatically used as a :obj:`datamodule` for multi-GPU node-level\n    training via :lightning:`null`\n    `PyTorch Lightning <https://www.pytorchlightning.ai>`__.\n    :class:`LightningDataset` will take care of providing mini-batches via\n    :class:`~torch_geometric.loader.NeighborLoader`.\n\n    .. note::\n\n        Currently only the\n        :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and\n        :class:`pytorch_lightning.strategies.DDPStrategy` training\n        strategies of :lightning:`null` `PyTorch Lightning\n        <https://pytorch-lightning.readthedocs.io/en/latest/guides/\n        speed.html>`__ are supported in order to correctly share data across\n        all devices/processes:\n\n        .. code-block:: python\n\n            import pytorch_lightning as pl\n            trainer = pl.Trainer(strategy=\"ddp_spawn\", accelerator=\"gpu\",\n                                 devices=4)\n            trainer.fit(model, datamodule)\n\n    Args:\n        data (Data or HeteroData): The :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` graph object.\n        input_train_nodes (torch.Tensor or str or (str, torch.Tensor)): The\n            indices of training nodes.\n            If not given, will try to automatically infer them from the\n            :obj:`data` object by searching for :obj:`train_mask`,\n            :obj:`train_idx`, or :obj:`train_index` attributes.\n            (default: :obj:`None`)\n        input_train_time (torch.Tensor, optional): The timestamp\n            of training nodes. (default: :obj:`None`)\n        input_val_nodes (torch.Tensor or str or (str, torch.Tensor)): The\n            indices of validation nodes.\n            If not given, will try to automatically infer them from the\n            :obj:`data` object by searching for :obj:`val_mask`,\n            :obj:`valid_mask`, :obj:`val_idx`, :obj:`valid_idx`,\n            :obj:`val_index`, or :obj:`valid_index` attributes.\n            (default: :obj:`None`)\n        input_val_time (torch.Tensor, optional): The timestamp\n            of validation edges. (default: :obj:`None`)\n        input_test_nodes (torch.Tensor or str or (str, torch.Tensor)): The\n            indices of test nodes.\n            If not given, will try to automatically infer them from the\n            :obj:`data` object by searching for :obj:`test_mask`,\n            :obj:`test_idx`, or :obj:`test_index` attributes.\n            (default: :obj:`None`)\n        input_test_time (torch.Tensor, optional): The timestamp\n            of test nodes. (default: :obj:`None`)\n        input_pred_nodes (torch.Tensor or str or (str, torch.Tensor)): The\n            indices of prediction nodes.\n            If not given, will try to automatically infer them from the\n            :obj:`data` object by searching for :obj:`pred_mask`,\n            :obj:`pred_idx`, or :obj:`pred_index` attributes.\n            (default: :obj:`None`)\n        input_pred_time (torch.Tensor, optional): The timestamp\n            of prediction nodes. (default: :obj:`None`)\n        loader (str): The scalability technique to use (:obj:`\"full\"`,\n            :obj:`\"neighbor\"`). (default: :obj:`\"neighbor\"`)\n        node_sampler (BaseSampler, optional): A custom sampler object to\n            generate mini-batches. If set, will ignore the :obj:`loader`\n            option. (default: :obj:`None`)\n        eval_loader_kwargs (Dict[str, Any], optional): Custom keyword arguments\n            that override the :class:`torch_geometric.loader.NeighborLoader`\n            configuration during evaluation. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.loader.NeighborLoader`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Union[Data, HeteroData],\n        input_train_nodes: InputNodes = None,\n        input_train_time: OptTensor = None,\n        input_val_nodes: InputNodes = None,\n        input_val_time: OptTensor = None,\n        input_test_nodes: InputNodes = None,\n        input_test_time: OptTensor = None,\n        input_pred_nodes: InputNodes = None,\n        input_pred_time: OptTensor = None,\n        loader: str = 'neighbor',\n        node_sampler: Optional[BaseSampler] = None,\n        eval_loader_kwargs: Optional[Dict[str, Any]] = None,\n        **kwargs: Any,\n    ) -> None:\n        if input_train_nodes is None:\n            input_train_nodes = infer_input_nodes(data, split='train')\n\n        if input_val_nodes is None:\n            input_val_nodes = infer_input_nodes(data, split='val')\n            if input_val_nodes is None:\n                input_val_nodes = infer_input_nodes(data, split='valid')\n\n        if input_test_nodes is None:\n            input_test_nodes = infer_input_nodes(data, split='test')\n\n        if input_pred_nodes is None:\n            input_pred_nodes = infer_input_nodes(data, split='pred')\n\n        super().__init__(\n            data=data,\n            has_val=input_val_nodes is not None,\n            has_test=input_test_nodes is not None,\n            loader=loader,\n            graph_sampler=node_sampler,\n            eval_loader_kwargs=eval_loader_kwargs,\n            **kwargs,\n        )\n\n        self.input_train_nodes = input_train_nodes\n        self.input_train_time = input_train_time\n        self.input_train_id: OptTensor = None\n\n        self.input_val_nodes = input_val_nodes\n        self.input_val_time = input_val_time\n        self.input_val_id: OptTensor = None\n\n        self.input_test_nodes = input_test_nodes\n        self.input_test_time = input_test_time\n        self.input_test_id: OptTensor = None\n\n        self.input_pred_nodes = input_pred_nodes\n        self.input_pred_time = input_pred_time\n        self.input_pred_id: OptTensor = None\n\n    def dataloader(\n        self,\n        input_nodes: InputNodes,\n        input_time: OptTensor = None,\n        input_id: OptTensor = None,\n        node_sampler: Optional[BaseSampler] = None,\n        **kwargs: Any,\n    ) -> torch.utils.data.DataLoader:\n        if self.loader == 'full':\n            return self.full_dataloader(**kwargs)\n\n        assert node_sampler is not None\n\n        return NodeLoader(\n            self.data,\n            node_sampler=node_sampler,\n            input_nodes=input_nodes,\n            input_time=input_time,\n            input_id=input_id,\n            **kwargs,\n        )\n\n    def train_dataloader(self) -> torch.utils.data.DataLoader:\n        return self.dataloader(\n            self.input_train_nodes,\n            self.input_train_time,\n            self.input_train_id,\n            node_sampler=getattr(self, 'graph_sampler', None),\n            shuffle=self.train_shuffle,\n            **self.loader_kwargs,\n        )\n\n    def val_dataloader(self) -> torch.utils.data.DataLoader:\n        return self.dataloader(\n            self.input_val_nodes,\n            self.input_val_time,\n            self.input_val_id,\n            node_sampler=getattr(self, 'eval_graph_sampler', None),\n            shuffle=False,\n            **self.eval_loader_kwargs,\n        )\n\n    def test_dataloader(self) -> torch.utils.data.DataLoader:\n        return self.dataloader(\n            self.input_test_nodes,\n            self.input_test_time,\n            self.input_test_id,\n            node_sampler=getattr(self, 'eval_graph_sampler', None),\n            shuffle=False,\n            **self.eval_loader_kwargs,\n        )\n\n    def predict_dataloader(self) -> torch.utils.data.DataLoader:\n        return self.dataloader(\n            self.input_pred_nodes,\n            self.input_pred_time,\n            self.input_pred_id,\n            node_sampler=getattr(self, 'eval_graph_sampler', None),\n            shuffle=False,\n            **self.eval_loader_kwargs,\n        )\n\n\nclass LightningLinkData(LightningData):\n    r\"\"\"Converts a :class:`~torch_geometric.data.Data` or\n    :class:`~torch_geometric.data.HeteroData` object into a\n    :class:`pytorch_lightning.LightningDataModule` variant. It can then be\n    automatically used as a :obj:`datamodule` for multi-GPU link-level\n    training via :lightning:`null`\n    `PyTorch Lightning <https://www.pytorchlightning.ai>`__.\n    :class:`LightningDataset` will take care of providing mini-batches via\n    :class:`~torch_geometric.loader.LinkNeighborLoader`.\n\n    .. note::\n\n        Currently only the\n        :class:`pytorch_lightning.strategies.SingleDeviceStrategy` and\n        :class:`pytorch_lightning.strategies.DDPStrategy` training\n        strategies of :lightning:`null` `PyTorch Lightning\n        <https://pytorch-lightning.readthedocs.io/en/latest/guides/\n        speed.html>`__ are supported in order to correctly share data across\n        all devices/processes:\n\n        .. code-block:: python\n\n            import pytorch_lightning as pl\n            trainer = pl.Trainer(strategy=\"ddp_spawn\", accelerator=\"gpu\",\n                                 devices=4)\n            trainer.fit(model, datamodule)\n\n    Args:\n        data (Data or HeteroData or Tuple[FeatureStore, GraphStore]): The\n            :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` graph object, or a\n            tuple of a :class:`~torch_geometric.data.FeatureStore` and\n            :class:`~torch_geometric.data.GraphStore` objects.\n        input_train_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):\n            The training edges. (default: :obj:`None`)\n        input_train_labels (torch.Tensor, optional):\n            The labels of training edges. (default: :obj:`None`)\n        input_train_time (torch.Tensor, optional): The timestamp\n            of training edges. (default: :obj:`None`)\n        input_val_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):\n            The validation edges. (default: :obj:`None`)\n        input_val_labels (torch.Tensor, optional):\n            The labels of validation edges. (default: :obj:`None`)\n        input_val_time (torch.Tensor, optional): The timestamp\n            of validation edges. (default: :obj:`None`)\n        input_test_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):\n            The test edges. (default: :obj:`None`)\n        input_test_labels (torch.Tensor, optional):\n            The labels of test edges. (default: :obj:`None`)\n        input_test_time (torch.Tensor, optional): The timestamp\n            of test edges. (default: :obj:`None`)\n        input_pred_edges (Tensor or EdgeType or Tuple[EdgeType, Tensor]):\n            The prediction edges. (default: :obj:`None`)\n        input_pred_labels (torch.Tensor, optional):\n            The labels of prediction edges. (default: :obj:`None`)\n        input_pred_time (torch.Tensor, optional): The timestamp\n            of prediction edges. (default: :obj:`None`)\n        loader (str): The scalability technique to use (:obj:`\"full\"`,\n            :obj:`\"neighbor\"`). (default: :obj:`\"neighbor\"`)\n        link_sampler (BaseSampler, optional): A custom sampler object to\n            generate mini-batches. If set, will ignore the :obj:`loader`\n            option. (default: :obj:`None`)\n        eval_loader_kwargs (Dict[str, Any], optional): Custom keyword arguments\n            that override the\n            :class:`torch_geometric.loader.LinkNeighborLoader` configuration\n            during evaluation. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.loader.LinkNeighborLoader`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Union[Data, HeteroData],\n        input_train_edges: InputEdges = None,\n        input_train_labels: OptTensor = None,\n        input_train_time: OptTensor = None,\n        input_val_edges: InputEdges = None,\n        input_val_labels: OptTensor = None,\n        input_val_time: OptTensor = None,\n        input_test_edges: InputEdges = None,\n        input_test_labels: OptTensor = None,\n        input_test_time: OptTensor = None,\n        input_pred_edges: InputEdges = None,\n        input_pred_labels: OptTensor = None,\n        input_pred_time: OptTensor = None,\n        loader: str = 'neighbor',\n        link_sampler: Optional[BaseSampler] = None,\n        eval_loader_kwargs: Optional[Dict[str, Any]] = None,\n        **kwargs: Any,\n    ) -> None:\n        super().__init__(\n            data=data,\n            has_val=input_val_edges is not None,\n            has_test=input_test_edges is not None,\n            loader=loader,\n            graph_sampler=link_sampler,\n            eval_loader_kwargs=eval_loader_kwargs,\n            **kwargs,\n        )\n\n        self.input_train_edges = input_train_edges\n        self.input_train_labels = input_train_labels\n        self.input_train_time = input_train_time\n        self.input_train_id: OptTensor = None\n\n        self.input_val_edges = input_val_edges\n        self.input_val_labels = input_val_labels\n        self.input_val_time = input_val_time\n        self.input_val_id: OptTensor = None\n\n        self.input_test_edges = input_test_edges\n        self.input_test_labels = input_test_labels\n        self.input_test_time = input_test_time\n        self.input_test_id: OptTensor = None\n\n        self.input_pred_edges = input_pred_edges\n        self.input_pred_labels = input_pred_labels\n        self.input_pred_time = input_pred_time\n        self.input_pred_id: OptTensor = None\n\n    def dataloader(\n        self,\n        input_edges: InputEdges,\n        input_labels: OptTensor = None,\n        input_time: OptTensor = None,\n        input_id: OptTensor = None,\n        link_sampler: Optional[BaseSampler] = None,\n        **kwargs: Any,\n    ) -> torch.utils.data.DataLoader:\n        if self.loader == 'full':\n            return self.full_dataloader(**kwargs)\n\n        assert link_sampler is not None\n\n        return LinkLoader(\n            self.data,\n            link_sampler=link_sampler,\n            edge_label_index=input_edges,\n            edge_label=input_labels,\n            edge_label_time=input_time,\n            input_id=input_id,\n            **kwargs,\n        )\n\n    def train_dataloader(self) -> torch.utils.data.DataLoader:\n        return self.dataloader(\n            self.input_train_edges,\n            self.input_train_labels,\n            self.input_train_time,\n            self.input_train_id,\n            link_sampler=getattr(self, 'graph_sampler', None),\n            shuffle=self.train_shuffle,\n            **self.loader_kwargs,\n        )\n\n    def val_dataloader(self) -> torch.utils.data.DataLoader:\n        return self.dataloader(\n            self.input_val_edges,\n            self.input_val_labels,\n            self.input_val_time,\n            self.input_val_id,\n            link_sampler=getattr(self, 'eval_graph_sampler', None),\n            shuffle=False,\n            **self.eval_loader_kwargs,\n        )\n\n    def test_dataloader(self) -> torch.utils.data.DataLoader:\n        return self.dataloader(\n            self.input_test_edges,\n            self.input_test_labels,\n            self.input_test_time,\n            self.input_test_id,\n            link_sampler=getattr(self, 'eval_graph_sampler', None),\n            shuffle=False,\n            **self.eval_loader_kwargs,\n        )\n\n    def predict_dataloader(self) -> torch.utils.data.DataLoader:\n        return self.dataloader(\n            self.input_pred_edges,\n            self.input_pred_labels,\n            self.input_pred_time,\n            self.input_pred_id,\n            link_sampler=getattr(self, 'eval_graph_sampler', None),\n            shuffle=False,\n            **self.eval_loader_kwargs,\n        )\n\n\n###############################################################################\n\n\n# TODO Support Tuple[FeatureStore, GraphStore]\ndef infer_input_nodes(data: Union[Data, HeteroData], split: str) -> InputNodes:\n    attr_name: Optional[str] = None\n    if f'{split}_mask' in data:\n        attr_name = f'{split}_mask'\n    elif f'{split}_idx' in data:\n        attr_name = f'{split}_idx'\n    elif f'{split}_index' in data:\n        attr_name = f'{split}_index'\n\n    if attr_name is None:\n        return None\n\n    if isinstance(data, Data):\n        return data[attr_name]\n    if isinstance(data, HeteroData):\n        input_nodes_dict = {\n            node_type: store[attr_name]\n            for node_type, store in data.node_items() if attr_name in store\n        }\n        if len(input_nodes_dict) != 1:\n            raise ValueError(f\"Could not automatically determine the input \"\n                             f\"nodes of {data} since there exists multiple \"\n                             f\"types with attribute '{attr_name}'\")\n        return list(input_nodes_dict.items())[0]\n    return None\n\n\ndef kwargs_repr(**kwargs: Any) -> str:\n    return ', '.join([f'{k}={v}' for k, v in kwargs.items() if v is not None])\n\n\ndef split_kwargs(\n    kwargs: Dict[str, Any],\n    sampler_cls: Type,\n) -> Tuple[Dict[str, Any], Dict[str, Any]]:\n    r\"\"\"Splits keyword arguments into sampler and loader arguments.\"\"\"\n    sampler_args = inspect.signature(sampler_cls).parameters\n\n    sampler_kwargs: Dict[str, Any] = {}\n    loader_kwargs: Dict[str, Any] = {}\n\n    for key, value in kwargs.items():\n        if key in sampler_args:\n            sampler_kwargs[key] = value\n        else:\n            loader_kwargs[key] = value\n\n    return sampler_kwargs, loader_kwargs\n"
  },
  {
    "path": "torch_geometric/data/makedirs.py",
    "content": "from torch_geometric.deprecation import deprecated\nfrom torch_geometric.io import fs\n\n\n@deprecated(\"use 'os.makedirs(path, exist_ok=True)' instead\")\ndef makedirs(path: str):\n    r\"\"\"Recursively creates a directory.\n\n    .. warning::\n\n        :meth:`makedirs` is deprecated and will be removed soon.\n        Please use :obj:`os.makedirs(path, exist_ok=True)` instead.\n\n    Args:\n        path (str): The path to create.\n    \"\"\"\n    fs.makedirs(path, exist_ok=True)\n"
  },
  {
    "path": "torch_geometric/data/on_disk_dataset.py",
    "content": "import os\nfrom typing import Any, Callable, Iterable, List, Optional, Sequence, Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.data import Database, RocksDatabase, SQLiteDatabase\nfrom torch_geometric.data.data import BaseData\nfrom torch_geometric.data.database import Schema\nfrom torch_geometric.data.dataset import Dataset\n\n\nclass OnDiskDataset(Dataset):\n    r\"\"\"Dataset base class for creating large graph datasets which do not\n    easily fit into CPU memory at once by leveraging a :class:`Database`\n    backend for on-disk storage and access of data objects.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in a\n            :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` object and returns a\n            transformed version.\n            The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in a\n            :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` object and returns a\n            boolean value, indicating whether the data object should be\n            included in the final dataset. (default: :obj:`None`)\n        backend (str): The :class:`Database` backend to use\n            (one of :obj:`\"sqlite\"` or :obj:`\"rocksdb\"`).\n            (default: :obj:`\"sqlite\"`)\n        schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of\n            the input data.\n            Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a\n            dictionary with :obj:`dtype` and :obj:`size` keys (for specifying\n            tensor data) as input, and can be nested as a tuple or dictionary.\n            Specifying the schema will improve efficiency, since by default the\n            database will use python pickling for serializing and\n            deserializing. If specified to anything different than\n            :obj:`object`, implementations of :class:`OnDiskDataset` need to\n            override :meth:`serialize` and :meth:`deserialize` methods.\n            (default: :obj:`object`)\n        log (bool, optional): Whether to print any console output while\n            downloading and processing the dataset. (default: :obj:`True`)\n    \"\"\"\n    BACKENDS = {\n        'sqlite': SQLiteDatabase,\n        'rocksdb': RocksDatabase,\n    }\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        backend: str = 'sqlite',\n        schema: Schema = object,\n        log: bool = True,\n    ) -> None:\n        if backend not in self.BACKENDS:\n            raise ValueError(f\"Database backend must be one of \"\n                             f\"{set(self.BACKENDS.keys())} \"\n                             f\"(got '{backend}')\")\n\n        self.backend = backend\n        self.schema = schema\n\n        self._db: Optional[Database] = None\n        self._numel: Optional[int] = None\n\n        super().__init__(root, transform, pre_filter=pre_filter, log=log)\n\n    @property\n    def processed_file_names(self) -> str:\n        return f'{self.backend}.db'\n\n    @property\n    def db(self) -> Database:\n        r\"\"\"Returns the underlying :class:`Database`.\"\"\"\n        if self._db is not None:\n            return self._db\n\n        kwargs = {}\n        cls = self.BACKENDS[self.backend]\n        if issubclass(cls, SQLiteDatabase):\n            kwargs['name'] = self.__class__.__name__\n\n        os.makedirs(self.processed_dir, exist_ok=True)\n        path = self.processed_paths[0]\n        self._db = cls(path=path, schema=self.schema, **kwargs)\n        self._numel = len(self._db)\n        return self._db\n\n    def close(self) -> None:\n        r\"\"\"Closes the connection to the underlying database.\"\"\"\n        if self._db is not None:\n            self._db.close()\n\n    def serialize(self, data: BaseData) -> Any:\n        r\"\"\"Serializes the :class:`~torch_geometric.data.Data` or\n        :class:`~torch_geometric.data.HeteroData` object into the expected DB\n        schema.\n        \"\"\"\n        if self.schema == object:\n            return data\n        raise NotImplementedError(f\"`{self.__class__.__name__}.serialize()` \"\n                                  f\"needs to be overridden in case a \"\n                                  f\"non-default schema was passed\")\n\n    def deserialize(self, data: Any) -> BaseData:\n        r\"\"\"Deserializes the DB entry into a\n        :class:`~torch_geometric.data.Data` or\n        :class:`~torch_geometric.data.HeteroData` object.\n        \"\"\"\n        if self.schema == object:\n            return data\n        raise NotImplementedError(f\"`{self.__class__.__name__}.deserialize()` \"\n                                  f\"needs to be overridden in case a \"\n                                  f\"non-default schema was passed\")\n\n    def append(self, data: BaseData) -> None:\n        r\"\"\"Appends the data object to the dataset.\"\"\"\n        index = len(self)\n        self.db.insert(index, self.serialize(data))\n        self._numel += 1\n\n    def extend(\n        self,\n        data_list: Sequence[BaseData],\n        batch_size: Optional[int] = None,\n    ) -> None:\n        r\"\"\"Extends the dataset by a list of data objects.\"\"\"\n        start = len(self)\n        end = start + len(data_list)\n        data_list = [self.serialize(data) for data in data_list]\n        self.db.multi_insert(range(start, end), data_list, batch_size)\n        self._numel += (end - start)\n\n    def get(self, idx: int) -> BaseData:\n        r\"\"\"Gets the data object at index :obj:`idx`.\"\"\"\n        return self.deserialize(self.db.get(idx))\n\n    def multi_get(\n        self,\n        indices: Union[Iterable[int], Tensor, slice, range],\n        batch_size: Optional[int] = None,\n    ) -> List[BaseData]:\n        r\"\"\"Gets a list of data objects from the specified indices.\"\"\"\n        if len(indices) == 1:\n            data_list = [self.db.get(indices[0])]\n        else:\n            data_list = self.db.multi_get(indices, batch_size)\n\n        data_list = [self.deserialize(data) for data in data_list]\n        if self.transform is not None:\n            data_list = [self.transform(data) for data in data_list]\n        return data_list\n\n    def __getitems__(self, indices: List[int]) -> List[BaseData]:\n        return self.multi_get(indices)\n\n    def len(self) -> int:\n        if self._numel is None:\n            self._numel = len(self.db)\n        return self._numel\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({len(self)})'\n"
  },
  {
    "path": "torch_geometric/data/remote_backend_utils.py",
    "content": "# This file defines a set of utilities for remote backends (backends that are\n# characterize as Tuple[FeatureStore, GraphStore]). TODO support for\n# non-heterogeneous graphs (feature stores with a group_name=None).\nfrom typing import Optional, Tuple, Union, overload\n\nfrom torch_geometric.data import FeatureStore, GraphStore\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\n@overload\ndef _internal_num_nodes(\n    feature_store: FeatureStore,\n    graph_store: GraphStore,\n    query: NodeType,\n) -> int:\n    pass\n\n\n@overload\ndef _internal_num_nodes(\n    feature_store: FeatureStore,\n    graph_store: GraphStore,\n    query: EdgeType,\n) -> Tuple[int, int]:\n    pass\n\n\n# NOTE PyG also supports querying by a relation type `rel` in an edge type\n# (src, rel, dst). It may be worth supporting this in remote backends as well.\ndef _internal_num_nodes(\n    feature_store: FeatureStore,\n    graph_store: GraphStore,\n    query: Union[NodeType, EdgeType],\n) -> Union[int, Tuple[int, int]]:\n    r\"\"\"Returns the number of nodes in the node type or the number of source\n    and destination nodes in an edge type by sequentially accessing attributes\n    in the feature and graph stores that reveal this number.\n    \"\"\"\n    def _matches_node_type(\n        query: Union[NodeType, EdgeType],\n        node_type: Optional[NodeType],\n    ) -> bool:\n        if isinstance(query, (list, tuple)):  # EdgeType:\n            return query[0] == node_type or query[-1] == node_type\n        else:\n            return query == node_type\n\n    node_query = isinstance(query, NodeType)\n\n    # TODO: In general, a feature store and graph store should be able to\n    # expose methods that allow for easy access to individual attributes,\n    # instead of requiring iteration to identify a particular attribute.\n    # Implementing this should reduce the iteration below.\n\n    # 1. Check the edges in the GraphStore, for each node type in each edge:\n    num_rows = num_cols = None\n    for edge_attr in graph_store.get_all_edge_attrs():\n        if edge_attr.size is None:\n            continue\n        if _matches_node_type(query, edge_attr.edge_type[0]):\n            num_rows = num_rows or edge_attr.size[0]\n        if _matches_node_type(query, edge_attr.edge_type[-1]):\n            num_cols = num_cols or edge_attr.size[-1]\n\n        if node_query and num_rows is not None:\n            return num_rows\n        if node_query and num_cols is not None:\n            return num_cols\n        if not node_query and num_rows is not None and num_cols is not None:\n            return num_rows, num_cols\n\n    # 2. Check the node types stored in the FeatureStore:\n    tensor_attrs = feature_store.get_all_tensor_attrs()\n    matching_attrs = [\n        attr for attr in tensor_attrs\n        if _matches_node_type(query, attr.group_name)\n    ]\n    if node_query:\n        if len(matching_attrs) > 0:\n            size = feature_store.get_tensor_size(matching_attrs[0])\n            if size is not None:\n                return size[0]\n    else:\n        matching_src_attrs = [\n            attr for attr in matching_attrs if attr.group_name == query[0]\n        ]\n        matching_dst_attrs = [\n            attr for attr in matching_attrs if attr.group_name == query[-1]\n        ]\n        if len(matching_src_attrs) > 0 and len(matching_dst_attrs) > 0:\n            src_size = feature_store.get_tensor_size(matching_src_attrs[0])\n            dst_size = feature_store.get_tensor_size(matching_dst_attrs[0])\n            if src_size is not None and dst_size is not None:\n                return src_size[0], dst_size[0]\n\n    raise ValueError(\n        f\"Unable to accurately infer the number of nodes corresponding to \"\n        f\"query {query} from feature store {feature_store} and graph store \"\n        f\"{graph_store}. Please consider either adding an edge containing \"\n        f\"the nodes in this query or feature tensors for the nodes in this \"\n        f\"query.\")\n\n\ndef num_nodes(\n    feature_store: FeatureStore,\n    graph_store: GraphStore,\n    query: NodeType,\n) -> int:\n    r\"\"\"Returns the number of nodes in a given node type stored in a remote\n    backend.\n    \"\"\"\n    return _internal_num_nodes(feature_store, graph_store, query)\n\n\ndef size(\n    feature_store: FeatureStore,\n    graph_store: GraphStore,\n    query: EdgeType,\n) -> Tuple[int, int]:\n    r\"\"\"Returns the size of an edge (number of source nodes, number of\n    destination nodes) in an edge stored in a remote backend.\n    \"\"\"\n    return _internal_num_nodes(feature_store, graph_store, query)\n"
  },
  {
    "path": "torch_geometric/data/separate.py",
    "content": "from collections.abc import Mapping, Sequence\nfrom typing import Any, Type, TypeVar\n\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex, Index\nfrom torch_geometric.data.data import BaseData\nfrom torch_geometric.data.storage import BaseStorage\nfrom torch_geometric.typing import SparseTensor, TensorFrame\nfrom torch_geometric.utils import narrow\n\nT = TypeVar('T')\n\n\ndef separate(\n    cls: Type[T],\n    batch: Any,\n    idx: int,\n    slice_dict: Any,\n    inc_dict: Any = None,\n    decrement: bool = True,\n) -> T:\n    # Separates the individual element from a `batch` at index `idx`.\n    # `separate` can handle both homogeneous and heterogeneous data objects by\n    # individually separating all their stores.\n    # In addition, `separate` can handle nested data structures such as\n    # dictionaries and lists.\n\n    data = cls().stores_as(batch)\n\n    # Iterate over each storage object and recursively separate its attributes:\n    for batch_store, data_store in zip(batch.stores, data.stores):\n        key = batch_store._key\n        if key is not None:  # Heterogeneous:\n            attrs = slice_dict[key].keys()\n        else:  # Homogeneous:\n            attrs = set(batch_store.keys())\n            attrs = [attr for attr in slice_dict.keys() if attr in attrs]\n\n        for attr in attrs:\n            if key is not None:\n                slices = slice_dict[key][attr]\n                incs = inc_dict[key][attr] if decrement else None\n            else:\n                slices = slice_dict[attr]\n                incs = inc_dict[attr] if decrement else None\n\n            data_store[attr] = _separate(attr, batch_store[attr], idx, slices,\n                                         incs, batch, batch_store, decrement)\n\n        # The `num_nodes` attribute needs special treatment, as we cannot infer\n        # the real number of nodes from the total number of nodes alone:\n        if hasattr(batch_store, '_num_nodes'):\n            data_store.num_nodes = batch_store._num_nodes[idx]\n\n    return data\n\n\ndef _separate(\n    key: str,\n    values: Any,\n    idx: int,\n    slices: Any,\n    incs: Any,\n    batch: BaseData,\n    store: BaseStorage,\n    decrement: bool,\n) -> Any:\n\n    if isinstance(values, Tensor):\n        # Narrow a `torch.Tensor` based on `slices`.\n        # NOTE: We need to take care of decrementing elements appropriately.\n        key = str(key)\n        cat_dim = batch.__cat_dim__(key, values, store)\n        start, end = int(slices[idx]), int(slices[idx + 1])\n        value = narrow(values, cat_dim or 0, start, end - start)\n        value = value.squeeze(0) if cat_dim is None else value\n\n        if isinstance(values, Index) and values._cat_metadata is not None:\n            # Reconstruct original `Index` metadata:\n            value._dim_size = values._cat_metadata.dim_size[idx]\n            value._is_sorted = values._cat_metadata.is_sorted[idx]\n\n        if isinstance(values, EdgeIndex) and values._cat_metadata is not None:\n            # Reconstruct original `EdgeIndex` metadata:\n            value._sparse_size = values._cat_metadata.sparse_size[idx]\n            value._sort_order = values._cat_metadata.sort_order[idx]\n            value._is_undirected = values._cat_metadata.is_undirected[idx]\n\n        if (decrement and incs is not None\n                and (incs.dim() > 1 or int(incs[idx]) != 0)):\n            value = value - incs[idx].to(value.device)\n\n        return value\n\n    elif isinstance(values, SparseTensor) and decrement:\n        # Narrow a `SparseTensor` based on `slices`.\n        # NOTE: `cat_dim` may return a tuple to allow for diagonal stacking.\n        key = str(key)\n        cat_dim = batch.__cat_dim__(key, values, store)\n        cat_dims = (cat_dim, ) if isinstance(cat_dim, int) else cat_dim\n        for i, dim in enumerate(cat_dims):\n            start, end = int(slices[idx][i]), int(slices[idx + 1][i])\n            values = values.narrow(dim, start, end - start)\n        return values\n\n    elif isinstance(values, TensorFrame):\n        key = str(key)\n        start, end = int(slices[idx]), int(slices[idx + 1])\n        value = values[start:end]\n        return value\n\n    elif isinstance(values, Mapping):\n        # Recursively separate elements of dictionaries.\n        return {\n            key:\n            _separate(\n                key,\n                value,\n                idx,\n                slices=slices[key],\n                incs=incs[key] if decrement else None,\n                batch=batch,\n                store=store,\n                decrement=decrement,\n            )\n            for key, value in values.items()\n        }\n\n    elif (isinstance(values, Sequence) and isinstance(values[0], Sequence)\n          and not isinstance(values[0], str) and len(values[0]) > 0\n          and isinstance(values[0][0], (Tensor, SparseTensor))\n          and isinstance(slices, Sequence)):\n        # Recursively separate elements of lists of lists.\n        return [value[idx] for value in values]\n\n    elif (isinstance(values, Sequence) and not isinstance(values, str)\n          and isinstance(values[0], (Tensor, SparseTensor))\n          and isinstance(slices, Sequence)):\n        # Recursively separate elements of lists of Tensors/SparseTensors.\n        return [\n            _separate(\n                key,\n                value,\n                idx,\n                slices=slices[i],\n                incs=incs[i] if decrement else None,\n                batch=batch,\n                store=store,\n                decrement=decrement,\n            ) for i, value in enumerate(values)\n        ]\n\n    else:\n        return values[idx]\n"
  },
  {
    "path": "torch_geometric/data/storage.py",
    "content": "import copy\nimport warnings\nimport weakref\nfrom collections import defaultdict, namedtuple\nfrom collections.abc import Mapping, MutableMapping, Sequence\nfrom enum import Enum\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    Iterator,\n    List,\n    NamedTuple,\n    Optional,\n    Set,\n    Tuple,\n    Union,\n    overload,\n)\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\nfrom typing_extensions import Self\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.data.view import ItemsView, KeysView, ValuesView\nfrom torch_geometric.typing import (\n    EdgeType,\n    NodeType,\n    SparseTensor,\n    TensorFrame,\n)\nfrom torch_geometric.utils import (\n    coalesce,\n    contains_isolated_nodes,\n    is_torch_sparse_tensor,\n    is_undirected,\n    select,\n    sort_edge_index,\n)\n\nN_KEYS = {'x', 'feat', 'pos', 'batch', 'node_type', 'n_id', 'tf'}\nE_KEYS = {'edge_index', 'edge_weight', 'edge_attr', 'edge_type', 'e_id'}\n\n\nclass AttrType(Enum):\n    NODE = 'NODE'\n    EDGE = 'EDGE'\n    OTHER = 'OTHER'\n\n\nclass BaseStorage(MutableMapping):\n    # This class wraps a Python dictionary and extends it as follows:\n    # 1. It allows attribute assignments, e.g.:\n    #    `storage.x = ...` in addition to `storage['x'] = ...`\n    # 2. It allows private attributes that are not exposed to the user, e.g.:\n    #    `storage._{key} = ...` and accessible via `storage._{key}`\n    # 3. It holds an (optional) weak reference to its parent object, e.g.:\n    #    `storage._parent = weakref.ref(parent)`\n    # 4. It allows iterating over only a subset of keys, e.g.:\n    #    `storage.values('x', 'y')` or `storage.items('x', 'y')\n    # 5. It adds additional PyTorch Tensor functionality, e.g.:\n    #    `storage.cpu()`, `storage.cuda()` or `storage.share_memory_()`.\n    def __init__(\n        self,\n        _mapping: Optional[Dict[str, Any]] = None,\n        **kwargs: Any,\n    ) -> None:\n        super().__init__()\n        self._mapping: Dict[str, Any] = {}\n        for key, value in (_mapping or {}).items():\n            setattr(self, key, value)\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n\n    @property\n    def _key(self) -> Any:\n        return None\n\n    def _pop_cache(self, key: str) -> None:\n        for cache in getattr(self, '_cached_attr', {}).values():\n            cache.discard(key)\n\n    def __len__(self) -> int:\n        return len(self._mapping)\n\n    def __getattr__(self, key: str) -> Any:\n        if key == '_mapping':\n            self._mapping = {}\n            return self._mapping\n        try:\n            return self[key]\n        except KeyError:\n            raise AttributeError(\n                f\"'{self.__class__.__name__}' object has no attribute '{key}'\"\n            ) from None\n\n    def __setattr__(self, key: str, value: Any) -> None:\n        propobj = getattr(self.__class__, key, None)\n        if propobj is not None and getattr(propobj, 'fset', None) is not None:\n            propobj.fset(self, value)\n        elif key == '_parent':\n            self.__dict__[key] = weakref.ref(value)\n        elif key[:1] == '_':\n            self.__dict__[key] = value\n        else:\n            self[key] = value\n\n    def __delattr__(self, key: str) -> None:\n        if key[:1] == '_':\n            del self.__dict__[key]\n        else:\n            del self[key]\n\n    def __getitem__(self, key: str) -> Any:\n        return self._mapping[key]\n\n    def __setitem__(self, key: str, value: Any) -> None:\n        self._pop_cache(key)\n        if value is None and key in self._mapping:\n            del self._mapping[key]\n        elif value is not None:\n            self._mapping[key] = value\n\n    def __delitem__(self, key: str) -> None:\n        if key in self._mapping:\n            self._pop_cache(key)\n            del self._mapping[key]\n\n    def __iter__(self) -> Iterator[Any]:\n        return iter(self._mapping)\n\n    def __copy__(self) -> Self:\n        out = self.__class__.__new__(self.__class__)\n        for key, value in self.__dict__.items():\n            if key != '_cached_attr':\n                out.__dict__[key] = value\n        out._mapping = copy.copy(out._mapping)\n        return out\n\n    def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> Self:\n        out = self.__class__.__new__(self.__class__)\n        for key, value in self.__dict__.items():\n            out.__dict__[key] = value\n        out._mapping = copy.deepcopy(out._mapping, memo)\n        return out\n\n    def __getstate__(self) -> Dict[str, Any]:\n        out = self.__dict__.copy()\n\n        _parent = out.get('_parent', None)\n        if _parent is not None:\n            out['_parent'] = _parent()\n\n        return out\n\n    def __setstate__(self, mapping: Dict[str, Any]) -> None:\n        for key, value in mapping.items():\n            self.__dict__[key] = value\n\n        _parent = self.__dict__.get('_parent', None)\n        if _parent is not None:\n            self.__dict__['_parent'] = weakref.ref(_parent)\n\n    def __repr__(self) -> str:\n        return repr(self._mapping)\n\n    # Allow iterating over subsets ############################################\n\n    # In contrast to standard `keys()`, `values()` and `items()` functions of\n    # Python dictionaries, we allow to only iterate over a subset of items\n    # denoted by a list of keys `args`.\n    # This is especially useful for adding PyTorch Tensor functionality to the\n    # storage object, e.g., in case we only want to transfer a subset of keys\n    # to the GPU (i.e. the ones that are relevant to the deep learning model).\n\n    def keys(self, *args: str) -> KeysView:  # type: ignore\n        return KeysView(self._mapping, *args)\n\n    def values(self, *args: str) -> ValuesView:  # type: ignore\n        return ValuesView(self._mapping, *args)\n\n    def items(self, *args: str) -> ItemsView:  # type: ignore\n        return ItemsView(self._mapping, *args)\n\n    def apply_(self, func: Callable, *args: str) -> Self:\n        r\"\"\"Applies the in-place function :obj:`func`, either to all attributes\n        or only the ones given in :obj:`*args`.\n        \"\"\"\n        for value in self.values(*args):\n            recursive_apply_(value, func)\n        return self\n\n    def apply(self, func: Callable, *args: str) -> Self:\n        r\"\"\"Applies the function :obj:`func`, either to all attributes or only\n        the ones given in :obj:`*args`.\n        \"\"\"\n        for key, value in self.items(*args):\n            self[key] = recursive_apply(value, func)\n        return self\n\n    # Additional functionality ################################################\n\n    def get(self, key: str, value: Optional[Any] = None) -> Any:\n        return self._mapping.get(key, value)\n\n    def to_dict(self) -> Dict[str, Any]:\n        r\"\"\"Returns a dictionary of stored key/value pairs.\"\"\"\n        out_dict = copy.copy(self._mapping)\n        # Needed to preserve individual `num_nodes` attributes when calling\n        # `BaseData.collate`.\n        # TODO (matthias) Try to make this more generic.\n        if '_num_nodes' in self.__dict__:\n            out_dict['_num_nodes'] = self.__dict__['_num_nodes']\n        return out_dict\n\n    def to_namedtuple(self) -> NamedTuple:\n        r\"\"\"Returns a :obj:`NamedTuple` of stored key/value pairs.\"\"\"\n        field_names = list(self.keys())\n        typename = f'{self.__class__.__name__}Tuple'\n        StorageTuple = namedtuple(typename, field_names)  # type: ignore\n        return StorageTuple(*[self[key] for key in field_names])\n\n    def clone(self, *args: str) -> Self:\n        r\"\"\"Performs a deep-copy of the object.\"\"\"\n        return copy.deepcopy(self)\n\n    def contiguous(self, *args: str) -> Self:\n        r\"\"\"Ensures a contiguous memory layout, either for all attributes or\n        only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.contiguous(), *args)\n\n    def to(\n        self,\n        device: Union[int, str],\n        *args: str,\n        non_blocking: bool = False,\n    ) -> Self:\n        r\"\"\"Performs tensor dtype and/or device conversion, either for all\n        attributes or only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(\n            lambda x: x.to(device=device, non_blocking=non_blocking), *args)\n\n    def cpu(self, *args: str) -> Self:\n        r\"\"\"Copies attributes to CPU memory, either for all attributes or only\n        the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.cpu(), *args)\n\n    def cuda(\n        self,\n        device: Optional[Union[int, str]] = None,\n        *args: str,\n        non_blocking: bool = False,\n    ) -> Self:  # pragma: no cover\n        r\"\"\"Copies attributes to CUDA memory, either for all attributes or only\n        the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.cuda(device, non_blocking=non_blocking),\n                          *args)\n\n    def pin_memory(self, *args: str) -> Self:\n        r\"\"\"Copies attributes to pinned memory, either for all attributes or\n        only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.pin_memory(), *args)\n\n    def share_memory_(self, *args: str) -> Self:\n        r\"\"\"Moves attributes to shared memory, either for all attributes or\n        only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.share_memory_(), *args)\n\n    def detach_(self, *args: str) -> Self:\n        r\"\"\"Detaches attributes from the computation graph, either for all\n        attributes or only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.detach_(), *args)\n\n    def detach(self, *args: str) -> Self:\n        r\"\"\"Detaches attributes from the computation graph by creating a new\n        tensor, either for all attributes or only the ones given in\n        :obj:`*args`.\n        \"\"\"\n        return self.apply(lambda x: x.detach(), *args)\n\n    def requires_grad_(self, *args: str, requires_grad: bool = True) -> Self:\n        r\"\"\"Tracks gradient computation, either for all attributes or only the\n        ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply(\n            lambda x: x.requires_grad_(requires_grad=requires_grad), *args)\n\n    def record_stream(self, stream: torch.cuda.Stream, *args: str) -> Self:\n        r\"\"\"Ensures that the tensor memory is not reused for another tensor\n        until all current work queued on :obj:`stream` has been completed,\n        either for all attributes or only the ones given in :obj:`*args`.\n        \"\"\"\n        return self.apply_(lambda x: x.record_stream(stream), *args)\n\n    # Time Handling ###########################################################\n\n    def _cat_dims(self, keys: Iterable[str]) -> Dict[str, int]:\n        return {\n            key: self._parent().__cat_dim__(key, self[key], self)\n            for key in keys\n        }\n\n    def _select(\n        self,\n        keys: Iterable[str],\n        index_or_mask: Tensor,\n    ) -> Self:\n\n        for key, dim in self._cat_dims(keys).items():\n            self[key] = select(self[key], index_or_mask, dim)\n\n        return self\n\n    def concat(self, other: Self) -> Self:\n        if not (set(self.keys()) == set(other.keys())):\n            raise AttributeError('Given storage is not compatible')\n\n        for key, dim in self._cat_dims(self.keys()).items():\n            value1 = self[key]\n            value2 = other[key]\n\n            if key in {'num_nodes', 'num_edges'}:\n                self[key] = value1 + value2\n\n            elif isinstance(value1, list):\n                self[key] = value1 + value2\n\n            elif isinstance(value1, Tensor):\n                self[key] = torch.cat([value1, value2], dim=dim)\n\n            else:\n                raise NotImplementedError(\n                    f\"'{self.__class__.__name__}.concat' not yet implemented \"\n                    f\"for '{type(value1)}'\")\n\n        return self\n\n    def is_sorted_by_time(self) -> bool:\n        if 'time' in self:\n            return bool(torch.all(self.time[:-1] <= self.time[1:]))\n        return True\n\n    def sort_by_time(self) -> Self:\n        if self.is_sorted_by_time():\n            return self\n\n        if 'time' in self:\n            _, perm = torch.sort(self.time, stable=True)\n\n            if self.is_node_attr('time'):\n                keys = self.node_attrs()\n            elif self.is_edge_attr('time'):\n                keys = self.edge_attrs()\n\n            self._select(keys, perm)\n\n        return self\n\n    def snapshot(\n        self,\n        start_time: Union[float, int],\n        end_time: Union[float, int],\n        attr: str = 'time',\n    ) -> Self:\n        if attr in self:\n            time = self[attr]\n            mask = (time >= start_time) & (time <= end_time)\n\n            if self.is_node_attr(attr):\n                keys = self.node_attrs()\n            elif self.is_edge_attr(attr):\n                keys = self.edge_attrs()\n\n            self._select(keys, mask)\n\n            if self.is_node_attr(attr) and 'num_nodes' in self:\n                self.num_nodes: Optional[int] = int(mask.sum())\n\n        return self\n\n    def up_to(self, time: Union[float, int]) -> Self:\n        if 'time' in self:\n            return self.snapshot(self.time.min().item(), time)\n        return self\n\n\nclass NodeStorage(BaseStorage):\n    r\"\"\"A storage for node-level information.\"\"\"\n    @property\n    def _key(self) -> NodeType:\n        key = self.__dict__.get('_key', None)\n        if key is None or not isinstance(key, str):\n            raise ValueError(\"'_key' does not denote a valid node type\")\n        return key\n\n    @property\n    def can_infer_num_nodes(self) -> bool:\n        keys = set(self.keys())\n        num_node_keys = {\n            'num_nodes', 'x', 'pos', 'batch', 'adj', 'adj_t', 'edge_index',\n            'face'\n        }\n        if len(keys & num_node_keys) > 0:\n            return True\n        elif len([key for key in keys if 'node' in key]) > 0:\n            return True\n        else:\n            return False\n\n    @property\n    def num_nodes(self) -> Optional[int]:\n        # We sequentially access attributes that reveal the number of nodes.\n        if 'num_nodes' in self:\n            return self['num_nodes']\n        for key, value in self.items():\n            if isinstance(value, Tensor) and key in N_KEYS:\n                cat_dim = self._parent().__cat_dim__(key, value, self)\n                return value.size(cat_dim)\n            if isinstance(value, np.ndarray) and key in N_KEYS:\n                cat_dim = self._parent().__cat_dim__(key, value, self)\n                return value.shape[cat_dim]\n            if isinstance(value, TensorFrame) and key in N_KEYS:\n                return value.num_rows\n        for key, value in self.items():\n            if isinstance(value, Tensor) and 'node' in key:\n                cat_dim = self._parent().__cat_dim__(key, value, self)\n                return value.size(cat_dim)\n            if isinstance(value, np.ndarray) and 'node' in key:\n                cat_dim = self._parent().__cat_dim__(key, value, self)\n                return value.shape[cat_dim]\n            if isinstance(value, TensorFrame) and 'node' in key:\n                return value.num_rows\n        if 'edge_index' in self and isinstance(self.edge_index, EdgeIndex):\n            if self.edge_index.sparse_size(0) is not None:\n                return self.edge_index.sparse_size(0)\n            if self.edge_index.sparse_size(1) is not None:\n                return self.edge_index.sparse_size(1)\n        if 'adj' in self and isinstance(self.adj, (Tensor, SparseTensor)):\n            return self.adj.size(0)\n        if 'adj_t' in self and isinstance(self.adj_t, (Tensor, SparseTensor)):\n            return self.adj_t.size(1)\n        warnings.warn(\n            f\"Unable to accurately infer 'num_nodes' from the attribute set \"\n            f\"'{set(self.keys())}'. Please explicitly set 'num_nodes' as an \"\n            f\"attribute of \" +\n            (\"'data'\" if self._key is None else f\"'data[{self._key}]'\") +\n            \" to suppress this warning\", stacklevel=2)\n        if 'edge_index' in self and isinstance(self.edge_index, Tensor):\n            if self.edge_index.numel() > 0:\n                return int(self.edge_index.max()) + 1\n            return 0\n        if 'face' in self and isinstance(self.face, Tensor):\n            if self.face.numel() > 0:\n                return int(self.face.max()) + 1\n            return 0\n        return None\n\n    @num_nodes.setter\n    def num_nodes(self, num_nodes: Optional[int]) -> None:\n        self['num_nodes'] = num_nodes\n\n    @property\n    def num_node_features(self) -> int:\n        x: Optional[Any] = self.get('x')\n        if isinstance(x, Tensor):\n            return 1 if x.dim() == 1 else x.size(-1)\n        if isinstance(x, np.ndarray):\n            return 1 if x.ndim == 1 else x.shape[-1]\n        if isinstance(x, SparseTensor):\n            return 1 if x.dim() == 1 else x.size(-1)\n        if isinstance(x, TensorFrame):\n            return x.num_cols\n\n        tf: Optional[Any] = self.get('tf')\n        if isinstance(tf, TensorFrame):\n            return tf.num_cols\n\n        return 0\n\n    @property\n    def num_features(self) -> int:\n        return self.num_node_features\n\n    def is_node_attr(self, key: str) -> bool:\n        if '_cached_attr' not in self.__dict__:\n            self._cached_attr: Dict[AttrType, Set[str]] = defaultdict(set)\n\n        if key in self._cached_attr[AttrType.NODE]:\n            return True\n        if key in self._cached_attr[AttrType.OTHER]:\n            return False\n\n        value = self[key]\n\n        if (isinstance(value, (list, tuple, TensorFrame))\n                and len(value) == self.num_nodes):\n            self._cached_attr[AttrType.NODE].add(key)\n            return True\n\n        if not isinstance(value, (Tensor, np.ndarray)):\n            self._cached_attr[AttrType.OTHER].add(key)\n            return False\n\n        if value.ndim == 0:\n            self._cached_attr[AttrType.OTHER].add(key)\n            return False\n\n        cat_dim = self._parent().__cat_dim__(key, value, self)\n        if value.shape[cat_dim] != self.num_nodes:\n            self._cached_attr[AttrType.OTHER].add(key)\n            return False\n\n        self._cached_attr[AttrType.NODE].add(key)\n        return True\n\n    def is_edge_attr(self, key: str) -> bool:\n        return False\n\n    def node_attrs(self) -> List[str]:\n        return [key for key in self.keys() if self.is_node_attr(key)]\n\n\nclass EdgeStorage(BaseStorage):\n    r\"\"\"A storage for edge-level information.\n\n    We support multiple ways to store edge connectivity in a\n    :class:`EdgeStorage` object:\n\n    * :obj:`edge_index`: A :class:`torch.LongTensor` holding edge indices in\n      COO format with shape :obj:`[2, num_edges]` (the default format)\n\n    * :obj:`adj`: A :class:`torch_sparse.SparseTensor` holding edge indices in\n      a sparse format, supporting both COO and CSR format.\n\n    * :obj:`adj_t`: A **transposed** :class:`torch_sparse.SparseTensor` holding\n      edge indices in a sparse format, supporting both COO and CSR format.\n      This is the most efficient one for graph-based deep learning models as\n      indices are sorted based on target nodes.\n    \"\"\"\n    @property\n    def _key(self) -> EdgeType:\n        key = self.__dict__.get('_key', None)\n        if key is None or not isinstance(key, tuple) or not len(key) == 3:\n            raise ValueError(\"'_key' does not denote a valid edge type\")\n        return key\n\n    @property\n    def edge_index(self) -> Tensor:\n        if 'edge_index' in self:\n            return self['edge_index']\n        if 'adj' in self and isinstance(self.adj, SparseTensor):\n            return torch.stack(self.adj.coo()[:2], dim=0)\n        if 'adj_t' in self and isinstance(self.adj_t, SparseTensor):\n            return torch.stack(self.adj_t.coo()[:2][::-1], dim=0)\n        raise AttributeError(\n            f\"'{self.__class__.__name__}' object has no attribute \"\n            f\"'edge_index', 'adj' or 'adj_t'\")\n\n    @edge_index.setter\n    def edge_index(self, edge_index: Optional[Tensor]) -> None:\n        self['edge_index'] = edge_index\n\n    @property\n    def num_edges(self) -> int:\n        # We sequentially access attributes that reveal the number of edges.\n        if 'num_edges' in self:\n            return self['num_edges']\n        for key, value in self.items():\n            if isinstance(value, Tensor) and key in E_KEYS:\n                cat_dim = self._parent().__cat_dim__(key, value, self)\n                return value.size(cat_dim)\n            if isinstance(value, np.ndarray) and key in E_KEYS:\n                cat_dim = self._parent().__cat_dim__(key, value, self)\n                return value.shape[cat_dim]\n            if isinstance(value, TensorFrame) and key in E_KEYS:\n                return value.num_rows\n        for key, value in self.items():\n            if isinstance(value, Tensor) and 'edge' in key:\n                cat_dim = self._parent().__cat_dim__(key, value, self)\n                return value.size(cat_dim)\n            if isinstance(value, np.ndarray) and 'edge' in key:\n                cat_dim = self._parent().__cat_dim__(key, value, self)\n                return value.shape[cat_dim]\n            if isinstance(value, TensorFrame) and 'edge' in key:\n                return value.num_rows\n        for value in self.values('adj', 'adj_t'):\n            if isinstance(value, SparseTensor):\n                return value.nnz()\n            elif is_torch_sparse_tensor(value):\n                return value._nnz()\n        return 0\n\n    @property\n    def num_edge_features(self) -> int:\n        edge_attr: Optional[Any] = self.get('edge_attr')\n        if isinstance(edge_attr, Tensor):\n            return 1 if edge_attr.dim() == 1 else edge_attr.size(-1)\n        if isinstance(edge_attr, np.ndarray):\n            return 1 if edge_attr.ndim == 1 else edge_attr.shape[-1]\n        if isinstance(edge_attr, TensorFrame):\n            return edge_attr.num_cols\n        return 0\n\n    @property\n    def num_features(self) -> int:\n        return self.num_edge_features\n\n    @overload\n    def size(self) -> Tuple[Optional[int], Optional[int]]:\n        pass\n\n    @overload\n    def size(self, dim: int) -> Optional[int]:\n        pass\n\n    def size(\n        self, dim: Optional[int] = None\n    ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]:\n\n        if self._key is None:\n            raise NameError(\"Unable to infer 'size' without explicit \"\n                            \"'_key' assignment\")\n\n        size = (self._parent()[self._key[0]].num_nodes,\n                self._parent()[self._key[-1]].num_nodes)\n\n        return size if dim is None else size[dim]\n\n    def is_node_attr(self, key: str) -> bool:\n        return False\n\n    def is_edge_attr(self, key: str) -> bool:\n        if '_cached_attr' not in self.__dict__:\n            self._cached_attr: Dict[AttrType, Set[str]] = defaultdict(set)\n\n        if key in self._cached_attr[AttrType.EDGE]:\n            return True\n        if key in self._cached_attr[AttrType.OTHER]:\n            return False\n\n        value = self[key]\n\n        if (isinstance(value, (list, tuple, TensorFrame))\n                and len(value) == self.num_edges):\n            self._cached_attr[AttrType.EDGE].add(key)\n            return True\n\n        if not isinstance(value, (Tensor, np.ndarray)):\n            self._cached_attr[AttrType.OTHER].add(key)\n            return False\n\n        if value.ndim == 0:\n            self._cached_attr[AttrType.OTHER].add(key)\n            return False\n\n        cat_dim = self._parent().__cat_dim__(key, value, self)\n        if value.shape[cat_dim] != self.num_edges:\n            self._cached_attr[AttrType.OTHER].add(key)\n            return False\n\n        self._cached_attr[AttrType.EDGE].add(key)\n        return True\n\n    def edge_attrs(self) -> List[str]:\n        return [key for key in self.keys() if self.is_edge_attr(key)]\n\n    def is_sorted(self, sort_by_row: bool = True) -> bool:\n        if 'edge_index' in self:\n            index = self.edge_index[0] if sort_by_row else self.edge_index[1]\n            return bool(torch.all(index[:-1] <= index[1:]))\n        return True\n\n    def sort(self, sort_by_row: bool = True) -> Self:\n        if 'edge_index' in self:\n            edge_attrs = self.edge_attrs()\n            edge_attrs.remove('edge_index')\n            edge_feats = [self[edge_attr] for edge_attr in edge_attrs]\n            self.edge_index, edge_feats = sort_edge_index(\n                self.edge_index, edge_feats, sort_by_row=sort_by_row)\n            for key, edge_feat in zip(edge_attrs, edge_feats):\n                self[key] = edge_feat\n        return self\n\n    def is_coalesced(self) -> bool:\n        for value in self.values('adj', 'adj_t'):\n            return value.is_coalesced()\n\n        if 'edge_index' in self:\n            size = [s for s in self.size() if s is not None]\n            num_nodes = max(size) if len(size) > 0 else None\n\n            new_edge_index = coalesce(self.edge_index, num_nodes=num_nodes)\n\n            return (self.edge_index.numel() == new_edge_index.numel()\n                    and torch.equal(self.edge_index, new_edge_index))\n\n        return True\n\n    def coalesce(self, reduce: str = 'sum') -> Self:\n        for key, value in self.items('adj', 'adj_t'):\n            self[key] = value.coalesce(reduce)\n\n        if 'edge_index' in self:\n\n            size = [s for s in self.size() if s is not None]\n            num_nodes = max(size) if len(size) > 0 else None\n\n            self.edge_index, self.edge_attr = coalesce(\n                self.edge_index,\n                edge_attr=self.get('edge_attr'),\n                num_nodes=num_nodes,\n            )\n\n        return self\n\n    def has_isolated_nodes(self) -> bool:\n        edge_index, num_nodes = self.edge_index, self.size(1)\n        if num_nodes is None:\n            raise NameError(\"Unable to infer 'num_nodes'\")\n        if self.is_bipartite():\n            return torch.unique(edge_index[1]).numel() < num_nodes\n        else:\n            return contains_isolated_nodes(edge_index, num_nodes)\n\n    def has_self_loops(self) -> bool:\n        if self.is_bipartite():\n            return False\n        edge_index = self.edge_index\n        return int((edge_index[0] == edge_index[1]).sum()) > 0\n\n    def is_undirected(self) -> bool:\n        if self.is_bipartite():\n            return False\n\n        for value in self.values('adj', 'adj_t'):\n            return value.is_symmetric()\n\n        edge_index = self.edge_index\n        edge_attr = self.edge_attr if 'edge_attr' in self else None\n        return is_undirected(edge_index, edge_attr, num_nodes=self.size(0))\n\n    def is_directed(self) -> bool:\n        return not self.is_undirected()\n\n    def is_bipartite(self) -> bool:\n        return self._key is not None and self._key[0] != self._key[-1]\n\n\nclass GlobalStorage(NodeStorage, EdgeStorage):\n    r\"\"\"A storage for both node-level and edge-level information.\"\"\"\n    @property\n    def _key(self) -> Any:\n        return None\n\n    @property\n    def num_features(self) -> int:\n        return self.num_node_features\n\n    @overload\n    def size(self) -> Tuple[Optional[int], Optional[int]]:\n        pass\n\n    @overload\n    def size(self, dim: int) -> Optional[int]:\n        pass\n\n    def size(\n        self, dim: Optional[int] = None\n    ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]:\n        size = (self.num_nodes, self.num_nodes)\n        return size if dim is None else size[dim]\n\n    def is_node_attr(self, key: str) -> bool:\n        if '_cached_attr' not in self.__dict__:\n            self._cached_attr: Dict[AttrType, Set[str]] = defaultdict(set)\n\n        if key in self._cached_attr[AttrType.NODE]:\n            return True\n        if key in self._cached_attr[AttrType.EDGE]:\n            return False\n        if key in self._cached_attr[AttrType.OTHER]:\n            return False\n\n        value = self[key]\n\n        if (isinstance(value, (list, tuple, TensorFrame))\n                and len(value) == self.num_nodes):\n            self._cached_attr[AttrType.NODE].add(key)\n            return True\n\n        if not isinstance(value, (Tensor, np.ndarray)):\n            return False\n\n        if value.ndim == 0:\n            self._cached_attr[AttrType.OTHER].add(key)\n            return False\n\n        cat_dim = self._parent().__cat_dim__(key, value, self)\n\n        if not isinstance(cat_dim, int):\n            return False\n\n        num_nodes, num_edges = self.num_nodes, self.num_edges\n\n        if value.shape[cat_dim] != num_nodes:\n            if value.shape[cat_dim] == num_edges:\n                self._cached_attr[AttrType.EDGE].add(key)\n            else:\n                self._cached_attr[AttrType.OTHER].add(key)\n            return False\n\n        if num_nodes != num_edges:\n            self._cached_attr[AttrType.NODE].add(key)\n            return True\n\n        if 'edge' not in key:\n            self._cached_attr[AttrType.NODE].add(key)\n            return True\n        else:\n            self._cached_attr[AttrType.EDGE].add(key)\n            return False\n\n    def is_edge_attr(self, key: str) -> bool:\n        if '_cached_attr' not in self.__dict__:\n            self._cached_attr = defaultdict(set)\n\n        if key in self._cached_attr[AttrType.EDGE]:\n            return True\n        if key in self._cached_attr[AttrType.NODE]:\n            return False\n        if key in self._cached_attr[AttrType.OTHER]:\n            return False\n\n        value = self[key]\n\n        if (isinstance(value, (list, tuple, TensorFrame))\n                and len(value) == self.num_edges):\n            self._cached_attr[AttrType.EDGE].add(key)\n            return True\n\n        if not isinstance(value, (Tensor, np.ndarray)):\n            return False\n\n        if value.ndim == 0:\n            self._cached_attr[AttrType.OTHER].add(key)\n            return False\n\n        cat_dim = self._parent().__cat_dim__(key, value, self)\n\n        if not isinstance(cat_dim, int):\n            return False\n\n        num_nodes, num_edges = self.num_nodes, self.num_edges\n\n        if value.shape[cat_dim] != num_edges:\n            if value.shape[cat_dim] == num_nodes:\n                self._cached_attr[AttrType.NODE].add(key)\n            else:\n                self._cached_attr[AttrType.OTHER].add(key)\n            return False\n\n        if num_edges != num_nodes:\n            self._cached_attr[AttrType.EDGE].add(key)\n            return True\n\n        if 'edge' in key:\n            self._cached_attr[AttrType.EDGE].add(key)\n            return True\n        else:\n            self._cached_attr[AttrType.NODE].add(key)\n            return False\n\n\ndef recursive_apply_(data: Any, func: Callable) -> Any:\n    if isinstance(data, Tensor):\n        func(data)\n    elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple\n        for value in data:\n            recursive_apply_(value, func)\n    elif isinstance(data, Sequence) and not isinstance(data, str):\n        for value in data:\n            recursive_apply_(value, func)\n    elif isinstance(data, Mapping):\n        for value in data.values():\n            recursive_apply_(value, func)\n    else:\n        try:\n            func(data)\n        except Exception:\n            pass\n\n\ndef recursive_apply(data: Any, func: Callable) -> Any:\n    if isinstance(data, Tensor):\n        return func(data)\n    elif isinstance(data, torch.nn.utils.rnn.PackedSequence):\n        return func(data)\n    elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple\n        return type(data)(*(recursive_apply(d, func) for d in data))\n    elif isinstance(data, Sequence) and not isinstance(data, str):\n        return [recursive_apply(d, func) for d in data]\n    elif isinstance(data, Mapping):\n        return {key: recursive_apply(data[key], func) for key in data}\n    else:\n        try:\n            return func(data)\n        except Exception:\n            return data\n"
  },
  {
    "path": "torch_geometric/data/summary.py",
    "content": "from collections import defaultdict\nfrom dataclasses import dataclass\nfrom typing import Dict, List, Optional, Union\n\nimport torch\nfrom tqdm import tqdm\nfrom typing_extensions import Self\n\nfrom torch_geometric.data import Dataset, HeteroData\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\n@dataclass\nclass Stats:\n    mean: float\n    std: float\n    min: float\n    quantile25: float\n    median: float\n    quantile75: float\n    max: float\n\n    @classmethod\n    def from_data(\n        cls,\n        data: Union[List[int], List[float], torch.Tensor],\n    ) -> Self:\n        if not isinstance(data, torch.Tensor):\n            data = torch.tensor(data)\n        data = data.to(torch.float)\n\n        return cls(\n            mean=data.mean().item(),\n            std=data.std().item(),\n            min=data.min().item(),\n            quantile25=data.quantile(0.25).item(),\n            median=data.median().item(),\n            quantile75=data.quantile(0.75).item(),\n            max=data.max().item(),\n        )\n\n\n@dataclass(repr=False)\nclass Summary:\n    name: str\n    num_graphs: int\n    num_nodes: Stats\n    num_edges: Stats\n    num_nodes_per_type: Optional[Dict[NodeType, Stats]] = None\n    num_edges_per_type: Optional[Dict[EdgeType, Stats]] = None\n\n    @classmethod\n    def from_dataset(\n        cls,\n        dataset: Dataset,\n        progress_bar: Optional[bool] = None,\n        per_type: bool = True,\n    ) -> Self:\n        r\"\"\"Creates a summary of a :class:`~torch_geometric.data.Dataset`\n        object.\n\n        Args:\n            dataset (Dataset): The dataset.\n            progress_bar (bool, optional): If set to :obj:`True`, will show a\n                progress bar during stats computation. If set to :obj:`None`,\n                will automatically decide whether to show a progress bar based\n                on dataset size. (default: :obj:`None`)\n            per_type (bool, optional): If set to :obj:`True`, will separate\n                statistics per node and edge type (only applicable in\n                heterogeneous graph datasets). (default: :obj:`True`)\n        \"\"\"\n        name = dataset.__class__.__name__\n\n        if progress_bar is None:\n            progress_bar = len(dataset) >= 10000\n\n        if progress_bar:\n            dataset = tqdm(dataset)\n\n        num_nodes, num_edges = [], []\n        _num_nodes_per_type = defaultdict(list)\n        _num_edges_per_type = defaultdict(list)\n\n        for data in dataset:\n            assert data.num_nodes is not None\n            num_nodes.append(data.num_nodes)\n            num_edges.append(data.num_edges)\n\n            if per_type and isinstance(data, HeteroData):\n                for node_type in data.node_types:\n                    _num_nodes_per_type[node_type].append(\n                        data[node_type].num_nodes)\n                for edge_type in data.edge_types:\n                    _num_edges_per_type[edge_type].append(\n                        data[edge_type].num_edges)\n\n        num_nodes_per_type = None\n        if len(_num_nodes_per_type) > 0:\n            num_nodes_per_type = {\n                node_type: Stats.from_data(num_nodes_list)\n                for node_type, num_nodes_list in _num_nodes_per_type.items()\n            }\n\n        num_edges_per_type = None\n        if len(_num_edges_per_type) > 0:\n            num_edges_per_type = {\n                edge_type: Stats.from_data(num_edges_list)\n                for edge_type, num_edges_list in _num_edges_per_type.items()\n            }\n\n        return cls(\n            name=name,\n            num_graphs=len(dataset),\n            num_nodes=Stats.from_data(num_nodes),\n            num_edges=Stats.from_data(num_edges),\n            num_nodes_per_type=num_nodes_per_type,\n            num_edges_per_type=num_edges_per_type,\n        )\n\n    def format(self, fmt: str = \"psql\") -> str:\n        r\"\"\"Formats summary statistics of the dataset.\n\n        Args:\n            fmt (str, optional): Summary tables format. Available table formats\n                can be found `here <https://github.com/astanin/python-tabulate?\n                tab=readme-ov-file#table-format>`__. (default: :obj:`\"psql\"`)\n        \"\"\"\n        from tabulate import tabulate\n\n        body = f'{self.name} (#graphs={self.num_graphs}):\\n'\n\n        content = [['', '#nodes', '#edges']]\n        stats = [self.num_nodes, self.num_edges]\n        for field in Stats.__dataclass_fields__:\n            row = [field] + [f'{getattr(s, field):.1f}' for s in stats]\n            content.append(row)\n        body += tabulate(content, headers='firstrow', tablefmt=fmt)\n\n        if self.num_nodes_per_type is not None:\n            content = [['']]\n            content[0] += list(self.num_nodes_per_type.keys())\n\n            for field in Stats.__dataclass_fields__:\n                row = [field] + [\n                    f'{getattr(s, field):.1f}'\n                    for s in self.num_nodes_per_type.values()\n                ]\n                content.append(row)\n            body += \"\\nNumber of nodes per node type:\\n\"\n            body += tabulate(content, headers='firstrow', tablefmt=fmt)\n\n        if self.num_edges_per_type is not None:\n            content = [['']]\n            content[0] += [\n                f\"({', '.join(edge_type)})\"\n                for edge_type in self.num_edges_per_type.keys()\n            ]\n\n            for field in Stats.__dataclass_fields__:\n                row = [field] + [\n                    f'{getattr(s, field):.1f}'\n                    for s in self.num_edges_per_type.values()\n                ]\n                content.append(row)\n            body += \"\\nNumber of edges per edge type:\\n\"\n            body += tabulate(content, headers='firstrow', tablefmt=fmt)\n\n        return body\n\n    def __repr__(self) -> str:\n        return self.format()\n"
  },
  {
    "path": "torch_geometric/data/temporal.py",
    "content": "import copy\nfrom typing import (\n    Any,\n    Dict,\n    Iterable,\n    List,\n    NamedTuple,\n    Optional,\n    Tuple,\n    Union,\n)\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data.data import BaseData, size_repr\nfrom torch_geometric.data.storage import (\n    BaseStorage,\n    EdgeStorage,\n    GlobalStorage,\n    NodeStorage,\n)\n\n\nclass TemporalData(BaseData):\n    r\"\"\"A data object composed by a stream of events describing a temporal\n    graph.\n    The :class:`~torch_geometric.data.TemporalData` object can hold a list of\n    events (that can be understood as temporal edges in a graph) with\n    structured messages.\n    An event is composed by a source node, a destination node, a timestamp\n    and a message. Any *Continuous-Time Dynamic Graph* (CTDG) can be\n    represented with these four values.\n\n    In general, :class:`~torch_geometric.data.TemporalData` tries to mimic\n    the behavior of a regular :python:`Python` dictionary.\n    In addition, it provides useful functionality for analyzing graph\n    structures, and provides basic PyTorch tensor functionalities.\n\n    .. code-block:: python\n\n        from torch import Tensor\n        from torch_geometric.data import TemporalData\n\n        events = TemporalData(\n            src=Tensor([1,2,3,4]),\n            dst=Tensor([2,3,4,5]),\n            t=Tensor([1000,1010,1100,2000]),\n            msg=Tensor([1,1,0,0])\n        )\n\n        # Add additional arguments to `events`:\n        events.y = Tensor([1,1,0,0])\n\n        # It is also possible to set additional arguments in the constructor\n        events = TemporalData(\n            ...,\n            y=Tensor([1,1,0,0])\n        )\n\n        # Get the number of events:\n        events.num_events\n        >>> 4\n\n        # Analyzing the graph structure:\n        events.num_nodes\n        >>> 5\n\n        # PyTorch tensor functionality:\n        events = events.pin_memory()\n        events = events.to('cuda:0', non_blocking=True)\n\n    Args:\n        src (torch.Tensor, optional): A list of source nodes for the events\n            with shape :obj:`[num_events]`. (default: :obj:`None`)\n        dst (torch.Tensor, optional): A list of destination nodes for the\n            events with shape :obj:`[num_events]`. (default: :obj:`None`)\n        t (torch.Tensor, optional): The timestamps for each event with shape\n            :obj:`[num_events]`. (default: :obj:`None`)\n        msg (torch.Tensor, optional): Messages feature matrix with shape\n            :obj:`[num_events, num_msg_features]`. (default: :obj:`None`)\n        **kwargs (optional): Additional attributes.\n\n    .. note::\n        The shape of :obj:`src`, :obj:`dst`, :obj:`t` and the first dimension\n        of :obj`msg` should be the same (:obj:`num_events`).\n    \"\"\"\n    def __init__(\n        self,\n        src: Optional[Tensor] = None,\n        dst: Optional[Tensor] = None,\n        t: Optional[Tensor] = None,\n        msg: Optional[Tensor] = None,\n        **kwargs,\n    ):\n        super().__init__()\n        self.__dict__['_store'] = GlobalStorage(_parent=self)\n\n        self.src = src\n        self.dst = dst\n        self.t = t\n        self.msg = msg\n\n        for key, value in kwargs.items():\n            setattr(self, key, value)\n\n    @classmethod\n    def from_dict(cls, mapping: Dict[str, Any]) -> 'TemporalData':\n        r\"\"\"Creates a :class:`~torch_geometric.data.TemporalData` object from\n        a Python dictionary.\n        \"\"\"\n        return cls(**mapping)\n\n    def index_select(self, idx: Any) -> 'TemporalData':\n        idx = prepare_idx(idx)\n        data = copy.copy(self)\n        for key, value in data._store.items():\n            if value.size(0) == self.num_events:\n                data[key] = value[idx]\n        return data\n\n    def __getitem__(self, idx: Any) -> Any:\n        if isinstance(idx, str):\n            return self._store[idx]\n        return self.index_select(idx)\n\n    def __setitem__(self, key: str, value: Any):\n        \"\"\"Sets the attribute :obj:`key` to :obj:`value`.\"\"\"\n        self._store[key] = value\n\n    def __delitem__(self, key: str):\n        if key in self._store:\n            del self._store[key]\n\n    def __getattr__(self, key: str) -> Any:\n        if '_store' not in self.__dict__:\n            raise RuntimeError(\n                \"The 'data' object was created by an older version of PyG. \"\n                \"If this error occurred while loading an already existing \"\n                \"dataset, remove the 'processed/' directory in the dataset's \"\n                \"root folder and try again.\")\n        return getattr(self._store, key)\n\n    def __setattr__(self, key: str, value: Any):\n        setattr(self._store, key, value)\n\n    def __delattr__(self, key: str):\n        delattr(self._store, key)\n\n    def __iter__(self) -> Iterable:\n        for i in range(self.num_events):\n            yield self[i]\n\n    def __len__(self) -> int:\n        return self.num_events\n\n    def __call__(self, *args: List[str]) -> Iterable:\n        yield from self._store.items(*args)\n\n    def __copy__(self):\n        out = self.__class__.__new__(self.__class__)\n        for key, value in self.__dict__.items():\n            out.__dict__[key] = value\n        out.__dict__['_store'] = copy.copy(self._store)\n        out._store._parent = out\n        return out\n\n    def __deepcopy__(self, memo):\n        out = self.__class__.__new__(self.__class__)\n        for key, value in self.__dict__.items():\n            out.__dict__[key] = copy.deepcopy(value, memo)\n        out._store._parent = out\n        return out\n\n    def stores_as(self, data: 'TemporalData'):\n        return self\n\n    @property\n    def stores(self) -> List[BaseStorage]:\n        return [self._store]\n\n    @property\n    def node_stores(self) -> List[NodeStorage]:\n        return [self._store]\n\n    @property\n    def edge_stores(self) -> List[EdgeStorage]:\n        return [self._store]\n\n    def to_dict(self) -> Dict[str, Any]:\n        return self._store.to_dict()\n\n    def to_namedtuple(self) -> NamedTuple:\n        return self._store.to_namedtuple()\n\n    def debug(self):\n        pass  # TODO\n\n    @property\n    def num_nodes(self) -> int:\n        r\"\"\"Returns the number of nodes in the graph.\"\"\"\n        return max(int(self.src.max()), int(self.dst.max())) + 1\n\n    @property\n    def num_events(self) -> int:\n        r\"\"\"Returns the number of events loaded.\n\n        .. note::\n            In a :class:`~torch_geometric.data.TemporalData`, each row denotes\n            an event.\n            Thus, they can be also understood as edges.\n        \"\"\"\n        return self.src.size(0)\n\n    @property\n    def num_edges(self) -> int:\n        r\"\"\"Alias for :meth:`~torch_geometric.data.TemporalData.num_events`.\"\"\"\n        return self.num_events\n\n    @property\n    def edge_index(self) -> Tensor:\n        r\"\"\"Returns the edge indices of the graph.\"\"\"\n        if 'edge_index' in self:\n            return self._store['edge_index']\n        if self.src is not None and self.dst is not None:\n            return torch.stack([self.src, self.dst], dim=0)\n        raise ValueError(f\"{self.__class__.__name__} does not contain \"\n                         f\"'edge_index' information\")\n\n    def size(\n        self, dim: Optional[int] = None\n    ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]:\n        r\"\"\"Returns the size of the adjacency matrix induced by the graph.\"\"\"\n        size = (int(self.src.max()), int(self.dst.max()))\n        return size if dim is None else size[dim]\n\n    def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:\n        return 0\n\n    def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:\n        if 'batch' in key and isinstance(value, Tensor):\n            return int(value.max()) + 1\n        elif key in ['src', 'dst']:\n            return self.num_nodes\n        else:\n            return 0\n\n    def __repr__(self) -> str:\n        cls = self.__class__.__name__\n        info = ', '.join([size_repr(k, v) for k, v in self._store.items()])\n        return f'{cls}({info})'\n\n    ###########################################################################\n\n    def train_val_test_split(self, val_ratio: float = 0.15,\n                             test_ratio: float = 0.15):\n        r\"\"\"Splits the data in training, validation and test sets according to\n        time.\n\n        Args:\n            val_ratio (float, optional): The proportion (in percents) of the\n                dataset to include in the validation split.\n                (default: :obj:`0.15`)\n            test_ratio (float, optional): The proportion (in percents) of the\n                dataset to include in the test split. (default: :obj:`0.15`)\n        \"\"\"\n        val_time, test_time = np.quantile(\n            self.t.cpu().numpy(),\n            [1. - val_ratio - test_ratio, 1. - test_ratio])\n\n        val_idx = int((self.t <= val_time).sum())\n        test_idx = int((self.t <= test_time).sum())\n\n        return self[:val_idx], self[val_idx:test_idx], self[test_idx:]\n\n    ###########################################################################\n\n    def coalesce(self):\n        raise NotImplementedError\n\n    def has_isolated_nodes(self) -> bool:\n        raise NotImplementedError\n\n    def has_self_loops(self) -> bool:\n        raise NotImplementedError\n\n    def is_undirected(self) -> bool:\n        raise NotImplementedError\n\n    def is_directed(self) -> bool:\n        raise NotImplementedError\n\n\n###############################################################################\n\n\ndef prepare_idx(idx):\n    if isinstance(idx, int):\n        return slice(idx, idx + 1)\n    if isinstance(idx, (list, tuple)):\n        return torch.tensor(idx)\n    elif isinstance(idx, slice):\n        return idx\n    elif isinstance(idx, torch.Tensor) and idx.dtype == torch.long:\n        return idx\n    elif isinstance(idx, torch.Tensor) and idx.dtype == torch.bool:\n        return idx\n\n    raise IndexError(\n        f\"Only strings, integers, slices (`:`), list, tuples, and long or \"\n        f\"bool tensors are valid indices (got '{type(idx).__name__}')\")\n"
  },
  {
    "path": "torch_geometric/data/view.py",
    "content": "from typing import Any, Iterator, List, Mapping, Tuple\n\n\nclass MappingView:\n    def __init__(self, mapping: Mapping[str, Any], *args: str):\n        self._mapping = mapping\n        self._args = args\n\n    def _keys(self) -> List[str]:\n        if len(self._args) == 0:\n            return list(self._mapping.keys())\n        else:\n            return [arg for arg in self._args if arg in self._mapping]\n\n    def __len__(self) -> int:\n        return len(self._keys())\n\n    def __repr__(self) -> str:\n        mapping = {key: self._mapping[key] for key in self._keys()}\n        return f'{self.__class__.__name__}({mapping})'\n\n    __class_getitem__ = classmethod(type([]))  # type: ignore\n\n\nclass KeysView(MappingView):\n    def __iter__(self) -> Iterator[str]:\n        yield from self._keys()\n\n\nclass ValuesView(MappingView):\n    def __iter__(self) -> Iterator[Any]:\n        for key in self._keys():\n            yield self._mapping[key]\n\n\nclass ItemsView(MappingView):\n    def __iter__(self) -> Iterator[Tuple[str, Any]]:\n        for key in self._keys():\n            yield (key, self._mapping[key])\n"
  },
  {
    "path": "torch_geometric/datasets/__init__.py",
    "content": "# flake8: noqa\n\nfrom .karate import KarateClub\nfrom .tu_dataset import TUDataset\nfrom .gnn_benchmark_dataset import GNNBenchmarkDataset\nfrom .planetoid import Planetoid\nfrom .nell import NELL\nfrom .citation_full import CitationFull, CoraFull\nfrom .coauthor import Coauthor\nfrom .amazon import Amazon\nfrom .ppi import PPI\nfrom .reddit import Reddit\nfrom .reddit2 import Reddit2\nfrom .flickr import Flickr\nfrom .yelp import Yelp\nfrom .amazon_products import AmazonProducts\nfrom .qm7 import QM7b\nfrom .qm9 import QM9\nfrom .md17 import MD17\nfrom .zinc import ZINC\nfrom .aqsol import AQSOL\nfrom .molecule_net import MoleculeNet\nfrom .pcqm4m import PCQM4Mv2\nfrom .entities import Entities\nfrom .rel_link_pred_dataset import RelLinkPredDataset\nfrom .ged_dataset import GEDDataset\nfrom .attributed_graph_dataset import AttributedGraphDataset\nfrom .mnist_superpixels import MNISTSuperpixels\nfrom .faust import FAUST\nfrom .dynamic_faust import DynamicFAUST\nfrom .shapenet import ShapeNet\nfrom .modelnet import ModelNet\nfrom .medshapenet import MedShapeNet\nfrom .coma import CoMA\nfrom .shrec2016 import SHREC2016\nfrom .tosca import TOSCA\nfrom .pcpnet_dataset import PCPNetDataset\nfrom .s3dis import S3DIS\nfrom .geometry import GeometricShapes\nfrom .bitcoin_otc import BitcoinOTC\nfrom .gdelt_lite import GDELTLite\nfrom .icews import ICEWS18\nfrom .gdelt import GDELT\nfrom .willow_object_class import WILLOWObjectClass\nfrom .pascal import PascalVOCKeypoints\nfrom .pascal_pf import PascalPF\nfrom .snap_dataset import SNAPDataset\nfrom .suite_sparse import SuiteSparseMatrixCollection\nfrom .word_net import WordNet18, WordNet18RR\nfrom .freebase import FB15k_237\nfrom .wikics import WikiCS\nfrom .webkb import WebKB\nfrom .wikipedia_network import WikipediaNetwork\nfrom .heterophilous_graph_dataset import HeterophilousGraphDataset\nfrom .actor import Actor\nfrom .upfd import UPFD\nfrom .github import GitHub\nfrom .facebook import FacebookPagePage\nfrom .lastfm_asia import LastFMAsia\nfrom .deezer_europe import DeezerEurope\nfrom .gemsec import GemsecDeezer\nfrom .twitch import Twitch\nfrom .airports import Airports\nfrom .lrgb import LRGBDataset\nfrom .malnet_tiny import MalNetTiny\nfrom .omdb import OMDB\nfrom .polblogs import PolBlogs\nfrom .email_eu_core import EmailEUCore\nfrom .linkx_dataset import LINKXDataset\nfrom .elliptic import EllipticBitcoinDataset\nfrom .elliptic_temporal import EllipticBitcoinTemporalDataset\nfrom .dgraph import DGraphFin\nfrom .hydro_net import HydroNet\nfrom .airfrans import AirfRANS\nfrom .jodie import JODIEDataset\nfrom .wikidata import Wikidata5M\nfrom .myket import MyketDataset\nfrom .brca_tgca import BrcaTcga\nfrom .neurograph import NeuroGraphDataset\nfrom .web_qsp_dataset import WebQSPDataset, CWQDataset\nfrom .git_mol_dataset import GitMolDataset\nfrom .molecule_gpt_dataset import MoleculeGPTDataset\nfrom .instruct_mol_dataset import InstructMolDataset\nfrom .protein_mpnn_dataset import ProteinMPNNDataset\nfrom .tag_dataset import TAGDataset\nfrom .city import CityNetwork\nfrom .teeth3ds import Teeth3DS\n\nfrom .dbp15k import DBP15K\nfrom .aminer import AMiner\nfrom .ogb_mag import OGB_MAG\nfrom .dblp import DBLP\nfrom .movie_lens import MovieLens\nfrom .movie_lens_100k import MovieLens100K\nfrom .movie_lens_1m import MovieLens1M\nfrom .imdb import IMDB\nfrom .last_fm import LastFM\nfrom .hgb_dataset import HGBDataset\nfrom .taobao import Taobao\nfrom .igmc_dataset import IGMCDataset\nfrom .amazon_book import AmazonBook\nfrom .hm import HM\nfrom .ose_gvcs import OSE_GVCS\nfrom .rcdd import RCDD\nfrom .opf import OPFDataset\n\nfrom .cornell import CornellTemporalHyperGraphDataset\n\nfrom .fake import FakeDataset, FakeHeteroDataset\nfrom .sbm_dataset import StochasticBlockModelDataset\nfrom .sbm_dataset import RandomPartitionGraphDataset\nfrom .mixhop_synthetic_dataset import MixHopSyntheticDataset\nfrom .explainer_dataset import ExplainerDataset\nfrom .infection_dataset import InfectionDataset\nfrom .ba2motif_dataset import BA2MotifDataset\nfrom .ba_multi_shapes import BAMultiShapesDataset\nfrom .ba_shapes import BAShapes\n\nimport torch_geometric.datasets.utils\n\nhomo_datasets = [\n    'KarateClub',\n    'TUDataset',\n    'GNNBenchmarkDataset',\n    'Planetoid',\n    'NELL',\n    'CitationFull',\n    'CoraFull',\n    'Coauthor',\n    'Amazon',\n    'PPI',\n    'Reddit',\n    'Reddit2',\n    'Flickr',\n    'Yelp',\n    'AmazonProducts',\n    'QM7b',\n    'QM9',\n    'MD17',\n    'ZINC',\n    'AQSOL',\n    'MoleculeNet',\n    'PCQM4Mv2',\n    'Entities',\n    'RelLinkPredDataset',\n    'GEDDataset',\n    'AttributedGraphDataset',\n    'MNISTSuperpixels',\n    'FAUST',\n    'DynamicFAUST',\n    'ShapeNet',\n    'ModelNet',\n    'MedShapeNet',\n    'CoMA',\n    'SHREC2016',\n    'TOSCA',\n    'PCPNetDataset',\n    'S3DIS',\n    'GeometricShapes',\n    'BitcoinOTC',\n    'GDELTLite',\n    'ICEWS18',\n    'GDELT',\n    'WILLOWObjectClass',\n    'PascalVOCKeypoints',\n    'PascalPF',\n    'SNAPDataset',\n    'SuiteSparseMatrixCollection',\n    'WordNet18',\n    'WordNet18RR',\n    'FB15k_237',\n    'WikiCS',\n    'WebKB',\n    'WikipediaNetwork',\n    'HeterophilousGraphDataset',\n    'Actor',\n    'UPFD',\n    'GitHub',\n    'FacebookPagePage',\n    'LastFMAsia',\n    'DeezerEurope',\n    'GemsecDeezer',\n    'Twitch',\n    'Airports',\n    'LRGBDataset',\n    'MalNetTiny',\n    'OMDB',\n    'PolBlogs',\n    'EmailEUCore',\n    'LINKXDataset',\n    'EllipticBitcoinDataset',\n    'EllipticBitcoinTemporalDataset',\n    'DGraphFin',\n    'HydroNet',\n    'AirfRANS',\n    'JODIEDataset',\n    'Wikidata5M',\n    'MyketDataset',\n    'BrcaTcga',\n    'NeuroGraphDataset',\n    'WebQSPDataset',\n    'CWQDataset',\n    'GitMolDataset',\n    'MoleculeGPTDataset',\n    'InstructMolDataset',\n    'ProteinMPNNDataset',\n    'TAGDataset',\n    'CityNetwork',\n    'Teeth3DS',\n]\n\nhetero_datasets = [\n    'DBP15K',\n    'AMiner',\n    'OGB_MAG',\n    'DBLP',\n    'MovieLens',\n    'MovieLens100K',\n    'MovieLens1M',\n    'IMDB',\n    'LastFM',\n    'HGBDataset',\n    'Taobao',\n    'IGMCDataset',\n    'AmazonBook',\n    'HM',\n    'OSE_GVCS',\n    'RCDD',\n    'OPFDataset',\n]\nhyper_datasets = [\n    'CornellTemporalHyperGraphDataset',\n]\nsynthetic_datasets = [\n    'FakeDataset',\n    'FakeHeteroDataset',\n    'StochasticBlockModelDataset',\n    'RandomPartitionGraphDataset',\n    'MixHopSyntheticDataset',\n    'ExplainerDataset',\n    'InfectionDataset',\n    'BA2MotifDataset',\n    'BAMultiShapesDataset',\n    'BAShapes',\n]\n\n__all__ = homo_datasets + hetero_datasets + hyper_datasets + synthetic_datasets\n"
  },
  {
    "path": "torch_geometric/datasets/actor.py",
    "content": "from typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.utils import coalesce\n\n\nclass Actor(InMemoryDataset):\n    r\"\"\"The actor-only induced subgraph of the film-director-actor-writer\n    network used in the\n    `\"Geom-GCN: Geometric Graph Convolutional Networks\"\n    <https://openreview.net/forum?id=S1e2agrFvS>`_ paper.\n    Each node corresponds to an actor, and the edge between two nodes denotes\n    co-occurrence on the same Wikipedia page.\n    Node features correspond to some keywords in the Wikipedia pages.\n    The task is to classify the nodes into five categories in term of words of\n    actor's Wikipedia.\n\n    Args:\n        root: Root directory where the dataset should be saved.\n        transform: A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n        pre_transform: A function/transform that takes in an\n            :class:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before being saved to\n            disk.\n        force_reload: Whether to re-process the dataset.\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 7,600\n          - 30,019\n          - 932\n          - 5\n    \"\"\"\n\n    url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['out1_node_feature_label.txt', 'out1_graph_edges.txt'\n                ] + [f'film_split_0.6_0.2_{i}.npz' for i in range(10)]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for f in self.raw_file_names[:2]:\n            download_url(f'{self.url}/new_data/film/{f}', self.raw_dir)\n        for f in self.raw_file_names[2:]:\n            download_url(f'{self.url}/splits/{f}', self.raw_dir)\n\n    def process(self) -> None:\n        with open(self.raw_paths[0]) as f:\n            node_data = [x.split('\\t') for x in f.read().split('\\n')[1:-1]]\n\n            rows, cols = [], []\n            for n_id, line, _ in node_data:\n                indices = [int(x) for x in line.split(',')]\n                rows += [int(n_id)] * len(indices)\n                cols += indices\n            row, col = torch.tensor(rows), torch.tensor(cols)\n\n            x = torch.zeros(int(row.max()) + 1, int(col.max()) + 1)\n            x[row, col] = 1.\n\n            y = torch.empty(len(node_data), dtype=torch.long)\n            for n_id, _, label in node_data:\n                y[int(n_id)] = int(label)\n\n        with open(self.raw_paths[1]) as f:\n            edge_data = f.read().split('\\n')[1:-1]\n            edge_indices = [[int(v) for v in r.split('\\t')] for r in edge_data]\n            edge_index = torch.tensor(edge_indices).t().contiguous()\n            edge_index = coalesce(edge_index, num_nodes=x.size(0))\n\n        train_masks, val_masks, test_masks = [], [], []\n        for path in self.raw_paths[2:]:\n            tmp = np.load(path)\n            train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)]\n            val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)]\n            test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)]\n        train_mask = torch.stack(train_masks, dim=1)\n        val_mask = torch.stack(val_masks, dim=1)\n        test_mask = torch.stack(test_masks, dim=1)\n\n        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,\n                    val_mask=val_mask, test_mask=test_mask)\n        data = data if self.pre_transform is None else self.pre_transform(data)\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/airfrans.py",
    "content": "import json\nimport os\nfrom typing import Callable, List, Optional\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass AirfRANS(InMemoryDataset):\n    r\"\"\"The AirfRANS dataset from the `\"AirfRANS: High Fidelity Computational\n    Fluid Dynamics Dataset for Approximating Reynolds-Averaged Navier-Stokes\n    Solutions\" <https://arxiv.org/abs/2212.07564>`_ paper, consisting of 1,000\n    simulations of steady-state aerodynamics over 2D airfoils in a subsonic\n    flight regime.\n    The different tasks (:obj:`\"full\"`, :obj:`\"scarce\"`, :obj:`\"reynolds\"`,\n    :obj:`\"aoa\"`) define the utilized training and test splits.\n\n    Each simulation is given as a point cloud defined as the nodes of the\n    simulation mesh. Each point of a point cloud is described via 5\n    features: the inlet velocity (two components in meter per second), the\n    distance to the airfoil (one component in meter), and the normals (two\n    components in meter, set to :obj:`0` if the point is not on the airfoil).\n    Each point is given a target of 4 components for the underlying regression\n    task: the velocity (two components in meter per second), the pressure\n    divided by the specific mass (one component in meter squared per second\n    squared), the turbulent kinematic viscosity (one component in meter squared\n    per second).\n    Finally, a boolean is attached to each point to inform if this point lies\n    on the airfoil or not.\n\n    A library for manipulating simulations of the dataset is available `here\n    <https://airfrans.readthedocs.io/en/latest/index.html>`_.\n\n    The dataset is released under the `ODbL v1.0 License\n    <https://opendatacommons.org/licenses/odbl/1-0/>`_.\n\n    .. note::\n\n        Data objects contain no edge indices to be agnostic to the simulation\n        mesh. You are free to build a graph via the\n        :obj:`torch_geometric.transforms.RadiusGraph` transform.\n\n    Args:\n        root: Root directory where the dataset should be saved.\n        task: The task to study (:obj:`\"full\"`, :obj:`\"scarce\"`,\n            :obj:`\"reynolds\"`, :obj:`\"aoa\"`) that defines the utilized training\n            and test splits.\n        train: If :obj:`True`, loads the training dataset, otherwise the test\n            dataset.\n        transform: A function/transform that takes in an\n            :class:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n        pre_transform: A function/transform that takes in an\n            :class:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk.\n        pre_filter: A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset.\n        force_reload: Whether to re-process the dataset.\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #tasks\n        * - 1,000\n          - ~180,000\n          - 0\n          - 5\n          - 4\n    \"\"\"\n    url = 'https://data.isir.upmc.fr/extrality/pytorch_geometric/AirfRANS.zip'\n    tasks = ['full', 'scarce', 'reynolds', 'aoa']\n\n    def __init__(\n        self,\n        root: str,\n        task: str,\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        if task not in self.tasks:\n            raise ValueError(f\"Expected 'task' to be in {self.tasks} \"\n                             f\"got '{task}'\")\n\n        self.task = 'full' if task == 'scarce' and not train else task\n        self.split = 'train' if train else 'test'\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['AirfRANS.pt', 'manifest.json']\n\n    @property\n    def processed_file_names(self) -> str:\n        return f'{self.task}_{self.split}.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        with open(self.raw_paths[1]) as f:\n            manifest = json.load(f)\n        total = manifest['full_train'] + manifest['full_test']\n        partial = set(manifest[f'{self.task}_{self.split}'])\n\n        data_list = []\n        raw_data = fs.torch_load(self.raw_paths[0])\n        for k, s in enumerate(total):\n            if s in partial:\n                data = Data(**raw_data[k])\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'task={self.task}, split={self.split})')\n"
  },
  {
    "path": "torch_geometric/datasets/airports.py",
    "content": "import os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.utils import coalesce\n\n\nclass Airports(InMemoryDataset):\n    r\"\"\"The Airports dataset from the `\"struc2vec: Learning Node\n    Representations from Structural Identity\"\n    <https://arxiv.org/abs/1704.03165>`_ paper, where nodes denote airports\n    and labels correspond to activity levels.\n    Features are given by one-hot encoded node identifiers, as described in the\n    `\"GraLSP: Graph Neural Networks with Local Structural Patterns\"\n    <https://arxiv.org/abs/1911.07675>`_ paper.\n\n    Args:\n        root: Root directory where the dataset should be saved.\n        name: The name of the dataset (:obj:`\"USA\"`, :obj:`\"Brazil\"`,\n            :obj:`\"Europe\"`).\n        transform: A function/transform that takes in an\n            :class:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n        pre_transform (callable, optional): A function/transform that takes in\n            :class:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk.\n        force_reload: Whether to re-process the dataset.\n    \"\"\"\n    edge_url = ('https://github.com/leoribeiro/struc2vec/'\n                'raw/master/graph/{}-airports.edgelist')\n    label_url = ('https://github.com/leoribeiro/struc2vec/'\n                 'raw/master/graph/labels-{}-airports.txt')\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in ['usa', 'brazil', 'europe']\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            f'{self.name}-airports.edgelist',\n            f'labels-{self.name}-airports.txt',\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(self.edge_url.format(self.name), self.raw_dir)\n        download_url(self.label_url.format(self.name), self.raw_dir)\n\n    def process(self) -> None:\n        index_map, ys = {}, []\n        with open(self.raw_paths[1]) as f:\n            rows = f.read().split('\\n')[1:-1]\n            for i, row in enumerate(rows):\n                idx, label = row.split()\n                index_map[int(idx)] = i\n                ys.append(int(label))\n        y = torch.tensor(ys, dtype=torch.long)\n        x = torch.eye(y.size(0))\n\n        edge_indices = []\n        with open(self.raw_paths[0]) as f:\n            rows = f.read().split('\\n')[:-1]\n            for row in rows:\n                src, dst = row.split()\n                edge_indices.append([index_map[int(src)], index_map[int(dst)]])\n        edge_index = torch.tensor(edge_indices).t().contiguous()\n        edge_index = coalesce(edge_index, num_nodes=y.size(0))\n\n        data = Data(x=x, edge_index=edge_index, y=y)\n        data = data if self.pre_transform is None else self.pre_transform(data)\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.name.capitalize()}Airports()'\n"
  },
  {
    "path": "torch_geometric/datasets/amazon.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nfrom torch_geometric.data import InMemoryDataset, download_url\nfrom torch_geometric.io import read_npz\n\n\nclass Amazon(InMemoryDataset):\n    r\"\"\"The Amazon Computers and Amazon Photo networks from the\n    `\"Pitfalls of Graph Neural Network Evaluation\"\n    <https://arxiv.org/abs/1811.05868>`_ paper.\n    Nodes represent goods and edges represent that two goods are frequently\n    bought together.\n    Given product reviews as bag-of-words node features, the task is to\n    map goods to their respective product category.\n\n    Args:\n        root: Root directory where the dataset should be saved.\n        name: The name of the dataset (:obj:`\"Computers\"`, :obj:`\"Photo\"`).\n        transform: A function/transform that takes in a\n            :class:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n        pre_transform: A function/transform that takes in an\n            :class:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk.\n        force_reload: Whether to re-process the dataset.\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - Computers\n          - 13,752\n          - 491,722\n          - 767\n          - 10\n        * - Photo\n          - 7,650\n          - 238,162\n          - 745\n          - 8\n    \"\"\"\n\n    url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/'\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in ['computers', 'photo']\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name.capitalize(), 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name.capitalize(), 'processed')\n\n    @property\n    def raw_file_names(self) -> str:\n        return f'amazon_electronics_{self.name.lower()}.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(self.url + self.raw_file_names, self.raw_dir)\n\n    def process(self) -> None:\n        data = read_npz(self.raw_paths[0], to_undirected=True)\n        data = data if self.pre_transform is None else self.pre_transform(data)\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}{self.name.capitalize()}()'\n"
  },
  {
    "path": "torch_geometric/datasets/amazon_book.py",
    "content": "from typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import HeteroData, InMemoryDataset, download_url\n\n\nclass AmazonBook(InMemoryDataset):\n    r\"\"\"A subset of the AmazonBook rating dataset from the\n    `\"LightGCN: Simplifying and Powering Graph Convolution Network for\n    Recommendation\" <https://arxiv.org/abs/2002.02126>`_ paper.\n    This is a heterogeneous dataset consisting of 52,643 users and 91,599 books\n    with approximately 2.9 million ratings between them.\n    No labels or features are provided.\n\n    Args:\n        root: Root directory where the dataset should be saved.\n        transform: A function/transform that takes in an\n            :class:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access.\n        pre_transform: A function/transform that takes in an\n            :class:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk.\n        force_reload: Whether to re-process the dataset.\n    \"\"\"\n    url = ('https://raw.githubusercontent.com/gusye1234/LightGCN-PyTorch/'\n           'master/data/amazon-book')\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['user_list.txt', 'item_list.txt', 'train.txt', 'test.txt']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for name in self.raw_file_names:\n            download_url(f'{self.url}/{name}', self.raw_dir)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        data = HeteroData()\n\n        # Process number of nodes for each node type:\n        node_types = ['user', 'book']\n        for path, node_type in zip(self.raw_paths, node_types):\n            df = pd.read_csv(path, sep=' ', header=0)\n            data[node_type].num_nodes = len(df)\n\n        # Process edge information for training and testing:\n        attr_names = ['edge_index', 'edge_label_index']\n        for path, attr_name in zip(self.raw_paths[2:], attr_names):\n            rows, cols = [], []\n            with open(path) as f:\n                lines = f.readlines()\n            for line in lines:\n                indices = line.strip().split(' ')\n                for dst in indices[1:]:\n                    rows.append(int(indices[0]))\n                    cols.append(int(dst))\n            index = torch.tensor([rows, cols])\n\n            data['user', 'rates', 'book'][attr_name] = index\n            if attr_name == 'edge_index':\n                data['book', 'rated_by', 'user'][attr_name] = index.flip([0])\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/amazon_products.py",
    "content": "import json\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_google_url\n\n\nclass AmazonProducts(InMemoryDataset):\n    r\"\"\"The Amazon dataset from the `\"GraphSAINT: Graph Sampling Based\n    Inductive Learning Method\" <https://arxiv.org/abs/1907.04931>`_ paper,\n    containing products and its categories.\n\n    Args:\n        root: Root directory where the dataset should be saved.\n        transform: A function/transform that takes in an\n            :class:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n        pre_transform: A function/transform that takes in a\n            :class:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk.\n        force_reload: Whether to re-process the dataset.\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 1,569,960\n          - 264,339,468\n          - 200\n          - 107\n    \"\"\"\n    adj_full_id = '17qhNA8H1IpbkkR-T2BmPQm8QNW5do-aa'\n    feats_id = '10SW8lCvAj-kb6ckkfTOC5y0l8XXdtMxj'\n    class_map_id = '1LIl4kimLfftj4-7NmValuWyCQE8AaE7P'\n    role_id = '1npK9xlmbnjNkV80hK2Q68wTEVOFjnt4K'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz')\n        download_google_url(self.feats_id, self.raw_dir, 'feats.npy')\n        download_google_url(self.class_map_id, self.raw_dir, 'class_map.json')\n        download_google_url(self.role_id, self.raw_dir, 'role.json')\n\n    def process(self) -> None:\n        import scipy.sparse as sp\n\n        f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))\n        adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])\n        adj = adj.tocoo()\n        row = torch.from_numpy(adj.row).to(torch.long)\n        col = torch.from_numpy(adj.col).to(torch.long)\n        edge_index = torch.stack([row, col], dim=0)\n\n        x = np.load(osp.join(self.raw_dir, 'feats.npy'))\n        x = torch.from_numpy(x).to(torch.float)\n\n        ys = [-1] * x.size(0)\n        with open(osp.join(self.raw_dir, 'class_map.json')) as f:\n            class_map = json.load(f)\n            for key, item in class_map.items():\n                ys[int(key)] = item\n        y = torch.tensor(ys)\n\n        with open(osp.join(self.raw_dir, 'role.json')) as f:\n            role = json.load(f)\n\n        train_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        train_mask[torch.tensor(role['tr'])] = True\n\n        val_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        val_mask[torch.tensor(role['va'])] = True\n\n        test_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        test_mask[torch.tensor(role['te'])] = True\n\n        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,\n                    val_mask=val_mask, test_mask=test_mask)\n\n        data = data if self.pre_transform is None else self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/aminer.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\nfrom torch_geometric.utils import coalesce\n\n\nclass AMiner(InMemoryDataset):\n    r\"\"\"The heterogeneous AMiner dataset from the `\"metapath2vec: Scalable\n    Representation Learning for Heterogeneous Networks\"\n    <https://ericdongyx.github.io/papers/\n    KDD17-dong-chawla-swami-metapath2vec.pdf>`_ paper, consisting of nodes from\n    type :obj:`\"paper\"`, :obj:`\"author\"` and :obj:`\"venue\"`.\n    Venue categories and author research interests are available as ground\n    truth labels for a subset of nodes.\n\n    Args:\n        root: Root directory where the dataset should be saved.\n        transform: A function/transform that takes in a\n            :class:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access.\n        pre_transform: A function/transform that takes in a\n            :class:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk.\n        force_reload: Whether to re-process the dataset.\n    \"\"\"\n\n    url = 'https://www.dropbox.com/s/1bnz8r7mofx0osf/net_aminer.zip?dl=1'\n    y_url = 'https://www.dropbox.com/s/nkocx16rpl4ydde/label.zip?dl=1'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'id_author.txt', 'id_conf.txt', 'paper.txt', 'paper_author.txt',\n            'paper_conf.txt', 'label'\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        fs.rm(self.raw_dir)\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        os.rename(osp.join(self.root, 'net_aminer'), self.raw_dir)\n        os.unlink(path)\n        path = download_url(self.y_url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        data = HeteroData()\n\n        # Get author labels.\n        path = osp.join(self.raw_dir, 'id_author.txt')\n        author = pd.read_csv(path, sep='\\t', names=['idx', 'name'],\n                             index_col=1)\n\n        path = osp.join(self.raw_dir, 'label',\n                        'googlescholar.8area.author.label.txt')\n        df = pd.read_csv(path, sep=' ', names=['name', 'y'])\n        df = df.join(author, on='name')\n\n        data['author'].y = torch.from_numpy(df['y'].values) - 1\n        data['author'].y_index = torch.from_numpy(df['idx'].values)\n\n        # Get venue labels.\n        path = osp.join(self.raw_dir, 'id_conf.txt')\n        venue = pd.read_csv(path, sep='\\t', names=['idx', 'name'], index_col=1)\n\n        path = osp.join(self.raw_dir, 'label',\n                        'googlescholar.8area.venue.label.txt')\n        df = pd.read_csv(path, sep=' ', names=['name', 'y'])\n        df = df.join(venue, on='name')\n\n        data['venue'].y = torch.from_numpy(df['y'].values) - 1\n        data['venue'].y_index = torch.from_numpy(df['idx'].values)\n\n        # Get paper<->author connectivity.\n        path = osp.join(self.raw_dir, 'paper_author.txt')\n        paper_author = pd.read_csv(path, sep='\\t', header=None)\n        paper_author = torch.from_numpy(paper_author.values)\n        paper_author = paper_author.t().contiguous()\n        M, N = int(paper_author[0].max() + 1), int(paper_author[1].max() + 1)\n        paper_author = coalesce(paper_author, num_nodes=max(M, N))\n        data['paper'].num_nodes = M\n        data['author'].num_nodes = N\n        data['paper', 'written_by', 'author'].edge_index = paper_author\n        data['author', 'writes', 'paper'].edge_index = paper_author.flip([0])\n\n        # Get paper<->venue connectivity.\n        path = osp.join(self.raw_dir, 'paper_conf.txt')\n        paper_venue = pd.read_csv(path, sep='\\t', header=None)\n        paper_venue = torch.from_numpy(paper_venue.values)\n        paper_venue = paper_venue.t().contiguous()\n        M, N = int(paper_venue[0].max() + 1), int(paper_venue[1].max() + 1)\n        paper_venue = coalesce(paper_venue, num_nodes=max(M, N))\n        data['venue'].num_nodes = N\n        data['paper', 'published_in', 'venue'].edge_index = paper_venue\n        data['venue', 'publishes', 'paper'].edge_index = paper_venue.flip([0])\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/aqsol.py",
    "content": "import os\nimport os.path as osp\nimport pickle\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass AQSOL(InMemoryDataset):\n    r\"\"\"The AQSOL dataset from the `Benchmarking Graph Neural Networks\n    <http://arxiv.org/abs/2003.00982>`_ paper based on\n    `AqSolDB <https://www.nature.com/articles/s41597-019-0151-1>`_, a\n    standardized database of 9,982 molecular graphs with their aqueous\n    solubility values, collected from 9 different data sources.\n\n    The aqueous solubility targets are collected from experimental measurements\n    and standardized to LogS units in AqSolDB. These final values denote the\n    property to regress in the :class:`AQSOL` dataset. After filtering out few\n    graphs with no bonds/edges, the total number of molecular graphs is 9,833.\n    For each molecular graph, the node features are the types of heavy atoms\n    and the edge features are the types of bonds between them, similar as in\n    the :class:`~torch_geometric.datasets.ZINC` dataset.\n\n    Args:\n        root: Root directory where the dataset should be saved.\n        split: If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset.\n        transform: A function/transform that takes in a\n            :class:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n        pre_transform: A function/transform that takes in a\n            :class:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk.\n        pre_filter (callable, optional): A function that takes in an\n            :class:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in\n            the final dataset.\n        force_reload: Whether to re-process the dataset.\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 9,833\n          - ~17.6\n          - ~35.8\n          - 1\n          - 1\n    \"\"\"\n    url = 'https://www.dropbox.com/s/lzu9lmukwov12kt/aqsol_graph_raw.zip?dl=1'\n\n    def __init__(\n        self,\n        root: str,\n        split: str = 'train',\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ):\n        assert split in ['train', 'val', 'test']\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = osp.join(self.processed_dir, f'{split}.pt')\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'train.pickle', 'val.pickle', 'test.pickle', 'atom_dict.pickle',\n            'bond_dict.pickle'\n        ]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train.pt', 'val.pt', 'test.pt']\n\n    def download(self) -> None:\n        fs.rm(self.raw_dir)\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        os.rename(osp.join(self.root, 'asqol_graph_raw'), self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        for raw_path, path in zip(self.raw_paths, self.processed_paths):\n            with open(raw_path, 'rb') as f:\n                graphs = pickle.load(f)\n\n            data_list: List[Data] = []\n            for graph in graphs:\n                x, edge_attr, edge_index, y = graph\n\n                x = torch.from_numpy(x)\n                edge_attr = torch.from_numpy(edge_attr)\n                edge_index = torch.from_numpy(edge_index)\n                y = torch.tensor([y]).float()\n\n                if edge_index.numel() == 0:\n                    continue  # Skipping for graphs with no bonds/edges.\n\n                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,\n                            y=y)\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                data_list.append(data)\n\n            self.save(data_list, path)\n\n    def atoms(self) -> List[str]:\n        return [\n            'Br', 'C', 'N', 'O', 'Cl', 'Zn', 'F', 'P', 'S', 'Na', 'Al', 'Si',\n            'Mo', 'Ca', 'W', 'Pb', 'B', 'V', 'Co', 'Mg', 'Bi', 'Fe', 'Ba', 'K',\n            'Ti', 'Sn', 'Cd', 'I', 'Re', 'Sr', 'H', 'Cu', 'Ni', 'Lu', 'Pr',\n            'Te', 'Ce', 'Nd', 'Gd', 'Zr', 'Mn', 'As', 'Hg', 'Sb', 'Cr', 'Se',\n            'La', 'Dy', 'Y', 'Pd', 'Ag', 'In', 'Li', 'Rh', 'Nb', 'Hf', 'Cs',\n            'Ru', 'Au', 'Sm', 'Ta', 'Pt', 'Ir', 'Be', 'Ge'\n        ]\n\n    def bonds(self) -> List[str]:\n        return ['NONE', 'SINGLE', 'DOUBLE', 'AROMATIC', 'TRIPLE']\n"
  },
  {
    "path": "torch_geometric/datasets/attributed_graph_dataset.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_google_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass AttributedGraphDataset(InMemoryDataset):\n    r\"\"\"A variety of attributed graph datasets from the\n    `\"Scaling Attributed Network Embedding to Massive Graphs\"\n    <https://arxiv.org/abs/2009.00826>`_ paper.\n\n    Args:\n        root: Root directory where the dataset should be saved.\n        name: The name of the dataset (:obj:`\"Wiki\"`, :obj:`\"Cora\"`,\n            :obj:`\"CiteSeer\"`, :obj:`\"PubMed\"`, :obj:`\"BlogCatalog\"`,\n            :obj:`\"PPI\"`, :obj:`\"Flickr\"`, :obj:`\"Facebook\"`, :obj:`\"Twitter\"`,\n            :obj:`\"TWeibo\"`, :obj:`\"MAG\"`).\n        transform: A function/transform that takes in a\n            :class:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n        pre_transform: A function/transform that takes in a\n            :class:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk.\n        force_reload: Whether to re-process the dataset.\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - Wiki\n          - 2,405\n          - 17,981\n          - 4,973\n          - 17\n        * - Cora\n          - 2,708\n          - 5,429\n          - 1,433\n          - 7\n        * - CiteSeer\n          - 3,312\n          - 4,715\n          - 3,703\n          - 6\n        * - PubMed\n          - 19,717\n          - 44,338\n          - 500\n          - 3\n        * - BlogCatalog\n          - 5,196\n          - 343,486\n          - 8,189\n          - 6\n        * - PPI\n          - 56,944\n          - 1,612,348\n          - 50\n          - 121\n        * - Flickr\n          - 7,575\n          - 479,476\n          - 12,047\n          - 9\n        * - Facebook\n          - 4,039\n          - 88,234\n          - 1,283\n          - 193\n        * - TWeibo\n          - 2,320,895\n          - 9,840,066\n          - 1,657\n          - 8\n        * - MAG\n          - 59,249,719\n          - 978,147,253\n          - 2,000\n          - 100\n    \"\"\"\n    datasets = {\n        'wiki': '1EPhlbziZTQv19OsTrKrAJwsElbVPEbiV',\n        'cora': '1FyVnpdsTT-lhkVPotUW8OVeuCi1vi3Ey',\n        'citeseer': '1d3uQIpHiemWJPgLgTafi70RFYye7hoCp',\n        'pubmed': '1DOK3FfslyJoGXUSCSrK5lzdyLfIwOz6k',\n        'blogcatalog': '178PqGqh67RUYMMP6-SoRHDoIBh8ku5FS',\n        'ppi': '1dvwRpPT4gGtOcNP_Q-G1TKl9NezYhtez',\n        'flickr': '1tZp3EB20fAC27SYWwa-x66_8uGsuU62X',\n        'facebook': '12aJWAGCM4IvdGI2fiydDNyWzViEOLZH8',\n        'twitter': '1fUYggzZlDrt9JsLsSdRUHiEzQRW1kSA4',\n        'tweibo': '1-2xHDPFCsuBuFdQN_7GLleWa8R_t50qU',\n        'mag': '1ggraUMrQgdUyA3DjSRzzqMv0jFkU65V5',\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in self.datasets.keys()\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['attrs.npz', 'edgelist.txt', 'labels.txt']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        id = self.datasets[self.name]\n        path = download_google_url(id, self.raw_dir, 'data.zip')\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n        path = osp.join(self.raw_dir, f'{self.name}.attr')\n        if self.name == 'mag':\n            path = osp.join(self.raw_dir, self.name)\n        for name in self.raw_file_names:\n            os.rename(osp.join(path, name), osp.join(self.raw_dir, name))\n        fs.rm(path)\n\n    def process(self) -> None:\n        import pandas as pd\n        import scipy.sparse as sp\n\n        x = sp.load_npz(self.raw_paths[0]).tocsr()\n        if x.shape[-1] > 10000 or self.name == 'mag':\n            x = torch.sparse_csr_tensor(\n                crow_indices=x.indptr,\n                col_indices=x.indices,\n                values=x.data,\n                size=x.shape,\n            )\n        else:\n            x = torch.from_numpy(x.todense()).to(torch.float)\n\n        df = pd.read_csv(self.raw_paths[1], header=None, sep=None,\n                         engine='python')\n        edge_index = torch.from_numpy(df.values).t().contiguous()\n\n        with open(self.raw_paths[2]) as f:\n            rows = f.read().split('\\n')[:-1]\n            ys = [[int(y) - 1 for y in row.split()[1:]] for row in rows]\n            multilabel = max([len(y) for y in ys]) > 1\n\n        if not multilabel:\n            y = torch.tensor(ys).view(-1)\n        else:\n            num_classes = max([y for row in ys for y in row]) + 1\n            y = torch.zeros((len(ys), num_classes), dtype=torch.float)\n            for i, row in enumerate(ys):\n                for j in row:\n                    y[i, j] = 1.\n\n        data = Data(x=x, edge_index=edge_index, y=y)\n        data = data if self.pre_transform is None else self.pre_transform(data)\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.name.capitalize()}()'\n"
  },
  {
    "path": "torch_geometric/datasets/ba2motif_dataset.py",
    "content": "import pickle\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass BA2MotifDataset(InMemoryDataset):\n    r\"\"\"The synthetic BA-2motifs graph classification dataset for evaluating\n    explainabilty algorithms, as described in the `\"Parameterized Explainer\n    for Graph Neural Network\" <https://arxiv.org/abs/2011.04573>`_ paper.\n    :class:`~torch_geometric.datasets.BA2MotifDataset` contains 1000 random\n    Barabasi-Albert (BA) graphs.\n    Half of the graphs are attached with a\n    :class:`~torch_geometric.datasets.motif_generator.HouseMotif`, and the rest\n    are attached with a five-node\n    :class:`~torch_geometric.datasets.motif_generator.CycleMotif`.\n    The graphs are assigned to one of the two classes according to the type of\n    attached motifs.\n\n    This dataset is pre-computed from the official implementation. If you want\n    to create own variations of it, you can make use of the\n    :class:`~torch_geometric.datasets.ExplainerDataset`:\n\n    .. code-block:: python\n\n        import torch\n        from torch_geometric.datasets import ExplainerDataset\n        from torch_geometric.datasets.graph_generator import BAGraph\n        from torch_geometric.datasets.motif_generator import HouseMotif\n        from torch_geometric.datasets.motif_generator import CycleMotif\n\n        dataset1 = ExplainerDataset(\n            graph_generator=BAGraph(num_nodes=25, num_edges=1),\n            motif_generator=HouseMotif(),\n            num_motifs=1,\n            num_graphs=500,\n        )\n\n        dataset2 = ExplainerDataset(\n            graph_generator=BAGraph(num_nodes=25, num_edges=1),\n            motif_generator=CycleMotif(5),\n            num_motifs=1,\n            num_graphs=500,\n        )\n\n        dataset = torch.utils.data.ConcatDataset([dataset1, dataset2])\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 1000\n          - 25\n          - ~51.0\n          - 10\n          - 2\n    \"\"\"\n    url = 'https://github.com/flyingdoog/PGExplainer/raw/master/dataset'\n    filename = 'BA-2motif.pkl'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> str:\n        return self.filename\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(f'{self.url}/{self.filename}', self.raw_dir)\n\n    def process(self) -> None:\n        with open(self.raw_paths[0], 'rb') as f:\n            adj, x, y = pickle.load(f)\n\n        adjs = torch.from_numpy(adj)\n        xs = torch.from_numpy(x).to(torch.float)\n        ys = torch.from_numpy(y)\n\n        data_list: List[Data] = []\n        for i in range(xs.size(0)):\n            edge_index = adjs[i].nonzero().t()\n            x = xs[i]\n            y = int(ys[i].nonzero())\n\n            data = Data(x=x, edge_index=edge_index, y=y)\n\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n\n            data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/ba_multi_shapes.py",
    "content": "import pickle\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass BAMultiShapesDataset(InMemoryDataset):\n    r\"\"\"The synthetic BA-Multi-Shapes graph classification dataset for\n    evaluating explainabilty algorithms, as described in the\n    `\"Global Explainability of GNNs via Logic Combination of Learned Concepts\"\n    <https://arxiv.org/abs/2210.07147>`_ paper.\n\n    Given three atomic motifs, namely House (H), Wheel (W), and Grid (G),\n    :class:`~torch_geometric.datasets.BAMultiShapesDataset` contains 1,000\n    graphs where each graph is obtained by attaching the motifs to a random\n    Barabasi-Albert (BA) as follows:\n\n    * class 0: :math:`\\emptyset \\lor H \\lor W \\lor G \\lor \\{ H, W, G \\}`\n\n    * class 1: :math:`(H \\land W) \\lor (H \\land G) \\lor (W \\land G)`\n\n    This dataset is pre-computed from the official implementation.\n\n    Args:\n        root: Root directory where the dataset should be saved.\n        transform: A function/transform that takes in a\n            :class:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n        pre_transform: A function/transform that takes in a\n            :class:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk.\n        pre_filter: A function that takes in a\n            :class:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset.\n        force_reload: Whether to re-process the dataset.\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 1000\n          - 40\n          - ~87.0\n          - 10\n          - 2\n    \"\"\"\n    url = ('https://github.com/steveazzolin/gnn_logic_global_expl/raw/master/'\n           'datasets/BAMultiShapes/BAMultiShapes.pkl')\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'BAMultiShapes.pkl'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(self.url, self.raw_dir)\n\n    def process(self) -> None:\n        with open(self.raw_paths[0], 'rb') as f:\n            adjs, xs, ys = pickle.load(f)\n\n        data_list: List[Data] = []\n        for adj, x, y in zip(adjs, xs, ys):\n            edge_index = torch.from_numpy(adj).nonzero().t()\n            x = torch.from_numpy(np.array(x)).to(torch.float)\n\n            data = Data(x=x, edge_index=edge_index, y=y)\n\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n\n            data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/ba_shapes.py",
    "content": "from typing import Callable, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.deprecation import deprecated\nfrom torch_geometric.utils import barabasi_albert_graph\n\n\ndef house() -> Tuple[Tensor, Tensor]:\n    edge_index = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4],\n                               [1, 3, 4, 4, 2, 0, 1, 3, 2, 0, 0, 1]])\n    label = torch.tensor([1, 1, 2, 2, 3])\n    return edge_index, label\n\n\n@deprecated(\"use 'datasets.ExplainerDataset' in combination with \"\n            \"'datasets.graph_generator.BAGraph' instead\")\nclass BAShapes(InMemoryDataset):\n    r\"\"\"The BA-Shapes dataset from the `\"GNNExplainer: Generating Explanations\n    for Graph Neural Networks\" <https://arxiv.org/abs/1903.03894>`__ paper,\n    containing a Barabasi-Albert (BA) graph with 300 nodes and a set of 80\n    \"house\"-structured graphs connected to it.\n\n    .. warning::\n\n        :class:`BAShapes` is deprecated and will be removed in a future\n        release. Use :class:`ExplainerDataset` in combination with\n        :class:`torch_geometric.datasets.graph_generator.BAGraph` instead.\n\n    Args:\n        connection_distribution: Specifies how the houses and the BA graph get\n            connected. Valid inputs are :obj:`\"random\"`\n            (random BA graph nodes are selected for connection to the houses),\n            and :obj:`\"uniform\"` (uniformly distributed BA graph nodes are\n            selected for connection to the houses).\n        transform: A function/transform that takes in a\n            :class:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n    \"\"\"\n    def __init__(\n        self,\n        connection_distribution: str = \"random\",\n        transform: Optional[Callable] = None,\n    ) -> None:\n        super().__init__(None, transform)\n        assert connection_distribution in ['random', 'uniform']\n\n        # Build the Barabasi-Albert graph:\n        num_nodes = 300\n        edge_index = barabasi_albert_graph(num_nodes, num_edges=5)\n        edge_label = torch.zeros(edge_index.size(1), dtype=torch.int64)\n        node_label = torch.zeros(num_nodes, dtype=torch.int64)\n\n        # Select nodes to connect shapes:\n        num_houses = 80\n        if connection_distribution == 'random':\n            connecting_nodes = torch.randperm(num_nodes)[:num_houses]\n        else:\n            step = num_nodes // num_houses\n            connecting_nodes = torch.arange(0, num_nodes, step)\n\n        # Connect houses to Barabasi-Albert graph:\n        edge_indices = [edge_index]\n        edge_labels = [edge_label]\n        node_labels = [node_label]\n        for i in range(num_houses):\n            house_edge_index, house_label = house()\n\n            edge_indices.append(house_edge_index + num_nodes)\n            edge_indices.append(\n                torch.tensor([[int(connecting_nodes[i]), num_nodes],\n                              [num_nodes, int(connecting_nodes[i])]]))\n\n            edge_labels.append(\n                torch.ones(house_edge_index.size(1), dtype=torch.long))\n            edge_labels.append(torch.zeros(2, dtype=torch.long))\n\n            node_labels.append(house_label)\n\n            num_nodes += 5\n\n        edge_index = torch.cat(edge_indices, dim=1)\n        edge_label = torch.cat(edge_labels, dim=0)\n        node_label = torch.cat(node_labels, dim=0)\n\n        x = torch.ones((num_nodes, 10), dtype=torch.float)\n        expl_mask = torch.zeros(num_nodes, dtype=torch.bool)\n        expl_mask[torch.arange(400, num_nodes, 5)] = True\n\n        data = Data(x=x, edge_index=edge_index, y=node_label,\n                    expl_mask=expl_mask, edge_label=edge_label)\n\n        self.data, self.slices = self.collate([data])\n"
  },
  {
    "path": "torch_geometric/datasets/bitcoin_otc.py",
    "content": "import datetime\nimport os\nfrom typing import Callable, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_gz,\n)\n\n\nclass BitcoinOTC(InMemoryDataset):\n    r\"\"\"The Bitcoin-OTC dataset from the `\"EvolveGCN: Evolving Graph\n    Convolutional Networks for Dynamic Graphs\"\n    <https://arxiv.org/abs/1902.10191>`_ paper, consisting of 138\n    who-trusts-whom networks of sequential time steps.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        edge_window_size (int, optional): The window size for the existence of\n            an edge in the graph sequence since its initial creation.\n            (default: :obj:`10`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 138\n          - 6,005\n          - ~2,573.2\n          - 0\n          - 0\n    \"\"\"\n\n    url = 'https://snap.stanford.edu/data/soc-sign-bitcoinotc.csv.gz'\n\n    def __init__(\n        self,\n        root: str,\n        edge_window_size: int = 10,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.edge_window_size = edge_window_size\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'soc-sign-bitcoinotc.csv'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    @property\n    def num_nodes(self) -> int:\n        assert isinstance(self._data, Data)\n        assert self._data.edge_index is not None\n        return int(self._data.edge_index.max()) + 1\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_gz(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        with open(self.raw_paths[0]) as f:\n            lines = [[x for x in line.split(',')]\n                     for line in f.read().split('\\n')[:-1]]\n\n            edge_indices = [[int(line[0]), int(line[1])] for line in lines]\n            edge_index = torch.tensor(edge_indices, dtype=torch.long)\n            edge_index = edge_index - edge_index.min()\n            edge_index = edge_index.t().contiguous()\n            num_nodes = int(edge_index.max()) + 1\n\n            edge_attrs = [int(line[2]) for line in lines]\n            edge_attr = torch.tensor(edge_attrs, dtype=torch.long)\n\n            stamps = [\n                datetime.datetime.fromtimestamp(int(float(line[3])))\n                for line in lines\n            ]\n\n        offset = datetime.timedelta(days=13.8)  # Results in 138 time steps.\n        graph_indices, factor = [], 1\n        for t in stamps:\n            factor = factor if t < stamps[0] + factor * offset else factor + 1\n            graph_indices.append(factor - 1)\n        graph_idx = torch.tensor(graph_indices, dtype=torch.long)\n\n        data_list = []\n        for i in range(int(graph_idx.max()) + 1):\n            mask = (graph_idx > (i - self.edge_window_size)) & (graph_idx <= i)\n            data = Data()\n            data.edge_index = edge_index[:, mask]\n            data.edge_attr = edge_attr[mask]\n            data.num_nodes = num_nodes\n            data_list.append(data)\n\n        if self.pre_filter is not None:\n            data_list = [d for d in data_list if self.pre_filter(d)]\n\n        if self.pre_transform is not None:\n            data_list = [self.pre_transform(d) for d in data_list]\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/brca_tgca.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass BrcaTcga(InMemoryDataset):\n    r\"\"\"The breast cancer (BRCA TCGA Pan-Cancer Atlas) dataset consisting of\n    patients with survival information and gene expression data from\n    `cBioPortal <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4160307/>`_\n    and a network of biological interactions between those nodes from\n    `Pathway Commons <https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7145667/>`_.\n    The dataset contains the gene features of 1,082 patients, and the overall\n    survival time (in months) of each patient as label.\n\n    Pre-processing and example model codes on how to use this dataset can be\n    found `here <https://github.com/cannin/pyg_pathway_commons_cbioportal>`_.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n        * - 1,082\n          - 9,288\n          - 271,771\n          - 1,082\n    \"\"\"\n    url = 'https://zenodo.org/record/8251328/files/brca_tcga.zip?download=1'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['graph_idx.csv', 'graph_labels.csv', 'edge_index.pt']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        os.unlink(path)\n        fs.rm(self.raw_dir)\n        os.rename(osp.join(self.root, 'brca_tcga'), self.raw_dir)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        graph_feat = pd.read_csv(self.raw_paths[0], index_col=0).values\n        graph_feat = torch.from_numpy(graph_feat).to(torch.float)\n        graph_labels = np.loadtxt(self.raw_paths[1], delimiter=',')\n        graph_label = torch.from_numpy(graph_labels).to(torch.float)\n        edge_index = fs.torch_load(self.raw_paths[2])\n\n        data_list = []\n        for x, y in zip(graph_feat, graph_label):\n            data = Data(x=x.view(-1, 1), edge_index=edge_index, y=y)\n\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n\n            data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/citation_full.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nfrom torch_geometric.data import InMemoryDataset, download_url\nfrom torch_geometric.io import read_npz\n\n\nclass CitationFull(InMemoryDataset):\n    r\"\"\"The full citation network datasets from the\n    `\"Deep Gaussian Embedding of Graphs: Unsupervised Inductive Learning via\n    Ranking\" <https://arxiv.org/abs/1707.03815>`_ paper.\n    Nodes represent documents and edges represent citation links.\n    Datasets include :obj:`\"Cora\"`, :obj:`\"Cora_ML\"`, :obj:`\"CiteSeer\"`,\n    :obj:`\"DBLP\"`, :obj:`\"PubMed\"`.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"Cora\"`, :obj:`\"Cora_ML\"`\n            :obj:`\"CiteSeer\"`, :obj:`\"DBLP\"`, :obj:`\"PubMed\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        to_undirected (bool, optional): Whether the original graph is\n            converted to an undirected one. (default: :obj:`True`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - Cora\n          - 19,793\n          - 126,842\n          - 8,710\n          - 70\n        * - Cora_ML\n          - 2,995\n          - 16,316\n          - 2,879\n          - 7\n        * - CiteSeer\n          - 4,230\n          - 10,674\n          - 602\n          - 6\n        * - DBLP\n          - 17,716\n          - 105,734\n          - 1,639\n          - 4\n        * - PubMed\n          - 19,717\n          - 88,648\n          - 500\n          - 3\n    \"\"\"\n\n    url = 'https://github.com/abojchevski/graph2gauss/raw/master/data/{}.npz'\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        to_undirected: bool = True,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        self.to_undirected = to_undirected\n        assert self.name in ['cora', 'cora_ml', 'citeseer', 'dblp', 'pubmed']\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> str:\n        return f'{self.name}.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        suffix = 'undirected' if self.to_undirected else 'directed'\n        return f'data_{suffix}.pt'\n\n    def download(self) -> None:\n        download_url(self.url.format(self.name), self.raw_dir)\n\n    def process(self) -> None:\n        data = read_npz(self.raw_paths[0], to_undirected=self.to_undirected)\n        data = data if self.pre_transform is None else self.pre_transform(data)\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.name.capitalize()}Full()'\n\n\nclass CoraFull(CitationFull):\n    r\"\"\"Alias for :class:`~torch_geometric.datasets.CitationFull` with\n    :obj:`name=\"Cora\"`.\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 19,793\n          - 126,842\n          - 8,710\n          - 70\n    \"\"\"\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n    ) -> None:\n        super().__init__(root, 'cora', transform, pre_transform)\n\n    def download(self) -> None:\n        super().download()\n\n    def process(self) -> None:\n        super().process()\n"
  },
  {
    "path": "torch_geometric/datasets/city.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_tar,\n)\nfrom torch_geometric.io import fs\n\n\nclass CityNetwork(InMemoryDataset):\n    r\"\"\"The City-Networks are introduced in\n    `\"Towards Quantifying Long-Range Interactions in Graph Machine Learning:\n    a Large Graph Dataset and a Measurement\"\n    <https://arxiv.org/abs/2503.09008>`_ paper.\n    The dataset contains four city networks: `paris`, `shanghai`, `la`,\n    and `london`, where nodes represent junctions and edges represent\n    undirected road segments. The task is to predict each node's eccentricity\n    score, which is approximated based on its 16-hop neighborhood and naturally\n    requires long-range information. The score indicates how accessible one\n    node is in the network, and is mapped to 10 quantiles for transductive\n    classification. See the original\n    `source code <https://github.com/LeonResearch/City-Networks>`_ for more\n    details on the individual networks.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (``\"paris\"``, ``\"shanghai\"``,\n            ``\"la\"``, ``\"london\"``).\n        augmented (bool, optional): Whether to use the augmented node features\n            from edge features.(default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :class:`~torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :class:`~torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - paris\n          - 114,127\n          - 182,511\n          - 37\n          - 10\n        * - shanghai\n          - 183,917\n          - 262,092\n          - 37\n          - 10\n        * - la\n          - 240,587\n          - 341,523\n          - 37\n          - 10\n        * - london\n          - 568,795\n          - 756,502\n          - 37\n          - 10\n    \"\"\"\n    url = \"https://github.com/LeonResearch/City-Networks/raw/refs/heads/main/data/\"  # noqa: E501\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        augmented: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n        delete_raw: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in [\"paris\", \"shanghai\", \"la\", \"london\"]\n        self.augmented = augmented\n        self.delete_raw = delete_raw\n        super().__init__(\n            root,\n            transform,\n            pre_transform,\n            force_reload=force_reload,\n        )\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, \"raw\")\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, \"processed\")\n\n    @property\n    def raw_file_names(self) -> str:\n        return f\"{self.name}.json\"\n\n    @property\n    def processed_file_names(self) -> str:\n        return \"data.pt\"\n\n    def download(self) -> None:\n        self.download_path = download_url(\n            self.url + f\"{self.name}.tar.gz\",\n            self.raw_dir,\n        )\n\n    def process(self) -> None:\n        extract_tar(self.download_path, self.raw_dir)\n        data_path = osp.join(self.raw_dir, self.name)\n        node_feat = fs.torch_load(\n            osp.join(\n                data_path,\n                f\"node_features{'_augmented' if self.augmented else ''}.pt\",\n            ))\n        edge_index = fs.torch_load(osp.join(data_path, \"edge_indices.pt\"))\n        label = fs.torch_load(\n            osp.join(data_path, \"10-chunk_16-hop_node_labels.pt\"))\n        train_mask = fs.torch_load(osp.join(data_path, \"train_mask.pt\"))\n        val_mask = fs.torch_load(osp.join(data_path, \"valid_mask.pt\"))\n        test_mask = fs.torch_load(osp.join(data_path, \"test_mask.pt\"))\n        data = Data(\n            x=node_feat,\n            edge_index=edge_index,\n            y=label,\n            train_mask=train_mask,\n            val_mask=val_mask,\n            test_mask=test_mask,\n        )\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n        if self.delete_raw:\n            fs.rm(data_path)\n\n    def __repr__(self) -> str:\n        return (f\"{self.__class__.__name__}(\"\n                f\"root='{self.root}', \"\n                f\"name='{self.name}', \"\n                f\"augmented={self.augmented})\")\n"
  },
  {
    "path": "torch_geometric/datasets/coauthor.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nfrom torch_geometric.data import InMemoryDataset, download_url\nfrom torch_geometric.io import read_npz\n\n\nclass Coauthor(InMemoryDataset):\n    r\"\"\"The Coauthor CS and Coauthor Physics networks from the\n    `\"Pitfalls of Graph Neural Network Evaluation\"\n    <https://arxiv.org/abs/1811.05868>`_ paper.\n    Nodes represent authors that are connected by an edge if they co-authored a\n    paper.\n    Given paper keywords for each author's papers, the task is to map authors\n    to their respective field of study.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"CS\"`, :obj:`\"Physics\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - CS\n          - 18,333\n          - 163,788\n          - 6,805\n          - 15\n        * - Physics\n          - 34,493\n          - 495,924\n          - 8,415\n          - 5\n    \"\"\"\n\n    url = 'https://github.com/shchur/gnn-benchmark/raw/master/data/npz/'\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        assert name.lower() in ['cs', 'physics']\n        self.name = 'CS' if name.lower() == 'cs' else 'Physics'\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> str:\n        return f'ms_academic_{self.name[:3].lower()}.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(self.url + self.raw_file_names, self.raw_dir)\n\n    def process(self) -> None:\n        data = read_npz(self.raw_paths[0], to_undirected=True)\n        data = data if self.pre_transform is None else self.pre_transform(data)\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}{self.name}()'\n"
  },
  {
    "path": "torch_geometric/datasets/coma.py",
    "content": "import os.path as osp\nfrom glob import glob\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import InMemoryDataset, extract_zip\nfrom torch_geometric.io import read_ply\n\n\nclass CoMA(InMemoryDataset):\n    r\"\"\"The CoMA 3D faces dataset from the `\"Generating 3D faces using\n    Convolutional Mesh Autoencoders\" <https://arxiv.org/abs/1807.10267>`_\n    paper, containing 20,466 meshes of extreme expressions captured over 12\n    different subjects.\n\n    .. note::\n\n        Data objects hold mesh faces instead of edge indices.\n        To convert the mesh to a graph, use the\n        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.\n        To convert the mesh to a point cloud, use the\n        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to\n        sample a fixed number of points on the mesh faces according to their\n        face area.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 20,465\n          - 5,023\n          - 29,990\n          - 3\n          - 12\n    \"\"\"\n\n    url = 'https://coma.is.tue.mpg.de/'\n\n    categories = [\n        'bareteeth',\n        'cheeks_in',\n        'eyebrow',\n        'high_smile',\n        'lips_back',\n        'lips_up',\n        'mouth_down',\n        'mouth_extreme',\n        'mouth_middle',\n        'mouth_open',\n        'mouth_side',\n        'mouth_up',\n    ]\n\n    def __init__(\n        self,\n        root: str,\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = self.processed_paths[0] if train else self.processed_paths[1]\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'COMA_data.zip'\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['training.pt', 'test.pt']\n\n    def download(self) -> None:\n        raise RuntimeError(\n            f\"Dataset not found. Please download 'COMA_data.zip' from \"\n            f\"'{self.url}' and move it to '{self.raw_dir}'\")\n\n    def process(self) -> None:\n        folders = sorted(glob(osp.join(self.raw_dir, 'FaceTalk_*')))\n        if len(folders) == 0:\n            extract_zip(self.raw_paths[0], self.raw_dir, log=False)\n            folders = sorted(glob(osp.join(self.raw_dir, 'FaceTalk_*')))\n\n        train_data_list, test_data_list = [], []\n        for folder in folders:\n            for i, category in enumerate(self.categories):\n                files = sorted(glob(osp.join(folder, category, '*.ply')))\n                for j, f in enumerate(files):\n                    data = read_ply(f)\n                    data.y = torch.tensor([i], dtype=torch.long)\n                    if self.pre_filter is not None and\\\n                       not self.pre_filter(data):\n                        continue\n                    if self.pre_transform is not None:\n                        data = self.pre_transform(data)\n\n                    if (j % 100) < 90:\n                        train_data_list.append(data)\n                    else:\n                        test_data_list.append(data)\n\n        self.save(train_data_list, self.processed_paths[0])\n        self.save(test_data_list, self.processed_paths[1])\n"
  },
  {
    "path": "torch_geometric/datasets/cornell.py",
    "content": "import os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import InMemoryDataset, download_url\nfrom torch_geometric.data.hypergraph_data import HyperGraphData\n\n\nclass CornellTemporalHyperGraphDataset(InMemoryDataset):\n    r\"\"\"A collection of temporal higher-order network datasets from the\n    `\"Simplicial Closure and higher-order link prediction\"\n    <https://arxiv.org/abs/1802.06916>`_ paper.\n    Each of the datasets is a timestamped sequence of simplices, where a\n    simplex is a set of :math:`k` nodes.\n\n    See the original `datasets page\n    <https://www.cs.cornell.edu/~arb/data/>`_ for more details about\n    individual datasets.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset.\n            (default: :obj:`\"train\"`)\n        setting (str, optional): If :obj:`\"transductive\"`, loads the dataset\n            for transductive training.\n            If :obj:`\"inductive\"`, loads the dataset for inductive training.\n            (default: :obj:`\"transductive\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    names = [\n        'email-Eu',\n        'email-Enron',\n        'NDC-classes',\n        'tags-math-sx',\n        'email-Eu-25',\n        'NDC-substances',\n        'congress-bills',\n        'tags-ask-ubuntu',\n        'email-Enron-25',\n        'NDC-classes-25',\n        'threads-ask-ubuntu',\n        'contact-high-school',\n        'NDC-substances-25',\n        'congress-bills-25',\n        'contact-primary-school',\n    ]\n    url = ('https://huggingface.co/datasets/SauravMaheshkar/{}/raw/main/'\n           'processed/{}/{}')\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        split: str = 'train',\n        setting: str = 'transductive',\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        assert name in self.names\n        assert setting in ['transductive', 'inductive']\n\n        self.name = name\n        self.setting = setting\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload)\n\n        if split == 'train':\n            path = self.processed_paths[0]\n        elif split == 'val':\n            path = self.processed_paths[1]\n        elif split == 'test':\n            path = self.processed_paths[2]\n        else:\n            raise ValueError(f\"Split '{split}' not found\")\n\n        self.load(path)\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, self.setting, 'raw')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['train_df.csv', 'val_df.csv', 'test_df.csv']\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, self.setting, 'processed')\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train_data.pt', 'val_data.pt', 'test_data.pt']\n\n    def download(self) -> None:\n        for filename in self.raw_file_names:\n            url = self.url.format(self.name, self.setting, filename)\n            download_url(url, self.raw_dir)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        for raw_path, path in zip(self.raw_paths, self.processed_paths):\n            df = pd.read_csv(raw_path)\n\n            data_list = []\n            for i, row in df.iterrows():\n                edge_indices: List[List[int]] = [[], []]\n                for node in eval(row['nodes']):  # str(list) -> list:\n                    edge_indices[0].append(node)\n                    edge_indices[1].append(i)  # Use `i` as hyper-edge index.\n\n                x = torch.tensor([[row['timestamp']]], dtype=torch.float)\n                edge_index = torch.tensor(edge_indices)\n\n                data = HyperGraphData(x=x, edge_index=edge_index)\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                data_list.append(data)\n\n            self.save(data_list, path)\n"
  },
  {
    "path": "torch_geometric/datasets/dblp.py",
    "content": "import os\nimport os.path as osp\nfrom itertools import product\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\n\n\nclass DBLP(InMemoryDataset):\n    r\"\"\"A subset of the DBLP computer science bibliography website, as\n    collected in the `\"MAGNN: Metapath Aggregated Graph Neural Network for\n    Heterogeneous Graph Embedding\" <https://arxiv.org/abs/2002.01680>`_ paper.\n    DBLP is a heterogeneous graph containing four types of entities - authors\n    (4,057 nodes), papers (14,328 nodes), terms (7,723 nodes), and conferences\n    (20 nodes).\n    The authors are divided into four research areas (database, data mining,\n    artificial intelligence, information retrieval).\n    Each author is described by a bag-of-words representation of their paper\n    keywords.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 20 10 10 10\n        :header-rows: 1\n\n        * - Node/Edge Type\n          - #nodes/#edges\n          - #features\n          - #classes\n        * - Author\n          - 4,057\n          - 334\n          - 4\n        * - Paper\n          - 14,328\n          - 4,231\n          -\n        * - Term\n          - 7,723\n          - 50\n          -\n        * - Conference\n          - 20\n          - 0\n          -\n        * - Author-Paper\n          - 196,425\n          -\n          -\n        * - Paper-Term\n          - 85,810\n          -\n          -\n        * - Conference-Paper\n          - 14,328\n          -\n          -\n    \"\"\"\n\n    url = 'https://www.dropbox.com/s/yh4grpeks87ugr2/DBLP_processed.zip?dl=1'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'adjM.npz', 'features_0.npz', 'features_1.npz', 'features_2.npy',\n            'labels.npy', 'node_types.npy', 'train_val_test_idx.npz'\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.remove(path)\n\n    def process(self) -> None:\n        import scipy.sparse as sp\n\n        data = HeteroData()\n\n        node_types = ['author', 'paper', 'term', 'conference']\n        for i, node_type in enumerate(node_types[:2]):\n            x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz'))\n            data[node_type].x = torch.from_numpy(x.todense()).to(torch.float)\n\n        x = np.load(osp.join(self.raw_dir, 'features_2.npy'))\n        data['term'].x = torch.from_numpy(x).to(torch.float)\n\n        node_type_idx = np.load(osp.join(self.raw_dir, 'node_types.npy'))\n        node_type_idx = torch.from_numpy(node_type_idx).to(torch.long)\n        data['conference'].num_nodes = int((node_type_idx == 3).sum())\n\n        y = np.load(osp.join(self.raw_dir, 'labels.npy'))\n        data['author'].y = torch.from_numpy(y).to(torch.long)\n\n        split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz'))\n        for name in ['train', 'val', 'test']:\n            idx = split[f'{name}_idx']\n            idx = torch.from_numpy(idx).to(torch.long)\n            mask = torch.zeros(data['author'].num_nodes, dtype=torch.bool)\n            mask[idx] = True\n            data['author'][f'{name}_mask'] = mask\n\n        s = {}\n        N_a = data['author'].num_nodes\n        N_p = data['paper'].num_nodes\n        N_t = data['term'].num_nodes\n        N_c = data['conference'].num_nodes\n        s['author'] = (0, N_a)\n        s['paper'] = (N_a, N_a + N_p)\n        s['term'] = (N_a + N_p, N_a + N_p + N_t)\n        s['conference'] = (N_a + N_p + N_t, N_a + N_p + N_t + N_c)\n\n        A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz'))\n        for src, dst in product(node_types, node_types):\n            A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo()\n            if A_sub.nnz > 0:\n                row = torch.from_numpy(A_sub.row).to(torch.long)\n                col = torch.from_numpy(A_sub.col).to(torch.long)\n                data[src, dst].edge_index = torch.stack([row, col], dim=0)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/datasets/dbp15k.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, Dict, List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_google_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs, read_txt_array\nfrom torch_geometric.utils import sort_edge_index\n\n\nclass DBP15K(InMemoryDataset):\n    r\"\"\"The DBP15K dataset from the\n    `\"Cross-lingual Entity Alignment via Joint Attribute-Preserving Embedding\"\n    <https://arxiv.org/abs/1708.05045>`_ paper, where Chinese, Japanese and\n    French versions of DBpedia were linked to its English version.\n    Node features are given by pre-trained and aligned monolingual word\n    embeddings from the `\"Cross-lingual Knowledge Graph Alignment via Graph\n    Matching Neural Network\" <https://arxiv.org/abs/1905.11605>`_ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        pair (str): The pair of languages (:obj:`\"en_zh\"`, :obj:`\"en_fr\"`,\n            :obj:`\"en_ja\"`, :obj:`\"zh_en\"`, :obj:`\"fr_en\"`, :obj:`\"ja_en\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    file_id = '1ggYlYf2_kTyi7oF9g07oTNn3VDhjl7so'\n\n    def __init__(\n        self,\n        root: str,\n        pair: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        assert pair in ['en_zh', 'en_fr', 'en_ja', 'zh_en', 'fr_en', 'ja_en']\n        self.pair = pair\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['en_zh', 'en_fr', 'en_ja', 'zh_en', 'fr_en', 'ja_en']\n\n    @property\n    def processed_file_names(self) -> str:\n        return f'{self.pair}.pt'\n\n    def download(self) -> None:\n        path = download_google_url(self.file_id, self.root, 'data.zip')\n        extract_zip(path, self.root)\n        os.unlink(path)\n        fs.rm(self.raw_dir)\n        os.rename(osp.join(self.root, 'DBP15K'), self.raw_dir)\n\n    def process(self) -> None:\n        embs = {}\n        with open(osp.join(self.raw_dir, 'sub.glove.300d')) as f:\n            for line in f:\n                info = line.strip().split(' ')\n                if len(info) > 300:\n                    embs[info[0]] = torch.tensor([float(x) for x in info[1:]])\n                else:\n                    embs['**UNK**'] = torch.tensor([float(x) for x in info])\n\n        g1_path = osp.join(self.raw_dir, self.pair, 'triples_1')\n        x1_path = osp.join(self.raw_dir, self.pair, 'id_features_1')\n        g2_path = osp.join(self.raw_dir, self.pair, 'triples_2')\n        x2_path = osp.join(self.raw_dir, self.pair, 'id_features_2')\n\n        x1, edge_index1, rel1, assoc1 = self.process_graph(\n            g1_path, x1_path, embs)\n        x2, edge_index2, rel2, assoc2 = self.process_graph(\n            g2_path, x2_path, embs)\n\n        train_path = osp.join(self.raw_dir, self.pair, 'train.examples.20')\n        train_y = self.process_y(train_path, assoc1, assoc2)\n\n        test_path = osp.join(self.raw_dir, self.pair, 'test.examples.1000')\n        test_y = self.process_y(test_path, assoc1, assoc2)\n\n        data = Data(x1=x1, edge_index1=edge_index1, rel1=rel1, x2=x2,\n                    edge_index2=edge_index2, rel2=rel2, train_y=train_y,\n                    test_y=test_y)\n        self.save([data], self.processed_paths[0])\n\n    def process_graph(\n        self,\n        triple_path: str,\n        feature_path: str,\n        embeddings: Dict[str, Tensor],\n    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n\n        g1 = read_txt_array(triple_path, sep='\\t', dtype=torch.long)\n        subj, rel, obj = g1.t()\n\n        x_dict = {}\n        with open(feature_path) as f:\n            for line in f:\n                info = line.strip().split('\\t')\n                info = info if len(info) == 2 else info + ['**UNK**']\n                seq = info[1].lower().split()\n                hs = [embeddings.get(w, embeddings['**UNK**']) for w in seq]\n                x_dict[int(info[0])] = torch.stack(hs, dim=0)\n\n        idx = torch.tensor(list(x_dict.keys()))\n        assoc = torch.full((int(idx.max()) + 1, ), -1, dtype=torch.long)\n        assoc[idx] = torch.arange(idx.size(0))\n\n        subj, obj = assoc[subj], assoc[obj]\n        edge_index = torch.stack([subj, obj], dim=0)\n        edge_index, rel = sort_edge_index(edge_index, rel)\n\n        xs = list(x_dict.values())\n        for i in x_dict.keys():\n            xs[assoc[i]] = x_dict[i]\n        x = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True)\n\n        return x, edge_index, rel, assoc\n\n    def process_y(self, path: str, assoc1: Tensor, assoc2: Tensor) -> Tensor:\n        row, col, mask = read_txt_array(path, sep='\\t', dtype=torch.long).t()\n        mask = mask.to(torch.bool)\n        return torch.stack([assoc1[row[mask]], assoc2[col[mask]]], dim=0)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.pair})'\n"
  },
  {
    "path": "torch_geometric/datasets/deezer_europe.py",
    "content": "from typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass DeezerEurope(InMemoryDataset):\n    r\"\"\"The Deezer Europe dataset introduced in the `\"Characteristic Functions\n    on Graphs: Birds of a Feather, from Statistical Descriptors to Parametric\n    Models\" <https://arxiv.org/abs/2005.07959>`_ paper.\n    Nodes represent European users of Deezer and edges are mutual follower\n    relationships.\n    It contains 28,281 nodes, 185,504 edges, 128 node features and 2 classes.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'https://graphmining.ai/datasets/ptg/deezer_europe.npz'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'deezer_europe.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(self.url, self.raw_dir)\n\n    def process(self) -> None:\n        data = np.load(self.raw_paths[0], 'r', allow_pickle=True)\n        x = torch.from_numpy(data['features']).to(torch.float)\n        y = torch.from_numpy(data['target']).to(torch.long)\n        edge_index = torch.from_numpy(data['edges']).to(torch.long)\n        edge_index = edge_index.t().contiguous()\n\n        data = Data(x=x, y=y, edge_index=edge_index)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/dgraph.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, extract_zip\nfrom torch_geometric.utils import index_to_mask\n\n\nclass DGraphFin(InMemoryDataset):\n    r\"\"\"The DGraphFin networks from the\n    `\"DGraph: A Large-Scale Financial Dataset for Graph Anomaly Detection\"\n    <https://arxiv.org/abs/2207.03579>`_ paper.\n    It is a directed, unweighted dynamic graph consisting of millions of\n    nodes and edges, representing a realistic user-to-user social network\n    in financial industry.\n    Node represents a Finvolution user, and an edge from one\n    user to another means that the user regards the other user\n    as the emergency contact person. Each edge is associated with a\n    timestamp ranging from 1 to 821 and a type of emergency contact\n    ranging from 0 to 11.\n\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 3,700,550\n          - 4,300,999\n          - 17\n          - 2\n    \"\"\"\n\n    url = \"https://dgraph.xinye.com\"\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    def download(self) -> None:\n        raise RuntimeError(\n            f\"Dataset not found. Please download '{self.raw_file_names}' from \"\n            f\"'{self.url}' and move it to '{self.raw_dir}'\")\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'DGraphFin.zip'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    @property\n    def num_classes(self) -> int:\n        return 2\n\n    def process(self) -> None:\n        extract_zip(self.raw_paths[0], self.raw_dir, log=False)\n        path = osp.join(self.raw_dir, \"dgraphfin.npz\")\n\n        with np.load(path) as loader:\n            x = torch.from_numpy(loader['x']).to(torch.float)\n            y = torch.from_numpy(loader['y']).to(torch.long)\n            edge_index = torch.from_numpy(loader['edge_index']).to(torch.long)\n            edge_type = torch.from_numpy(loader['edge_type']).to(torch.long)\n            edge_time = torch.from_numpy(loader['edge_timestamp']).to(\n                torch.long)\n            train_nodes = torch.from_numpy(loader['train_mask']).to(torch.long)\n            val_nodes = torch.from_numpy(loader['valid_mask']).to(torch.long)\n            test_nodes = torch.from_numpy(loader['test_mask']).to(torch.long)\n\n            train_mask = index_to_mask(train_nodes, size=x.size(0))\n            val_mask = index_to_mask(val_nodes, size=x.size(0))\n            test_mask = index_to_mask(test_nodes, size=x.size(0))\n            data = Data(x=x, edge_index=edge_index.t(), edge_type=edge_type,\n                        edge_time=edge_time, y=y, train_mask=train_mask,\n                        val_mask=val_mask, test_mask=test_mask)\n\n        data = data if self.pre_transform is None else self.pre_transform(data)\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/dynamic_faust.py",
    "content": "from itertools import product\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset\n\n\nclass DynamicFAUST(InMemoryDataset):\n    r\"\"\"The dynamic FAUST humans dataset from the `\"Dynamic FAUST: Registering\n    Human Bodies in Motion\"\n    <http://files.is.tue.mpg.de/black/papers/dfaust2017.pdf>`_ paper.\n\n    .. note::\n\n        Data objects hold mesh faces instead of edge indices.\n        To convert the mesh to a graph, use the\n        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.\n        To convert the mesh to a point cloud, use the\n        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to\n        sample a fixed number of points on the mesh faces according to their\n        face area.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        subjects (list, optional): List of subjects to include in the\n            dataset. Can include the subjects :obj:`\"50002\"`, :obj:`\"50004\"`,\n            :obj:`\"50007\"`, :obj:`\"50009\"`, :obj:`\"50020\"`, :obj:`\"50021\"`,\n            :obj:`\"50022\"`, :obj:`\"50025\"`, :obj:`\"50026\"`, :obj:`\"50027\"`.\n            If set to :obj:`None`, the dataset will contain all subjects.\n            (default: :obj:`None`)\n        categories (list, optional): List of categories to include in the\n            dataset. Can include the categories :obj:`\"chicken_wings\"`,\n            :obj:`\"hips\"`, :obj:`\"jiggle_on_toes\"`, :obj:`\"jumping_jacks\"`,\n            :obj:`\"knees\"`, :obj:`\"light_hopping_loose\"`,\n            :obj:`\"light_hopping_stiff\"`, :obj:`\"one_leg_jump\"`,\n            :obj:`\"one_leg_loose\"`, :obj:`\"personal_move\"`, :obj:`\"punching\"`,\n            :obj:`\"running_on_spot\"`, :obj:`\"running_on_spot_bugfix\"`,\n            :obj:`\"shake_arms\"`, :obj:`\"shake_hips\"`, :obj:`\"shoulders\"`.\n            If set to :obj:`None`, the dataset will contain all categories.\n            (default: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'http://dfaust.is.tue.mpg.de/'\n\n    subjects = [\n        '50002', '50004', '50007', '50009', '50020', '50021', '50022', '50025',\n        '50026', '50027'\n    ]\n    categories = [\n        'chicken_wings', 'hips', 'jiggle_on_toes', 'jumping_jacks', 'knees',\n        'light_hopping_loose', 'light_hopping_stiff', 'one_leg_jump',\n        'one_leg_loose', 'personal_move', 'punching', 'running_on_spot',\n        'running_on_spot_bugfix', 'shake_arms', 'shake_hips', 'shake_shoulders'\n    ]\n\n    def __init__(\n        self,\n        root: str,\n        subjects: Optional[List[str]] = None,\n        categories: Optional[List[str]] = None,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n\n        subjects = self.subjects if subjects is None else subjects\n        subjects = [sid.lower() for sid in subjects]\n        for sid in subjects:\n            assert sid in self.subjects\n        self.subjects = subjects\n\n        categories = self.categories if categories is None else categories\n        categories = [cat.lower() for cat in categories]\n        for cat in categories:\n            assert cat in self.categories\n        self.categories = categories\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['registrations_m.hdf5', 'registrations_f.hdf5']\n\n    @property\n    def processed_file_names(self) -> str:\n        sids = '_'.join([sid[-2:] for sid in self.subjects])\n        cats = '_'.join([\n            ''.join([w[0] for w in cat.split('_')]) for cat in self.categories\n        ])\n        return f'{sids}_{cats}.pt'\n\n    def download(self) -> None:\n        raise RuntimeError(\n            f\"Dataset not found. Please download male registrations \"\n            f\"'registrations_m.hdf5' and female registrations \"\n            f\"'registrations_f.hdf5' from '{self.url}' and move it to \"\n            f\"'{self.raw_dir}'\")\n\n    def process(self) -> None:\n        import h5py\n\n        fm = h5py.File(self.raw_paths[0], 'r')\n        ff = h5py.File(self.raw_paths[1], 'r')\n\n        face = torch.from_numpy(fm['faces'][()]).to(torch.long)\n        face = face.t().contiguous()\n\n        data_list = []\n        for (sid, cat) in product(self.subjects, self.categories):\n            idx = f'{sid}_{cat}'\n            if idx in fm:\n                pos = torch.from_numpy(fm[idx][()])\n            elif idx in ff:\n                pos = torch.from_numpy(ff[idx][()])\n            else:\n                continue\n            pos = pos.permute(2, 0, 1).contiguous()\n            data_list.append(Data(pos=pos, face=face, num_nodes=pos.size(1)))\n\n        if self.pre_filter is not None:\n            data_list = [d for d in data_list if self.pre_filter(d)]\n\n        if self.pre_transform is not None:\n            data_list = [self.pre_transform(d) for d in data_list]\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/elliptic.py",
    "content": "from typing import Any, Callable, List, Optional, Tuple\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.io import fs\n\n\nclass EllipticBitcoinDataset(InMemoryDataset):\n    r\"\"\"The Elliptic Bitcoin dataset of Bitcoin transactions from the\n    `\"Anti-Money Laundering in Bitcoin: Experimenting with Graph Convolutional\n    Networks for Financial Forensics\" <https://arxiv.org/abs/1908.02591>`_\n    paper.\n\n    :class:`EllipticBitcoinDataset` maps Bitcoin transactions to real entities\n    belonging to licit categories (exchanges, wallet providers, miners,\n    licit services, etc.) versus illicit ones (scams, malware, terrorist\n    organizations, ransomware, Ponzi schemes, etc.)\n\n    There exists 203,769 node transactions and 234,355 directed edge payments\n    flows, with two percent of nodes (4,545) labelled as illicit, and\n    twenty-one percent of nodes (42,019) labelled as licit.\n    The remaining transactions are unknown.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 203,769\n          - 234,355\n          - 165\n          - 2\n    \"\"\"\n    url = 'https://data.pyg.org/datasets/elliptic'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'elliptic_txs_features.csv',\n            'elliptic_txs_edgelist.csv',\n            'elliptic_txs_classes.csv',\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for file_name in self.raw_file_names:\n            fs.cp(f'{self.url}/{file_name}.zip', self.raw_dir, extract=True)\n\n    def _process_df(self, feat_df: Any, edge_df: Any,\n                    class_df: Any) -> Tuple[Any, Any, Any]:\n        return feat_df, edge_df, class_df\n\n    def process(self) -> None:\n        import pandas as pd\n\n        feat_df = pd.read_csv(self.raw_paths[0], header=None)\n        edge_df = pd.read_csv(self.raw_paths[1])\n        class_df = pd.read_csv(self.raw_paths[2])\n\n        columns = {0: 'txId', 1: 'time_step'}\n        feat_df = feat_df.rename(columns=columns)\n\n        feat_df, edge_df, class_df = self._process_df(\n            feat_df,\n            edge_df,\n            class_df,\n        )\n\n        x = torch.from_numpy(feat_df.loc[:, 2:].values).to(torch.float)\n\n        # There exists 3 different classes in the dataset:\n        # 0=licit,  1=illicit, 2=unknown\n        mapping = {'unknown': 2, '1': 1, '2': 0}\n        class_df['class'] = class_df['class'].map(mapping)\n        y = torch.from_numpy(class_df['class'].values)\n\n        mapping = {idx: i for i, idx in enumerate(feat_df['txId'].values)}\n        edge_df['txId1'] = edge_df['txId1'].map(mapping)\n        edge_df['txId2'] = edge_df['txId2'].map(mapping)\n        edge_index = torch.from_numpy(edge_df.values).t().contiguous()\n\n        # Timestamp based split:\n        # train_mask: 1 - 34 time_step, test_mask: 35-49 time_step\n        time_step = torch.from_numpy(feat_df['time_step'].values)\n        train_mask = (time_step < 35) & (y != 2)\n        test_mask = (time_step >= 35) & (y != 2)\n\n        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,\n                    test_mask=test_mask)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    @property\n    def num_classes(self) -> int:\n        return 2\n"
  },
  {
    "path": "torch_geometric/datasets/elliptic_temporal.py",
    "content": "from typing import Any, Callable, Optional, Tuple\n\nfrom torch_geometric.datasets import EllipticBitcoinDataset\n\n\nclass EllipticBitcoinTemporalDataset(EllipticBitcoinDataset):\n    r\"\"\"The time-step aware Elliptic Bitcoin dataset of Bitcoin transactions\n    from the `\"Anti-Money Laundering in Bitcoin: Experimenting with Graph\n    Convolutional Networks for Financial Forensics\"\n    <https://arxiv.org/abs/1908.02591>`_ paper.\n\n    :class:`EllipticBitcoinTemporalDataset` maps Bitcoin transactions to real\n    entities belonging to licit categories (exchanges, wallet providers,\n    miners, licit services, etc.) versus illicit ones (scams, malware,\n    terrorist organizations, ransomware, Ponzi schemes, etc.)\n\n    There exists 203,769 node transactions and 234,355 directed edge payments\n    flows, with two percent of nodes (4,545) labelled as illicit, and\n    twenty-one percent of nodes (42,019) labelled as licit.\n    The remaining transactions are unknown.\n\n    .. note::\n\n        In contrast to :class:`EllipticBitcoinDataset`, this dataset returns\n        Bitcoin transactions only for a given timestamp :obj:`t`.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        t (int): The Timestep for which nodes should be selected (from :obj:`1`\n            to :obj:`49`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 203,769\n          - 234,355\n          - 165\n          - 2\n    \"\"\"\n    def __init__(\n        self,\n        root: str,\n        t: int,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ):\n        if t < 1 or t > 49:\n            raise ValueError(\"'t' needs to be between 1 and 49\")\n\n        self.t = t\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n\n    @property\n    def processed_file_names(self) -> str:\n        return f'data_t_{self.t}.pt'\n\n    def _process_df(self, feat_df: Any, edge_df: Any,\n                    class_df: Any) -> Tuple[Any, Any, Any]:\n\n        feat_df = feat_df[feat_df['time_step'] == self.t]\n\n        mask = edge_df['txId1'].isin(feat_df['txId'].values)\n        edge_df = edge_df[mask]\n\n        class_df = class_df.merge(feat_df[['txId']], how='right',\n                                  left_on='txId', right_on='txId')\n\n        return feat_df, edge_df, class_df\n"
  },
  {
    "path": "torch_geometric/datasets/email_eu_core.py",
    "content": "import os\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_gz,\n)\n\n\nclass EmailEUCore(InMemoryDataset):\n    r\"\"\"An e-mail communication network of a large European research\n    institution, taken from the `\"Local Higher-order Graph Clustering\"\n    <https://www-cs.stanford.edu/~jure/pubs/mappr-kdd17.pdf>`_ paper.\n    Nodes indicate members of the institution.\n    An edge between a pair of members indicates that they exchanged at least\n    one email.\n    Node labels indicate membership to one of the 42 departments.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    urls = [\n        'https://snap.stanford.edu/data/email-Eu-core.txt.gz',\n        'https://snap.stanford.edu/data/email-Eu-core-department-labels.txt.gz'\n    ]\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['email-Eu-core.txt', 'email-Eu-core-department-labels.txt']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for url in self.urls:\n            path = download_url(url, self.raw_dir)\n            extract_gz(path, self.raw_dir)\n            os.unlink(path)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        edge_index = pd.read_csv(self.raw_paths[0], sep=' ', header=None)\n        edge_index = torch.from_numpy(edge_index.values).t().contiguous()\n\n        y = pd.read_csv(self.raw_paths[1], sep=' ', header=None, usecols=[1])\n        y = torch.from_numpy(y.values).view(-1)\n\n        data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0))\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/entities.py",
    "content": "import logging\nimport os\nimport os.path as osp\nfrom collections import Counter\nfrom typing import Any, Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_tar,\n)\nfrom torch_geometric.utils import index_sort\n\n\nclass Entities(InMemoryDataset):\n    r\"\"\"The relational entities networks :obj:`\"AIFB\"`, :obj:`\"MUTAG\"`,\n    :obj:`\"BGS\"` and :obj:`\"AM\"` from the `\"Modeling Relational Data with Graph\n    Convolutional Networks\" <https://arxiv.org/abs/1703.06103>`_ paper.\n    Training and test splits are given by node indices.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"AIFB\"`, :obj:`\"MUTAG\"`,\n            :obj:`\"BGS\"`, :obj:`\"AM\"`).\n        hetero (bool, optional): If set to :obj:`True`, will save the dataset\n            as a :class:`~torch_geometric.data.HeteroData` object.\n            (default: :obj:`False`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - AIFB\n          - 8,285\n          - 58,086\n          - 0\n          - 4\n        * - AM\n          - 1,666,764\n          - 11,976,642\n          - 0\n          - 11\n        * - MUTAG\n          - 23,644\n          - 148,454\n          - 0\n          - 2\n        * - BGS\n          - 333,845\n          - 1,832,398\n          - 0\n          - 2\n    \"\"\"\n\n    url = 'https://data.dgl.ai/dataset/{}.tgz'\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        hetero: bool = False,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        self.hetero = hetero\n        assert self.name in ['aifb', 'am', 'mutag', 'bgs']\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        if hetero:\n            self.load(self.processed_paths[0], data_cls=HeteroData)\n        else:\n            self.load(self.processed_paths[0], data_cls=Data)\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def num_relations(self) -> int:\n        return int(self._data.edge_type.max()) + 1  # type: ignore\n\n    @property\n    def num_classes(self) -> int:\n        return int(self._data.train_y.max()) + 1  # type: ignore\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            f'{self.name}_stripped.nt.gz',\n            'completeDataset.tsv',\n            'trainingSet.tsv',\n            'testSet.tsv',\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'hetero_data.pt' if self.hetero else 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url.format(self.name), self.root)\n        extract_tar(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        import gzip\n\n        import pandas as pd\n        import rdflib as rdf\n\n        graph_file, task_file, train_file, test_file = self.raw_paths\n\n        with hide_stdout():\n            g = rdf.Graph()\n            with gzip.open(graph_file, 'rb') as f:\n                g.parse(file=f, format='nt')  # type: ignore\n\n        freq = Counter(g.predicates())\n\n        relations = sorted(set(g.predicates()), key=lambda p: -freq.get(p, 0))\n        subjects = set(g.subjects())\n        objects = set(g.objects())\n        nodes = list(subjects.union(objects))\n\n        N = len(nodes)\n        R = 2 * len(relations)\n\n        relations_dict = {rel: i for i, rel in enumerate(relations)}\n        nodes_dict = {str(node): i for i, node in enumerate(nodes)}\n\n        edges = []\n        for s, p, o in g.triples((None, None, None)):\n            src, dst = nodes_dict[str(s)], nodes_dict[str(o)]\n            rel = relations_dict[p]\n            edges.append([src, dst, 2 * rel])\n            edges.append([dst, src, 2 * rel + 1])\n\n        edge = torch.tensor(edges, dtype=torch.long).t().contiguous()\n        _, perm = index_sort(N * R * edge[0] + R * edge[1] + edge[2])\n        edge = edge[:, perm]\n\n        edge_index, edge_type = edge[:2], edge[2]\n\n        if self.name == 'am':\n            label_header = 'label_cateogory'\n            nodes_header = 'proxy'\n        elif self.name == 'aifb':\n            label_header = 'label_affiliation'\n            nodes_header = 'person'\n        elif self.name == 'mutag':\n            label_header = 'label_mutagenic'\n            nodes_header = 'bond'\n        elif self.name == 'bgs':\n            label_header = 'label_lithogenesis'\n            nodes_header = 'rock'\n\n        labels_df = pd.read_csv(task_file, sep='\\t')\n        labels_set = set(labels_df[label_header].values.tolist())\n        labels_dict = {lab: i for i, lab in enumerate(list(labels_set))}\n\n        train_labels_df = pd.read_csv(train_file, sep='\\t')\n        train_indices, train_labels = [], []\n        for nod, lab in zip(train_labels_df[nodes_header].values,\n                            train_labels_df[label_header].values):\n            train_indices.append(nodes_dict[nod])\n            train_labels.append(labels_dict[lab])\n\n        train_idx = torch.tensor(train_indices, dtype=torch.long)\n        train_y = torch.tensor(train_labels, dtype=torch.long)\n\n        test_labels_df = pd.read_csv(test_file, sep='\\t')\n        test_indices, test_labels = [], []\n        for nod, lab in zip(test_labels_df[nodes_header].values,\n                            test_labels_df[label_header].values):\n            test_indices.append(nodes_dict[nod])\n            test_labels.append(labels_dict[lab])\n\n        test_idx = torch.tensor(test_indices, dtype=torch.long)\n        test_y = torch.tensor(test_labels, dtype=torch.long)\n\n        data = Data(edge_index=edge_index, edge_type=edge_type,\n                    train_idx=train_idx, train_y=train_y, test_idx=test_idx,\n                    test_y=test_y, num_nodes=N)\n\n        if self.hetero:\n            data = data.to_heterogeneous(node_type_names=['v'])\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.name.upper()}{self.__class__.__name__}()'\n\n\nclass hide_stdout:\n    def __enter__(self) -> None:\n        self.level = logging.getLogger().level\n        logging.getLogger().setLevel(logging.ERROR)\n\n    def __exit__(self, *args: Any) -> None:\n        logging.getLogger().setLevel(self.level)\n"
  },
  {
    "path": "torch_geometric/datasets/explainer_dataset.py",
    "content": "from typing import Any, Callable, Dict, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import InMemoryDataset\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.datasets.motif_generator import MotifGenerator\nfrom torch_geometric.explain import Explanation\n\n\nclass ExplainerDataset(InMemoryDataset):\n    r\"\"\"Generates a synthetic dataset for evaluating explainabilty algorithms,\n    as described in the `\"GNNExplainer: Generating Explanations for Graph\n    Neural Networks\" <https://arxiv.org/abs/1903.03894>`__ paper.\n    The :class:`~torch_geometric.datasets.ExplainerDataset` creates synthetic\n    graphs coming from a\n    :class:`~torch_geometric.datasets.graph_generator.GraphGenerator`, and\n    randomly attaches :obj:`num_motifs` many motifs to it coming from a\n    :class:`~torch_geometric.datasets.graph_generator.MotifGenerator`.\n    Ground-truth node-level and edge-level explainabilty masks are given based\n    on whether nodes and edges are part of a certain motif or not.\n\n    For example, to generate a random Barabasi-Albert (BA) graph with 300\n    nodes, in which we want to randomly attach 80 :obj:`\"house\"` motifs, write:\n\n    .. code-block:: python\n\n        from torch_geometric.datasets import ExplainerDataset\n        from torch_geometric.datasets.graph_generator import BAGraph\n\n        dataset = ExplainerDataset(\n            graph_generator=BAGraph(num_nodes=300, num_edges=5),\n            motif_generator='house',\n            num_motifs=80,\n        )\n\n    .. note::\n\n        For an example of using :class:`ExplainerDataset`, see\n        `examples/explain/gnn_explainer_ba_shapes.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        /explain/gnn_explainer_ba_shapes.py>`_.\n\n    Args:\n        graph_generator (GraphGenerator or str): The graph generator to be\n            used, *e.g.*,\n            :class:`torch.geometric.datasets.graph_generator.BAGraph`\n            (or any string that automatically resolves to it).\n        motif_generator (MotifGenerator): The motif generator to be used,\n            *e.g.*,\n            :class:`torch_geometric.datasets.motif_generator.HouseMotif`\n            (or any string that automatically resolves to it).\n        num_motifs (int): The number of motifs to attach to the graph.\n        num_graphs (int, optional): The number of graphs to generate.\n            (default: :obj:`1`)\n        graph_generator_kwargs (Dict[str, Any], optional): Arguments passed to\n            the respective graph generator module in case it gets automatically\n            resolved. (default: :obj:`None`)\n        motif_generator_kwargs (Dict[str, Any], optional): Arguments passed to\n            the respective motif generator module in case it gets automatically\n            resolved. (default: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        graph_generator: Union[GraphGenerator, str],\n        motif_generator: Union[MotifGenerator, str],\n        num_motifs: int,\n        num_graphs: int = 1,\n        graph_generator_kwargs: Optional[Dict[str, Any]] = None,\n        motif_generator_kwargs: Optional[Dict[str, Any]] = None,\n        transform: Optional[Callable] = None,\n    ):\n        super().__init__(root=None, transform=transform)\n\n        if num_motifs <= 0:\n            raise ValueError(f\"At least one motif needs to be attached to the \"\n                             f\"graph (got {num_motifs})\")\n\n        self.graph_generator = GraphGenerator.resolve(\n            graph_generator,\n            **(graph_generator_kwargs or {}),\n        )\n        self.motif_generator = MotifGenerator.resolve(\n            motif_generator,\n            **(motif_generator_kwargs or {}),\n        )\n        self.num_motifs = num_motifs\n\n        # TODO (matthias) support on-the-fly graph generation.\n        data_list = [self.get_graph() for _ in range(num_graphs)]\n        self.data, self.slices = self.collate(data_list)\n\n    def get_graph(self) -> Explanation:\n        data = self.graph_generator()\n        assert data.num_nodes is not None\n        assert data.edge_index is not None\n\n        edge_indices = [data.edge_index]\n        num_nodes = data.num_nodes\n        node_masks = [torch.zeros(data.num_nodes)]\n        edge_masks = [torch.zeros(data.num_edges)]\n        ys = [torch.zeros(num_nodes, dtype=torch.long)]\n\n        connecting_nodes = torch.randperm(num_nodes)[:self.num_motifs]\n        for i in connecting_nodes.tolist():\n            motif = self.motif_generator()\n            assert motif.num_nodes is not None\n            assert motif.edge_index is not None\n\n            # Add motif to the graph.\n            edge_indices.append(motif.edge_index + num_nodes)\n            node_masks.append(torch.ones(motif.num_nodes))\n            edge_masks.append(torch.ones(motif.num_edges))\n\n            # Add random motif connection to the graph.\n            j = int(torch.randint(0, motif.num_nodes, (1, ))) + num_nodes\n            edge_indices.append(torch.tensor([[i, j], [j, i]]))\n            edge_masks.append(torch.zeros(2))\n\n            if isinstance(motif.y, Tensor):\n                ys.append(motif.y + 1 if motif.y.min() == 0 else motif.y)\n            else:\n                ys.append(torch.ones(motif.num_nodes, dtype=torch.long))\n\n            num_nodes += motif.num_nodes\n\n        return Explanation(\n            edge_index=torch.cat(edge_indices, dim=1),\n            y=torch.cat(ys, dim=0),\n            edge_mask=torch.cat(edge_masks, dim=0),\n            node_mask=torch.cat(node_masks, dim=0),\n        )\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'graph_generator={self.graph_generator}, '\n                f'motif_generator={self.motif_generator}, '\n                f'num_motifs={self.num_motifs})')\n"
  },
  {
    "path": "torch_geometric/datasets/facebook.py",
    "content": "from typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass FacebookPagePage(InMemoryDataset):\n    r\"\"\"The Facebook Page-Page network dataset introduced in the\n    `\"Multi-scale Attributed Node Embedding\"\n    <https://arxiv.org/abs/1909.13021>`_ paper.\n    Nodes represent verified pages on Facebook and edges are mutual likes.\n    It contains 22,470 nodes, 342,004 edges, 128 node features and 4 classes.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'https://graphmining.ai/datasets/ptg/facebook.npz'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'facebook.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(self.url, self.raw_dir)\n\n    def process(self) -> None:\n        data = np.load(self.raw_paths[0], 'r', allow_pickle=True)\n        x = torch.from_numpy(data['features']).to(torch.float)\n        y = torch.from_numpy(data['target']).to(torch.long)\n        edge_index = torch.from_numpy(data['edges']).to(torch.long)\n        edge_index = edge_index.t().contiguous()\n\n        data = Data(x=x, y=y, edge_index=edge_index)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/fake.py",
    "content": "import random\nfrom collections import defaultdict\nfrom itertools import product\nfrom typing import Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData, InMemoryDataset\nfrom torch_geometric.utils import coalesce, remove_self_loops, to_undirected\n\n\nclass FakeDataset(InMemoryDataset):\n    r\"\"\"A fake dataset that returns randomly generated\n    :class:`~torch_geometric.data.Data` objects.\n\n    Args:\n        num_graphs (int, optional): The number of graphs. (default: :obj:`1`)\n        avg_num_nodes (int, optional): The average number of nodes in a graph.\n            (default: :obj:`1000`)\n        avg_degree (float, optional): The average degree per node.\n            (default: :obj:`10.0`)\n        num_channels (int, optional): The number of node features.\n            (default: :obj:`64`)\n        edge_dim (int, optional): The number of edge features.\n            (default: :obj:`0`)\n        num_classes (int, optional): The number of classes in the dataset.\n            (default: :obj:`10`)\n        task (str, optional): Whether to return node-level or graph-level\n            labels (:obj:`\"node\"`, :obj:`\"graph\"`, :obj:`\"auto\"`).\n            If set to :obj:`\"auto\"`, will return graph-level labels if\n            :obj:`num_graphs > 1`, and node-level labels other-wise.\n            (default: :obj:`\"auto\"`)\n        is_undirected (bool, optional): Whether the graphs to generate are\n            undirected. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        **kwargs (optional): Additional attributes and their shapes\n            *e.g.* :obj:`global_features=5`.\n    \"\"\"\n    def __init__(\n        self,\n        num_graphs: int = 1,\n        avg_num_nodes: int = 1000,\n        avg_degree: float = 10.0,\n        num_channels: int = 64,\n        edge_dim: int = 0,\n        num_classes: int = 10,\n        task: str = 'auto',\n        is_undirected: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        **kwargs: Union[int, Tuple[int, ...]],\n    ) -> None:\n        super().__init__(None, transform)\n\n        if task == 'auto':\n            task = 'graph' if num_graphs > 1 else 'node'\n        assert task in ['node', 'graph']\n\n        self.avg_num_nodes = max(avg_num_nodes, int(avg_degree))\n        self.avg_degree = max(avg_degree, 1)\n        self.num_channels = num_channels\n        self.edge_dim = edge_dim\n        self._num_classes = num_classes\n        self.task = task\n        self.is_undirected = is_undirected\n        self.kwargs = kwargs\n\n        data_list = [self.generate_data() for _ in range(max(num_graphs, 1))]\n        self.data, self.slices = self.collate(data_list)\n\n    def generate_data(self) -> Data:\n        num_nodes = get_num_nodes(self.avg_num_nodes, self.avg_degree)\n\n        data = Data()\n\n        if self._num_classes > 0 and self.task == 'node':\n            data.y = torch.randint(self._num_classes, (num_nodes, ))\n        elif self._num_classes > 0 and self.task == 'graph':\n            data.y = torch.tensor([random.randint(0, self._num_classes - 1)])\n\n        data.edge_index = get_edge_index(num_nodes, num_nodes, self.avg_degree,\n                                         self.is_undirected, remove_loops=True)\n\n        if self.num_channels > 0:\n            x = torch.randn(num_nodes, self.num_channels)\n            if self._num_classes > 0 and self.task == 'node':\n                assert isinstance(data.y, Tensor)\n                x = x + data.y.unsqueeze(1)\n            elif self._num_classes > 0 and self.task == 'graph':\n                assert isinstance(data.y, Tensor)\n                x = x + data.y\n            data.x = x\n        else:\n            data.num_nodes = num_nodes\n\n        if self.edge_dim > 1:\n            data.edge_attr = torch.rand(data.num_edges, self.edge_dim)\n        elif self.edge_dim == 1:\n            data.edge_weight = torch.rand(data.num_edges)\n\n        for feature_name, feature_shape in self.kwargs.items():\n            setattr(data, feature_name, torch.randn(feature_shape))\n\n        return data\n\n\nclass FakeHeteroDataset(InMemoryDataset):\n    r\"\"\"A fake dataset that returns randomly generated\n    :class:`~torch_geometric.data.HeteroData` objects.\n\n    Args:\n        num_graphs (int, optional): The number of graphs. (default: :obj:`1`)\n        num_node_types (int, optional): The number of node types.\n            (default: :obj:`3`)\n        num_edge_types (int, optional): The number of edge types.\n            (default: :obj:`6`)\n        avg_num_nodes (int, optional): The average number of nodes in a graph.\n            (default: :obj:`1000`)\n        avg_degree (float, optional): The average degree per node.\n            (default: :obj:`10.0`)\n        avg_num_channels (int, optional): The average number of node features.\n            (default: :obj:`64`)\n        edge_dim (int, optional): The number of edge features.\n            (default: :obj:`0`)\n        num_classes (int, optional): The number of classes in the dataset.\n            (default: :obj:`10`)\n        task (str, optional): Whether to return node-level or graph-level\n            labels (:obj:`\"node\"`, :obj:`\"graph\"`, :obj:`\"auto\"`).\n            If set to :obj:`\"auto\"`, will return graph-level labels if\n            :obj:`num_graphs > 1`, and node-level labels other-wise.\n            (default: :obj:`\"auto\"`)\n        transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        **kwargs (optional): Additional attributes and their shapes\n            *e.g.* :obj:`global_features=5`.\n    \"\"\"\n    def __init__(\n        self,\n        num_graphs: int = 1,\n        num_node_types: int = 3,\n        num_edge_types: int = 6,\n        avg_num_nodes: int = 1000,\n        avg_degree: float = 10.0,\n        avg_num_channels: int = 64,\n        edge_dim: int = 0,\n        num_classes: int = 10,\n        task: str = \"auto\",\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        **kwargs: Union[int, Tuple[int, ...]],\n    ) -> None:\n        super().__init__(None, transform)\n\n        if task == 'auto':\n            task = 'graph' if num_graphs > 1 else 'node'\n        assert task in ['node', 'graph']\n\n        self.node_types = [f'v{i}' for i in range(max(num_node_types, 1))]\n\n        edge_types: List[Tuple[str, str]] = []\n        edge_type_product = list(product(self.node_types, self.node_types))\n        while len(edge_types) < max(num_edge_types, 1):\n            edge_types.extend(edge_type_product)\n        random.shuffle(edge_types)\n\n        self.edge_types: List[Tuple[str, str, str]] = []\n        count: Dict[Tuple[str, str], int] = defaultdict(int)\n        for edge_type in edge_types[:max(num_edge_types, 1)]:\n            rel = f'e{count[edge_type]}'\n            count[edge_type] += 1\n            self.edge_types.append((edge_type[0], rel, edge_type[1]))\n\n        self.avg_num_nodes = max(avg_num_nodes, int(avg_degree))\n        self.avg_degree = max(avg_degree, 1)\n        self.num_channels = [\n            get_num_channels(avg_num_channels) for _ in self.node_types\n        ]\n        self.edge_dim = edge_dim\n        self._num_classes = num_classes\n        self.task = task\n        self.kwargs = kwargs\n\n        data_list = [self.generate_data() for _ in range(max(num_graphs, 1))]\n        self.data, self.slices = self.collate(data_list)\n\n    def generate_data(self) -> HeteroData:\n        data = HeteroData()\n\n        iterator = zip(self.node_types, self.num_channels)\n        for i, (node_type, num_channels) in enumerate(iterator):\n            num_nodes = get_num_nodes(self.avg_num_nodes, self.avg_degree)\n\n            store = data[node_type]\n\n            if num_channels > 0:\n                store.x = torch.randn(num_nodes, num_channels)\n            else:\n                store.num_nodes = num_nodes\n\n            if self._num_classes > 0 and self.task == 'node' and i == 0:\n                store.y = torch.randint(self._num_classes, (num_nodes, ))\n\n        for (src, rel, dst) in self.edge_types:\n            store = data[(src, rel, dst)]\n\n            store.edge_index = get_edge_index(\n                data[src].num_nodes,\n                data[dst].num_nodes,\n                self.avg_degree,\n                is_undirected=False,\n                remove_loops=False,\n            )\n\n            if self.edge_dim > 1:\n                store.edge_attr = torch.rand(store.num_edges, self.edge_dim)\n            elif self.edge_dim == 1:\n                store.edge_weight = torch.rand(store.num_edges)\n\n        if self._num_classes > 0 and self.task == 'graph':\n            data.y = torch.tensor([random.randint(0, self._num_classes - 1)])\n\n        for feature_name, feature_shape in self.kwargs.items():\n            setattr(data, feature_name, torch.randn(feature_shape))\n\n        return data\n\n\n###############################################################################\n\n\ndef get_num_nodes(avg_num_nodes: int, avg_degree: float) -> int:\n    min_num_nodes = max(3 * avg_num_nodes // 4, int(avg_degree))\n    max_num_nodes = 5 * avg_num_nodes // 4\n    return random.randint(min_num_nodes, max_num_nodes)\n\n\ndef get_num_channels(num_channels: int) -> int:\n    min_num_channels = 3 * num_channels // 4\n    max_num_channels = 5 * num_channels // 4\n    return random.randint(min_num_channels, max_num_channels)\n\n\ndef get_edge_index(\n    num_src_nodes: int,\n    num_dst_nodes: int,\n    avg_degree: float,\n    is_undirected: bool = False,\n    remove_loops: bool = False,\n) -> Tensor:\n\n    num_edges = int(num_src_nodes * avg_degree)\n    row = torch.randint(num_src_nodes, (num_edges, ), dtype=torch.int64)\n    col = torch.randint(num_dst_nodes, (num_edges, ), dtype=torch.int64)\n    edge_index = torch.stack([row, col], dim=0)\n\n    if remove_loops:\n        edge_index, _ = remove_self_loops(edge_index)\n\n    num_nodes = max(num_src_nodes, num_dst_nodes)\n    if is_undirected:\n        edge_index = to_undirected(edge_index, num_nodes=num_nodes)\n    else:\n        edge_index = coalesce(edge_index, num_nodes=num_nodes)\n\n    return edge_index\n"
  },
  {
    "path": "torch_geometric/datasets/faust.py",
    "content": "import os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import InMemoryDataset, extract_zip\nfrom torch_geometric.io import fs, read_ply\n\n\nclass FAUST(InMemoryDataset):\n    r\"\"\"The FAUST humans dataset from the `\"FAUST: Dataset and Evaluation for\n    3D Mesh Registration\"\n    <http://files.is.tue.mpg.de/black/papers/FAUST2014.pdf>`_ paper,\n    containing 100 watertight meshes representing 10 different poses for 10\n    different subjects.\n\n    .. note::\n\n        Data objects hold mesh faces instead of edge indices.\n        To convert the mesh to a graph, use the\n        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.\n        To convert the mesh to a point cloud, use the\n        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to\n        sample a fixed number of points on the mesh faces according to their\n        face area.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 100\n          - 6,890\n          - 41,328\n          - 3\n          - 10\n    \"\"\"\n\n    url = 'http://faust.is.tue.mpg.de/'\n\n    def __init__(\n        self,\n        root: str,\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = self.processed_paths[0] if train else self.processed_paths[1]\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'MPI-FAUST.zip'\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['training.pt', 'test.pt']\n\n    def download(self) -> None:\n        raise RuntimeError(\n            f\"Dataset not found. Please download '{self.raw_file_names}' from \"\n            f\"'{self.url}' and move it to '{self.raw_dir}'\")\n\n    def process(self) -> None:\n        extract_zip(self.raw_paths[0], self.raw_dir, log=False)\n\n        path = osp.join(self.raw_dir, 'MPI-FAUST', 'training', 'registrations')\n        path = osp.join(path, 'tr_reg_{0:03d}.ply')\n        data_list = []\n        for i in range(100):\n            data = read_ply(path.format(i))\n            data.y = torch.tensor([i % 10], dtype=torch.long)\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n            data_list.append(data)\n\n        self.save(data_list[:80], self.processed_paths[0])\n        self.save(data_list[80:], self.processed_paths[1])\n\n        fs.rm(osp.join(self.raw_dir, 'MPI-FAUST'))\n"
  },
  {
    "path": "torch_geometric/datasets/flickr.py",
    "content": "import json\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_google_url\n\n\nclass Flickr(InMemoryDataset):\n    r\"\"\"The Flickr dataset from the `\"GraphSAINT: Graph Sampling Based\n    Inductive Learning Method\" <https://arxiv.org/abs/1907.04931>`_ paper,\n    containing descriptions and common properties of images.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 89,250\n          - 899,756\n          - 500\n          - 7\n    \"\"\"\n    adj_full_id = '1crmsTbd1-2sEXsGwa2IKnIB7Zd3TmUsy'\n    feats_id = '1join-XdvX3anJU_MLVtick7MgeAQiWIZ'\n    class_map_id = '1uxIkbtg5drHTsKt-PAsZZ4_yJmgFmle9'\n    role_id = '1htXCtuktuCW8TR8KiKfrFDAxUgekQoV7'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz')\n        download_google_url(self.feats_id, self.raw_dir, 'feats.npy')\n        download_google_url(self.class_map_id, self.raw_dir, 'class_map.json')\n        download_google_url(self.role_id, self.raw_dir, 'role.json')\n\n    def process(self) -> None:\n        import scipy.sparse as sp\n\n        f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))\n        adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])\n        adj = adj.tocoo()\n        row = torch.from_numpy(adj.row).to(torch.long)\n        col = torch.from_numpy(adj.col).to(torch.long)\n        edge_index = torch.stack([row, col], dim=0)\n\n        x = np.load(osp.join(self.raw_dir, 'feats.npy'))\n        x = torch.from_numpy(x).to(torch.float)\n\n        ys = [-1] * x.size(0)\n        with open(osp.join(self.raw_dir, 'class_map.json')) as f:\n            class_map = json.load(f)\n            for key, item in class_map.items():\n                ys[int(key)] = item\n        y = torch.tensor(ys)\n\n        with open(osp.join(self.raw_dir, 'role.json')) as f:\n            role = json.load(f)\n\n        train_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        train_mask[torch.tensor(role['tr'])] = True\n\n        val_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        val_mask[torch.tensor(role['va'])] = True\n\n        test_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        test_mask[torch.tensor(role['te'])] = True\n\n        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,\n                    val_mask=val_mask, test_mask=test_mask)\n\n        data = data if self.pre_transform is None else self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/freebase.py",
    "content": "from typing import Callable, Dict, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass FB15k_237(InMemoryDataset):\n    r\"\"\"The FB15K237 dataset from the `\"Translating Embeddings for Modeling\n    Multi-Relational Data\"\n    <https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling\n    -multi-relational-data>`_ paper,\n    containing 14,541 entities, 237 relations and 310,116 fact triples.\n\n    .. note::\n\n        The original :class:`FB15k` dataset suffers from major test leakage\n        through inverse relations, where a large number of test triples could\n        be obtained by inverting triples in the training set.\n        In order to create a dataset without this characteristic, the\n        :class:`~torch_geometric.datasets.FB15k_237` describes a subset of\n        :class:`FB15k` where inverse relations are removed.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset. (default: :obj:`\"train\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    url = ('https://raw.githubusercontent.com/villmow/'\n           'datasets_knowledge_embedding/master/FB15k-237')\n\n    def __init__(\n        self,\n        root: str,\n        split: str = \"train\",\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n\n        if split not in {'train', 'val', 'test'}:\n            raise ValueError(f\"Invalid 'split' argument (got {split})\")\n\n        path = self.processed_paths[['train', 'val', 'test'].index(split)]\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['train.txt', 'valid.txt', 'test.txt']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train_data.pt', 'val_data.pt', 'test_data.pt']\n\n    def download(self) -> None:\n        for filename in self.raw_file_names:\n            download_url(f'{self.url}/{filename}', self.raw_dir)\n\n    def process(self) -> None:\n        data_list: List[Data] = []\n        node_dict: Dict[str, int] = {}\n        rel_dict: Dict[str, int] = {}\n\n        for path in self.raw_paths:\n            with open(path) as f:\n                lines = [x.split('\\t') for x in f.read().split('\\n')[:-1]]\n\n            edge_index = torch.empty((2, len(lines)), dtype=torch.long)\n            edge_type = torch.empty(len(lines), dtype=torch.long)\n            for i, (src, rel, dst) in enumerate(lines):\n                if src not in node_dict:\n                    node_dict[src] = len(node_dict)\n                if dst not in node_dict:\n                    node_dict[dst] = len(node_dict)\n                if rel not in rel_dict:\n                    rel_dict[rel] = len(rel_dict)\n\n                edge_index[0, i] = node_dict[src]\n                edge_index[1, i] = node_dict[dst]\n                edge_type[i] = rel_dict[rel]\n\n            data = Data(edge_index=edge_index, edge_type=edge_type)\n            data_list.append(data)\n\n        for data, path in zip(data_list, self.processed_paths):\n            data.num_nodes = len(node_dict)\n            self.save([data], path)\n"
  },
  {
    "path": "torch_geometric/datasets/gdelt.py",
    "content": "from typing import Callable, List, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import download_url\nfrom torch_geometric.datasets.icews import EventDataset\nfrom torch_geometric.io import read_txt_array\n\n\nclass GDELT(EventDataset):\n    r\"\"\"The Global Database of Events, Language, and Tone (GDELT) dataset used\n    in the, *e.g.*, `\"Recurrent Event Network for Reasoning over Temporal\n    Knowledge Graphs\" <https://arxiv.org/abs/1904.05530>`_ paper, consisting of\n    events collected from 1/1/2018 to 1/31/2018 (15 minutes time granularity).\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset. (default: :obj:`\"train\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'https://github.com/INK-USC/RE-Net/raw/master/data/GDELT'\n    splits = [0, 1734399, 1973164, 2278405]  # Train/Val/Test splits.\n\n    def __init__(\n        self,\n        root: str,\n        split: str = \"train\",\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        assert split in ['train', 'val', 'test']\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        idx = self.processed_file_names.index(f'{split}.pt')\n        self.load(self.processed_paths[idx])\n\n    @property\n    def num_nodes(self) -> int:\n        return 7691\n\n    @property\n    def num_rels(self) -> int:\n        return 240\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [f'{name}.txt' for name in ['train', 'valid', 'test']]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train.pt', 'val.pt', 'test.pt']\n\n    def download(self) -> None:\n        for filename in self.raw_file_names:\n            download_url(f'{self.url}/{filename}', self.raw_dir)\n\n    def process_events(self) -> Tensor:\n        events = []\n        for path in self.raw_paths:\n            data = read_txt_array(path, sep='\\t', end=4, dtype=torch.long)\n            data[:, 3] = data[:, 3] // 15\n            events += [data]\n        return torch.cat(events, dim=0)\n\n    def process(self) -> None:\n        s = self.splits\n        data_list = self._process_data_list()\n        self.save(data_list[s[0]:s[1]], self.processed_paths[0])\n        self.save(data_list[s[1]:s[2]], self.processed_paths[1])\n        self.save(data_list[s[2]:s[3]], self.processed_paths[2])\n"
  },
  {
    "path": "torch_geometric/datasets/gdelt_lite.py",
    "content": "import os\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass GDELTLite(InMemoryDataset):\n    r\"\"\"The (reduced) version of the Global Database of Events, Language, and\n    Tone (GDELT) dataset used in the `\"Do We Really Need Complicated Model\n    Architectures for Temporal Networks?\" <http://arxiv.org/abs/2302.11636>`_\n    paper, consisting of events collected from 2016 to 2020.\n\n    Each node (actor) holds a 413-dimensional multi-hot feature vector that\n    represents CAMEO codes attached to the corresponding actor to server.\n\n    Each edge (event) holds a timestamp and a 186-dimensional multi-hot vector\n    representing CAMEO codes attached to the corresponding event to server.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 8,831\n          - 1,912,909\n          - 413\n          -\n    \"\"\"\n    url = 'https://data.pyg.org/datasets/gdelt_lite.zip'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['node_features.pt', 'edges.csv', 'edge_features.pt']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        x = fs.torch_load(self.raw_paths[0])\n        df = pd.read_csv(self.raw_paths[1])\n        edge_attr = fs.torch_load(self.raw_paths[2])\n\n        row = torch.from_numpy(df['src'].values)\n        col = torch.from_numpy(df['dst'].values)\n        edge_index = torch.stack([row, col], dim=0)\n        time = torch.from_numpy(df['time'].values).to(torch.long)\n\n        data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, time=time)\n        data = data if self.pre_transform is None else self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/ged_dataset.py",
    "content": "import glob\nimport os\nimport os.path as osp\nimport pickle\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_google_url,\n    extract_tar,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\nfrom torch_geometric.utils import one_hot, to_undirected\n\n\nclass GEDDataset(InMemoryDataset):\n    r\"\"\"The GED datasets from the `\"Graph Edit Distance Computation via Graph\n    Neural Networks\" <https://arxiv.org/abs/1808.05689>`_ paper.\n\n    GEDs can be accessed via the global attributes :obj:`ged` and\n    :obj:`norm_ged` for all train/train graph pairs and all train/test graph\n    pairs:\n\n    .. code-block:: python\n\n        dataset = GEDDataset(root, name=\"LINUX\")\n        data1, data2 = dataset[0], dataset[1]\n        ged = dataset.ged[data1.i, data2.i]  # GED between `data1` and `data2`.\n\n    Note that GEDs are not available if both graphs are from the test set.\n    For evaluation, it is recommended to pair up each graph from the test set\n    with each graph in the training set.\n\n    .. note::\n\n        :obj:`ALKANE` is missing GEDs for train/test graph pairs since they are\n        not provided in the `official datasets\n        <https://github.com/yunshengb/SimGNN>`_.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (one of :obj:`\"AIDS700nef\"`,\n            :obj:`\"LINUX\"`, :obj:`\"ALKANE\"`, :obj:`\"IMDBMulti\"`).\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 20 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - AIDS700nef\n          - 700\n          - ~8.9\n          - ~17.6\n          - 29\n          - 0\n        * - LINUX\n          - 1,000\n          - ~7.6\n          - ~13.9\n          - 0\n          - 0\n        * - ALKANE\n          - 150\n          - ~8.9\n          - ~15.8\n          - 0\n          - 0\n        * - IMDBMulti\n          - 1,500\n          - ~13.0\n          - ~131.9\n          - 0\n          - 0\n    \"\"\"\n    datasets = {\n        'AIDS700nef': {\n            'id': '10czBPJDEzEDI2tq7Z7mkBjLhj55F-a2z',\n            'extract': extract_zip,\n            'pickle': '1OpV4bCHjBkdpqI6H5Mg0-BqlA2ee2eBW',\n        },\n        'LINUX': {\n            'id': '1nw0RRVgyLpit4V4XFQyDy0pI6wUEXSOI',\n            'extract': extract_tar,\n            'pickle': '14FDm3NSnrBvB7eNpLeGy5Bz6FjuCSF5v',\n        },\n        'ALKANE': {\n            'id': '1-LmxaWW3KulLh00YqscVEflbqr0g4cXt',\n            'extract': extract_tar,\n            'pickle': '15BpvMuHx77-yUGYgM27_sQett02HQNYu',\n        },\n        'IMDBMulti': {\n            'id': '12QxZ7EhYA7pJiF4cO-HuE8szhSOWcfST',\n            'extract': extract_zip,\n            'pickle': '1wy9VbZvZodkixxVIOuRllC-Lp-0zdoYZ',\n        },\n    }\n\n    # List of atoms contained in the AIDS700nef dataset:\n    types = [\n        'O', 'S', 'C', 'N', 'Cl', 'Br', 'B', 'Si', 'Hg', 'I', 'Bi', 'P', 'F',\n        'Cu', 'Ho', 'Pd', 'Ru', 'Pt', 'Sn', 'Li', 'Ga', 'Tb', 'As', 'Co', 'Pb',\n        'Sb', 'Se', 'Ni', 'Te'\n    ]\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name\n        assert self.name in self.datasets.keys()\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = self.processed_paths[0] if train else self.processed_paths[1]\n        self.load(path)\n        path = osp.join(self.processed_dir, f'{self.name}_ged.pt')\n        self.ged = fs.torch_load(path)\n        path = osp.join(self.processed_dir, f'{self.name}_norm_ged.pt')\n        self.norm_ged = fs.torch_load(path)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        # Returns, e.g., ['LINUX/train', 'LINUX/test']\n        return [osp.join(self.name, s) for s in ['train', 'test']]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        # Returns, e.g., ['LINUX_training.pt', 'LINUX_test.pt']\n        return [f'{self.name}_{s}.pt' for s in ['training', 'test']]\n\n    def download(self) -> None:\n        # Downloads the .tar/.zip file of the graphs and extracts them:\n        id = self.datasets[self.name]['id']\n        assert isinstance(id, str)\n        path = download_google_url(id, self.raw_dir, 'data')\n        extract_fn = self.datasets[self.name]['extract']\n        assert callable(extract_fn)\n        extract_fn(path, self.raw_dir)\n        os.unlink(path)\n\n        # Downloads the pickle file containing pre-computed GEDs:\n        id = self.datasets[self.name]['pickle']\n        assert isinstance(id, str)\n        path = download_google_url(id, self.raw_dir, 'ged.pickle')\n\n    def process(self) -> None:\n        import networkx as nx\n\n        ids, Ns = [], []\n        # Iterating over paths for raw and processed data (train + test):\n        for r_path, p_path in zip(self.raw_paths, self.processed_paths):\n            # Find the paths of all raw graphs:\n            names = glob.glob(osp.join(r_path, '*.gexf'))\n            # Get sorted graph IDs given filename: 123.gexf -> 123\n            ids.append(sorted([int(osp.basename(i)[:-5]) for i in names]))\n\n            data_list = []\n            # Convert graphs in .gexf format to a NetworkX Graph:\n            for i, idx in enumerate(ids[-1]):\n                i = i if len(ids) == 1 else i + len(ids[0])\n                # Reading the raw `*.gexf` graph:\n                G = nx.read_gexf(osp.join(r_path, f'{idx}.gexf'))\n                # Mapping of nodes in `G` to a contiguous number:\n                mapping = {name: j for j, name in enumerate(G.nodes())}\n                G = nx.relabel_nodes(G, mapping)\n                Ns.append(G.number_of_nodes())\n\n                edge_index = torch.tensor(list(G.edges)).t().contiguous()\n                if edge_index.numel() == 0:\n                    edge_index = torch.empty((2, 0), dtype=torch.long)\n                edge_index = to_undirected(edge_index, num_nodes=Ns[-1])\n\n                data = Data(edge_index=edge_index, i=i)\n                data.num_nodes = Ns[-1]\n\n                # Create a one-hot encoded feature matrix denoting the atom\n                # type (for the `AIDS700nef` dataset):\n                if self.name == 'AIDS700nef':\n                    assert data.num_nodes is not None\n                    x = torch.zeros(data.num_nodes, dtype=torch.long)\n                    for node, info in G.nodes(data=True):\n                        x[int(node)] = self.types.index(info['type'])\n                    data.x = one_hot(x, num_classes=len(self.types))\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                data_list.append(data)\n\n            self.save(data_list, p_path)\n\n        assoc = {idx: i for i, idx in enumerate(ids[0])}\n        assoc.update({idx: i + len(ids[0]) for i, idx in enumerate(ids[1])})\n\n        # Extracting ground-truth GEDs from the GED pickle file\n        path = osp.join(self.raw_dir, self.name, 'ged.pickle')\n        # Initialize GEDs as float('inf'):\n        mat = torch.full((len(assoc), len(assoc)), float('inf'))\n        with open(path, 'rb') as f:\n            obj = pickle.load(f)\n            xs, ys, gs = [], [], []\n            for (_x, _y), g in obj.items():\n                xs += [assoc[_x]]\n                ys += [assoc[_y]]\n                gs += [g]\n            # The pickle file does not contain GEDs for test graph pairs, i.e.\n            # GEDs for (test_graph, test_graph) pairs are still float('inf'):\n            x, y = torch.tensor(xs), torch.tensor(ys)\n            ged = torch.tensor(gs, dtype=torch.float)\n            mat[x, y], mat[y, x] = ged, ged\n\n        path = osp.join(self.processed_dir, f'{self.name}_ged.pt')\n        torch.save(mat, path)\n\n        # Calculate the normalized GEDs:\n        N = torch.tensor(Ns, dtype=torch.float)\n        norm_mat = mat / (0.5 * (N.view(-1, 1) + N.view(1, -1)))\n\n        path = osp.join(self.processed_dir, f'{self.name}_norm_ged.pt')\n        torch.save(norm_mat, path)\n\n    def __repr__(self) -> str:\n        return f'{self.name}({len(self)})'\n"
  },
  {
    "path": "torch_geometric/datasets/gemsec.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass GemsecDeezer(InMemoryDataset):\n    r\"\"\"The Deezer User Network datasets introduced in the\n    `\"GEMSEC: Graph Embedding with Self Clustering\"\n    <https://arxiv.org/abs/1802.03997>`_ paper.\n    Nodes represent Deezer user and edges are mutual friendships.\n    The task is multi-label multi-class node classification about\n    the genres liked by the users.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"HU\"`, :obj:`\"HR\"`,\n            :obj:`\"RO\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'https://graphmining.ai/datasets/ptg/gemsec'\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name\n        assert self.name in ['HU', 'HR', 'RO']\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> str:\n        return f'{self.name}.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(osp.join(self.url, self.name + '.npz'), self.raw_dir)\n\n    def process(self) -> None:\n        data = np.load(self.raw_paths[0], 'r', allow_pickle=True)\n        y = torch.from_numpy(data['target']).to(torch.long)\n        edge_index = torch.from_numpy(data['edges']).to(torch.long)\n        edge_index = edge_index.t().contiguous()\n\n        data = Data(y=y, edge_index=edge_index)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/geometry.py",
    "content": "import glob\nimport os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import read_off\n\n\nclass GeometricShapes(InMemoryDataset):\n    r\"\"\"Synthetic dataset of various geometric shapes like cubes, spheres or\n    pyramids.\n\n    .. note::\n\n        Data objects hold mesh faces instead of edge indices.\n        To convert the mesh to a graph, use the\n        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.\n        To convert the mesh to a point cloud, use the\n        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to\n        sample a fixed number of points on the mesh faces according to their\n        face area.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 80\n          - ~148.8\n          - ~859.5\n          - 3\n          - 40\n    \"\"\"\n\n    url = 'https://github.com/Yannick-S/geometric_shapes/raw/master/raw.zip'\n\n    def __init__(\n        self,\n        root: str,\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = self.processed_paths[0] if train else self.processed_paths[1]\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> str:\n        return '2d_circle'\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['training.pt', 'test.pt']\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        os.unlink(path)\n\n    def process(self) -> None:\n        self.save(self.process_set('train'), self.processed_paths[0])\n        self.save(self.process_set('test'), self.processed_paths[1])\n\n    def process_set(self, dataset: str) -> List[Data]:\n        categories = glob.glob(osp.join(self.raw_dir, '*', ''))\n        categories = sorted([x.split(os.sep)[-2] for x in categories])\n\n        data_list = []\n        for target, category in enumerate(categories):\n            folder = osp.join(self.raw_dir, category, dataset)\n            paths = glob.glob(f'{folder}/*.off')\n            for path in paths:\n                data = read_off(path)\n                assert data.pos is not None\n                data.pos = data.pos - data.pos.mean(dim=0, keepdim=True)\n                data.y = torch.tensor([target])\n                data_list.append(data)\n\n        if self.pre_filter is not None:\n            data_list = [d for d in data_list if self.pre_filter(d)]\n\n        if self.pre_transform is not None:\n            data_list = [self.pre_transform(d) for d in data_list]\n\n        return data_list\n"
  },
  {
    "path": "torch_geometric/datasets/git_mol_dataset.py",
    "content": "import sys\nfrom typing import Any, Callable, Dict, List, Optional\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_google_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\ndef safe_index(lst: List[Any], e: int) -> int:\n    return lst.index(e) if e in lst else len(lst) - 1\n\n\nclass GitMolDataset(InMemoryDataset):\n    r\"\"\"The dataset from the `\"GIT-Mol: A Multi-modal Large Language Model\n    for Molecular Science with Graph, Image, and Text\"\n    <https://arxiv.org/pdf/2308.06911>`_ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        split (int, optional): Datasets split, train/valid/test=0/1/2.\n            (default: :obj:`0`)\n    \"\"\"\n\n    raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n        split: int = 0,\n    ):\n        from torchvision import transforms\n\n        self.split = split\n\n        if self.split == 0:\n            self.img_transform = transforms.Compose([\n                transforms.Resize((224, 224)),\n                transforms.RandomRotation(15),\n                transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),\n                transforms.ToTensor(),\n                transforms.Normalize(mean=[0.485, 0.456, 0.406],\n                                     std=[0.229, 0.224, 0.225])\n            ])\n        else:\n            self.img_transform = transforms.Compose([\n                transforms.Resize((224, 224)),\n                transforms.ToTensor(),\n                transforms.Normalize(mean=[0.485, 0.456, 0.406],\n                                     std=[0.229, 0.224, 0.225])\n            ])\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl']\n\n    @property\n    def processed_file_names(self) -> str:\n        return ['train.pt', 'valid.pt', 'test.pt'][self.split]\n\n    def download(self) -> None:\n        file_path = download_google_url(\n            self.raw_url_id,\n            self.raw_dir,\n            'gitmol.zip',\n        )\n        extract_zip(file_path, self.raw_dir)\n\n    def process(self) -> None:\n        import pandas as pd\n        from PIL import Image\n\n        try:\n            from rdkit import Chem, RDLogger\n            RDLogger.DisableLog('rdApp.*')  # type: ignore[attr-defined]\n            WITH_RDKIT = True\n\n        except ImportError:\n            WITH_RDKIT = False\n\n        if not WITH_RDKIT:\n            print((\"Using a pre-processed version of the dataset. Please \"\n                   \"install 'rdkit' to alternatively process the raw data.\"),\n                  file=sys.stderr)\n\n            data_list = fs.torch_load(self.raw_paths[0])\n            data_list = [Data(**data_dict) for data_dict in data_list]\n\n            if self.pre_filter is not None:\n                data_list = [d for d in data_list if self.pre_filter(d)]\n\n            if self.pre_transform is not None:\n                data_list = [self.pre_transform(d) for d in data_list]\n\n            self.save(data_list, self.processed_paths[0])\n            return\n\n        allowable_features: Dict[str, List[Any]] = {\n            'possible_atomic_num_list':\n            list(range(1, 119)) + ['misc'],\n            'possible_formal_charge_list':\n            [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'],\n            'possible_chirality_list': [\n                Chem.rdchem.ChiralType.CHI_UNSPECIFIED,\n                Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,\n                Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,\n                Chem.rdchem.ChiralType.CHI_OTHER\n            ],\n            'possible_hybridization_list': [\n                Chem.rdchem.HybridizationType.SP,\n                Chem.rdchem.HybridizationType.SP2,\n                Chem.rdchem.HybridizationType.SP3,\n                Chem.rdchem.HybridizationType.SP3D,\n                Chem.rdchem.HybridizationType.SP3D2,\n                Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc'\n            ],\n            'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'],\n            'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],\n            'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'],\n            'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'],\n            'possible_is_aromatic_list': [False, True],\n            'possible_is_in_ring_list': [False, True],\n            'possible_bond_type_list': [\n                Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,\n                Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC,\n                Chem.rdchem.BondType.ZERO\n            ],\n            'possible_bond_dirs': [  # only for double bond stereo information\n                Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT,\n                Chem.rdchem.BondDir.ENDDOWNRIGHT\n            ],\n            'possible_bond_stereo_list': [\n                Chem.rdchem.BondStereo.STEREONONE,\n                Chem.rdchem.BondStereo.STEREOZ,\n                Chem.rdchem.BondStereo.STEREOE,\n                Chem.rdchem.BondStereo.STEREOCIS,\n                Chem.rdchem.BondStereo.STEREOTRANS,\n                Chem.rdchem.BondStereo.STEREOANY,\n            ],\n            'possible_is_conjugated_list': [False, True]\n        }\n\n        data = pd.read_pickle(\n            f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}')\n\n        data_list = []\n        for _, r in tqdm(data.iterrows(), total=data.shape[0]):\n            smiles = r['isosmiles']\n            mol = Chem.MolFromSmiles(smiles.strip('\\n'))\n            if mol is not None:\n                # text\n                summary = r['summary']\n                # image\n                cid = r['cid']\n                img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png'\n                img = Image.open(img_file).convert('RGB')\n                img = self.img_transform(img).unsqueeze(0)\n                # graph\n                atom_features_list = []\n                for atom in mol.GetAtoms():\n                    atom_feature = [\n                        safe_index(\n                            allowable_features['possible_atomic_num_list'],\n                            atom.GetAtomicNum()),\n                        allowable_features['possible_chirality_list'].index(\n                            atom.GetChiralTag()),\n                        safe_index(allowable_features['possible_degree_list'],\n                                   atom.GetTotalDegree()),\n                        safe_index(\n                            allowable_features['possible_formal_charge_list'],\n                            atom.GetFormalCharge()),\n                        safe_index(allowable_features['possible_numH_list'],\n                                   atom.GetTotalNumHs()),\n                        safe_index(\n                            allowable_features[\n                                'possible_number_radical_e_list'],\n                            atom.GetNumRadicalElectrons()),\n                        safe_index(\n                            allowable_features['possible_hybridization_list'],\n                            atom.GetHybridization()),\n                        allowable_features['possible_is_aromatic_list'].index(\n                            atom.GetIsAromatic()),\n                        allowable_features['possible_is_in_ring_list'].index(\n                            atom.IsInRing()),\n                    ]\n                    atom_features_list.append(atom_feature)\n                x = torch.tensor(np.array(atom_features_list),\n                                 dtype=torch.long)\n\n                edges_list = []\n                edge_features_list = []\n                for bond in mol.GetBonds():\n                    i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()\n                    edge_feature = [\n                        safe_index(\n                            allowable_features['possible_bond_type_list'],\n                            bond.GetBondType()),\n                        allowable_features['possible_bond_stereo_list'].index(\n                            bond.GetStereo()),\n                        allowable_features['possible_is_conjugated_list'].\n                        index(bond.GetIsConjugated()),\n                    ]\n                    edges_list.append((i, j))\n                    edge_features_list.append(edge_feature)\n                    edges_list.append((j, i))\n                    edge_features_list.append(edge_feature)\n\n                edge_index = torch.tensor(\n                    np.array(edges_list).T,\n                    dtype=torch.long,\n                )\n                edge_attr = torch.tensor(\n                    np.array(edge_features_list),\n                    dtype=torch.long,\n                )\n\n                data = Data(\n                    x=x,\n                    edge_index=edge_index,\n                    smiles=smiles,\n                    edge_attr=edge_attr,\n                    image=img,\n                    caption=summary,\n                )\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/github.py",
    "content": "from typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass GitHub(InMemoryDataset):\n    r\"\"\"The GitHub Web and ML Developers dataset introduced in the\n    `\"Multi-scale Attributed Node Embedding\"\n    <https://arxiv.org/abs/1909.13021>`_ paper.\n    Nodes represent developers on :obj:`github:`GitHub` and edges are mutual\n    follower relationships.\n    It contains 37,300 nodes, 578,006 edges, 128 node features and 2 classes.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 37,700\n          - 578,006\n          - 0\n          - 2\n    \"\"\"\n    url = 'https://graphmining.ai/datasets/ptg/github.npz'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'github.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(self.url, self.raw_dir)\n\n    def process(self) -> None:\n        data = np.load(self.raw_paths[0], 'r', allow_pickle=True)\n        x = torch.from_numpy(data['features']).to(torch.float)\n        y = torch.from_numpy(data['target']).to(torch.long)\n        edge_index = torch.from_numpy(data['edges']).to(torch.long)\n        edge_index = edge_index.t().contiguous()\n\n        data = Data(x=x, y=y, edge_index=edge_index)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/gnn_benchmark_dataset.py",
    "content": "import logging\nimport os\nimport os.path as osp\nimport pickle\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\nfrom torch_geometric.utils import remove_self_loops\n\n\nclass GNNBenchmarkDataset(InMemoryDataset):\n    r\"\"\"A variety of artificially and semi-artificially generated graph\n    datasets from the `\"Benchmarking Graph Neural Networks\"\n    <https://arxiv.org/abs/2003.00982>`_ paper.\n\n    .. note::\n        The ZINC dataset is provided via\n        :class:`torch_geometric.datasets.ZINC`.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (one of :obj:`\"PATTERN\"`,\n            :obj:`\"CLUSTER\"`, :obj:`\"MNIST\"`, :obj:`\"CIFAR10\"`,\n            :obj:`\"TSP\"`, :obj:`\"CSL\"`)\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset.\n            (default: :obj:`\"train\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 20 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - PATTERN\n          - 14,000\n          - ~118.9\n          - ~6,098.9\n          - 3\n          - 2\n        * - CLUSTER\n          - 12,000\n          - ~117.2\n          - ~4,303.9\n          - 7\n          - 6\n        * - MNIST\n          - 70,000\n          - ~70.6\n          - ~564.5\n          - 3\n          - 10\n        * - CIFAR10\n          - 60,000\n          - ~117.6\n          - ~941.2\n          - 5\n          - 10\n        * - TSP\n          - 12,000\n          - ~275.4\n          - ~6,885.0\n          - 2\n          - 2\n        * - CSL\n          - 150\n          - ~41.0\n          - ~164.0\n          - 0\n          - 10\n    \"\"\"\n\n    names = ['PATTERN', 'CLUSTER', 'MNIST', 'CIFAR10', 'TSP', 'CSL']\n\n    root_url = 'https://data.pyg.org/datasets/benchmarking-gnns'\n    urls = {\n        'PATTERN': f'{root_url}/PATTERN_v2.zip',\n        'CLUSTER': f'{root_url}/CLUSTER_v2.zip',\n        'MNIST': f'{root_url}/MNIST_v2.zip',\n        'CIFAR10': f'{root_url}/CIFAR10_v2.zip',\n        'TSP': f'{root_url}/TSP_v2.zip',\n        'CSL': 'https://www.dropbox.com/s/rnbkp5ubgk82ocu/CSL.zip?dl=1',\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        split: str = \"train\",\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name\n        assert self.name in self.names\n\n        if self.name == 'CSL' and split != 'train':\n            split = 'train'\n            logging.warning(\n                \"Dataset 'CSL' does not provide a standardized splitting. \"\n                \"Instead, it is recommended to perform 5-fold cross \"\n                \"validation with stratifed sampling\")\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n\n        if split == 'train':\n            path = self.processed_paths[0]\n        elif split == 'val':\n            path = self.processed_paths[1]\n        elif split == 'test':\n            path = self.processed_paths[2]\n        else:\n            raise ValueError(f\"Split '{split}' found, but expected either \"\n                             f\"'train', 'val', or 'test'\")\n\n        self.load(path)\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        if self.name == 'CSL':\n            return [\n                'graphs_Kary_Deterministic_Graphs.pkl',\n                'y_Kary_Deterministic_Graphs.pt'\n            ]\n        else:\n            name = self.urls[self.name].split('/')[-1][:-4]\n            return [f'{name}.pt']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        if self.name == 'CSL':\n            return ['data.pt']\n        else:\n            return ['train_data.pt', 'val_data.pt', 'test_data.pt']\n\n    def download(self) -> None:\n        path = download_url(self.urls[self.name], self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        if self.name == 'CSL':\n            data_list = self.process_CSL()\n            self.save(data_list, self.processed_paths[0])\n        else:\n            inputs = fs.torch_load(self.raw_paths[0])\n            for i in range(len(inputs)):\n                data_list = [Data(**data_dict) for data_dict in inputs[i]]\n\n                if self.pre_filter is not None:\n                    data_list = [d for d in data_list if self.pre_filter(d)]\n\n                if self.pre_transform is not None:\n                    data_list = [self.pre_transform(d) for d in data_list]\n\n                self.save(data_list, self.processed_paths[i])\n\n    def process_CSL(self) -> List[Data]:\n        with open(self.raw_paths[0], 'rb') as f:\n            adjs = pickle.load(f)\n\n        ys = fs.torch_load(self.raw_paths[1]).tolist()\n\n        data_list = []\n        for adj, y in zip(adjs, ys):\n            row, col = torch.from_numpy(adj.row), torch.from_numpy(adj.col)\n            edge_index = torch.stack([row, col], dim=0).to(torch.long)\n            edge_index, _ = remove_self_loops(edge_index)\n            data = Data(edge_index=edge_index, y=y, num_nodes=adj.shape[0])\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n            data_list.append(data)\n        return data_list\n\n    def __repr__(self) -> str:\n        return f'{self.name}({len(self)})'\n"
  },
  {
    "path": "torch_geometric/datasets/graph_generator/__init__.py",
    "content": "from .base import GraphGenerator\nfrom .ba_graph import BAGraph\nfrom .er_graph import ERGraph\nfrom .grid_graph import GridGraph\nfrom .tree_graph import TreeGraph\n\n__all__ = classes = [\n    'GraphGenerator',\n    'BAGraph',\n    'ERGraph',\n    'GridGraph',\n    'TreeGraph',\n]\n"
  },
  {
    "path": "torch_geometric/datasets/graph_generator/ba_graph.py",
    "content": "from torch_geometric.data import Data\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.utils import barabasi_albert_graph\n\n\nclass BAGraph(GraphGenerator):\n    r\"\"\"Generates random Barabasi-Albert (BA) graphs.\n    See :meth:`~torch_geometric.utils.barabasi_albert_graph` for more\n    information.\n\n    Args:\n        num_nodes (int): The number of nodes.\n        num_edges (int): The number of edges from a new node to existing nodes.\n    \"\"\"\n    def __init__(self, num_nodes: int, num_edges: int):\n        super().__init__()\n        self.num_nodes = num_nodes\n        self.num_edges = num_edges\n\n    def __call__(self) -> Data:\n        edge_index = barabasi_albert_graph(self.num_nodes, self.num_edges)\n        return Data(num_nodes=self.num_nodes, edge_index=edge_index)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, '\n                f'num_edges={self.num_edges})')\n"
  },
  {
    "path": "torch_geometric/datasets/graph_generator/base.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Any\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.resolver import resolver\n\n\nclass GraphGenerator(ABC):\n    r\"\"\"An abstract base class for generating synthetic graphs.\"\"\"\n    @abstractmethod\n    def __call__(self) -> Data:\n        r\"\"\"To be implemented by :class:`GraphGenerator` subclasses.\"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    def resolve(query: Any, *args: Any, **kwargs: Any) -> 'GraphGenerator':\n        import torch_geometric.datasets.graph_generator as _graph_generators\n        graph_generators = [\n            gen for gen in vars(_graph_generators).values()\n            if isinstance(gen, type) and issubclass(gen, GraphGenerator)\n        ]\n        return resolver(graph_generators, {}, query, GraphGenerator, 'Graph',\n                        *args, **kwargs)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/datasets/graph_generator/er_graph.py",
    "content": "from torch_geometric.data import Data\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.utils import erdos_renyi_graph\n\n\nclass ERGraph(GraphGenerator):\n    r\"\"\"Generates random Erdos-Renyi (ER) graphs.\n    See :meth:`~torch_geometric.utils.erdos_renyi_graph` for more information.\n\n    Args:\n        num_nodes (int): The number of nodes.\n        edge_prob (float): Probability of an edge.\n    \"\"\"\n    def __init__(self, num_nodes: int, edge_prob: float):\n        super().__init__()\n        self.num_nodes = num_nodes\n        self.edge_prob = edge_prob\n\n    def __call__(self) -> Data:\n        edge_index = erdos_renyi_graph(self.num_nodes, self.edge_prob)\n        return Data(num_nodes=self.num_nodes, edge_index=edge_index)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, '\n                f'edge_prob={self.edge_prob})')\n"
  },
  {
    "path": "torch_geometric/datasets/graph_generator/grid_graph.py",
    "content": "from typing import Optional\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.utils import grid\n\n\nclass GridGraph(GraphGenerator):\n    r\"\"\"Generates two-dimensional grid graphs.\n    See :meth:`~torch_geometric.utils.grid` for more information.\n\n    Args:\n        height (int): The height of the grid.\n        width (int): The width of the grid.\n        dtype (:obj:`torch.dtype`, optional): The desired data type of the\n            returned position tensor. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        height: int,\n        width: int,\n        dtype: Optional[torch.dtype] = None,\n    ):\n        super().__init__()\n        self.height = height\n        self.width = width\n        self.dtype = dtype\n\n    def __call__(self) -> Data:\n        edge_index, pos = grid(height=self.height, width=self.width,\n                               dtype=self.dtype)\n        return Data(edge_index=edge_index, pos=pos)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(height={self.height}, '\n                f'width={self.width})')\n"
  },
  {
    "path": "torch_geometric/datasets/graph_generator/tree_graph.py",
    "content": "from typing import List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.utils import to_undirected\n\n\ndef tree(\n    depth: int,\n    branch: int = 2,\n    undirected: bool = False,\n    device: Optional[torch.device] = None,\n) -> Tuple[Tensor, Tensor]:\n    \"\"\"Generates a tree graph with the given depth and branch size, along with\n    node-level depth indicators.\n\n    Args:\n        depth (int): The depth of the tree.\n        branch (int, optional): The branch size of the tree.\n            (default: :obj:`2`)\n        undirected (bool, optional): If set to :obj:`True`, the tree graph will\n            be undirected. (default: :obj:`False`)\n        device (torch.device, optional): The desired device of the returned\n            tensors. (default: :obj:`None`)\n    \"\"\"\n    edges: List[Tuple[int, int]] = []\n    depths: List[int] = [0]\n\n    def add_edges(node: int, current_depth: int) -> None:\n        node_count = len(depths)\n\n        if current_depth < depth:\n            for i in range(branch):\n                edges.append((node, node_count + i))\n                depths.append(current_depth + 1)\n\n            for i in range(branch):\n                add_edges(node=node_count + i, current_depth=current_depth + 1)\n\n    add_edges(node=0, current_depth=0)\n\n    edge_index = torch.tensor(edges, device=device).t().contiguous()\n    if undirected:\n        edge_index = to_undirected(edge_index, num_nodes=len(depths))\n\n    return edge_index, torch.tensor(depths, device=device)\n\n\nclass TreeGraph(GraphGenerator):\n    r\"\"\"Generates tree graphs.\n\n    Args:\n        depth (int): The depth of the tree.\n        branch (int, optional): The branch size of the tree.\n            (default: :obj:`2`)\n        undirected (bool, optional): If set to :obj:`True`, the tree graph will\n            be undirected. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        depth: int,\n        branch: int = 2,\n        undirected: bool = False,\n    ) -> None:\n        super().__init__()\n        self.depth = depth\n        self.branch = branch\n        self.undirected = undirected\n\n    def __call__(self) -> Data:\n        edge_index, depth = tree(self.depth, self.branch, self.undirected)\n        num_nodes = depth.numel()\n        return Data(edge_index=edge_index, depth=depth, num_nodes=num_nodes)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(depth={self.depth}, '\n                f'branch={self.branch}, undirected={self.undirected})')\n"
  },
  {
    "path": "torch_geometric/datasets/heterophilous_graph_dataset.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.utils import to_undirected\n\n\nclass HeterophilousGraphDataset(InMemoryDataset):\n    r\"\"\"The heterophilous graphs :obj:`\"Roman-empire\"`,\n    :obj:`\"Amazon-ratings\"`, :obj:`\"Minesweeper\"`, :obj:`\"Tolokers\"` and\n    :obj:`\"Questions\"` from the `\"A Critical Look at the Evaluation of GNNs\n    under Heterophily: Are We Really Making Progress?\"\n    <https://arxiv.org/abs/2302.11640>`_ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"Roman-empire\"`,\n            :obj:`\"Amazon-ratings\"`, :obj:`\"Minesweeper\"`, :obj:`\"Tolokers\"`,\n            :obj:`\"Questions\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - Roman-empire\n          - 22,662\n          - 32,927\n          - 300\n          - 18\n        * - Amazon-ratings\n          - 24,492\n          - 93,050\n          - 300\n          - 5\n        * - Minesweeper\n          - 10,000\n          - 39,402\n          - 7\n          - 2\n        * - Tolokers\n          - 11,758\n          - 519,000\n          - 10\n          - 2\n        * - Questions\n          - 48,921\n          - 153,540\n          - 301\n          - 2\n    \"\"\"\n    url = ('https://github.com/yandex-research/heterophilous-graphs/raw/'\n           'main/data')\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower().replace('-', '_')\n        assert self.name in [\n            'roman_empire',\n            'amazon_ratings',\n            'minesweeper',\n            'tolokers',\n            'questions',\n        ]\n\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> str:\n        return f'{self.name}.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(f'{self.url}/{self.name}.npz', self.raw_dir)\n\n    def process(self) -> None:\n        raw = np.load(self.raw_paths[0], 'r')\n        x = torch.from_numpy(raw['node_features'])\n        y = torch.from_numpy(raw['node_labels'])\n        edge_index = torch.from_numpy(raw['edges']).t().contiguous()\n        edge_index = to_undirected(edge_index, num_nodes=x.size(0))\n        train_mask = torch.from_numpy(raw['train_masks']).t().contiguous()\n        val_mask = torch.from_numpy(raw['val_masks']).t().contiguous()\n        test_mask = torch.from_numpy(raw['test_masks']).t().contiguous()\n\n        data = Data(x=x, y=y, edge_index=edge_index, train_mask=train_mask,\n                    val_mask=val_mask, test_mask=test_mask)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(name={self.name})'\n"
  },
  {
    "path": "torch_geometric/datasets/hgb_dataset.py",
    "content": "import json\nimport os\nimport os.path as osp\nfrom collections import defaultdict\nfrom typing import Callable, Dict, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_google_url,\n    extract_zip,\n)\n\n\nclass HGBDataset(InMemoryDataset):\n    r\"\"\"A variety of heterogeneous graph benchmark datasets from the\n    `\"Are We Really Making Much Progress? Revisiting, Benchmarking, and\n    Refining Heterogeneous Graph Neural Networks\"\n    <http://keg.cs.tsinghua.edu.cn/jietang/publications/\n    KDD21-Lv-et-al-HeterGNN.pdf>`_ paper.\n\n    .. note::\n        Test labels are randomly given to prevent data leakage issues.\n        If you want to obtain final test performance, you will need to submit\n        your model predictions to the\n        `HGB leaderboard <https://www.biendata.xyz/hgb/>`_.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (one of :obj:`\"ACM\"`,\n            :obj:`\"DBLP\"`, :obj:`\"Freebase\"`, :obj:`\"IMDB\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :class:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :class:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    names = {\n        'acm': 'ACM',\n        'dblp': 'DBLP',\n        'freebase': 'Freebase',\n        'imdb': 'IMDB',\n    }\n\n    file_ids = {\n        'acm': '1xbJ4QE9pcDJOcALv7dYhHDCPITX2Iddz',\n        'dblp': '1fLLoy559V7jJaQ_9mQEsC06VKd6Qd3SC',\n        'freebase': '1vw-uqbroJZfFsWpriC1CWbtHCJMGdWJ7',\n        'imdb': '18qXmmwKJBrEJxVQaYwKTL3Ny3fPqJeJ2',\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in set(self.names.keys())\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        x = ['info.dat', 'node.dat', 'link.dat', 'label.dat', 'label.dat.test']\n        return [osp.join(self.names[self.name], f) for f in x]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        id = self.file_ids[self.name]\n        path = download_google_url(id, self.raw_dir, 'data.zip')\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        data = HeteroData()\n\n        # node_types = {0: 'paper', 1, 'author', ...}\n        # edge_types = {0: ('paper', 'cite', 'paper'), ...}\n        if self.name in ['acm', 'dblp', 'imdb']:\n            with open(self.raw_paths[0]) as f:  # `info.dat`\n                info = json.load(f)\n            n_types = info['node.dat']['node type']\n            n_types = {int(k): v for k, v in n_types.items()}\n            e_types = info['link.dat']['link type']\n            e_types = {int(k): tuple(v.values()) for k, v in e_types.items()}\n            for key, (src, dst, rel) in e_types.items():\n                src, dst = n_types[int(src)], n_types[int(dst)]\n                rel = rel.split('-')[1]\n                rel = rel if rel != dst and rel[1:] != dst else 'to'\n                e_types[key] = (src, rel, dst)\n            num_classes = len(info['label.dat']['node type']['0'])\n        elif self.name in ['freebase']:\n            with open(self.raw_paths[0]) as f:  # `info.dat`\n                info = f.read().split('\\n')\n            start = info.index('TYPE\\tMEANING') + 1\n            end = info[start:].index('')\n            n_types = [v.split('\\t\\t') for v in info[start:start + end]]\n            n_types = {int(k): v.lower() for k, v in n_types}\n\n            e_types = {}\n            start = info.index('LINK\\tSTART\\tEND\\tMEANING') + 1\n            end = info[start:].index('')\n            for key, row in enumerate(info[start:start + end]):\n                edge = row.split('\\t')[1:]\n                src, dst, rel = (v for v in edge if v != '')\n                src, dst = n_types[int(src)], n_types[int(dst)]\n                rel = rel.split('-')[1]\n                e_types[key] = (src, rel, dst)\n        else:  # Link prediction:\n            raise NotImplementedError\n\n        # Extract node information:\n        mapping_dict = {}  # Maps global node indices to local ones.\n        x_dict = defaultdict(list)\n        num_nodes_dict: Dict[str, int] = defaultdict(int)\n        with open(self.raw_paths[1]) as f:  # `node.dat`\n            xs = [v.split('\\t') for v in f.read().split('\\n')[:-1]]\n        for x in xs:\n            n_id, n_type = int(x[0]), n_types[int(x[2])]\n            mapping_dict[n_id] = num_nodes_dict[n_type]\n            num_nodes_dict[n_type] += 1\n            if len(x) >= 4:  # Extract features (in case they are given).\n                x_dict[n_type].append([float(v) for v in x[3].split(',')])\n        for n_type in n_types.values():\n            if len(x_dict[n_type]) == 0:\n                data[n_type].num_nodes = num_nodes_dict[n_type]\n            else:\n                data[n_type].x = torch.tensor(x_dict[n_type])\n\n        edge_index_dict = defaultdict(list)\n        edge_weight_dict = defaultdict(list)\n        with open(self.raw_paths[2]) as f:  # `link.dat`\n            edges = [v.split('\\t') for v in f.read().split('\\n')[:-1]]\n        for src, dst, rel, weight in edges:\n            e_type = e_types[int(rel)]\n            src, dst = mapping_dict[int(src)], mapping_dict[int(dst)]\n            edge_index_dict[e_type].append([src, dst])\n            edge_weight_dict[e_type].append(float(weight))\n        for e_type in e_types.values():\n            edge_index = torch.tensor(edge_index_dict[e_type])\n            edge_weight = torch.tensor(edge_weight_dict[e_type])\n            data[e_type].edge_index = edge_index.t().contiguous()\n            # Only add \"weighted\" edgel to the graph:\n            if not torch.allclose(edge_weight, torch.ones_like(edge_weight)):\n                data[e_type].edge_weight = edge_weight\n\n        # Node classification:\n        if self.name in ['acm', 'dblp', 'freebase', 'imdb']:\n            with open(self.raw_paths[3]) as f:  # `label.dat`\n                train_ys = [v.split('\\t') for v in f.read().split('\\n')[:-1]]\n            with open(self.raw_paths[4]) as f:  # `label.dat.test`\n                test_ys = [v.split('\\t') for v in f.read().split('\\n')[:-1]]\n            for y in train_ys:\n                n_id, n_type = mapping_dict[int(y[0])], n_types[int(y[2])]\n\n                if not hasattr(data[n_type], 'y'):\n                    num_nodes = data[n_type].num_nodes\n                    if self.name in ['imdb']:  # multi-label\n                        data[n_type].y = torch.zeros((num_nodes, num_classes))\n                    else:\n                        data[n_type].y = torch.full((num_nodes, ), -1).long()\n                    data[n_type].train_mask = torch.zeros(num_nodes).bool()\n                    data[n_type].test_mask = torch.zeros(num_nodes).bool()\n\n                if data[n_type].y.dim() > 1:  # multi-label\n                    for v in y[3].split(','):\n                        data[n_type].y[n_id, int(v)] = 1\n                else:\n                    data[n_type].y[n_id] = int(y[3])\n                data[n_type].train_mask[n_id] = True\n            for y in test_ys:\n                n_id, n_type = mapping_dict[int(y[0])], n_types[int(y[2])]\n                if data[n_type].y.dim() > 1:  # multi-label\n                    for v in y[3].split(','):\n                        data[n_type].y[n_id, int(v)] = 1\n                else:\n                    data[n_type].y[n_id] = int(y[3])\n                data[n_type].test_mask[n_id] = True\n\n        else:  # Link prediction:\n            raise NotImplementedError\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.names[self.name]}()'\n"
  },
  {
    "path": "torch_geometric/datasets/hm.py",
    "content": "from typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import HeteroData, InMemoryDataset\n\n\nclass HM(InMemoryDataset):\n    r\"\"\"The heterogeneous H&M dataset from the `Kaggle H&M Personalized Fashion\n    Recommendations\n    <https://www.kaggle.com/competitions/h-and-m-personalized-fashion-recommendations>`_\n    challenge.\n    The task is to develop product recommendations based on data from previous\n    transactions, as well as from customer and product meta data.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        use_all_tables_as_node_types (bool, optional): If set to :obj:`True`,\n            will use the transaction table as a distinct node type.\n            (default: :obj:`False`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    url = ('https://www.kaggle.com/competitions/'\n           'h-and-m-personalized-fashion-recommendations/data')\n\n    def __init__(\n        self,\n        root: str,\n        use_all_tables_as_node_types: bool = False,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.use_all_tables_as_node_types = use_all_tables_as_node_types\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'customers.csv.zip', 'articles.csv.zip',\n            'transactions_train.csv.zip'\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        if self.use_all_tables_as_node_types:\n            return 'data.pt'\n        else:\n            return 'data_merged.pt'\n\n    def download(self) -> None:\n        raise RuntimeError(\n            f\"Dataset not found. Please download {self.raw_file_names} from \"\n            f\"'{self.url}' and move it to '{self.raw_dir}'\")\n\n    def process(self) -> None:\n        import pandas as pd\n\n        data = HeteroData()\n\n        # Process customer data ###############################################\n        df = pd.read_csv(self.raw_paths[0], index_col='customer_id')\n        customer_map = {idx: i for i, idx in enumerate(df.index)}\n\n        xs = []\n        for name in [\n                'Active', 'FN', 'club_member_status', 'fashion_news_frequency'\n        ]:\n            x = pd.get_dummies(df[name]).values\n            xs.append(torch.from_numpy(x).to(torch.float))\n\n        x = torch.from_numpy(df['age'].values).to(torch.float).view(-1, 1)\n        x = x.nan_to_num(nan=x.nanmean())  # type: ignore\n        xs.append(x / x.max())\n\n        data['customer'].x = torch.cat(xs, dim=-1)\n\n        # Process article data ################################################\n        df = pd.read_csv(self.raw_paths[1], index_col='article_id')\n        article_map = {idx: i for i, idx in enumerate(df.index)}\n\n        xs = []\n        for name in [  # We drop a few columns here that are high cardinality.\n                # 'product_code',  # Drop.\n                # 'prod_name',  # Drop.\n                'product_type_no',\n                'product_type_name',\n                'product_group_name',\n                'graphical_appearance_no',\n                'graphical_appearance_name',\n                'colour_group_code',\n                'colour_group_name',\n                'perceived_colour_value_id',\n                'perceived_colour_value_name',\n                'perceived_colour_master_id',\n                'perceived_colour_master_name',\n                # 'department_no',  # Drop.\n                # 'department_name',  # Drop.\n                'index_code',\n                'index_name',\n                'index_group_no',\n                'index_group_name',\n                'section_no',\n                'section_name',\n                'garment_group_no',\n                'garment_group_name',\n                # 'detail_desc',  # Drop.\n        ]:\n            x = pd.get_dummies(df[name]).values\n            xs.append(torch.from_numpy(x).to(torch.float))\n\n        data['article'].x = torch.cat(xs, dim=-1)\n\n        # Process transaction data ############################################\n        df = pd.read_csv(self.raw_paths[2], parse_dates=['t_dat'])\n\n        x1 = pd.get_dummies(df['sales_channel_id']).values\n        x1 = torch.from_numpy(x1).to(torch.float)\n        x2 = torch.from_numpy(df['price'].values).to(torch.float).view(-1, 1)\n        x = torch.cat([x1, x2], dim=-1)\n\n        time = torch.from_numpy(df['t_dat'].values.astype(int))\n        time = time // (60 * 60 * 24 * 10**9)  # Convert nanoseconds to days.\n\n        src = torch.tensor([customer_map[idx] for idx in df['customer_id']])\n        dst = torch.tensor([article_map[idx] for idx in df['article_id']])\n\n        if self.use_all_tables_as_node_types:\n            data['transaction'].x = x\n            data['transaction'].time = time\n\n            edge_index = torch.stack([src, torch.arange(len(df))], dim=0)\n            data['customer', 'to', 'transaction'].edge_index = edge_index\n            edge_index = edge_index.flip([0])\n            data['transaction', 'rev_to', 'customer'].edge_index = edge_index\n\n            edge_index = torch.stack([dst, torch.arange(len(df))], dim=0)\n            data['article', 'to', 'transaction'].edge_index = edge_index\n            edge_index = edge_index.flip([0])\n            data['transaction', 'rev_to', 'article'].edge_index = edge_index\n        else:\n            edge_index = torch.stack([src, dst], dim=0)\n            data['customer', 'to', 'article'].edge_index = edge_index\n            data['customer', 'to', 'article'].time = time\n            data['customer', 'to', 'article'].edge_attr = x\n\n            edge_index = edge_index.flip([0])\n            data['article', 'rev_to', 'customer'].edge_index = edge_index\n            data['article', 'rev_to', 'customer'].time = time\n            data['article', 'rev_to', 'customer'].edge_attr = x\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/hydro_net.py",
    "content": "import copy\nimport os\nimport os.path as osp\nfrom dataclasses import dataclass\nfrom functools import cached_property\nfrom glob import glob\nfrom pathlib import Path\nfrom typing import Callable, List, MutableSequence, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch.utils.data import ConcatDataset, Subset\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.data.data import BaseData\n\n\nclass HydroNet(InMemoryDataset):\n    r\"\"\"The HydroNet dataest from the\n    `\"HydroNet: Benchmark Tasks for Preserving Intermolecular Interactions and\n    Structural Motifs in Predictive and Generative Models for Molecular Data\"\n    <https://arxiv.org/abs/2012.00131>`_ paper, consisting of 5 million water\n    clusters held together by hydrogen bonding networks.  This dataset\n    provides atomic coordinates and total energy in kcal/mol for the cluster.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str, optional): Name of the subset of the full dataset to use:\n            :obj:`\"small\"` uses 500k graphs sampled from the :obj:`\"medium\"`\n            dataset, :obj:`\"medium\"` uses 2.7m graphs with maximum size of 75\n            nodes.\n            Mutually exclusive option with the clusters argument.\n            (default :obj:`None`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        num_workers (int): Number of multiprocessing workers to use for\n            pre-processing the dataset. (default :obj:`8`)\n        clusters (int or List[int], optional): Select a subset of clusters\n            from the full dataset. If set to :obj:`None`, will select all.\n            (default :obj:`None`)\n        use_processed (bool): Option to use a pre-processed version of the\n            original :obj:`xyz` dataset. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        root: str,\n        name: Optional[str] = None,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n        num_workers: int = 8,\n        clusters: Optional[Union[int, List[int]]] = None,\n        use_processed: bool = True,\n    ) -> None:\n        self.name = name\n        self.num_workers = num_workers\n        self.use_processed = use_processed\n\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n\n        self.select_clusters(clusters)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [f'W{c}_geoms_all.zip' for c in range(3, 31)]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return [f'W{c}_geoms_all.npz' for c in range(3, 31)]\n\n    def download(self) -> None:\n        token_file = Path(osp.join(self.raw_dir, 'use_processed'))\n        if self.use_processed and token_file.exists():\n            return\n\n        file = RemoteFile.hydronet_splits()\n        file.unpack_to(self.raw_dir)\n\n        if self.use_processed:\n            file = RemoteFile.processed_dataset()\n            file.unpack_to(self.raw_dir)\n            token_file.touch()\n            return\n\n        file = RemoteFile.raw_dataset()\n        file.unpack_to(self.raw_dir)\n        folder_name, _ = osp.splitext(file.name)\n        files = glob(osp.join(self.raw_dir, folder_name, '*.zip'))\n\n        for f in files:\n            dst = osp.join(self.raw_dir, osp.basename(f))\n            os.rename(f, dst)\n\n        os.rmdir(osp.join(self.raw_dir, folder_name))\n\n    def process(self) -> None:\n        if self.use_processed:\n            return self._unpack_processed()\n\n        from tqdm.contrib.concurrent import process_map\n\n        self._partitions = process_map(\n            self._create_partitions,\n            self.raw_paths,\n            max_workers=self.num_workers,\n            position=0,\n            leave=True,\n        )\n\n    def _unpack_processed(self) -> None:\n        files = glob(osp.join(self.raw_dir, '*.npz'))\n        for f in files:\n            dst = osp.join(self.processed_dir, osp.basename(f))\n            os.rename(f, dst)\n\n    def _create_partitions(self, file: str) -> 'Partition':\n        name = osp.basename(file)\n        name, _ = osp.splitext(name)\n        return Partition(self.root, name, self.transform, self.pre_transform)\n\n    def select_clusters(\n        self,\n        clusters: Optional[Union[int, List[int]]],\n    ) -> None:\n        if self.name is not None:\n            clusters = self._validate_name(clusters)\n\n        self._partitions = [self._create_partitions(f) for f in self.raw_paths]\n\n        if clusters is None:\n            return\n\n        clusters = [clusters] if isinstance(clusters, int) else clusters\n\n        def is_valid_cluster(x: Union[int, List[int]]) -> bool:\n            return isinstance(x, int) and x >= 3 and x <= 30\n\n        if not all([is_valid_cluster(x) for x in clusters]):\n            raise ValueError(\n                \"Selected clusters must be an integer in the range [3, 30]\")\n\n        self._partitions = [self._partitions[c - 3] for c in clusters]\n\n    def _validate_name(\n        self,\n        clusters: Optional[Union[int, List[int]]],\n    ) -> List[int]:\n        if clusters is not None:\n            raise ValueError(\"'name' and 'clusters' are mutually exclusive\")\n\n        if self.name not in ['small', 'medium']:\n            raise ValueError(f\"Invalid subset name '{self.name}'. \"\n                             f\"Must be either 'small' or 'medium'\")\n\n        return list(range(3, 26))\n\n    @cached_property\n    def _dataset(self) -> Union[ConcatDataset, Subset]:\n        dataset: ConcatDataset = ConcatDataset(self._partitions)\n\n        if self.name == \"small\":\n            return self._load_small_split(dataset)\n\n        return dataset\n\n    def _load_small_split(self, dataset: ConcatDataset) -> Subset:\n        split_file = osp.join(self.processed_dir, 'split_00_small.npz')\n        with np.load(split_file) as split:\n            train_idx = split['train_idx']\n            val_idx = split['val_idx']\n        all_idx = np.concatenate([train_idx, val_idx])\n        return Subset(dataset, all_idx)\n\n    def len(self) -> int:\n        return len(self._dataset)\n\n    def get(self, idx: int) -> Data:\n        return self._dataset[idx]\n\n\ndef get_num_clusters(filepath: str) -> int:\n    name = osp.basename(filepath)\n    return int(name[1:name.find('_')])\n\n\ndef read_energy(file: str, chunk_size: int) -> np.ndarray:\n    import pandas as pd\n\n    def skipatoms(i: int) -> bool:\n        return (i - 1) % chunk_size != 0\n\n    if chunk_size - 2 == 11 * 3:\n        # Manually handle bad lines in W11\n        df = pd.read_table(file, header=None, dtype=\"string\",\n                           skiprows=skipatoms)\n        df = df[0].str.split().str[-1].astype(np.float32)\n    else:\n        df = pd.read_table(file, sep=r'\\s+', names=[\"label\", \"energy\"],\n                           dtype=np.float32, skiprows=skipatoms,\n                           usecols=['energy'], memory_map=True)\n\n    return df.to_numpy().squeeze()\n\n\ndef read_atoms(file: str, chunk_size: int) -> Tuple[np.ndarray, np.ndarray]:\n    import pandas as pd\n\n    def skipheaders(i: int) -> bool:\n        return i % chunk_size == 0 or (i - 1) % chunk_size == 0\n\n    dtypes = {\n        'atom': 'string',\n        'x': np.float16,\n        'y': np.float16,\n        'z': np.float16\n    }\n    df = pd.read_table(file, sep=r'\\s+', names=list(dtypes.keys()),\n                       dtype=dtypes, skiprows=skipheaders, memory_map=True)\n\n    z = np.ones(len(df), dtype=np.int8)\n    z[(df.atom == 'O').to_numpy(dtype=np.bool_)] = 8\n    pos = df.iloc[:, 1:4].to_numpy()\n    num_nodes = (chunk_size - 2)\n    num_graphs = z.shape[0] // num_nodes\n    z.shape = (num_graphs, num_nodes)\n    pos.shape = (num_graphs, num_nodes, 3)\n    return (z, pos)\n\n\n@dataclass\nclass RemoteFile:\n    url: str\n    name: str\n\n    def unpack_to(self, dest_folder: str) -> None:\n        file = download_url(self.url, dest_folder, filename=self.name)\n        extract_zip(file, dest_folder)\n        os.unlink(file)\n\n    @staticmethod\n    def raw_dataset() -> 'RemoteFile':\n        return RemoteFile(\n            url='https://figshare.com/ndownloader/files/38063847',\n            name='W3-W30_all_geoms_TTM2.1-F.zip')\n\n    @staticmethod\n    def processed_dataset() -> 'RemoteFile':\n        return RemoteFile(\n            url='https://figshare.com/ndownloader/files/38075781',\n            name='W3-W30_pyg_processed.zip')\n\n    @staticmethod\n    def hydronet_splits() -> 'RemoteFile':\n        return RemoteFile(\n            url=\"https://figshare.com/ndownloader/files/38075904\",\n            name=\"hydronet_splits.zip\")\n\n\nclass Partition(InMemoryDataset):\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n    ) -> None:\n        self.name = name\n        self.num_clusters = get_num_clusters(name)\n        super().__init__(root, transform, pre_transform, pre_filter=None,\n                         log=False)\n        self.is_loaded = False\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [self.name + \".zip\"]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return [self.name + '.npz']\n\n    def process(self) -> None:\n        num_nodes = self.num_clusters * 3\n        chunk_size = num_nodes + 2\n        z, pos = read_atoms(self.raw_paths[0], chunk_size)\n        y = read_energy(self.raw_paths[0], chunk_size)\n        np.savez(self.processed_paths[0], z=z, pos=pos, y=y,\n                 num_graphs=z.shape[0])\n\n    def _load(self) -> None:\n        if self.is_loaded:\n            return None\n\n        with np.load(self.processed_paths[0]) as npzfile:\n            self.z = npzfile['z']\n            self.pos = npzfile['pos']\n            self.y = npzfile['y']\n            numel = int(npzfile['num_graphs'])\n\n        self._data_list: MutableSequence[Optional[BaseData]] = [None] * numel\n        self.is_loaded = True\n\n    @cached_property\n    def num_graphs(self) -> int:\n        with np.load(self.processed_paths[0]) as npzfile:\n            return int(npzfile['num_graphs'])\n\n    def len(self) -> int:\n        return self.num_graphs\n\n    def get(self, idx: int) -> Data:\n        self._load()\n\n        if self._data_list[idx] is not None:\n            cached_data = self._data_list[idx]\n            assert isinstance(cached_data, Data)\n            return copy.copy(cached_data)\n\n        data = Data(\n            z=torch.from_numpy(self.z[idx, :]),\n            pos=torch.from_numpy(self.pos[idx, :, :]),\n            y=torch.tensor(self.y[idx]),\n        )\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self._data_list[idx] = copy.copy(data)\n        return data\n"
  },
  {
    "path": "torch_geometric/datasets/icews.py",
    "content": "from typing import Callable, List, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.io import read_txt_array\n\n\nclass EventDataset(InMemoryDataset):\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n\n    @property\n    def num_nodes(self) -> int:\n        raise NotImplementedError\n\n    @property\n    def num_rels(self) -> int:\n        raise NotImplementedError\n\n    def process_events(self) -> Tensor:\n        raise NotImplementedError\n\n    def _process_data_list(self) -> List[Data]:\n        events = self.process_events()\n        events = events - events.min(dim=0, keepdim=True)[0]\n\n        data_list = []\n        for (sub, rel, obj, t) in events.tolist():\n            data = Data(sub=sub, rel=rel, obj=obj, t=t)\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n            data_list.append(data)\n\n        return data_list\n\n\nclass ICEWS18(EventDataset):\n    r\"\"\"The Integrated Crisis Early Warning System (ICEWS) dataset used in\n    the, *e.g.*, `\"Recurrent Event Network for Reasoning over Temporal\n    Knowledge Graphs\" <https://arxiv.org/abs/1904.05530>`_ paper, consisting of\n    events collected from 1/1/2018 to 10/31/2018 (24 hours time granularity).\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset. (default: :obj:`\"train\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'https://github.com/INK-USC/RE-Net/raw/master/data/ICEWS18'\n    splits = [0, 373018, 419013, 468558]  # Train/Val/Test splits.\n\n    def __init__(\n        self,\n        root: str,\n        split: str = 'train',\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        assert split in ['train', 'val', 'test']\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        idx = self.processed_file_names.index(f'{split}.pt')\n        self.load(self.processed_paths[idx])\n\n    @property\n    def num_nodes(self) -> int:\n        return 23033\n\n    @property\n    def num_rels(self) -> int:\n        return 256\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [f'{name}.txt' for name in ['train', 'valid', 'test']]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train.pt', 'val.pt', 'test.pt']\n\n    def download(self) -> None:\n        for filename in self.raw_file_names:\n            download_url(f'{self.url}/{filename}', self.raw_dir)\n\n    def process_events(self) -> Tensor:\n        events = []\n        for path in self.raw_paths:\n            data = read_txt_array(path, sep='\\t', end=4, dtype=torch.long)\n            data[:, 3] = data[:, 3] // 24\n            events += [data]\n        return torch.cat(events, dim=0)\n\n    def process(self) -> None:\n        s = self.splits\n        data_list = self._process_data_list()\n        self.save(data_list[s[0]:s[1]], self.processed_paths[0])\n        self.save(data_list[s[1]:s[2]], self.processed_paths[1])\n        self.save(data_list[s[2]:s[3]], self.processed_paths[2])\n"
  },
  {
    "path": "torch_geometric/datasets/igmc_dataset.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import HeteroData, InMemoryDataset, download_url\n\n\nclass IGMCDataset(InMemoryDataset):\n    r\"\"\"The user-item heterogeneous rating datasets :obj:`\"Douban\"`,\n    :obj:`\"Flixster\"` and :obj:`\"Yahoo-Music\"` from the `\"Inductive Matrix\n    Completion Based on Graph Neural Networks\"\n    <https://arxiv.org/abs/1904.12058>`_ paper.\n\n    Nodes represent users and items.\n    Edges and features between users and items represent a (training) rating of\n    the item given by the user.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"Douban\"`,\n            :obj:`\"Flixster\"`, :obj:`\"Yahoo-Music\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    url = 'https://github.com/muhanzhang/IGMC/raw/master/raw_data'\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower().replace('-', '_')\n        assert self.name in ['flixster', 'douban', 'yahoo_music']\n\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'training_test_dataset.mat'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = f'{self.url}/{self.name}/training_test_dataset.mat'\n        download_url(path, self.raw_dir)\n\n    @staticmethod\n    def load_matlab_file(path_file: str, name: str) -> Tensor:\n        import h5py\n        import numpy as np\n\n        db = h5py.File(path_file, 'r')\n        out = torch.from_numpy(np.asarray(db[name])).to(torch.float).t()\n        db.close()\n\n        return out\n\n    def process(self) -> None:\n        data = HeteroData()\n\n        M = self.load_matlab_file(self.raw_paths[0], 'M')\n\n        if self.name == 'flixster':\n            user_x = self.load_matlab_file(self.raw_paths[0], 'W_users')\n            item_x = self.load_matlab_file(self.raw_paths[0], 'W_movies')\n        elif self.name == 'douban':\n            user_x = self.load_matlab_file(self.raw_paths[0], 'W_users')\n            item_x = torch.eye(M.size(1))\n        elif self.name == 'yahoo_music':\n            user_x = torch.eye(M.size(0))\n            item_x = self.load_matlab_file(self.raw_paths[0], 'W_tracks')\n\n        data['user'].x = user_x\n        data['item'].x = item_x\n\n        train_mask = self.load_matlab_file(self.raw_paths[0], 'Otraining')\n        train_mask = train_mask.to(torch.bool)\n\n        edge_index = train_mask.nonzero().t()\n        rating = M[edge_index[0], edge_index[1]]\n\n        data['user', 'rates', 'item'].edge_index = edge_index\n        data['user', 'rates', 'item'].rating = rating\n\n        data['item', 'rated_by', 'user'].edge_index = edge_index.flip([0])\n        data['item', 'rated_by', 'user'].rating = rating\n\n        test_mask = self.load_matlab_file(self.raw_paths[0], 'Otest')\n        test_mask = test_mask.to(torch.bool)\n\n        edge_label_index = test_mask.nonzero().t()\n        edge_label = M[edge_label_index[0], edge_label_index[1]]\n\n        data['user', 'rates', 'item'].edge_label_index = edge_label_index\n        data['user', 'rates', 'item'].edge_label = edge_label\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(name={self.name})'\n"
  },
  {
    "path": "torch_geometric/datasets/imdb.py",
    "content": "import os\nimport os.path as osp\nfrom itertools import product\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\n\n\nclass IMDB(InMemoryDataset):\n    r\"\"\"A subset of the Internet Movie Database (IMDB), as collected in the\n    `\"MAGNN: Metapath Aggregated Graph Neural Network for Heterogeneous Graph\n    Embedding\" <https://arxiv.org/abs/2002.01680>`_ paper.\n    IMDB is a heterogeneous graph containing three types of entities - movies\n    (4,278 nodes), actors (5,257 nodes), and directors (2,081 nodes).\n    The movies are divided into three classes (action, comedy, drama) according\n    to their genre.\n    Movie features correspond to elements of a bag-of-words representation of\n    its plot keywords.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    url = 'https://www.dropbox.com/s/g0btk9ctr1es39x/IMDB_processed.zip?dl=1'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'adjM.npz', 'features_0.npz', 'features_1.npz', 'features_2.npz',\n            'labels.npy', 'train_val_test_idx.npz'\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.remove(path)\n\n    def process(self) -> None:\n        import scipy.sparse as sp\n\n        data = HeteroData()\n\n        node_types = ['movie', 'director', 'actor']\n        for i, node_type in enumerate(node_types):\n            x = sp.load_npz(osp.join(self.raw_dir, f'features_{i}.npz'))\n            data[node_type].x = torch.from_numpy(x.todense()).to(torch.float)\n\n        y = np.load(osp.join(self.raw_dir, 'labels.npy'))\n        data['movie'].y = torch.from_numpy(y).to(torch.long)\n\n        split = np.load(osp.join(self.raw_dir, 'train_val_test_idx.npz'))\n        for name in ['train', 'val', 'test']:\n            idx = split[f'{name}_idx']\n            idx = torch.from_numpy(idx).to(torch.long)\n            mask = torch.zeros(data['movie'].num_nodes, dtype=torch.bool)\n            mask[idx] = True\n            data['movie'][f'{name}_mask'] = mask\n\n        s = {}\n        N_m = data['movie'].num_nodes\n        N_d = data['director'].num_nodes\n        N_a = data['actor'].num_nodes\n        s['movie'] = (0, N_m)\n        s['director'] = (N_m, N_m + N_d)\n        s['actor'] = (N_m + N_d, N_m + N_d + N_a)\n\n        A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz'))\n        for src, dst in product(node_types, node_types):\n            A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo()\n            if A_sub.nnz > 0:\n                row = torch.from_numpy(A_sub.row).to(torch.long)\n                col = torch.from_numpy(A_sub.col).to(torch.long)\n                data[src, dst].edge_index = torch.stack([row, col], dim=0)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/datasets/infection_dataset.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport torch\n\nfrom torch_geometric.data import InMemoryDataset\nfrom torch_geometric.datasets.graph_generator import GraphGenerator\nfrom torch_geometric.explain import Explanation\nfrom torch_geometric.utils import k_hop_subgraph\n\n\nclass InfectionDataset(InMemoryDataset):\n    r\"\"\"Generates a synthetic infection dataset for evaluating explainabilty\n    algorithms, as described in the `\"Explainability Techniques for Graph\n    Convolutional Networks\" <https://arxiv.org/abs/1905.13686>`__ paper.\n    The :class:`~torch_geometric.datasets.InfectionDataset` creates synthetic\n    graphs coming from a\n    :class:`~torch_geometric.datasets.graph_generator.GraphGenerator` with\n    :obj:`num_infected` randomly assigned infected nodes.\n    The dataset describes a node classification task of predicting the length\n    of the shortest path to infected nodes, with corresponding ground-truth\n    edge-level masks.\n\n    For example, to generate a random Erdos-Renyi (ER) infection graph\n    with :obj:`500` nodes and :obj:`0.004` edge probability, write:\n\n    .. code-block:: python\n\n        from torch_geometric.datasets import InfectionDataset\n        from torch_geometric.datasets.graph_generator import ERGraph\n\n        dataset = InfectionDataset(\n            graph_generator=ERGraph(num_nodes=500, edge_prob=0.004),\n            num_infected_nodes=50,\n            max_path_length=3,\n        )\n\n    Args:\n        graph_generator (GraphGenerator or str): The graph generator to be\n            used, *e.g.*,\n            :class:`torch.geometric.datasets.graph_generator.BAGraph`\n            (or any string that automatically resolves to it).\n        num_infected_nodes (int or List[int]): The number of randomly\n            selected infected nodes in the graph.\n            If given as a list, will select a different number of infected\n            nodes for different graphs.\n        max_path_length (int, List[int]): The maximum shortest path length to\n            determine whether a node will be infected.\n            If given as a list, will apply different shortest path lengths for\n            different graphs. (default: :obj:`5`)\n        num_graphs (int, optional): The number of graphs to generate.\n            The number of graphs will be automatically determined by\n            :obj:`len(num_infected_nodes)` or :obj:`len(max_path_length)` in\n            case either of them is given as a list, and should only be set in\n            case one wants to create multiple graphs while\n            :obj:`num_infected_nodes` and :obj:`max_path_length` are given as\n            an integer. (default: :obj:`None`)\n        graph_generator_kwargs (Dict[str, Any], optional): Arguments passed to\n            the respective graph generator module in case it gets automatically\n            resolved. (default: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        graph_generator: Union[GraphGenerator, str],\n        num_infected_nodes: Union[int, List[int]],\n        max_path_length: Union[int, List[int]],\n        num_graphs: Optional[int] = None,\n        graph_generator_kwargs: Optional[Dict[str, Any]] = None,\n        transform: Optional[Callable] = None,\n    ):\n        super().__init__(root=None, transform=transform)\n\n        assert isinstance(num_infected_nodes, (int, list))\n        assert isinstance(max_path_length, (int, list))\n\n        if (num_graphs is None and isinstance(num_infected_nodes, int)\n                and isinstance(max_path_length, int)):\n            num_graphs = 1\n\n        if num_graphs is None and isinstance(num_infected_nodes, list):\n            num_graphs = len(num_infected_nodes)\n\n        if num_graphs is None and isinstance(max_path_length, list):\n            num_graphs = len(max_path_length)\n\n        assert num_graphs is not None\n\n        self.graph_generator = GraphGenerator.resolve(\n            graph_generator,\n            **(graph_generator_kwargs or {}),\n        )\n        self.num_infected_nodes = num_infected_nodes\n        self.max_path_length = max_path_length\n        self.num_graphs = num_graphs\n\n        if isinstance(num_infected_nodes, int):\n            num_infected_nodes = [num_infected_nodes] * num_graphs\n\n        if isinstance(max_path_length, int):\n            max_path_length = [max_path_length] * num_graphs\n\n        if len(num_infected_nodes) != num_graphs:\n            raise ValueError(f\"The length of 'num_infected_nodes' \"\n                             f\"(got {len(num_infected_nodes)} does not match \"\n                             f\"with the number of graphs (got {num_graphs})\")\n\n        if len(max_path_length) != num_graphs:\n            raise ValueError(f\"The length of 'max_path_length' \"\n                             f\"(got {len(max_path_length)} does not match \"\n                             f\"with the number of graphs (got {num_graphs})\")\n\n        if any(num_infected_nodes) <= 0:\n            raise ValueError(f\"'num_infected_nodes' needs to be positive \"\n                             f\"(got {min(num_infected_nodes)})\")\n\n        if any(max_path_length) <= 0:\n            raise ValueError(f\"'max_path_length' needs to be positive \"\n                             f\"(got {min(max_path_length)})\")\n\n        data_list: List[Explanation] = []\n        for N, L in zip(num_infected_nodes, max_path_length):\n            data_list.append(self.get_graph(N, L))\n\n        self.data, self.slices = self.collate(data_list)\n\n    def get_graph(self, num_infected_nodes: int,\n                  max_path_length: int) -> Explanation:\n        data = self.graph_generator()\n\n        assert data.num_nodes is not None\n        perm = torch.randperm(data.num_nodes)\n        x = torch.zeros((data.num_nodes, 2))\n        x[perm[:num_infected_nodes], 1] = 1  # Infected\n        x[perm[num_infected_nodes:], 0] = 1  # Healthy\n\n        y = torch.empty(data.num_nodes, dtype=torch.long)\n        y.fill_(max_path_length + 1)\n        y[perm[:num_infected_nodes]] = 0  # Infected nodes have label `0`.\n\n        assert data.edge_index is not None\n        edge_mask = torch.zeros(data.num_edges, dtype=torch.bool)\n        for num_hops in range(1, max_path_length + 1):\n            sub_node_index, _, _, sub_edge_mask = k_hop_subgraph(\n                perm[:num_infected_nodes], num_hops, data.edge_index,\n                num_nodes=data.num_nodes, flow='target_to_source',\n                directed=True)\n\n            value = torch.full_like(sub_node_index, fill_value=num_hops)\n            y[sub_node_index] = torch.min(y[sub_node_index], value)\n            edge_mask |= sub_edge_mask\n\n        return Explanation(\n            x=x,\n            edge_index=data.edge_index,\n            y=y,\n            edge_mask=edge_mask.to(torch.float),\n        )\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'graph_generator={self.graph_generator}, '\n                f'num_infected_nodes={self.num_infected_nodes}, '\n                f'max_path_length={self.max_path_length})')\n"
  },
  {
    "path": "torch_geometric/datasets/instruct_mol_dataset.py",
    "content": "import json\nimport sys\nfrom typing import Callable, List, Optional\n\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.io import fs\nfrom torch_geometric.utils import one_hot\n\n\nclass InstructMolDataset(InMemoryDataset):\n    r\"\"\"The dataset from the `\"InstructMol: Multi-Modal Integration for\n    Building a Versatile and Reliable Molecular Assistant in Drug Discovery\"\n    <https://arxiv.org/pdf/2311.16208>`_ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/resolve/main'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ):\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['all_clean.json']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['data.pt']\n\n    def download(self) -> None:\n        print('downloading dataset...')\n        fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir)\n\n    def process(self) -> None:\n        try:\n            from rdkit import Chem\n            from rdkit.Chem.rdchem import BondType as BT\n            WITH_RDKIT = True\n\n        except ImportError:\n            WITH_RDKIT = False\n\n        if not WITH_RDKIT:\n            print((\"Using a pre-processed version of the dataset. Please \"\n                   \"install 'rdkit' to alternatively process the raw data.\"),\n                  file=sys.stderr)\n\n            data_list = fs.torch_load(self.raw_paths[0])\n            data_list = [Data(**data_dict) for data_dict in data_list]\n\n            if self.pre_filter is not None:\n                data_list = [d for d in data_list if self.pre_filter(d)]\n\n            if self.pre_transform is not None:\n                data_list = [self.pre_transform(d) for d in data_list]\n\n            self.save(data_list, self.processed_paths[0])\n            return\n\n        # types of atom and bond\n        types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}\n        bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}\n\n        # load data\n        mols = json.load(open(f'{self.raw_dir}/all_clean.json'))\n\n        data_list = []\n        for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)):\n            mol = Chem.MolFromSmiles(smiles)\n            if mol is None:\n                continue\n\n            x: torch.Tensor = torch.tensor([\n                types[atom.GetSymbol()] if atom.GetSymbol() in types else 5\n                for atom in mol.GetAtoms()\n            ])\n            x = one_hot(x, num_classes=len(types), dtype=torch.float)\n\n            rows, cols, edge_types = [], [], []\n            for bond in mol.GetBonds():\n                i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()\n                edge_types += [bonds[bond.GetBondType()]] * 2\n                rows += [i, j]\n                cols += [j, i]\n\n            edge_index = torch.tensor([rows, cols], dtype=torch.long)\n            edge_type = torch.tensor(edge_types, dtype=torch.long)\n            edge_attr = one_hot(edge_type, num_classes=len(bonds))\n\n            for question, answer in qa_pairs:\n                data = Data(\n                    x=x,\n                    edge_index=edge_index,\n                    edge_attr=edge_attr,\n                    smiles=smiles,\n                    instruction=question,\n                    y=answer,\n                )\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/jodie.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nimport torch\n\nfrom torch_geometric.data import InMemoryDataset, TemporalData, download_url\n\n\nclass JODIEDataset(InMemoryDataset):\n    r\"\"\"The temporal graph datasets\n    from the `\"JODIE: Predicting Dynamic Embedding\n    Trajectory in Temporal Interaction Networks\"\n    <https://cs.stanford.edu/~srijan/pubs/jodie-kdd2019.pdf>`_ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"Reddit\"`,\n            :obj:`\"Wikipedia\"`, :obj:`\"MOOC\"`, and :obj:`\"LastFM\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - Reddit\n          - 6,509\n          - 25,470\n          - 172\n          - 1\n        * - Wikipedia\n          - 9,227\n          - 157,474\n          - 172\n          - 2\n        * - MOOC\n          - 7,144\n          - 411,749\n          - 4\n          - 2\n        * - LastFM\n          - 1,980\n          - 1,293,103\n          - 2\n          - 1\n    \"\"\"\n    url = 'http://snap.stanford.edu/jodie/{}.csv'\n    names = ['reddit', 'wikipedia', 'mooc', 'lastfm']\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in self.names\n\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=TemporalData)\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> str:\n        return f'{self.name}.csv'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(self.url.format(self.name), self.raw_dir)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        df = pd.read_csv(self.raw_paths[0], skiprows=1, header=None)\n\n        src = torch.from_numpy(df.iloc[:, 0].values).to(torch.long)\n        dst = torch.from_numpy(df.iloc[:, 1].values).to(torch.long)\n        dst += int(src.max()) + 1\n        t = torch.from_numpy(df.iloc[:, 2].values).to(torch.long)\n        y = torch.from_numpy(df.iloc[:, 3].values).to(torch.long)\n        msg = torch.from_numpy(df.iloc[:, 4:].values).to(torch.float)\n\n        data = TemporalData(src=src, dst=dst, t=t, msg=msg, y=y)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.name.capitalize()}()'\n"
  },
  {
    "path": "torch_geometric/datasets/karate.py",
    "content": "from typing import Callable, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset\n\n\nclass KarateClub(InMemoryDataset):\n    r\"\"\"Zachary's karate club network from the `\"An Information Flow Model for\n    Conflict and Fission in Small Groups\"\n    <https://www.journals.uchicago.edu/doi/abs/10.1086/jar.33.4.3629752>`_\n    paper, containing 34 nodes,\n    connected by 156 (undirected and unweighted) edges.\n    Every node is labeled by one of four classes obtained via modularity-based\n    clustering, following the `\"Semi-supervised Classification with Graph\n    Convolutional Networks\" <https://arxiv.org/abs/1609.02907>`_ paper.\n    Training is based on a single labeled example per class, *i.e.* a total\n    number of 4 labeled nodes.\n\n    Args:\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 34\n          - 156\n          - 34\n          - 4\n    \"\"\"\n    def __init__(self, transform: Optional[Callable] = None):\n        super().__init__(None, transform)\n\n        row = [\n            0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n            1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4,\n            5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 10, 10,\n            10, 11, 12, 12, 13, 13, 13, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17,\n            18, 18, 19, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 23, 23, 23, 24,\n            24, 24, 25, 25, 25, 26, 26, 27, 27, 27, 27, 28, 28, 28, 29, 29, 29,\n            29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32,\n            32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33,\n            33, 33, 33, 33, 33, 33\n        ]\n        col = [\n            1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 17, 19, 21, 31, 0, 2, 3, 7,\n            13, 17, 19, 21, 30, 0, 1, 3, 7, 8, 9, 13, 27, 28, 32, 0, 1, 2, 7,\n            12, 13, 0, 6, 10, 0, 6, 10, 16, 0, 4, 5, 16, 0, 1, 2, 3, 0, 2, 30,\n            32, 33, 2, 33, 0, 4, 5, 0, 0, 3, 0, 1, 2, 3, 33, 32, 33, 32, 33, 5,\n            6, 0, 1, 32, 33, 0, 1, 33, 32, 33, 0, 1, 32, 33, 25, 27, 29, 32,\n            33, 25, 27, 31, 23, 24, 31, 29, 33, 2, 23, 24, 33, 2, 31, 33, 23,\n            26, 32, 33, 1, 8, 32, 33, 0, 24, 25, 28, 32, 33, 2, 8, 14, 15, 18,\n            20, 22, 23, 29, 30, 31, 33, 8, 9, 13, 14, 15, 18, 19, 20, 22, 23,\n            26, 27, 28, 29, 30, 31, 32\n        ]\n        edge_index = torch.tensor([row, col])\n\n        y = torch.tensor([  # Create communities.\n            1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1,\n            0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0\n        ])\n\n        x = torch.eye(y.size(0), dtype=torch.float)\n\n        # Select a single training node for each community\n        # (we just use the first one).\n        train_mask = torch.zeros(y.size(0), dtype=torch.bool)\n        for i in range(int(y.max()) + 1):\n            train_mask[(y == i).nonzero(as_tuple=False)[0]] = True\n\n        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask)\n\n        self.data, self.slices = self.collate([data])\n"
  },
  {
    "path": "torch_geometric/datasets/last_fm.py",
    "content": "import os\nimport os.path as osp\nfrom itertools import product\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\n\n\nclass LastFM(InMemoryDataset):\n    r\"\"\"A subset of the last.fm music website keeping track of users' listining\n    information from various sources, as collected in the\n    `\"MAGNN: Metapath Aggregated Graph Neural Network for Heterogeneous Graph\n    Embedding\" <https://arxiv.org/abs/2002.01680>`_ paper.\n    last.fm is a heterogeneous graph containing three types of entities - users\n    (1,892 nodes), artists (17,632 nodes), and artist tags (1,088 nodes).\n    This dataset can be used for link prediction, and no labels or features are\n    provided.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    url = 'https://www.dropbox.com/s/jvlbs09pz6zwcka/LastFM_processed.zip?dl=1'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'adjM.npz', 'node_types.npy', 'train_val_test_neg_user_artist.npz',\n            'train_val_test_pos_user_artist.npz'\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.remove(path)\n\n    def process(self) -> None:\n        import scipy.sparse as sp\n\n        data = HeteroData()\n\n        node_type_idx = np.load(osp.join(self.raw_dir, 'node_types.npy'))\n        node_type_idx = torch.from_numpy(node_type_idx).to(torch.long)\n\n        node_types = ['user', 'artist', 'tag']\n        for i, node_type in enumerate(node_types):\n            data[node_type].num_nodes = int((node_type_idx == i).sum())\n\n        pos_split = np.load(\n            osp.join(self.raw_dir, 'train_val_test_pos_user_artist.npz'))\n        neg_split = np.load(\n            osp.join(self.raw_dir, 'train_val_test_neg_user_artist.npz'))\n\n        for name in ['train', 'val', 'test']:\n            if name != 'train':\n                edge_index = pos_split[f'{name}_pos_user_artist']\n                edge_index = torch.from_numpy(edge_index)\n                edge_index = edge_index.t().to(torch.long).contiguous()\n                data['user', 'artist'][f'{name}_pos_edge_index'] = edge_index\n\n            edge_index = neg_split[f'{name}_neg_user_artist']\n            edge_index = torch.from_numpy(edge_index)\n            edge_index = edge_index.t().to(torch.long).contiguous()\n            data['user', 'artist'][f'{name}_neg_edge_index'] = edge_index\n\n        s = {}\n        N_u = data['user'].num_nodes\n        N_a = data['artist'].num_nodes\n        N_t = data['tag'].num_nodes\n        s['user'] = (0, N_u)\n        s['artist'] = (N_u, N_u + N_a)\n        s['tag'] = (N_u + N_a, N_u + N_a + N_t)\n\n        A = sp.load_npz(osp.join(self.raw_dir, 'adjM.npz'))\n        for src, dst in product(node_types, node_types):\n            A_sub = A[s[src][0]:s[src][1], s[dst][0]:s[dst][1]].tocoo()\n            if A_sub.nnz > 0:\n                row = torch.from_numpy(A_sub.row).to(torch.long)\n                col = torch.from_numpy(A_sub.col).to(torch.long)\n                data[src, dst].edge_index = torch.stack([row, col], dim=0)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/datasets/lastfm_asia.py",
    "content": "from typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass LastFMAsia(InMemoryDataset):\n    r\"\"\"The LastFM Asia Network dataset introduced in the `\"Characteristic\n    Functions on Graphs: Birds of a Feather, from Statistical Descriptors to\n    Parametric Models\" <https://arxiv.org/abs/2005.07959>`_ paper.\n    Nodes represent LastFM users from Asia and edges are friendships.\n    It contains 7,624 nodes, 55,612 edges, 128 node features and 18 classes.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'https://graphmining.ai/datasets/ptg/lastfm_asia.npz'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'lastfm_asia.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(self.url, self.raw_dir)\n\n    def process(self) -> None:\n        data = np.load(self.raw_paths[0], 'r', allow_pickle=True)\n        x = torch.from_numpy(data['features']).to(torch.float)\n        y = torch.from_numpy(data['target']).to(torch.long)\n        edge_index = torch.from_numpy(data['edges']).to(torch.long)\n        edge_index = edge_index.t().contiguous()\n\n        data = Data(x=x, y=y, edge_index=edge_index)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/linkx_dataset.py",
    "content": "import os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.io import fs\nfrom torch_geometric.utils import one_hot\n\n\nclass LINKXDataset(InMemoryDataset):\n    r\"\"\"A variety of non-homophilous graph datasets from the `\"Large Scale\n    Learning on Non-Homophilous Graphs: New Benchmarks and Strong Simple\n    Methods\" <https://arxiv.org/abs/2110.14446>`_ paper.\n\n    .. note::\n        Some of the datasets provided in :class:`LINKXDataset` are from other\n        sources, but have been updated with new features and/or labels.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"penn94\"`, :obj:`\"reed98\"`,\n            :obj:`\"amherst41\"`, :obj:`\"cornell5\"`, :obj:`\"johnshopkins55\"`,\n            :obj:`\"genius\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    github_url = ('https://github.com/CUAI/Non-Homophily-Large-Scale/'\n                  'raw/master/data')\n    gdrive_url = 'https://drive.usercontent.google.com/download?confirm=t'\n\n    facebook_datasets = [\n        'penn94', 'reed98', 'amherst41', 'cornell5', 'johnshopkins55'\n    ]\n\n    datasets = {\n        'penn94': {\n            'data.mat': f'{github_url}/facebook100/Penn94.mat'\n        },\n        'reed98': {\n            'data.mat': f'{github_url}/facebook100/Reed98.mat'\n        },\n        'amherst41': {\n            'data.mat': f'{github_url}/facebook100/Amherst41.mat',\n        },\n        'cornell5': {\n            'data.mat': f'{github_url}/facebook100/Cornell5.mat'\n        },\n        'johnshopkins55': {\n            'data.mat': f'{github_url}/facebook100/Johns%20Hopkins55.mat'\n        },\n        'genius': {\n            'data.mat': f'{github_url}/genius.mat'\n        },\n        'wiki': {\n            'wiki_views2M.pt':\n            f'{gdrive_url}&id=1p5DlVHrnFgYm3VsNIzahSsvCD424AyvP',\n            'wiki_edges2M.pt':\n            f'{gdrive_url}&id=14X7FlkjrlUgmnsYtPwdh-gGuFla4yb5u',\n            'wiki_features2M.pt':\n            f'{gdrive_url}&id=1ySNspxbK-snNoAZM7oxiWGvOnTRdSyEK'\n        }\n    }\n\n    splits = {\n        'penn94': f'{github_url}/splits/fb100-Penn94-splits.npy',\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in self.datasets.keys()\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        names = list(self.datasets[self.name].keys())\n        if self.name in self.splits:\n            names += [self.splits[self.name].split('/')[-1]]\n        return names\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for filename, path in self.datasets[self.name].items():\n            download_url(path, self.raw_dir, filename=filename)\n        if self.name in self.splits:\n            download_url(self.splits[self.name], self.raw_dir)\n\n    def _process_wiki(self) -> Data:\n        paths = {x.split('/')[-1]: x for x in self.raw_paths}\n        x = fs.torch_load(paths['wiki_features2M.pt'])\n        edge_index = fs.torch_load(paths['wiki_edges2M.pt']).t().contiguous()\n        y = fs.torch_load(paths['wiki_views2M.pt'])\n\n        return Data(x=x, edge_index=edge_index, y=y)\n\n    def _process_facebook(self) -> Data:\n        from scipy.io import loadmat\n\n        mat = loadmat(self.raw_paths[0])\n\n        A = mat['A'].tocsr().tocoo()\n        row = torch.from_numpy(A.row).to(torch.long)\n        col = torch.from_numpy(A.col).to(torch.long)\n        edge_index = torch.stack([row, col], dim=0)\n\n        metadata = torch.from_numpy(mat['local_info'].astype('int64'))\n\n        xs = []\n        y = metadata[:, 1] - 1  # gender label, -1 means unlabeled\n        x = torch.cat([metadata[:, :1], metadata[:, 2:]], dim=-1)\n        for i in range(x.size(1)):\n            _, out = x[:, i].unique(return_inverse=True)\n            xs.append(one_hot(out))\n        x = torch.cat(xs, dim=-1)\n\n        data = Data(x=x, edge_index=edge_index, y=y)\n\n        if self.name in self.splits:\n            splits = np.load(self.raw_paths[1], allow_pickle=True)\n            assert data.num_nodes is not None\n            sizes = (data.num_nodes, len(splits))\n            data.train_mask = torch.zeros(sizes, dtype=torch.bool)\n            data.val_mask = torch.zeros(sizes, dtype=torch.bool)\n            data.test_mask = torch.zeros(sizes, dtype=torch.bool)\n\n            for i, split in enumerate(splits):\n                data.train_mask[:, i][torch.tensor(split['train'])] = True\n                data.val_mask[:, i][torch.tensor(split['valid'])] = True\n                data.test_mask[:, i][torch.tensor(split['test'])] = True\n\n        return data\n\n    def _process_genius(self) -> Data:\n        from scipy.io import loadmat\n\n        mat = loadmat(self.raw_paths[0])\n        edge_index = torch.from_numpy(mat['edge_index']).to(torch.long)\n        x = torch.from_numpy(mat['node_feat']).to(torch.float)\n        y = torch.from_numpy(mat['label']).squeeze().to(torch.long)\n\n        return Data(x=x, edge_index=edge_index, y=y)\n\n    def process(self) -> None:\n        if self.name in self.facebook_datasets:\n            data = self._process_facebook()\n        elif self.name == 'genius':\n            data = self._process_genius()\n        elif self.name == 'wiki':\n            data = self._process_wiki()\n        else:\n            raise NotImplementedError(\n                f\"chosen dataset '{self.name}' is not implemented\")\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.name.capitalize()}({len(self)})'\n"
  },
  {
    "path": "torch_geometric/datasets/lrgb.py",
    "content": "import os\nimport os.path as osp\nimport pickle\nfrom typing import Callable, Dict, List, Optional\n\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass LRGBDataset(InMemoryDataset):\n    r\"\"\"The `\"Long Range Graph Benchmark (LRGB)\"\n    <https://arxiv.org/abs/2206.08164>`_\n    datasets which is a collection of 5 graph learning datasets with tasks\n    that are based on long-range dependencies in graphs. See the original\n    `source code <https://github.com/vijaydwivedi75/lrgb>`_ for more details\n    on the individual datasets.\n\n    +------------------------+-------------------+----------------------+\n    | Dataset                | Domain            | Task                 |\n    +========================+===================+======================+\n    | :obj:`PascalVOC-SP`    | Computer Vision   | Node Classification  |\n    +------------------------+-------------------+----------------------+\n    | :obj:`COCO-SP`         | Computer Vision   | Node Classification  |\n    +------------------------+-------------------+----------------------+\n    | :obj:`PCQM-Contact`    | Quantum Chemistry | Link Prediction      |\n    +------------------------+-------------------+----------------------+\n    | :obj:`Peptides-func`   | Chemistry         | Graph Classification |\n    +------------------------+-------------------+----------------------+\n    | :obj:`Peptides-struct` | Chemistry         | Graph Regression     |\n    +------------------------+-------------------+----------------------+\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (one of :obj:`\"PascalVOC-SP\"`,\n            :obj:`\"COCO-SP\"`, :obj:`\"PCQM-Contact\"`, :obj:`\"Peptides-func\"`,\n            :obj:`\"Peptides-struct\"`)\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset.\n            (default: :obj:`\"train\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 15 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #graphs\n          - #nodes\n          - #edges\n          - #classes\n        * - PascalVOC-SP\n          - 11,355\n          - ~479.40\n          - ~2,710.48\n          - 21\n        * - COCO-SP\n          - 123,286\n          - ~476.88\n          - ~2,693.67\n          - 81\n        * - PCQM-Contact\n          - 529,434\n          - ~30.14\n          - ~61.09\n          - 1\n        * - Peptides-func\n          - 15,535\n          - ~150.94\n          - ~307.30\n          - 10\n        * - Peptides-struct\n          - 15,535\n          - ~150.94\n          - ~307.30\n          - 11\n    \"\"\"\n    names = [\n        'pascalvoc-sp', 'coco-sp', 'pcqm-contact', 'peptides-func',\n        'peptides-struct'\n    ]\n\n    urls = {\n        'pascalvoc-sp':\n        'https://www.dropbox.com/s/8x722ai272wqwl4/pascalvocsp.zip?dl=1',\n        'coco-sp':\n        'https://www.dropbox.com/s/r6ihg1f4pmyjjy0/cocosp.zip?dl=1',\n        'pcqm-contact':\n        'https://www.dropbox.com/s/qdag867u6h6i60y/pcqmcontact.zip?dl=1',\n        'peptides-func':\n        'https://www.dropbox.com/s/ycsq37q8sxs1ou8/peptidesfunc.zip?dl=1',\n        'peptides-struct':\n        'https://www.dropbox.com/s/zgv4z8fcpmknhs8/peptidesstruct.zip?dl=1'\n    }\n\n    dwnld_file_name = {\n        'pascalvoc-sp': 'voc_superpixels_edge_wt_region_boundary',\n        'coco-sp': 'coco_superpixels_edge_wt_region_boundary',\n        'pcqm-contact': 'pcqmcontact',\n        'peptides-func': 'peptidesfunc',\n        'peptides-struct': 'peptidesstruct'\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        split: str = \"train\",\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in self.names\n        assert split in ['train', 'val', 'test']\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = osp.join(self.processed_dir, f'{split}.pt')\n        self.load(path)\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        if self.name.split('-')[1] == 'sp':\n            return ['train.pickle', 'val.pickle', 'test.pickle']\n        else:\n            return ['train.pt', 'val.pt', 'test.pt']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train.pt', 'val.pt', 'test.pt']\n\n    def download(self) -> None:\n        fs.rm(self.raw_dir)\n        path = download_url(self.urls[self.name], self.root)\n        extract_zip(path, self.root)\n        os.rename(osp.join(self.root, self.dwnld_file_name[self.name]),\n                  self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        if self.name == 'pcqm-contact':\n            # PCQM-Contact\n            self.process_pcqm_contact()\n        else:\n            if self.name == 'coco-sp':\n                # Label remapping for coco-sp.\n                # See self.label_remap_coco() func\n                label_map = self.label_remap_coco()\n\n            for split in ['train', 'val', 'test']:\n                if self.name.split('-')[1] == 'sp':\n                    # PascalVOC-SP and COCO-SP\n                    with open(osp.join(self.raw_dir, f'{split}.pickle'),\n                              'rb') as f:\n                        graphs = pickle.load(f)\n                elif self.name.split('-')[0] == 'peptides':\n                    # Peptides-func and Peptides-struct\n                    graphs = fs.torch_load(\n                        osp.join(self.raw_dir, f'{split}.pt'))\n\n                data_list = []\n                for graph in tqdm(graphs, desc=f'Processing {split} dataset'):\n                    if self.name.split('-')[1] == 'sp':\n                        \"\"\"\n                        PascalVOC-SP and COCO-SP\n                        Each `graph` is a tuple (x, edge_attr, edge_index, y)\n                            Shape of x : [num_nodes, 14]\n                            Shape of edge_attr : [num_edges, 2]\n                            Shape of edge_index : [2, num_edges]\n                            Shape of y : [num_nodes]\n                        \"\"\"\n                        x = graph[0].to(torch.float)\n                        edge_attr = graph[1].to(torch.float)\n                        edge_index = graph[2]\n                        y = torch.LongTensor(graph[3])\n                    elif self.name.split('-')[0] == 'peptides':\n                        \"\"\"\n                        Peptides-func and Peptides-struct\n                        Each `graph` is a tuple (x, edge_attr, edge_index, y)\n                            Shape of x : [num_nodes, 9]\n                            Shape of edge_attr : [num_edges, 3]\n                            Shape of edge_index : [2, num_edges]\n                            Shape of y : [1, 10] for Peptides-func,  or\n                                         [1, 11] for Peptides-struct\n                        \"\"\"\n                        x = graph[0]\n                        edge_attr = graph[1]\n                        edge_index = graph[2]\n                        y = graph[3]\n\n                    if self.name == 'coco-sp':\n                        for i, label in enumerate(y):\n                            y[i] = label_map[label.item()]\n\n                    data = Data(x=x, edge_index=edge_index,\n                                edge_attr=edge_attr, y=y)\n\n                    if self.pre_filter is not None and not self.pre_filter(\n                            data):\n                        continue\n\n                    if self.pre_transform is not None:\n                        data = self.pre_transform(data)\n\n                    data_list.append(data)\n\n                path = osp.join(self.processed_dir, f'{split}.pt')\n                self.save(data_list, path)\n\n    def label_remap_coco(self) -> Dict[int, int]:\n        # Util function for name 'COCO-SP'\n        # to remap the labels as the original label idxs are not contiguous\n        original_label_idx = [\n            0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19,\n            20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39,\n            40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57,\n            58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78,\n            79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90\n        ]\n\n        label_map = {}\n        for i, key in enumerate(original_label_idx):\n            label_map[key] = i\n\n        return label_map\n\n    def process_pcqm_contact(self) -> None:\n        for split in ['train', 'val', 'test']:\n            graphs = fs.torch_load(osp.join(self.raw_dir, f'{split}.pt'))\n\n            data_list = []\n            for graph in tqdm(graphs, desc=f'Processing {split} dataset'):\n                \"\"\"\n                PCQM-Contact\n                Each `graph` is a tuple (x, edge_attr, edge_index,\n                                        edge_label_index, edge_label)\n                    Shape of x : [num_nodes, 9]\n                    Shape of edge_attr : [num_edges, 3]\n                    Shape of edge_index : [2, num_edges]\n                    Shape of edge_label_index: [2, num_labeled_edges]\n                    Shape of edge_label : [num_labeled_edges]\n\n                    where,\n                    num_labeled_edges are negative edges and link pred labels,\n                    https://github.com/vijaydwivedi75/lrgb/blob/main/graphgps/loader/dataset/pcqm4mv2_contact.py#L192\n                \"\"\"\n                x = graph[0]\n                edge_attr = graph[1]\n                edge_index = graph[2]\n                edge_label_index = graph[3]\n                edge_label = graph[4]\n\n                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,\n                            edge_label_index=edge_label_index,\n                            edge_label=edge_label)\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                data_list.append(data)\n\n            self.save(data_list, osp.join(self.processed_dir, f'{split}.pt'))\n"
  },
  {
    "path": "torch_geometric/datasets/malnet_tiny.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, Dict, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_tar,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass MalNetTiny(InMemoryDataset):\n    r\"\"\"The MalNet Tiny dataset from the\n    `\"A Large-Scale Database for Graph Representation Learning\"\n    <https://openreview.net/pdf?id=1xDTDk3XPW>`_ paper.\n    :class:`MalNetTiny` contains 5,000 malicious and benign software function\n    call graphs across 5 different types. Each graph contains at most 5k nodes.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"trainval\"`, loads the training and validation dataset.\n            If :obj:`\"test\"`, loads the test dataset.\n            If :obj:`None`, loads the entire dataset.\n            (default: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    data_url = ('http://malnet.cc.gatech.edu/'\n                'graph-data/malnet-graphs-tiny.tar.gz')\n    split_url = 'http://malnet.cc.gatech.edu/split-info/split_info_tiny.zip'\n    splits = ['train', 'val', 'test']\n\n    def __init__(\n        self,\n        root: str,\n        split: Optional[str] = None,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        if split not in {'train', 'val', 'trainval', 'test', None}:\n            raise ValueError(f'Split \"{split}\" found, but expected either '\n                             f'\"train\", \"val\", \"trainval\", \"test\" or None')\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n        if split is not None:\n            split_slices = fs.torch_load(self.processed_paths[1])\n            if split == 'train':\n                self._indices = range(split_slices[0], split_slices[1])\n            elif split == 'val':\n                self._indices = range(split_slices[1], split_slices[2])\n            elif split == 'trainval':\n                self._indices = range(split_slices[0], split_slices[2])\n            elif split == 'test':\n                self._indices = range(split_slices[2], split_slices[3])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['malnet-graphs-tiny', osp.join('split_info_tiny', 'type')]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['data.pt', 'split_slices.pt']\n\n    def download(self) -> None:\n        path = download_url(self.data_url, self.raw_dir)\n        extract_tar(path, self.raw_dir)\n        os.unlink(path)\n\n        path = download_url(self.split_url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        y_map: Dict[str, int] = {}\n        data_list = []\n        split_slices = [0]\n\n        for split in ['train', 'val', 'test']:\n            with open(osp.join(self.raw_paths[1], f'{split}.txt')) as f:\n                filenames = f.read().split('\\n')[:-1]\n                split_slices.append(split_slices[-1] + len(filenames))\n\n            for filename in filenames:\n                path = osp.join(self.raw_paths[0], f'{filename}.edgelist')\n                malware_type = filename.split('/')[0]\n                y = y_map.setdefault(malware_type, len(y_map))\n\n                with open(path) as f:\n                    edges = f.read().split('\\n')[5:-1]\n\n                edge_indices = [[int(s) for s in e.split()] for e in edges]\n                edge_index = torch.tensor(edge_indices).t().contiguous()\n                num_nodes = int(edge_index.max()) + 1\n                data = Data(edge_index=edge_index, y=y, num_nodes=num_nodes)\n                data_list.append(data)\n\n        if self.pre_filter is not None:\n            data_list = [data for data in data_list if self.pre_filter(data)]\n\n        if self.pre_transform is not None:\n            data_list = [self.pre_transform(data) for data in data_list]\n\n        self.save(data_list, self.processed_paths[0])\n        torch.save(split_slices, self.processed_paths[1])\n"
  },
  {
    "path": "torch_geometric/datasets/md17.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_tar,\n    extract_zip,\n)\n\n\nclass MD17(InMemoryDataset):\n    r\"\"\"A variety of ab-initio molecular dynamics trajectories from the authors\n    of `sGDML <http://quantum-machine.org/gdml>`_.\n    This class provides access to the original MD17 datasets, their revised\n    versions, and the CCSD(T) trajectories.\n\n    For every trajectory, the dataset contains the Cartesian positions of atoms\n    (in Angstrom), their atomic numbers, as well as the total energy\n    (in kcal/mol) and forces (kcal/mol/Angstrom) on each atom.\n    The latter two are the regression targets for this collection.\n\n    .. note::\n\n        Data objects contain no edge indices as these are most commonly\n        constructed via the :obj:`torch_geometric.transforms.RadiusGraph`\n        transform, with its cut-off being a hyperparameter.\n\n    The `original MD17 dataset <https://arxiv.org/abs/1611.04678>`_ contains\n    ten molecule trajectories.\n    This version of the dataset was found to suffer from high numerical noise.\n    The `revised MD17 dataset <https://arxiv.org/abs/2007.09593>`_ contains the\n    same molecules, but the energies and forces were recalculated at the\n    PBE/def2-SVP level of theory using very tight SCF convergence and very\n    dense DFT integration grid.\n    The third version of the dataset contains fewer molecules, computed at the\n    CCSD(T) level of theory.\n    The benzene molecule at the DFT FHI-aims level of theory was\n    `released separately <https://arxiv.org/abs/1802.09238>`_.\n\n    Check the table below for detailed information on the molecule, level of\n    theory and number of data points contained in each dataset.\n    Which trajectory is loaded is determined by the :attr:`name` argument.\n    For the coupled cluster trajectories, the dataset comes with pre-defined\n    training and testing splits which are loaded separately via the\n    :attr:`train` argument.\n\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Molecule           | Level of Theory    | Name                          | #Examples |\n    +====================+====================+===============================+===========+\n    | Benzene            | DFT                | :obj:`benzene`                | 627,983   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Uracil             | DFT                | :obj:`uracil`                 | 133,770   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Naphthalene        | DFT                | :obj:`naphthalene`            | 326,250   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Aspirin            | DFT                | :obj:`aspirin`                | 211,762   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Salicylic acid     | DFT                | :obj:`salicylic acid`         | 320,231   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Malonaldehyde      | DFT                | :obj:`malonaldehyde`          | 993,237   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Ethanol            | DFT                | :obj:`ethanol`                | 555,092   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Toluene            | DFT                | :obj:`toluene`                | 442,790   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Paracetamol        | DFT                | :obj:`paracetamol`            | 106,490   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Azobenzene         | DFT                | :obj:`azobenzene`             | 99,999    |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Benzene (R)        | DFT (PBE/def2-SVP) | :obj:`revised benzene`        | 100,000   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Uracil (R)         | DFT (PBE/def2-SVP) | :obj:`revised uracil`         | 100,000   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Naphthalene (R)    | DFT (PBE/def2-SVP) | :obj:`revised naphthalene`    | 100,000   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Aspirin (R)        | DFT (PBE/def2-SVP) | :obj:`revised aspirin`        | 100,000   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Salicylic acid (R) | DFT (PBE/def2-SVP) | :obj:`revised salicylic acid` | 100,000   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Malonaldehyde (R)  | DFT (PBE/def2-SVP) | :obj:`revised malonaldehyde`  | 100,000   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Ethanol (R)        | DFT (PBE/def2-SVP) | :obj:`revised ethanol`        | 100,000   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Toluene (R)        | DFT (PBE/def2-SVP) | :obj:`revised toluene`        | 100,000   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Paracetamol (R)    | DFT (PBE/def2-SVP) | :obj:`revised paracetamol`    | 100,000   |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Azobenzene (R)     | DFT (PBE/def2-SVP) | :obj:`revised azobenzene`     | 99,988    |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Benzene            | CCSD(T)            | :obj:`benzene CCSD(T)`        | 1,500     |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Aspirin            | CCSD               | :obj:`aspirin CCSD`           | 1,500     |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Malonaldehyde      | CCSD(T)            | :obj:`malonaldehyde CCSD(T)`  | 1,500     |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Ethanol            | CCSD(T)            | :obj:`ethanol CCSD(T)`        | 2,000     |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Toluene            | CCSD(T)            | :obj:`toluene CCSD(T)`        | 1,501     |\n    +--------------------+--------------------+-------------------------------+-----------+\n    | Benzene            | DFT FHI-aims       | :obj:`benzene FHI-aims`       | 49,863    |\n    +--------------------+--------------------+-------------------------------+-----------+\n\n    .. warning::\n\n        It is advised to not train a model on more than 1,000 samples from the\n        original or revised MD17 dataset.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): Keyword of the trajectory that should be loaded.\n        train (bool, optional): Determines whether the train or test split\n            gets loaded for the coupled cluster trajectories.\n            (default: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 20 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #tasks\n        * - Benzene\n          - 627,983\n          - 12\n          - 0\n          - 1\n          - 2\n        * - Uracil\n          - 133,770\n          - 12\n          - 0\n          - 1\n          - 2\n        * - Naphthalene\n          - 326,250\n          - 10\n          - 0\n          - 1\n          - 2\n        * - Aspirin\n          - 211,762\n          - 21\n          - 0\n          - 1\n          - 2\n        * - Salicylic acid\n          - 320,231\n          - 16\n          - 0\n          - 1\n          - 2\n        * - Malonaldehyde\n          - 993,237\n          - 9\n          - 0\n          - 1\n          - 2\n        * - Ethanol\n          - 555,092\n          - 9\n          - 0\n          - 1\n          - 2\n        * - Toluene\n          - 442,790\n          - 15\n          - 0\n          - 1\n          - 2\n        * - Paracetamol\n          - 106,490\n          - 20\n          - 0\n          - 1\n          - 2\n        * - Azobenzene\n          - 99,999\n          - 24\n          - 0\n          - 1\n          - 2\n        * - Benzene (R)\n          - 100,000\n          - 12\n          - 0\n          - 1\n          - 2\n        * - Uracil (R)\n          - 100,000\n          - 12\n          - 0\n          - 1\n          - 2\n        * - Naphthalene (R)\n          - 100,000\n          - 10\n          - 0\n          - 1\n          - 2\n        * - Aspirin (R)\n          - 100,000\n          - 21\n          - 0\n          - 1\n          - 2\n        * - Salicylic acid (R)\n          - 100,000\n          - 16\n          - 0\n          - 1\n          - 2\n        * - Malonaldehyde (R)\n          - 100,000\n          - 9\n          - 0\n          - 1\n          - 2\n        * - Ethanol (R)\n          - 100,000\n          - 9\n          - 0\n          - 1\n          - 2\n        * - Toluene (R)\n          - 100,000\n          - 15\n          - 0\n          - 1\n          - 2\n        * - Paracetamol (R)\n          - 100,000\n          - 20\n          - 0\n          - 1\n          - 2\n        * - Azobenzene (R)\n          - 99,988\n          - 24\n          - 0\n          - 1\n          - 2\n        * - Benzene CCSD-T\n          - 1,500\n          - 12\n          - 0\n          - 1\n          - 2\n        * - Aspirin CCSD-T\n          - 1,500\n          - 21\n          - 0\n          - 1\n          - 2\n        * - Malonaldehyde CCSD-T\n          - 1,500\n          - 9\n          - 0\n          - 1\n          - 2\n        * - Ethanol CCSD-T\n          - 2000\n          - 9\n          - 0\n          - 1\n          - 2\n        * - Toluene CCSD-T\n          - 1,501\n          - 15\n          - 0\n          - 1\n          - 2\n        * - Benzene FHI-aims\n          - 49,863\n          - 12\n          - 0\n          - 1\n          - 2\n    \"\"\"  # noqa: E501\n\n    gdml_url = 'http://quantum-machine.org/gdml/data/npz'\n    revised_url = ('https://archive.materialscloud.org/record/'\n                   'file?filename=rmd17.tar.bz2&record_id=466')\n\n    file_names = {\n        'benzene': 'md17_benzene2017.npz',\n        'uracil': 'md17_uracil.npz',\n        'naphthalene': 'md17_naphthalene.npz',\n        'aspirin': 'md17_aspirin.npz',\n        'salicylic acid': 'md17_salicylic.npz',\n        'malonaldehyde': 'md17_malonaldehyde.npz',\n        'ethanol': 'md17_ethanol.npz',\n        'toluene': 'md17_toluene.npz',\n        'paracetamol': 'paracetamol_dft.npz',\n        'azobenzene': 'azobenzene_dft.npz',\n        'revised benzene': 'rmd17_benzene.npz',\n        'revised uracil': 'rmd17_uracil.npz',\n        'revised naphthalene': 'rmd17_naphthalene.npz',\n        'revised aspirin': 'rmd17_aspirin.npz',\n        'revised salicylic acid': 'rmd17_salicylic.npz',\n        'revised malonaldehyde': 'rmd17_malonaldehyde.npz',\n        'revised ethanol': 'rmd17_ethanol.npz',\n        'revised toluene': 'rmd17_toluene.npz',\n        'revised paracetamol': 'rmd17_paracetamol.npz',\n        'revised azobenzene': 'rmd17_azobenzene.npz',\n        'benzene CCSD(T)': 'benzene_ccsd_t.zip',\n        'aspirin CCSD': 'aspirin_ccsd.zip',\n        'malonaldehyde CCSD(T)': 'malonaldehyde_ccsd_t.zip',\n        'ethanol CCSD(T)': 'ethanol_ccsd_t.zip',\n        'toluene CCSD(T)': 'toluene_ccsd_t.zip',\n        'benzene FHI-aims': 'benzene2018_dft.npz',\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        train: Optional[bool] = None,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        if name not in self.file_names:\n            raise ValueError(f\"Unknown dataset name '{name}'\")\n\n        self.name = name\n        self.revised = 'revised' in name\n        self.ccsd = 'CCSD' in self.name\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n\n        if len(self.processed_file_names) == 1 and train is not None:\n            raise ValueError(\n                f\"'{self.name}' dataset does not provide pre-defined splits \"\n                f\"but the 'train' argument is set to '{train}'\")\n        elif len(self.processed_file_names) == 2 and train is None:\n            raise ValueError(\n                f\"'{self.name}' dataset does provide pre-defined splits but \"\n                f\"the 'train' argument was not specified\")\n\n        idx = 0 if train is None or train else 1\n        self.load(self.processed_paths[idx])\n\n    def mean(self) -> float:\n        assert isinstance(self._data, Data)\n        return float(self._data.energy.mean())\n\n    @property\n    def raw_dir(self) -> str:\n        if self.revised:\n            return osp.join(self.root, 'raw')\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> Union[str, List[str]]:\n        name = self.file_names[self.name]\n        if self.revised:\n            return osp.join('rmd17', 'npz_data', name)\n        elif self.ccsd:\n            return [name[:-4] + '-train.npz', name[:-4] + '-test.npz']\n        return name\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        if self.ccsd:\n            return ['train.pt', 'test.pt']\n        else:\n            return ['data.pt']\n\n    def download(self) -> None:\n        if self.revised:\n            path = download_url(self.revised_url, self.raw_dir)\n            extract_tar(path, self.raw_dir, mode='r:bz2')\n            os.unlink(path)\n        else:\n            url = f'{self.gdml_url}/{self.file_names[self.name]}'\n            path = download_url(url, self.raw_dir)\n            if self.ccsd:\n                extract_zip(path, self.raw_dir)\n                os.unlink(path)\n\n    def process(self) -> None:\n        it = zip(self.raw_paths, self.processed_paths)\n        for raw_path, processed_path in it:\n            raw_data = np.load(raw_path)\n\n            if self.revised:\n                z = torch.from_numpy(raw_data['nuclear_charges']).long()\n                pos = torch.from_numpy(raw_data['coords']).float()\n                energy = torch.from_numpy(raw_data['energies']).float()\n                force = torch.from_numpy(raw_data['forces']).float()\n            else:\n                z = torch.from_numpy(raw_data['z']).long()\n                pos = torch.from_numpy(raw_data['R']).float()\n                energy = torch.from_numpy(raw_data['E']).float()\n                force = torch.from_numpy(raw_data['F']).float()\n\n            data_list = []\n            for i in range(pos.size(0)):\n                data = Data(z=z, pos=pos[i], energy=energy[i], force=force[i])\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n                data_list.append(data)\n\n            self.save(data_list, processed_path)\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}({len(self)}, name='{self.name}')\"\n"
  },
  {
    "path": "torch_geometric/datasets/medshapenet.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset\n\n\nclass MedShapeNet(InMemoryDataset):\n    r\"\"\"The MedShapeNet datasets from the `\"MedShapeNet -- A Large-Scale\n    Dataset of 3D Medical Shapes for Computer Vision\"\n    <https://arxiv.org/abs/2308.16139>`_ paper,\n    containing 8 different type of structures (classes).\n\n    .. note::\n\n        Data objects hold mesh faces instead of edge indices.\n        To convert the mesh to a graph, use the\n        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.\n        To convert the mesh to a point cloud, use the\n        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to\n        sample a fixed number of points on the mesh faces according to their\n        face area.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        size (int): Number of invividual 3D structures to download per\n            type (classes).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        root: str,\n        size: int = 100,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.size = size\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n\n        path = self.processed_paths[0]\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            '3DTeethSeg', 'CoronaryArteries', 'FLARE', 'KITS', 'PULMONARY',\n            'SurgicalInstruments', 'ThoracicAorta_Saitta', 'ToothFairy'\n        ]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['dataset.pt']\n\n    @property\n    def raw_paths(self) -> List[str]:\n        r\"\"\"The absolute filepaths that must be present in order to skip\n        downloading.\n        \"\"\"\n        return [osp.join(self.raw_dir, f) for f in self.raw_file_names]\n\n    def process(self) -> None:\n        import urllib3\n        from MedShapeNet import MedShapeNet as msn\n\n        msn_instance = msn(timeout=120)\n\n        urllib3.HTTPConnectionPool(\"medshapenet.ddns.net\", maxsize=50)\n\n        list_of_datasets = msn_instance.datasets(False)\n        list_of_datasets = list(\n            filter(\n                lambda x: x not in [\n                    'medshapenetcore/ASOCA', 'medshapenetcore/AVT',\n                    'medshapenetcore/AutoImplantCraniotomy',\n                    'medshapenetcore/FaceVR'\n                ], list_of_datasets))\n\n        subset = []\n        for dataset in list_of_datasets:\n            parts = dataset.split(\"/\")\n            self.newpath = self.root + '/' + parts[1 if len(parts) > 1 else 0]\n            if not os.path.exists(self.newpath):\n                os.makedirs(self.newpath)\n            stl_files = msn_instance.dataset_files(dataset, '.stl')\n            subset.extend(stl_files[:self.size])\n\n            for stl_file in stl_files[:self.size]:\n                msn_instance.download_stl_as_numpy(bucket_name=dataset,\n                                                   stl_file=stl_file,\n                                                   output_dir=self.newpath,\n                                                   print_output=False)\n\n        class_mapping = {\n            '3DTeethSeg': 0,\n            'CoronaryArteries': 1,\n            'FLARE': 2,\n            'KITS': 3,\n            'PULMONARY': 4,\n            'SurgicalInstruments': 5,\n            'ThoracicAorta_Saitta': 6,\n            'ToothFairy': 7\n        }\n\n        for dataset, path in zip([subset], self.processed_paths):\n            data_list = []\n            for item in dataset:\n                class_name = item.split(\"/\")[0]\n                item = item.split(\"stl\")[0]\n                target = class_mapping[class_name]\n                file = osp.join(self.root, item + 'npz')\n\n                data = np.load(file)\n                pre_data_list = Data(\n                    pos=torch.tensor(data[\"vertices\"], dtype=torch.float),\n                    face=torch.tensor(data[\"faces\"],\n                                      dtype=torch.long).t().contiguous())\n                pre_data_list.y = torch.tensor([target], dtype=torch.long)\n                data_list.append(pre_data_list)\n\n            if self.pre_filter is not None:\n                data_list = [d for d in data_list if self.pre_filter(d)]\n\n            if self.pre_transform is not None:\n                data_list = [self.pre_transform(d) for d in data_list]\n\n            self.save(data_list, path)\n"
  },
  {
    "path": "torch_geometric/datasets/mixhop_synthetic_dataset.py",
    "content": "import os.path as osp\nimport pickle\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass MixHopSyntheticDataset(InMemoryDataset):\n    r\"\"\"The MixHop synthetic dataset from the `\"MixHop: Higher-Order\n    Graph Convolutional Architectures via Sparsified Neighborhood Mixing\"\n    <https://arxiv.org/abs/1905.00067>`_ paper, containing 10\n    graphs, each with varying degree of homophily (ranging from 0.0 to 0.9).\n    All graphs have 5,000 nodes, where each node corresponds to 1 out of 10\n    classes.\n    The feature values of the nodes are sampled from a 2D Gaussian\n    distribution, which are distinct for each class.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        homophily (float): The degree of homophily (one of :obj:`0.0`,\n            :obj:`0.1`, ..., :obj:`0.9`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = ('https://raw.githubusercontent.com/samihaija/mixhop/master/data'\n           '/synthetic')\n\n    def __init__(\n        self,\n        root: str,\n        homophily: float,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.homophily = homophily\n        assert homophily in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, f'{self.homophily:0.1f}'[::2], 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, f'{self.homophily:0.1f}'[::2], 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        name = f'ind.n5000-h{self.homophily:0.1f}-c10'\n        return [f'{name}.allx', f'{name}.ally', f'{name}.graph']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for filename in self.raw_file_names:\n            download_url(f'{self.url}/{filename}', self.raw_dir)\n\n    def process(self) -> None:\n        x = torch.from_numpy(np.load(self.raw_paths[0]))\n        y = torch.from_numpy(np.load(self.raw_paths[1])).argmax(dim=-1)\n\n        edges = pickle.load(open(self.raw_paths[2], 'rb'), encoding='latin1')\n        row, col = [], []\n        for k, v in edges.items():\n            row += [k] * len(v)\n            col += v\n\n        edge_index = torch.tensor([row, col], dtype=torch.long)\n\n        N_s = x.size(0) // 3\n        train_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        train_mask[:N_s] = True\n        val_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        val_mask[N_s:2 * N_s] = True\n        test_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        test_mask[2 * N_s:] = True\n\n        data = Data(x=x, y=y, edge_index=edge_index, train_mask=train_mask,\n                    val_mask=val_mask, test_mask=test_mask)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(homophily={self.homophily:.1f})'\n"
  },
  {
    "path": "torch_geometric/datasets/mnist_superpixels.py",
    "content": "import os\nfrom typing import Callable, List, Optional\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass MNISTSuperpixels(InMemoryDataset):\n    r\"\"\"MNIST superpixels dataset from the `\"Geometric Deep Learning on\n    Graphs and Manifolds Using Mixture Model CNNs\"\n    <https://arxiv.org/abs/1611.08402>`_ paper, containing 70,000 graphs with\n    75 nodes each.\n    Every graph is labeled by one of 10 classes.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 70,000\n          - 75\n          - ~1,393.0\n          - 1\n          - 10\n    \"\"\"\n\n    url = 'https://data.pyg.org/datasets/MNISTSuperpixels.zip'\n\n    def __init__(\n        self,\n        root: str,\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = self.processed_paths[0] if train else self.processed_paths[1]\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'MNISTSuperpixels.pt'\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train_data.pt', 'test_data.pt']\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        inputs = fs.torch_load(self.raw_paths[0])\n        for i in range(len(inputs)):\n            data_list = [Data(**data_dict) for data_dict in inputs[i]]\n\n            if self.pre_filter is not None:\n                data_list = [d for d in data_list if self.pre_filter(d)]\n\n            if self.pre_transform is not None:\n                data_list = [self.pre_transform(d) for d in data_list]\n\n            self.save(data_list, self.processed_paths[i])\n"
  },
  {
    "path": "torch_geometric/datasets/modelnet.py",
    "content": "import glob\nimport os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs, read_off\n\n\nclass ModelNet(InMemoryDataset):\n    r\"\"\"The ModelNet10/40 datasets from the `\"3D ShapeNets: A Deep\n    Representation for Volumetric Shapes\"\n    <https://people.csail.mit.edu/khosla/papers/cvpr2015_wu.pdf>`_ paper,\n    containing CAD models of 10 and 40 categories, respectively.\n\n    .. note::\n\n        Data objects hold mesh faces instead of edge indices.\n        To convert the mesh to a graph, use the\n        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.\n        To convert the mesh to a point cloud, use the\n        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to\n        sample a fixed number of points on the mesh faces according to their\n        face area.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str, optional): The name of the dataset (:obj:`\"10\"` for\n            ModelNet10, :obj:`\"40\"` for ModelNet40). (default: :obj:`\"10\"`)\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 20 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - ModelNet10\n          - 4,899\n          - ~9,508.2\n          - ~37,450.5\n          - 3\n          - 10\n        * - ModelNet40\n          - 12,311\n          - ~17,744.4\n          - ~66,060.9\n          - 3\n          - 40\n    \"\"\"\n\n    urls = {\n        '10':\n        'http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip',  # noqa\n        '40': 'http://modelnet.cs.princeton.edu/ModelNet40.zip'\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str = '10',\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        assert name in ['10', '40']\n        self.name = name\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = self.processed_paths[0] if train else self.processed_paths[1]\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'bathtub', 'bed', 'chair', 'desk', 'dresser', 'monitor',\n            'night_stand', 'sofa', 'table', 'toilet'\n        ]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['training.pt', 'test.pt']\n\n    def download(self) -> None:\n        path = download_url(self.urls[self.name], self.root)\n        extract_zip(path, self.root)\n        os.unlink(path)\n        folder = osp.join(self.root, f'ModelNet{self.name}')\n        fs.rm(self.raw_dir)\n        os.rename(folder, self.raw_dir)\n\n        # Delete osx metadata generated during compression of ModelNet10\n        metadata_folder = osp.join(self.root, '__MACOSX')\n        if osp.exists(metadata_folder):\n            fs.rm(metadata_folder)\n\n    def process(self) -> None:\n        self.save(self.process_set('train'), self.processed_paths[0])\n        self.save(self.process_set('test'), self.processed_paths[1])\n\n    def process_set(self, dataset: str) -> List[Data]:\n        categories = glob.glob(osp.join(self.raw_dir, '*', ''))\n        categories = sorted([x.split(os.sep)[-2] for x in categories])\n\n        data_list = []\n        for target, category in enumerate(categories):\n            folder = osp.join(self.raw_dir, category, dataset)\n            paths = glob.glob(f'{folder}/{category}_*.off')\n            for path in paths:\n                data = read_off(path)\n                data.y = torch.tensor([target])\n                data_list.append(data)\n\n        if self.pre_filter is not None:\n            data_list = [d for d in data_list if self.pre_filter(d)]\n\n        if self.pre_transform is not None:\n            data_list = [self.pre_transform(d) for d in data_list]\n\n        return data_list\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}{self.name}({len(self)})'\n"
  },
  {
    "path": "torch_geometric/datasets/molecule_gpt_dataset.py",
    "content": "import gzip\nimport json\nimport multiprocessing\nimport os\nimport sys\nfrom collections import defaultdict\nfrom multiprocessing import Pool\nfrom typing import Callable, List, Optional, Tuple\n\nimport numpy as np\nimport requests\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.io import fs\nfrom torch_geometric.llm.models import LLM\nfrom torch_geometric.utils import one_hot\n\n\ndef clean_up_description(description: str) -> str:\n    description = description + \" \"\n\n    # extra adj Pure\n    if description.startswith(\"Pure \"):\n        description = description.replace(\"Pure \", \"\")\n    # fix typo\n    if description.startswith(\"Mercurycombines\"):\n        description = description.replace(\"Mercurycombines\",\n                                          \"Mercury combines\")\n\n    # a special case\n    description = description.replace(\n        \"17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. \",\n        \"17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is \")\n\n    # a special case\n    description = description.replace(\"5-Thymidylic acid. \",\n                                      \"5-Thymidylic acid. is \")\n\n    # a special case\n    description = description.replace(\n        \"5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. \",\n        \"5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is \")\n\n    # a special case\n    description = description.replace(\n        (\"Guanosine 5'-(trihydrogen diphosphate), monoanhydride\"\n         \" with phosphorothioic acid. \"),\n        (\"Guanosine 5'-(trihydrogen diphosphate), monoanhydride\"\n         \" with phosphorothioic acid is \"))\n\n    # a special case\n    description = description.replace(\"5'-Uridylic acid. \",\n                                      \"5'-Uridylic acid is \")\n\n    # a special case\n    description = description.replace(\"5'-Adenylic acid, \",\n                                      \"5'-Adenylic acid is \")\n\n    # a special case\n    description = description.replace(\n        \"Uridine 5'-(tetrahydrogen triphosphate). \",\n        \"Uridine 5'-(tetrahydrogen triphosphate). is \")\n\n    # a special case\n    description = description.replace(\"Inosine 5'-Monophosphate. \",\n                                      \"Inosine 5'-Monophosphate. is \")\n\n    # a special case\n    description = description.replace(\"Pivaloyloxymethyl butyrate (AN-9), \",\n                                      \"Pivaloyloxymethyl butyrate (AN-9) is \")\n\n    # a special case\n    description = description.replace(\n        \"4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. \",\n        \"4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is \")\n\n    # a special case\n    description = description.replace(\n        \"Cardamonin (also known as Dihydroxymethoxychalcone), \",\n        \"Cardamonin (also known as Dihydroxymethoxychalcone) is \")\n\n    # a special case\n    description = description.replace(\"Lithium has been used to treat \",\n                                      \"Lithium is \")\n\n    # a special case\n    description = description.replace(\"4,4'-Methylenebis \",\n                                      \"4,4'-Methylenebis is \")\n\n    # a special case\n    description = description.replace(\n        \"2,3,7,8-Tetrachlorodibenzo-p-dioxin\",\n        \"2,3,7,8-Tetrachlorodibenzo-p-dioxin is \")\n\n    # a special case\n    description = description.replace(\"Exposure to 2,4,5-trichlorophenol \",\n                                      \"2,4,5-Trichlorophenol exposure \")\n\n    index = 0\n    L = len(description)\n    if description.startswith('C.I. '):\n        start_index = len('C.I. ')\n    elif description.startswith('Nectriapyrone. D '):\n        start_index = len('Nectriapyrone. D ')\n    elif description.startswith(\n            'Salmonella enterica sv. Minnesota LPS core oligosaccharide'):\n        start_index = len(\n            'Salmonella enterica sv. Minnesota LPS core oligosaccharide')\n    else:\n        start_index = 0\n    for index in range(start_index, L - 1):\n        if index < L - 2:\n            if description[index] == '.' and description[\n                    index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z':\n                break\n        elif index == L - 2:\n            break\n\n    first_sentence = description[:index + 1]\n    return first_sentence\n\n\ndef extract_name(\n    name_raw: str,\n    description: str,\n) -> Tuple[Optional[str], str, str]:\n    first_sentence = clean_up_description(description)\n\n    splitter = '  --  --  '\n    if ' are ' in first_sentence or ' were ' in first_sentence:\n        replaced_words = 'These molecules'\n    else:\n        replaced_words = 'This molecule'\n\n    first_sentence = first_sentence.replace(' is ', splitter)\n    first_sentence = first_sentence.replace(' are ', splitter)\n    first_sentence = first_sentence.replace(' was ', splitter)\n    first_sentence = first_sentence.replace(' were ', splitter)\n    first_sentence = first_sentence.replace(' appears ', splitter)\n    first_sentence = first_sentence.replace(' occurs ', splitter)\n    first_sentence = first_sentence.replace(' stands for ', splitter)\n    first_sentence = first_sentence.replace(' belongs to ', splitter)\n    first_sentence = first_sentence.replace(' exists ',\n                                            splitter)  # only for CID=11443\n    first_sentence = first_sentence.replace(' has been used in trials ',\n                                            splitter)\n    first_sentence = first_sentence.replace(' has been investigated ',\n                                            splitter)\n    first_sentence = first_sentence.replace(' has many uses ', splitter)\n\n    if splitter in first_sentence:\n        extracted_name = first_sentence.split(splitter, 1)[0]\n    elif first_sentence.startswith(name_raw):\n        extracted_name = name_raw\n    elif name_raw in first_sentence:\n        extracted_name = name_raw\n        extracted_name = None\n        print(\"=====\", name_raw)\n        print(\"first sentence: \", first_sentence)\n    else:\n        extracted_name = None\n\n    if extracted_name is not None:\n        extracted_description = description.replace(extracted_name,\n                                                    replaced_words)\n    else:\n        extracted_description = description\n\n    return extracted_name, extracted_description, first_sentence\n\n\nclass MoleculeGPTDataset(InMemoryDataset):\n    r\"\"\"The dataset from the `\"MoleculeGPT: Instruction Following Large\n    Language Models for Molecular Property Prediction\"\n    <https://ai4d3.github.io/2023/papers/34.pdf>`_ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        total_page_num (int, optional): The number of pages from PubChem.\n            (default: :obj:`10`)\n        total_block_num (int, optional): The blocks of SDF files from PubChem.\n            (default: :obj:`1`)\n        num_units (int, optional): Number of units of the sample.\n            (default: :obj:`-1`, which means all units will be used)\n    \"\"\"\n    description_url = (\n        'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/'\n        'heading/json?heading_type=Compound&heading=Record+Description&page={}'\n    )\n    compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/'\n                    'CURRENT-Full/SDF')\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n        total_page_num: int = 10,\n        total_block_num: int = 1,\n        num_units: int = -1,\n    ):\n        self.total_page_num = total_page_num\n        self.total_block_num = total_block_num\n        self.num_units = num_units\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['pubchem.csv']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['data.pt']\n\n    def download(self) -> None:\n        # Step 01. Extract description\n        step1_folder = f\"{self.raw_dir}/step_01_PubChemSTM_description\"\n        if not os.path.exists(step1_folder):\n            os.makedirs(step1_folder)\n            valid_CID_set = set()\n            CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict(\n                list)\n            CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict(\n                list)\n\n            for page_index in tqdm(range(self.total_page_num)):\n                page_num = page_index + 1\n                f_out = open(\n                    f\"{step1_folder}/Compound_description_{page_num}.txt\", \"w\")\n\n                description_data = requests.get(\n                    self.description_url.format(page_num)).json()\n\n                description_data = description_data[\"Annotations\"]\n                assert description_data[\"Page\"] == page_num\n\n                record_list = description_data[\"Annotation\"]\n\n                for record in record_list:\n                    try:\n                        CID = record[\"LinkedRecords\"][\"CID\"][0]\n                        if \"Name\" in record:\n                            name_raw = record[\"Name\"]\n                            CID2name_raw[CID].append(name_raw)\n                        else:\n                            name_raw = None\n\n                        data_list = record[\"Data\"]\n                        for data in data_list:\n                            description = data[\"Value\"][\"StringWithMarkup\"][0][\n                                \"String\"].strip()\n\n                            extracted_name, extracted_description, _ = extract_name(  # noqa: E501\n                                name_raw, description)\n                            if extracted_name is not None:\n                                CID2name_extracted[CID].append(extracted_name)\n\n                            CID2text_raw[CID].append(description)\n                            CID2text_extracted[CID].append(\n                                extracted_description)\n\n                            valid_CID_set.add(CID)\n                            f_out.write(f\"{CID}\\n\")\n                            f_out.write(f\"{extracted_description}\\n\\n\")\n                    except Exception:\n                        continue\n\n            valid_CID_list = sorted(list(valid_CID_set))\n            print(f\"Total CID (with raw name) {len(CID2name_raw)}\")\n            print(f\"Total CID (with extracted name) {len(CID2name_extracted)}\")\n            print(f\"Total CID {len(valid_CID_list)}\")\n\n            with open(f\"{self.raw_dir}/CID2name_raw.json\", \"w\") as f:\n                json.dump(CID2name_raw, f)\n\n            with open(f\"{self.raw_dir}/CID2name.json\", \"w\") as f:\n                json.dump(CID2name_extracted, f)\n\n            with open(f\"{self.raw_dir}/CID2text_raw.json\", \"w\") as f:\n                json.dump(CID2text_raw, f)\n\n            with open(f\"{self.raw_dir}/CID2text.json\", \"w\") as f:\n                json.dump(CID2text_extracted, f)\n\n        # Step 02. Download SDF Files\n        step2_folder = f\"{self.raw_dir}/step_02_PubChemSTM_SDF\"\n        if not os.path.exists(step2_folder):\n            for block_id in tqdm(range(self.total_block_num)):\n                block_size = 500000\n                l_id = block_id * block_size + 1\n                r_id = (block_id + 1) * block_size\n\n                compound_file_name = f\"Compound_{l_id:09d}_{r_id:09d}.sdf.gz\"\n                download_url(f\"{self.compound_url}/{compound_file_name}\",\n                             step2_folder)\n\n    def process(self, use_mp: bool = False) -> None:\n        try:\n            from rdkit import Chem\n            from rdkit.Chem.rdchem import BondType as BT\n            WITH_RDKIT = True\n\n        except ImportError:\n            WITH_RDKIT = False\n\n        if not WITH_RDKIT:\n            print((\"Using a pre-processed version of the dataset. Please \"\n                   \"install 'rdkit' to alternatively process the raw data.\"),\n                  file=sys.stderr)\n\n            data_list = fs.torch_load(self.raw_paths[0])\n            data_list = [Data(**data_dict) for data_dict in data_list]\n\n            if self.pre_filter is not None:\n                data_list = [d for d in data_list if self.pre_filter(d)]\n\n            if self.pre_transform is not None:\n                data_list = [self.pre_transform(d) for d in data_list]\n\n            self.save(data_list, self.processed_paths[0])\n            return\n\n        # Step 03. Filter out SDF\n        step2_folder = f\"{self.raw_dir}/step_02_PubChemSTM_SDF\"\n        step3_folder = f\"{self.raw_dir}/step_03_PubChemSTM_filtered\"\n        if not os.path.exists(step3_folder):\n            os.makedirs(step3_folder)\n            with open(f\"{self.raw_dir}/CID2text.json\") as f:\n                CID2text = json.load(f)\n            target_CID_list = set(CID2text.keys())\n\n            block_size = 500000\n\n            def extract_one_SDF_file(block_id: int) -> None:\n                valid_mol_count = 0\n\n                writer = Chem.SDWriter(\n                    f'{step3_folder}/filtered_{block_id}.sdf')\n                l_id = block_id * block_size + 1\n                r_id = (block_id + 1) * block_size\n\n                compound_file_name = f\"Compound_{l_id:09d}_{r_id:09d}.sdf.gz\"\n                gzip_loader = gzip.open(f\"{step2_folder}/{compound_file_name}\")\n                suppl = Chem.ForwardSDMolSupplier(gzip_loader)\n\n                for mol in tqdm(suppl):\n                    if mol is None:\n                        continue\n                    cid = mol.GetProp(\"PUBCHEM_COMPOUND_CID\")\n\n                    if cid not in target_CID_list:\n                        continue\n\n                    writer.write(mol)\n                    valid_mol_count += 1\n\n                writer.close()\n                print(f\"block id: {block_id}\\nfound {valid_mol_count}\\n\\n\")\n                sys.stdout.flush()\n                return\n\n            if use_mp:\n                num_process = multiprocessing.cpu_count()\n                print(f\"{num_process} CPUs\")\n                num_process = 8\n                p = Pool(num_process)\n\n                block_id_list = np.arange(self.total_block_num)\n                with p:\n                    p.map(extract_one_SDF_file, block_id_list)\n            else:\n                for block_id in range(self.total_block_num):\n                    extract_one_SDF_file(block_id)\n\n        # Step 04. Merge SDF\n        with open(f\"{self.raw_dir}/CID2text.json\") as f:\n            CID2text = json.load(f)\n        target_CID_list = set(CID2text.keys())\n        print(f'The length of target_CID_list: {len(target_CID_list)}')\n\n        writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf')\n\n        found_CID_set = set()\n        for block_id in range(self.total_block_num + 1):\n            compound_file_path = f\"{step3_folder}/filtered_{block_id}.sdf\"\n            try:\n                suppl = Chem.SDMolSupplier(compound_file_path)\n\n                for mol in tqdm(suppl):\n                    writer.write(mol)\n                    cid = mol.GetProp(\"PUBCHEM_COMPOUND_CID\")\n                    found_CID_set.add(cid)\n            except Exception:\n                print(f\"block id: {block_id} with 0 valid SDF file\")\n                continue\n\n        writer.close()\n        print(f\"In total: {len(found_CID_set)} molecules\")\n\n        # Step 05. Convert to PyG data format\n        types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5}\n        bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}\n\n        data_list = []\n        # Real data\n        CID2text_file = f'{self.raw_dir}/CID2text.json'\n\n        with open(CID2text_file) as f:\n            CID2text_data = json.load(f)\n\n        suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf')\n\n        llm = LLM(\n            model_name='Qwen/Qwen3-0.6B',\n            num_params=1,\n            dtype=torch.bfloat16,\n            sys_prompt='You are an agent, answer my questions.',\n        )\n        prompt = (\"Propose a question regarding the molecule '∼' \"\n                  \"whose answer is: {}:\")\n        for mol in tqdm(suppl):\n            if mol.HasProp('PUBCHEM_COMPOUND_CID'):\n                CID = mol.GetProp(\"PUBCHEM_COMPOUND_CID\")\n                CAN_SMILES = mol.GetProp(\"PUBCHEM_SMILES\")\n\n                m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES)\n                if m is None:\n                    continue\n                RDKit_CAN_SMILES = Chem.MolToSmiles(m)\n\n                ground_truth = CID2text_data[CID][0]\n\n                instruction = llm.inference([prompt.format(ground_truth)])[0]\n\n                x: torch.Tensor = torch.tensor([\n                    types[atom.GetSymbol()] if atom.GetSymbol() in types else 5\n                    for atom in m.GetAtoms()\n                ])\n                x = one_hot(x, num_classes=len(types), dtype=torch.float)\n\n                rows, cols, edge_types = [], [], []\n                for bond in m.GetBonds():\n                    i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()\n                    edge_types += [bonds[bond.GetBondType()]] * 2\n                    rows += [i, j]\n                    cols += [j, i]\n\n                edge_index = torch.tensor([rows, cols], dtype=torch.long)\n                edge_type = torch.tensor(edge_types, dtype=torch.long)\n                edge_attr = one_hot(edge_type, num_classes=len(bonds))\n\n                data = Data(\n                    x=x,\n                    edge_index=edge_index,\n                    edge_attr=edge_attr,\n                    smiles=RDKit_CAN_SMILES,\n                    instruction=instruction,\n                    y=ground_truth,\n                )\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                data_list.append(data)\n\n                if self.num_units > 0 and len(data_list) >= self.num_units:\n                    break\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/molecule_net.py",
    "content": "import os\nimport os.path as osp\nimport re\nimport warnings\nfrom typing import Callable, Dict, Optional, Tuple, Union\n\nimport torch\n\nfrom torch_geometric.data import InMemoryDataset, download_url, extract_gz\nfrom torch_geometric.utils import from_smiles as _from_smiles\n\n\nclass MoleculeNet(InMemoryDataset):\n    r\"\"\"The `MoleculeNet <http://moleculenet.org/datasets-1>`_ benchmark\n    collection  from the `\"MoleculeNet: A Benchmark for Molecular Machine\n    Learning\" <https://arxiv.org/abs/1703.00564>`_ paper, containing datasets\n    from physical chemistry, biophysics and physiology.\n    All datasets come with the additional node and edge features introduced by\n    the :ogb:`null`\n    `Open Graph Benchmark <https://ogb.stanford.edu/docs/graphprop/>`_.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"ESOL\"`, :obj:`\"FreeSolv\"`,\n            :obj:`\"Lipo\"`, :obj:`\"PCBA\"`, :obj:`\"MUV\"`, :obj:`\"HIV\"`,\n            :obj:`\"BACE\"`, :obj:`\"BBBP\"`, :obj:`\"Tox21\"`, :obj:`\"ToxCast\"`,\n            :obj:`\"SIDER\"`, :obj:`\"ClinTox\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        from_smiles (callable, optional): A custom function that takes a SMILES\n            string and outputs a :obj:`~torch_geometric.data.Data` object.\n            If not set, defaults to :meth:`~torch_geometric.utils.from_smiles`.\n            (default: :obj:`None`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 20 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - ESOL\n          - 1,128\n          - ~13.3\n          - ~27.4\n          - 9\n          - 1\n        * - FreeSolv\n          - 642\n          - ~8.7\n          - ~16.8\n          - 9\n          - 1\n        * - Lipophilicity\n          - 4,200\n          - ~27.0\n          - ~59.0\n          - 9\n          - 1\n        * - PCBA\n          - 437,929\n          - ~26.0\n          - ~56.2\n          - 9\n          - 128\n        * - MUV\n          - 93,087\n          - ~24.2\n          - ~52.6\n          - 9\n          - 17\n        * - HIV\n          - 41,127\n          - ~25.5\n          - ~54.9\n          - 9\n          - 1\n        * - BACE\n          - 1513\n          - ~34.1\n          - ~73.7\n          - 9\n          - 1\n        * - BBBP\n          - 2,050\n          - ~23.9\n          - ~51.6\n          - 9\n          - 1\n        * - Tox21\n          - 7,831\n          - ~18.6\n          - ~38.6\n          - 9\n          - 12\n        * - ToxCast\n          - 8,597\n          - ~18.7\n          - ~38.4\n          - 9\n          - 617\n        * - SIDER\n          - 1,427\n          - ~33.6\n          - ~70.7\n          - 9\n          - 27\n        * - ClinTox\n          - 1,484\n          - ~26.1\n          - ~55.5\n          - 9\n          - 2\n    \"\"\"\n\n    url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/{}'\n\n    # Format: name: (display_name, url_name, csv_name, smiles_idx, y_idx)\n    names: Dict[str, Tuple[str, str, str, int, Union[int, slice]]] = {\n        'esol': ('ESOL', 'delaney-processed.csv', 'delaney-processed', -1, -2),\n        'freesolv': ('FreeSolv', 'SAMPL.csv', 'SAMPL', 1, 2),\n        'lipo': ('Lipophilicity', 'Lipophilicity.csv', 'Lipophilicity', 2, 1),\n        'pcba': ('PCBA', 'pcba.csv.gz', 'pcba', -1, slice(0, 128)),\n        'muv': ('MUV', 'muv.csv.gz', 'muv', -1, slice(0, 17)),\n        'hiv': ('HIV', 'HIV.csv', 'HIV', 0, -1),\n        'bace': ('BACE', 'bace.csv', 'bace', 0, 2),\n        'bbbp': ('BBBP', 'BBBP.csv', 'BBBP', -1, -2),\n        'tox21': ('Tox21', 'tox21.csv.gz', 'tox21', -1, slice(0, 12)),\n        'toxcast':\n        ('ToxCast', 'toxcast_data.csv.gz', 'toxcast_data', 0, slice(1, 618)),\n        'sider': ('SIDER', 'sider.csv.gz', 'sider', 0, slice(1, 28)),\n        'clintox': ('ClinTox', 'clintox.csv.gz', 'clintox', 0, slice(1, 3)),\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n        from_smiles: Optional[Callable] = None,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in self.names.keys()\n        self.from_smiles = from_smiles or _from_smiles\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> str:\n        return f'{self.names[self.name][2]}.csv'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        url = self.url.format(self.names[self.name][1])\n        path = download_url(url, self.raw_dir)\n        if self.names[self.name][1][-2:] == 'gz':\n            extract_gz(path, self.raw_dir)\n            os.unlink(path)\n\n    def process(self) -> None:\n        with open(self.raw_paths[0]) as f:\n            dataset = f.read().split('\\n')[1:-1]\n            dataset = [x for x in dataset if len(x) > 0]  # Filter empty lines.\n\n        data_list = []\n        for line in dataset:\n            line = re.sub(r'\\\".*\\\"', '', line)  # Replace \".*\" strings.\n            values = line.split(',')\n\n            smiles = values[self.names[self.name][3]]\n            labels = values[self.names[self.name][4]]\n            labels = labels if isinstance(labels, list) else [labels]\n\n            ys = [float(y) if len(y) > 0 else float('NaN') for y in labels]\n            y = torch.tensor(ys, dtype=torch.float).view(1, -1)\n\n            data = self.from_smiles(smiles)\n            data.y = y\n\n            if data.num_nodes == 0:\n                warnings.warn(\n                    f\"Skipping molecule '{smiles}' since it \"\n                    f\"resulted in zero atoms\", stacklevel=2)\n                continue\n\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n\n            data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.names[self.name][0]}({len(self)})'\n"
  },
  {
    "path": "torch_geometric/datasets/motif_generator/__init__.py",
    "content": "from .base import MotifGenerator\nfrom .custom import CustomMotif\nfrom .house import HouseMotif\nfrom .cycle import CycleMotif\nfrom .grid import GridMotif\n\n__all__ = classes = [\n    'MotifGenerator',\n    'CustomMotif',\n    'HouseMotif',\n    'CycleMotif',\n    'GridMotif',\n]\n"
  },
  {
    "path": "torch_geometric/datasets/motif_generator/base.py",
    "content": "from abc import ABC, abstractmethod\nfrom typing import Any\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.resolver import resolver\n\n\nclass MotifGenerator(ABC):\n    r\"\"\"An abstract base class for generating a motif.\"\"\"\n    @abstractmethod\n    def __call__(self) -> Data:\n        r\"\"\"To be implemented by :class:`Motif` subclasses.\"\"\"\n\n    @staticmethod\n    def resolve(query: Any, *args: Any, **kwargs: Any) -> 'MotifGenerator':\n        import torch_geometric.datasets.motif_generator as _motif_generators\n        motif_generators = [\n            gen for gen in vars(_motif_generators).values()\n            if isinstance(gen, type) and issubclass(gen, MotifGenerator)\n        ]\n        return resolver(motif_generators, {}, query, MotifGenerator, 'Motif',\n                        *args, **kwargs)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/datasets/motif_generator/custom.py",
    "content": "from typing import Any, Optional\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.motif_generator import MotifGenerator\nfrom torch_geometric.utils import from_networkx\n\n\nclass CustomMotif(MotifGenerator):\n    r\"\"\"Generates a motif based on a custom structure coming from a\n    :class:`torch_geometric.data.Data` or :class:`networkx.Graph` object.\n\n    Args:\n        structure (torch_geometric.data.Data or networkx.Graph): The structure\n            to use as a motif.\n    \"\"\"\n    def __init__(self, structure: Any):\n        super().__init__()\n\n        self.structure: Optional[Data] = None\n\n        if isinstance(structure, Data):\n            self.structure = structure\n        else:\n            try:\n                import networkx as nx\n                if isinstance(structure, nx.Graph):\n                    self.structure = from_networkx(structure)\n            except ImportError:\n                pass\n\n        if self.structure is None:\n            raise ValueError(f\"Expected a motif structure of type \"\n                             f\"'torch_geometric.data.Data' or 'networkx.Graph'\"\n                             f\"(got {type(structure)})\")\n\n    def __call__(self) -> Data:\n        assert isinstance(self.structure, Data)\n        return self.structure\n"
  },
  {
    "path": "torch_geometric/datasets/motif_generator/cycle.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.motif_generator import CustomMotif\n\n\nclass CycleMotif(CustomMotif):\n    r\"\"\"Generates the cycle motif from the `\"GNNExplainer:\n    Generating Explanations for Graph Neural Networks\"\n    <https://arxiv.org/abs/1903.03894>`__ paper.\n\n    Args:\n        num_nodes (int): The number of nodes in the cycle.\n    \"\"\"\n    def __init__(self, num_nodes: int):\n        self.num_nodes = num_nodes\n\n        row = torch.arange(num_nodes).view(-1, 1).repeat(1, 2).view(-1)\n        col1 = torch.arange(-1, num_nodes - 1) % num_nodes\n        col2 = torch.arange(1, num_nodes + 1) % num_nodes\n        col = torch.stack([col1, col2], dim=1).sort(dim=-1)[0].view(-1)\n\n        structure = Data(\n            num_nodes=num_nodes,\n            edge_index=torch.stack([row, col], dim=0),\n        )\n        super().__init__(structure)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.num_nodes})'\n"
  },
  {
    "path": "torch_geometric/datasets/motif_generator/grid.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.motif_generator import CustomMotif\n\n\nclass GridMotif(CustomMotif):\n    r\"\"\"Generates the grid-structured motif from the\n    `\"GNNExplainer: Generating Explanations for Graph Neural Networks\"\n    <https://arxiv.org/abs/1903.03894>`__ paper.\n    \"\"\"\n    def __init__(self) -> None:\n        edge_indices = [\n            [0, 1],\n            [0, 3],\n            [1, 4],\n            [3, 4],\n            [1, 2],\n            [2, 5],\n            [4, 5],\n            [3, 6],\n            [6, 7],\n            [4, 7],\n            [5, 8],\n            [7, 8],\n            [1, 0],\n            [3, 0],\n            [4, 1],\n            [4, 3],\n            [2, 1],\n            [5, 2],\n            [5, 4],\n            [6, 3],\n            [7, 6],\n            [7, 4],\n            [8, 5],\n            [8, 7],\n        ]\n        structure = Data(\n            num_nodes=9,\n            edge_index=torch.tensor(edge_indices).t().contiguous(),\n            y=torch.tensor([0, 1, 0, 1, 2, 1, 0, 1, 0]),\n        )\n        super().__init__(structure)\n"
  },
  {
    "path": "torch_geometric/datasets/motif_generator/house.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.datasets.motif_generator import CustomMotif\n\n\nclass HouseMotif(CustomMotif):\n    r\"\"\"Generates the house-structured motif from the `\"GNNExplainer:\n    Generating Explanations for Graph Neural Networks\"\n    <https://arxiv.org/abs/1903.03894>`__ paper, containing 5 nodes and 6\n    undirected edges. Nodes are labeled according to their structural role:\n    the top, middle and bottom of the house.\n    \"\"\"\n    def __init__(self) -> None:\n        structure = Data(\n            num_nodes=5,\n            edge_index=torch.tensor([\n                [0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4],\n                [1, 3, 4, 4, 2, 0, 1, 3, 2, 0, 0, 1],\n            ]),\n            y=torch.tensor([0, 0, 1, 1, 2]),\n        )\n        super().__init__(structure)\n"
  },
  {
    "path": "torch_geometric/datasets/movie_lens.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\n\n\nclass MovieLens(InMemoryDataset):\n    r\"\"\"A heterogeneous rating dataset, assembled by GroupLens Research from\n    the `MovieLens web site <https://movielens.org>`_, consisting of nodes of\n    type :obj:`\"movie\"` and :obj:`\"user\"`.\n    User ratings for movies are available as ground truth labels for the edges\n    between the users and the movies :obj:`(\"user\", \"rates\", \"movie\")`.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        model_name (str): Name of model used to transform movie titles to node\n            features. The model comes from the`Huggingface SentenceTransformer\n            <https://huggingface.co/sentence-transformers>`_.\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    url = 'https://files.grouplens.org/datasets/movielens/ml-latest-small.zip'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        model_name: Optional[str] = 'all-MiniLM-L6-v2',\n        force_reload: bool = False,\n    ) -> None:\n        self.model_name = model_name\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            osp.join('ml-latest-small', 'movies.csv'),\n            osp.join('ml-latest-small', 'ratings.csv'),\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return f'data_{self.model_name}.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.remove(path)\n\n    def process(self) -> None:\n        import pandas as pd\n        from sentence_transformers import SentenceTransformer\n\n        data = HeteroData()\n\n        df = pd.read_csv(self.raw_paths[0], index_col='movieId')\n        movie_mapping = {idx: i for i, idx in enumerate(df.index)}\n\n        genres = df['genres'].str.get_dummies('|').values\n        genres = torch.from_numpy(genres).to(torch.float)\n\n        model = SentenceTransformer(self.model_name)\n        with torch.no_grad():\n            emb = model.encode(df['title'].values, show_progress_bar=True,\n                               convert_to_tensor=True).cpu()\n\n        data['movie'].x = torch.cat([emb, genres], dim=-1)\n\n        df = pd.read_csv(self.raw_paths[1])\n        user_mapping = {idx: i for i, idx in enumerate(df['userId'].unique())}\n        data['user'].num_nodes = len(user_mapping)\n\n        src = [user_mapping[idx] for idx in df['userId']]\n        dst = [movie_mapping[idx] for idx in df['movieId']]\n        edge_index = torch.tensor([src, dst])\n\n        rating = torch.from_numpy(df['rating'].values).to(torch.long)\n        time = torch.from_numpy(df['timestamp'].values).to(torch.long)\n\n        data['user', 'rates', 'movie'].edge_index = edge_index\n        data['user', 'rates', 'movie'].edge_label = rating\n        data['user', 'rates', 'movie'].time = time\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/movie_lens_100k.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\nMOVIE_HEADERS = [\n    \"movieId\", \"title\", \"releaseDate\", \"videoReleaseDate\", \"IMDb URL\",\n    \"unknown\", \"Action\", \"Adventure\", \"Animation\", \"Children's\", \"Comedy\",\n    \"Crime\", \"Documentary\", \"Drama\", \"Fantasy\", \"Film-Noir\", \"Horror\",\n    \"Musical\", \"Mystery\", \"Romance\", \"Sci-Fi\", \"Thriller\", \"War\", \"Western\"\n]\nUSER_HEADERS = [\"userId\", \"age\", \"gender\", \"occupation\", \"zipCode\"]\nRATING_HEADERS = [\"userId\", \"movieId\", \"rating\", \"timestamp\"]\n\n\nclass MovieLens100K(InMemoryDataset):\n    r\"\"\"The MovieLens 100K heterogeneous rating dataset, assembled by GroupLens\n    Research from the `MovieLens web site <https://movielens.org>`__,\n    consisting of movies (1,682 nodes) and users (943 nodes) with 100K\n    ratings between them.\n    User ratings for movies are available as ground truth labels.\n    Features of users and movies are encoded according to the `\"Inductive\n    Matrix Completion Based on Graph Neural Networks\"\n    <https://arxiv.org/abs/1904.12058>`__ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 20 10 10 10\n        :header-rows: 1\n\n        * - Node/Edge Type\n          - #nodes/#edges\n          - #features\n          - #tasks\n        * - Movie\n          - 1,682\n          - 18\n          -\n        * - User\n          - 943\n          - 24\n          -\n        * - User-Movie\n          - 80,000\n          - 1\n          - 1\n    \"\"\"\n    url = 'https://files.grouplens.org/datasets/movielens/ml-100k.zip'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['u.item', 'u.user', 'u1.base', 'u1.test']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        os.remove(path)\n        folder = osp.join(self.root, 'ml-100k')\n        fs.rm(self.raw_dir)\n        os.rename(folder, self.raw_dir)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        data = HeteroData()\n\n        # Process movie data:\n        df = pd.read_csv(\n            self.raw_paths[0],\n            sep='|',\n            header=None,\n            names=MOVIE_HEADERS,\n            index_col='movieId',\n            encoding='ISO-8859-1',\n        )\n        movie_mapping = {idx: i for i, idx in enumerate(df.index)}\n\n        x = df[MOVIE_HEADERS[6:]].values\n        data['movie'].x = torch.from_numpy(x).to(torch.float)\n\n        # Process user data:\n        df = pd.read_csv(\n            self.raw_paths[1],\n            sep='|',\n            header=None,\n            names=USER_HEADERS,\n            index_col='userId',\n            encoding='ISO-8859-1',\n        )\n        user_mapping = {idx: i for i, idx in enumerate(df.index)}\n\n        age = df['age'].values / df['age'].values.max()\n        age = torch.from_numpy(age).to(torch.float).view(-1, 1)\n\n        gender = df['gender'].str.get_dummies().values\n        gender = torch.from_numpy(gender).to(torch.float)\n\n        occupation = df['occupation'].str.get_dummies().values\n        occupation = torch.from_numpy(occupation).to(torch.float)\n\n        data['user'].x = torch.cat([age, gender, occupation], dim=-1)\n\n        # Process rating data for training:\n        df = pd.read_csv(\n            self.raw_paths[2],\n            sep='\\t',\n            header=None,\n            names=RATING_HEADERS,\n        )\n\n        src = [user_mapping[idx] for idx in df['userId']]\n        dst = [movie_mapping[idx] for idx in df['movieId']]\n        edge_index = torch.tensor([src, dst])\n        data['user', 'rates', 'movie'].edge_index = edge_index\n\n        rating = torch.from_numpy(df['rating'].values).to(torch.long)\n        data['user', 'rates', 'movie'].rating = rating\n\n        time = torch.from_numpy(df['timestamp'].values)\n        data['user', 'rates', 'movie'].time = time\n\n        data['movie', 'rated_by', 'user'].edge_index = edge_index.flip([0])\n        data['movie', 'rated_by', 'user'].rating = rating\n        data['movie', 'rated_by', 'user'].time = time\n\n        # Process rating data for testing:\n        df = pd.read_csv(\n            self.raw_paths[3],\n            sep='\\t',\n            header=None,\n            names=RATING_HEADERS,\n        )\n\n        src = [user_mapping[idx] for idx in df['userId']]\n        dst = [movie_mapping[idx] for idx in df['movieId']]\n        edge_label_index = torch.tensor([src, dst])\n        data['user', 'rates', 'movie'].edge_label_index = edge_label_index\n\n        edge_label = torch.from_numpy(df['rating'].values).to(torch.float)\n        data['user', 'rates', 'movie'].edge_label = edge_label\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/movie_lens_1m.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\nMOVIE_HEADERS = [\"movieId\", \"title\", \"genres\"]\nUSER_HEADERS = [\"userId\", \"gender\", \"age\", \"occupation\", \"zipCode\"]\nRATING_HEADERS = ['userId', 'movieId', 'rating', 'timestamp']\n\n\nclass MovieLens1M(InMemoryDataset):\n    r\"\"\"The MovieLens 1M heterogeneous rating dataset, assembled by GroupLens\n    Research from the `MovieLens web site <https://movielens.org>`__,\n    consisting of movies (3,883 nodes) and users (6,040 nodes) with\n    approximately 1 million ratings between them.\n    User ratings for movies are available as ground truth labels.\n    Features of users and movies are encoded according to the `\"Inductive\n    Matrix Completion Based on Graph Neural Networks\"\n    <https://arxiv.org/abs/1904.12058>`__ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 20 10 10 10\n        :header-rows: 1\n\n        * - Node/Edge Type\n          - #nodes/#edges\n          - #features\n          - #tasks\n        * - Movie\n          - 3,883\n          - 18\n          -\n        * - User\n          - 6,040\n          - 30\n          -\n        * - User-Movie\n          - 1,000,209\n          - 1\n          - 1\n    \"\"\"\n    url = 'https://files.grouplens.org/datasets/movielens/ml-1m.zip'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['movies.dat', 'users.dat', 'ratings.dat']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        os.remove(path)\n        folder = osp.join(self.root, 'ml-1m')\n        fs.rm(self.raw_dir)\n        os.rename(folder, self.raw_dir)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        data = HeteroData()\n\n        # Process movie data:\n        df = pd.read_csv(\n            self.raw_paths[0],\n            sep='::',\n            header=None,\n            index_col='movieId',\n            names=MOVIE_HEADERS,\n            encoding='ISO-8859-1',\n            engine='python',\n        )\n        movie_mapping = {idx: i for i, idx in enumerate(df.index)}\n\n        genres = df['genres'].str.get_dummies('|').values\n        genres = torch.from_numpy(genres).to(torch.float)\n\n        data['movie'].x = genres\n\n        # Process user data:\n        df = pd.read_csv(\n            self.raw_paths[1],\n            sep='::',\n            header=None,\n            index_col='userId',\n            names=USER_HEADERS,\n            dtype='str',\n            encoding='ISO-8859-1',\n            engine='python',\n        )\n        user_mapping = {idx: i for i, idx in enumerate(df.index)}\n\n        age = df['age'].str.get_dummies().values\n        age = torch.from_numpy(age).to(torch.float)\n\n        gender = df['gender'].str.get_dummies().values\n        gender = torch.from_numpy(gender).to(torch.float)\n\n        occupation = df['occupation'].str.get_dummies().values\n        occupation = torch.from_numpy(occupation).to(torch.float)\n\n        data['user'].x = torch.cat([age, gender, occupation], dim=-1)\n\n        # Process rating data:\n        df = pd.read_csv(\n            self.raw_paths[2],\n            sep='::',\n            header=None,\n            names=RATING_HEADERS,\n            encoding='ISO-8859-1',\n            engine='python',\n        )\n\n        src = [user_mapping[idx] for idx in df['userId']]\n        dst = [movie_mapping[idx] for idx in df['movieId']]\n        edge_index = torch.tensor([src, dst])\n        data['user', 'rates', 'movie'].edge_index = edge_index\n\n        rating = torch.from_numpy(df['rating'].values).to(torch.long)\n        data['user', 'rates', 'movie'].rating = rating\n\n        time = torch.from_numpy(df['timestamp'].values)\n        data['user', 'rates', 'movie'].time = time\n\n        data['movie', 'rated_by', 'user'].edge_index = edge_index.flip([0])\n        data['movie', 'rated_by', 'user'].rating = rating\n        data['movie', 'rated_by', 'user'].time = time\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/myket.py",
    "content": "from typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import InMemoryDataset, TemporalData, download_url\n\n\nclass MyketDataset(InMemoryDataset):\n    r\"\"\"The Myket Android Application Install dataset from the\n    `\"Effect of Choosing Loss Function when Using T-Batching for Representation\n    Learning on Dynamic Networks\" <https://arxiv.org/abs/2308.06862>`_ paper.\n    The dataset contains a temporal graph of application install interactions\n    in an Android application market.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - Myket\n          - 17,988\n          - 694,121\n          - 33\n          - 1\n    \"\"\"\n    url = ('https://raw.githubusercontent.com/erfanloghmani/'\n           'myket-android-application-market-dataset/main/data_int_index')\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=TemporalData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['myket.csv', 'app_info_sample.npy']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for file_name in self.raw_file_names:\n            download_url(f'{self.url}/{file_name}', self.raw_dir)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        df = pd.read_csv(self.raw_paths[0], skiprows=1, header=None)\n\n        src = torch.from_numpy(df[0].values)\n        dst = torch.from_numpy(df[1].values)\n        t = torch.from_numpy(df[2].values)\n\n        x = torch.from_numpy(np.load(self.raw_paths[1])).to(torch.float)\n        msg = x[dst]\n\n        dst = dst + (int(src.max()) + 1)\n\n        data = TemporalData(src=src, dst=dst, t=t, msg=msg)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/nell.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nfrom torch_geometric.data import InMemoryDataset, download_url, extract_tar\nfrom torch_geometric.io import fs, read_planetoid_data\n\n\nclass NELL(InMemoryDataset):\n    r\"\"\"The NELL dataset, a knowledge graph from the\n    `\"Toward an Architecture for Never-Ending Language Learning\"\n    <https://www.cs.cmu.edu/~acarlson/papers/carlson-aaai10.pdf>`_ paper.\n    The dataset is processed as in the\n    `\"Revisiting Semi-Supervised Learning with Graph Embeddings\"\n    <https://arxiv.org/abs/1603.08861>`_ paper.\n\n    .. note::\n\n        Entity nodes are described by sparse feature vectors of type\n        :class:`torch.sparse_csr_tensor`.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 65,755\n          - 251,550\n          - 61,278\n          - 186\n    \"\"\"\n\n    url = 'http://www.cs.cmu.edu/~zhiliny/data/nell_data.tar.gz'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']\n        return [f'ind.nell.0.001.{name}' for name in names]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_tar(path, self.root)\n        os.unlink(path)\n        fs.rm(self.raw_dir)\n        os.rename(osp.join(self.root, 'nell_data'), self.raw_dir)\n\n    def process(self) -> None:\n        data = read_planetoid_data(self.raw_dir, 'nell.0.001')\n        data = data if self.pre_transform is None else self.pre_transform(data)\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/neurograph.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass NeuroGraphDataset(InMemoryDataset):\n    r\"\"\"The NeuroGraph benchmark datasets from the\n    `\"NeuroGraph: Benchmarks for Graph Machine Learning in Brain Connectomics\"\n    <https://arxiv.org/abs/2306.06202>`_ paper.\n    :class:`NeuroGraphDataset` holds a collection of five neuroimaging graph\n    learning datasets that span multiple categories of demographics, mental\n    states, and cognitive traits.\n    See the `documentation\n    <https://neurograph.readthedocs.io/en/latest/NeuroGraph.html>`_ and the\n    `Github <https://github.com/Anwar-Said/NeuroGraph>`_ for more details.\n\n    +--------------------+---------+----------------------+\n    | Dataset            | #Graphs | Task                 |\n    +====================+=========+======================+\n    | :obj:`HCPTask`     | 7,443   | Graph Classification |\n    +--------------------+---------+----------------------+\n    | :obj:`HCPGender`   | 1,078   | Graph Classification |\n    +--------------------+---------+----------------------+\n    | :obj:`HCPAge`      | 1,065   | Graph Classification |\n    +--------------------+---------+----------------------+\n    | :obj:`HCPFI`       | 1,071   | Graph Regression     |\n    +--------------------+---------+----------------------+\n    | :obj:`HCPWM`       | 1,078   | Graph Regression     |\n    +--------------------+---------+----------------------+\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (one of :obj:`\"HCPGender\"`,\n            :obj:`\"HCPTask\"`, :obj:`\"HCPAge\"`, :obj:`\"HCPFI\"`,\n            :obj:`\"HCPWM\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    url = 'https://vanderbilt.box.com/shared/static'\n    filenames = {\n        'HCPGender': 'r6hlz2arm7yiy6v6981cv2nzq3b0meax.zip',\n        'HCPTask': '8wzz4y17wpxg2stip7iybtmymnybwvma.zip',\n        'HCPAge': 'lzzks4472czy9f9vc8aikp7pdbknmtfe.zip',\n        'HCPWM': 'xtmpa6712fidi94x6kevpsddf9skuoxy.zip',\n        'HCPFI': 'g2md9h9snh7jh6eeay02k1kr9m4ido9f.zip',\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        assert name in self.filenames.keys()\n        self.name = name\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'data.pt'\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        url = f'{self.url}/{self.filenames[self.name]}'\n        path = download_url(url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n        os.rename(\n            osp.join(self.raw_dir, self.name, 'processed', f'{self.name}.pt'),\n            osp.join(self.raw_dir, 'data.pt'))\n        fs.rm(osp.join(self.raw_dir, self.name))\n\n    def process(self) -> None:\n        data, slices = fs.torch_load(self.raw_paths[0])\n\n        num_samples = slices['x'].size(0) - 1\n        data_list: List[Data] = []\n        for i in range(num_samples):\n            x = data.x[slices['x'][i]:slices['x'][i + 1]]\n            start = slices['edge_index'][i]\n            end = slices['edge_index'][i + 1]\n            edge_index = data.edge_index[:, start:end]\n            sample = Data(x=x, edge_index=edge_index, y=data.y[i])\n\n            if self.pre_filter is not None and not self.pre_filter(sample):\n                continue\n\n            if self.pre_transform is not None:\n                sample = self.pre_transform(sample)\n\n            data_list.append(sample)\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/ogb_mag.py",
    "content": "import os\nimport os.path as osp\nimport shutil\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass OGB_MAG(InMemoryDataset):\n    r\"\"\"The :obj:`ogbn-mag` dataset from the `\"Open Graph Benchmark: Datasets\n    for Machine Learning on Graphs\" <https://arxiv.org/abs/2005.00687>`_ paper.\n    :obj:`ogbn-mag` is a heterogeneous graph composed of a subset of the\n    Microsoft Academic Graph (MAG).\n    It contains four types of entities — papers (736,389 nodes), authors\n    (1,134,649 nodes), institutions (8,740 nodes), and fields of study\n    (59,965 nodes) — as well as four types of directed relations connecting two\n    types of entities.\n    Each paper is associated with a 128-dimensional :obj:`word2vec` feature\n    vector, while all other node types are not associated with any input\n    features.\n    The task is to predict the venue (conference or journal) of each paper.\n    In total, there are 349 different venues.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        preprocess (str, optional): Pre-processes the original dataset by\n            adding structural features (:obj:`\"metapath2vec\"`, :obj:`\"TransE\"`)\n            to featureless nodes. (default: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'http://snap.stanford.edu/ogb/data/nodeproppred/mag.zip'\n    urls = {\n        'metapath2vec': ('https://data.pyg.org/datasets/'\n                         'mag_metapath2vec_emb.zip'),\n        'transe': ('https://data.pyg.org/datasets/'\n                   'mag_transe_emb.zip'),\n    }\n\n    def __init__(\n        self,\n        root: str,\n        preprocess: Optional[str] = None,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        preprocess = None if preprocess is None else preprocess.lower()\n        self.preprocess = preprocess\n        assert self.preprocess in [None, 'metapath2vec', 'transe']\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def num_classes(self) -> int:\n        assert isinstance(self._data, HeteroData)\n        return int(self._data['paper'].y.max()) + 1\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, 'mag', 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, 'mag', 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        file_names = [\n            'node-feat', 'node-label', 'relations', 'split',\n            'num-node-dict.csv.gz'\n        ]\n\n        if self.preprocess is not None:\n            file_names += [f'mag_{self.preprocess}_emb.pt']\n\n        return file_names\n\n    @property\n    def processed_file_names(self) -> str:\n        if self.preprocess is not None:\n            return f'data_{self.preprocess}.pt'\n        else:\n            return 'data.pt'\n\n    def download(self) -> None:\n        if not all([osp.exists(f) for f in self.raw_paths[:5]]):\n            path = download_url(self.url, self.raw_dir)\n            extract_zip(path, self.raw_dir)\n            for file_name in ['node-feat', 'node-label', 'relations']:\n                path = osp.join(self.raw_dir, 'mag', 'raw', file_name)\n                shutil.move(path, self.raw_dir)\n            path = osp.join(self.raw_dir, 'mag', 'split')\n            shutil.move(path, self.raw_dir)\n            path = osp.join(self.raw_dir, 'mag', 'raw', 'num-node-dict.csv.gz')\n            shutil.move(path, self.raw_dir)\n            fs.rm(osp.join(self.raw_dir, 'mag'))\n            os.remove(osp.join(self.raw_dir, 'mag.zip'))\n        if self.preprocess is not None:\n            path = download_url(self.urls[self.preprocess], self.raw_dir)\n            extract_zip(path, self.raw_dir)\n            os.remove(path)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        data = HeteroData()\n\n        path = osp.join(self.raw_dir, 'node-feat', 'paper', 'node-feat.csv.gz')\n        x_paper = pd.read_csv(path, compression='gzip', header=None,\n                              dtype=np.float32).values\n        data['paper'].x = torch.from_numpy(x_paper)\n\n        path = osp.join(self.raw_dir, 'node-feat', 'paper', 'node_year.csv.gz')\n        year_paper = pd.read_csv(path, compression='gzip', header=None,\n                                 dtype=np.int64).values\n        data['paper'].year = torch.from_numpy(year_paper).view(-1)\n\n        path = osp.join(self.raw_dir, 'node-label', 'paper',\n                        'node-label.csv.gz')\n        y_paper = pd.read_csv(path, compression='gzip', header=None,\n                              dtype=np.int64).values.flatten()\n        data['paper'].y = torch.from_numpy(y_paper)\n\n        if self.preprocess is None:\n            path = osp.join(self.raw_dir, 'num-node-dict.csv.gz')\n            num_nodes_df = pd.read_csv(path, compression='gzip')\n            for node_type in ['author', 'institution', 'field_of_study']:\n                data[node_type].num_nodes = num_nodes_df[node_type].tolist()[0]\n        else:\n            emb_dict = fs.torch_load(self.raw_paths[-1])\n            for key, value in emb_dict.items():\n                if key != 'paper':\n                    data[key].x = value\n\n        for edge_type in [('author', 'affiliated_with', 'institution'),\n                          ('author', 'writes', 'paper'),\n                          ('paper', 'cites', 'paper'),\n                          ('paper', 'has_topic', 'field_of_study')]:\n\n            f = '___'.join(edge_type)\n            path = osp.join(self.raw_dir, 'relations', f, 'edge.csv.gz')\n            edge_index = pd.read_csv(path, compression='gzip', header=None,\n                                     dtype=np.int64).values\n            edge_index = torch.from_numpy(edge_index).t().contiguous()\n            data[edge_type].edge_index = edge_index\n\n        for f, v in [('train', 'train'), ('valid', 'val'), ('test', 'test')]:\n            path = osp.join(self.raw_dir, 'split', 'time', 'paper',\n                            f'{f}.csv.gz')\n            idx = pd.read_csv(path, compression='gzip', header=None,\n                              dtype=np.int64).values.flatten()\n            idx = torch.from_numpy(idx)\n            mask = torch.zeros(data['paper'].num_nodes, dtype=torch.bool)\n            mask[idx] = True\n            data['paper'][f'{v}_mask'] = mask\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return 'ogbn-mag()'\n"
  },
  {
    "path": "torch_geometric/datasets/omdb.py",
    "content": "import os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, extract_tar\n\n\nclass OMDB(InMemoryDataset):\n    r\"\"\"The `Organic Materials Database (OMDB)\n    <https://omdb.mathub.io/dataset>`__ of bulk organic crystals.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'https://omdb.mathub.io/dataset'\n\n    def __init__(\n        self,\n        root: str,\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = self.processed_paths[0] if train else self.processed_paths[1]\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'OMDB-GAP1_v1.1.tar.gz'\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train_data.pt', 'test_data.pt']\n\n    def download(self) -> None:\n        raise RuntimeError(\n            f\"Dataset not found. Please download '{self.raw_file_names}' from \"\n            f\"'{self.url}' and move it to '{self.raw_dir}'\")\n\n    def process(self) -> None:\n        from ase.io import read\n\n        extract_tar(self.raw_paths[0], self.raw_dir, log=False)\n        materials = read(osp.join(self.raw_dir, 'structures.xyz'), index=':')\n        bandgaps = np.loadtxt(osp.join(self.raw_dir, 'bandgaps.csv'))\n\n        data_list = []\n        for material, bandgap in zip(materials, bandgaps):\n            pos = torch.from_numpy(material.get_positions()).to(torch.float)\n            z = torch.from_numpy(material.get_atomic_numbers()).to(torch.int64)\n            y = torch.tensor([float(bandgap)])\n            data_list.append(Data(z=z, pos=pos, y=y))\n\n        train_data = data_list[:10000]\n        test_data = data_list[10000:]\n\n        if self.pre_filter is not None:\n            train_data = [d for d in train_data if self.pre_filter(d)]\n            test_data = [d for d in test_data if self.pre_filter(d)]\n\n        if self.pre_transform is not None:\n            train_data = [self.pre_transform(d) for d in train_data]\n            test_data = [self.pre_transform(d) for d in test_data]\n\n        self.save(train_data, self.processed_paths[0])\n        self.save(test_data, self.processed_paths[1])\n"
  },
  {
    "path": "torch_geometric/datasets/opf.py",
    "content": "import json\nimport os\nimport os.path as osp\nfrom typing import Callable, Dict, List, Literal, Optional\n\nimport torch\nimport tqdm\nfrom torch import Tensor\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_tar,\n)\n\n\nclass OPFDataset(InMemoryDataset):\n    r\"\"\"The heterogeneous OPF data from the `\"Large-scale Datasets for AC\n    Optimal Power Flow with Topological Perturbations\"\n    <https://arxiv.org/abs/2406.07234>`_ paper.\n\n    :class:`OPFDataset` is a large-scale dataset of solved optimal power flow\n    problems, derived from the\n    `pglib-opf <https://github.com/power-grid-lib/pglib-opf>`_ dataset.\n\n    The physical topology of the grid is represented by the :obj:`\"bus\"` node\n    type, and the connecting AC lines and transformers. Additionally,\n    :obj:`\"generator\"`, :obj:`\"load\"`, and :obj:`\"shunt\"` nodes are connected\n    to :obj:`\"bus\"` nodes using a dedicated edge type each, *e.g.*,\n    :obj:`\"generator_link\"`.\n\n    Edge direction corresponds to the properties of the line, *e.g.*,\n    :obj:`b_fr` is the line charging susceptance at the :obj:`from`\n    (source/sender) bus.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset. (default: :obj:`\"train\"`)\n        case_name (str, optional): The name of the original pglib-opf case.\n            (default: :obj:`\"pglib_opf_case14_ieee\"`)\n        num_groups (int, optional): The dataset is divided into 20 groups with\n            each group containing 15,000 samples.\n            For large networks, this amount of data can be overwhelming.\n            The :obj:`num_groups` parameters controls the amount of data being\n            downloaded. Allowed values are :obj:`[1, 20]`.\n            (default: :obj:`20`)\n        topological_perturbations (bool, optional): Whether to use the dataset\n            with added topological perturbations. (default: :obj:`False`)\n        transform (callable, optional): A function/transform that takes in\n            a :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes\n            in a :obj:`torch_geometric.data.HeteroData` object and returns\n            a transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in a\n            :obj:`torch_geometric.data.HeteroData` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    url = 'https://storage.googleapis.com/gridopt-dataset'\n\n    def __init__(\n        self,\n        root: str,\n        split: Literal['train', 'val', 'test'] = 'train',\n        case_name: Literal[\n            'pglib_opf_case14_ieee',\n            'pglib_opf_case30_ieee',\n            'pglib_opf_case57_ieee',\n            'pglib_opf_case118_ieee',\n            'pglib_opf_case500_goc',\n            'pglib_opf_case2000_goc',\n            'pglib_opf_case6470_rte',\n            'pglib_opf_case4661_sdet'\n            'pglib_opf_case10000_goc',\n            'pglib_opf_case13659_pegase',\n        ] = 'pglib_opf_case14_ieee',\n        num_groups: int = 20,\n        topological_perturbations: bool = False,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n\n        self.split = split\n        self.case_name = case_name\n        self.num_groups = num_groups\n        self.topological_perturbations = topological_perturbations\n\n        self._release = 'dataset_release_1'\n        if topological_perturbations:\n            self._release += '_nminusone'\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n\n        idx = self.processed_file_names.index(f'{split}.pt')\n        self.load(self.processed_paths[idx])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self._release, self.case_name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self._release, self.case_name,\n                        f'processed_{self.num_groups}')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [f'{self.case_name}_{i}.tar.gz' for i in range(self.num_groups)]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train.pt', 'val.pt', 'test.pt']\n\n    def download(self) -> None:\n        for name in self.raw_file_names:\n            url = f'{self.url}/{self._release}/{name}'\n            path = download_url(url, self.raw_dir)\n            extract_tar(path, self.raw_dir)\n\n    def process(self) -> None:\n        train_data_list = []\n        val_data_list = []\n        test_data_list = []\n\n        for group in tqdm.tqdm(range(self.num_groups)):\n            tmp_dir = osp.join(\n                self.raw_dir,\n                'gridopt-dataset-tmp',\n                self._release,\n                self.case_name,\n                f'group_{group}',\n            )\n\n            for name in os.listdir(tmp_dir):\n                with open(osp.join(tmp_dir, name)) as f:\n                    obj = json.load(f)\n\n                grid = obj['grid']\n                solution = obj['solution']\n                metadata = obj['metadata']\n\n                # Graph-level properties:\n                data = HeteroData()\n                data.x = torch.tensor(grid['context']).view(-1)\n\n                data.objective = torch.tensor(metadata['objective'])\n\n                # Nodes (only some have a target):\n                data['bus'].x = torch.tensor(grid['nodes']['bus'])\n                data['bus'].y = torch.tensor(solution['nodes']['bus'])\n\n                data['generator'].x = torch.tensor(grid['nodes']['generator'])\n                data['generator'].y = torch.tensor(\n                    solution['nodes']['generator'])\n\n                data['load'].x = torch.tensor(grid['nodes']['load'])\n\n                data['shunt'].x = torch.tensor(grid['nodes']['shunt'])\n\n                # Edges (only ac lines and transformers have features):\n                data['bus', 'ac_line', 'bus'].edge_index = (  #\n                    extract_edge_index(obj, 'ac_line'))\n                data['bus', 'ac_line', 'bus'].edge_attr = torch.tensor(\n                    grid['edges']['ac_line']['features'])\n                data['bus', 'ac_line', 'bus'].edge_label = torch.tensor(\n                    solution['edges']['ac_line']['features'])\n\n                data['bus', 'transformer', 'bus'].edge_index = (  #\n                    extract_edge_index(obj, 'transformer'))\n                data['bus', 'transformer', 'bus'].edge_attr = torch.tensor(\n                    grid['edges']['transformer']['features'])\n                data['bus', 'transformer', 'bus'].edge_label = torch.tensor(\n                    solution['edges']['transformer']['features'])\n\n                data['generator', 'generator_link', 'bus'].edge_index = (  #\n                    extract_edge_index(obj, 'generator_link'))\n                data['bus', 'generator_link', 'generator'].edge_index = (  #\n                    extract_edge_index_rev(obj, 'generator_link'))\n\n                data['load', 'load_link', 'bus'].edge_index = (  #\n                    extract_edge_index(obj, 'load_link'))\n                data['bus', 'load_link', 'load'].edge_index = (  #\n                    extract_edge_index_rev(obj, 'load_link'))\n\n                data['shunt', 'shunt_link', 'bus'].edge_index = (  #\n                    extract_edge_index(obj, 'shunt_link'))\n                data['bus', 'shunt_link', 'shunt'].edge_index = (  #\n                    extract_edge_index_rev(obj, 'shunt_link'))\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                i = int(name.split('.')[0].split('_')[1])\n                train_limit = int(15_000 * self.num_groups * 0.9)\n                val_limit = train_limit + int(15_000 * self.num_groups * 0.05)\n                if i < train_limit:\n                    train_data_list.append(data)\n                elif i < val_limit:\n                    val_data_list.append(data)\n                else:\n                    test_data_list.append(data)\n\n        self.save(train_data_list, self.processed_paths[0])\n        self.save(val_data_list, self.processed_paths[1])\n        self.save(test_data_list, self.processed_paths[2])\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'split={self.split}, '\n                f'case_name={self.case_name}, '\n                f'topological_perturbations={self.topological_perturbations})')\n\n\ndef extract_edge_index(obj: Dict, edge_name: str) -> Tensor:\n    return torch.tensor([\n        obj['grid']['edges'][edge_name]['senders'],\n        obj['grid']['edges'][edge_name]['receivers'],\n    ])\n\n\ndef extract_edge_index_rev(obj: Dict, edge_name: str) -> Tensor:\n    return torch.tensor([\n        obj['grid']['edges'][edge_name]['receivers'],\n        obj['grid']['edges'][edge_name]['senders'],\n    ])\n"
  },
  {
    "path": "torch_geometric/datasets/ose_gvcs.py",
    "content": "import json\nimport os\nfrom collections import defaultdict\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_tar,\n)\n\n\nclass OSE_GVCS(InMemoryDataset):\n    r\"\"\"A dataset describing the `Product ecology\n    <https://wiki.opensourceecology.org/wiki/Product_Ecologies>`_ of the Open\n    Source Ecology's iconoclastic `Global Village Construction Set\n    <https://wiki.opensourceecology.org/wiki/\n    Global_Village_Construction_Set>`_.\n    GVCS is a modular, DIY, low-cost set of blueprints that enables the\n    fabrication of the 50 different industrial machines that it takes to\n    build a small, sustainable civilization with modern comforts.\n\n    The dataset contains a heterogenous graphs with 50 :obj:`machine` nodes,\n    composing the GVCS, and 290 directed edges, each representing one out of\n    three relationships between machines.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    machines = [\n        '3D Printer', '3D Scanner', 'Aluminum Extractor', 'Backhoe',\n        'Bakery Oven', 'Baler', 'Bioplastic Extruder', 'Bulldozer', 'Car',\n        'CEB Press', 'Cement Mixer', 'Chipper Hammermill', 'CNC Circuit Mill',\n        'CNC Torch Table', 'Dairy Milker', 'Drill Press',\n        'Electric Motor Generator', 'Gasifier Burner', 'Hay Cutter',\n        'Hay Rake', 'Hydraulic Motor', 'Induction Furnace', 'Industrial Robot',\n        'Ironworker', 'Laser Cutter', 'Metal Roller', 'Microcombine',\n        'Microtractor', 'Multimachine', 'Nickel-Iron Battery', 'Pelletizer',\n        'Plasma Cutter', 'Power Cube', 'Press Forge', 'Rod and Wire Mill',\n        'Rototiller', 'Sawmill', 'Seeder', 'Solar Concentrator', 'Spader',\n        'Steam Engine', 'Steam Generator', 'Tractor', 'Trencher', 'Truck',\n        'Universal Power Supply', 'Universal Rotor', 'Welder',\n        'Well-Drilling Rig', 'Wind Turbine'\n    ]\n    categories = [\n        'habitat', 'agriculture', 'industry', 'energy', 'materials',\n        'transportation'\n    ]\n    relationships = ['from', 'uses', 'enables']\n\n    url = 'https://github.com/Wesxdz/ose_gvcs/raw/master/ose_gvcs.tar.gz'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            f\"{machine.lower().replace(' ', '_')}.json\"\n            for machine in self.machines\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_tar(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        data = HeteroData()\n\n        categories = []\n        edges = defaultdict(list)\n\n        for path in self.raw_paths:\n            with open(path) as f:\n                product = json.load(f)\n            categories.append(self.categories.index(product['category']))\n            for interaction in product['ecology']:\n                # NOTE Some ecology items are not GVCS machines or have other\n                # relationship types we don't want included.\n                rt = interaction['relationship']\n                if rt not in self.relationships:\n                    continue\n                dst = interaction['tool']\n                if dst not in self.machines:\n                    continue\n                # Machines are guaranteed to be sorted according to their order\n                # in `self.machines`, so we can use its index for the mapping:\n                src = self.machines.index(product['machine'])\n                dst = self.machines.index(dst)\n                edges[rt].append((src, dst))\n\n        data['machine'].num_nodes = len(categories)\n        data['machine'].category = torch.tensor(categories)\n\n        for rel, edge_indices, in edges.items():\n            edge_index = torch.tensor(edge_indices).t().contiguous()\n            data['machine', rel, 'machine'].edge_index = edge_index\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/particle.py",
    "content": "import glob\nimport os.path as osp\nfrom typing import Any, Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, Dataset\nfrom torch_geometric.utils import index_sort, scatter\n\n\nclass TrackingData(Data):\n    def __inc__(self, key: str, value: Any, *args: Any, **kwargs: Any) -> Any:\n        if key == 'y_index':\n            return torch.tensor([value[0].max().item() + 1, self.num_nodes])\n        else:\n            return super().__inc__(key, value, *args, **kwargs)\n\n\nclass TrackMLParticleTrackingDataset(Dataset):\n    r\"\"\"The `TrackML Particle Tracking Challenge\n    <https://www.kaggle.com/c/trackml-particle-identification>`_ dataset to\n    reconstruct particle tracks from 3D points left in the silicon detectors.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n    \"\"\"\n\n    url = 'https://www.kaggle.com/c/trackml-particle-identification'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n    ) -> None:\n        super().__init__(root, transform)\n        events = glob.glob(osp.join(self.raw_dir, 'event*-hits.csv'))\n        events = [e.split(osp.sep)[-1].split('-')[0][5:] for e in events]\n        self.events: List[str] = sorted(events)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        event_indices = ['000001000']\n        file_names = []\n        file_names += [f'event{idx}-cells.csv' for idx in event_indices]\n        file_names += [f'event{idx}-hits.csv' for idx in event_indices]\n        file_names += [f'event{idx}-particles.csv' for idx in event_indices]\n        file_names += [f'event{idx}-truth.csv' for idx in event_indices]\n        return file_names\n\n    def download(self) -> None:\n        raise RuntimeError(\n            f'Dataset not found. Please download it from {self.url} and move '\n            f'all *.csv files to {self.raw_dir}')\n\n    def len(self) -> int:\n        return len(glob.glob(osp.join(self.raw_dir, 'event*-hits.csv')))\n\n    def get(self, i: int) -> TrackingData:\n        import pandas as pd\n\n        idx = self.events[i]\n\n        # Get hit positions.\n        hits_path = osp.join(self.raw_dir, f'event{idx}-hits.csv')\n        pos = pd.read_csv(hits_path, usecols=['x', 'y', 'z'], dtype=np.float32)\n        pos = torch.from_numpy(pos.values).div_(1000.)\n\n        # Get hit features.\n        cells_path = osp.join(self.raw_dir, f'event{idx}-cells.csv')\n        cell = pd.read_csv(cells_path, usecols=['hit_id', 'value'])\n        hit_id = torch.from_numpy(cell['hit_id'].values).to(torch.long).sub_(1)\n        value = torch.from_numpy(cell['value'].values).to(torch.float)\n        ones = torch.ones(hit_id.size(0))\n        num_cells = scatter(ones, hit_id, 0, pos.size(0), 'sum').div_(10.)\n        value = scatter(value, hit_id, 0, pos.size(0), 'sum')\n        x = torch.stack([num_cells, value], dim=-1)\n\n        # Get ground-truth hit assignments.\n        truth_path = osp.join(self.raw_dir, f'event{idx}-truth.csv')\n        y = pd.read_csv(truth_path,\n                        usecols=['hit_id', 'particle_id', 'weight'])\n        hit_id = torch.from_numpy(y['hit_id'].values).to(torch.long).sub_(1)\n        particle_id = torch.from_numpy(y['particle_id'].values).to(torch.long)\n        particle_id = particle_id.unique(return_inverse=True)[1].sub_(1)\n        weight = torch.from_numpy(y['weight'].values).to(torch.float)\n\n        # Sort.\n        _, perm = index_sort(particle_id * hit_id.size(0) + hit_id)\n        hit_id = hit_id[perm]\n        particle_id = particle_id[perm]\n        weight = weight[perm]\n\n        # Remove invalid particle ids.\n        mask = particle_id >= 0\n        hit_id = hit_id[mask]\n        particle_id = particle_id[mask]\n        weight = weight[mask]\n\n        y_index = torch.stack([particle_id, hit_id], dim=0)\n\n        return TrackingData(x=x, pos=pos, y_index=y_index, y_weight=weight)\n"
  },
  {
    "path": "torch_geometric/datasets/pascal.py",
    "content": "import os\nimport os.path as osp\nfrom itertools import chain\nfrom typing import Callable, Dict, List, Optional\nfrom xml.dom import minidom\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.utils.data import DataLoader\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_tar,\n)\nfrom torch_geometric.io import fs\n\n\nclass PascalVOCKeypoints(InMemoryDataset):\n    r\"\"\"The Pascal VOC 2011 dataset with Berkely annotations of keypoints from\n    the `\"Poselets: Body Part Detectors Trained Using 3D Human Pose\n    Annotations\" <https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/\n    human/ poselets_iccv09.pdf>`_ paper, containing 0 to 23 keypoints per\n    example over 20 categories.\n    The dataset is pre-filtered to exclude difficult, occluded and truncated\n    objects.\n    The keypoints contain interpolated features from a pre-trained VGG16 model\n    on ImageNet (:obj:`relu4_2` and :obj:`relu5_1`).\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        category (str): The category of the images (one of\n            :obj:`\"Aeroplane\"`, :obj:`\"Bicycle\"`, :obj:`\"Bird\"`,\n            :obj:`\"Boat\"`, :obj:`\"Bottle\"`, :obj:`\"Bus\"`, :obj:`\"Car\"`,\n            :obj:`\"Cat\"`, :obj:`\"Chair\"`, :obj:`\"Diningtable\"`, :obj:`\"Dog\"`,\n            :obj:`\"Horse\"`, :obj:`\"Motorbike\"`, :obj:`\"Person\"`,\n            :obj:`\"Pottedplant\"`, :obj:`\"Sheep\"`, :obj:`\"Sofa\"`,\n            :obj:`\"Train\"`, :obj:`\"TVMonitor\"`)\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        device (str or torch.device, optional): The device to use for\n            processing the raw data. If set to :obj:`None`, will utilize\n            GPU-processing if available. (default: :obj:`None`)\n    \"\"\"\n    image_url = ('http://host.robots.ox.ac.uk/pascal/VOC/voc2011/'\n                 'VOCtrainval_25-May-2011.tar')\n    annotation_url = ('https://www2.eecs.berkeley.edu/Research/Projects/CS/'\n                      'vision/shape/poselets/voc2011_keypoints_Feb2012.tgz')\n    # annotation_url = 'http://www.roemisch-drei.de/pascal_annotations.tar'\n    # split_url = 'http://cvgl.stanford.edu/projects/ucn/voc2011_pairs.npz'\n    split_url = ('https://github.com/Thinklab-SJTU/PCA-GM/raw/master/data/'\n                 'PascalVOC/voc2011_pairs.npz')\n\n    categories = [\n        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',\n        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',\n        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'\n    ]\n\n    batch_size = 32\n\n    def __init__(\n        self,\n        root: str,\n        category: str,\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n        device: Optional[str] = None,\n    ) -> None:\n        if device is None:\n            device = 'cuda' if torch.cuda.is_available() else 'cpu'\n\n        self.category = category.lower()\n        assert self.category in self.categories\n        self.device = device\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = self.processed_paths[0] if train else self.processed_paths[1]\n        self.load(path)\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.category.capitalize(), 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['images', 'annotations', 'splits.npz']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['training.pt', 'test.pt']\n\n    def download(self) -> None:\n        path = download_url(self.image_url, self.raw_dir)\n        extract_tar(path, self.raw_dir, mode='r')\n        os.unlink(path)\n        image_path = osp.join(self.raw_dir, 'TrainVal', 'VOCdevkit', 'VOC2011')\n        os.rename(image_path, osp.join(self.raw_dir, 'images'))\n        fs.rm(osp.join(self.raw_dir, 'TrainVal'))\n\n        path = download_url(self.annotation_url, self.raw_dir)\n        extract_tar(path, self.raw_dir, mode='r')\n        os.unlink(path)\n\n        path = download_url(self.split_url, self.raw_dir)\n        os.rename(path, osp.join(self.raw_dir, 'splits.npz'))\n\n    def process(self) -> None:\n        import torchvision.models as models\n        import torchvision.transforms as T\n        from PIL import Image\n\n        splits = np.load(osp.join(self.raw_dir, 'splits.npz'),\n                         allow_pickle=True)\n        category_idx = self.categories.index(self.category)\n        train_split = list(splits['train'])[category_idx]\n        test_split = list(splits['test'])[category_idx]\n\n        image_path = osp.join(self.raw_dir, 'images', 'JPEGImages')\n        info_path = osp.join(self.raw_dir, 'images', 'Annotations')\n        annotation_path = osp.join(self.raw_dir, 'annotations')\n\n        labels: Dict[str, int] = {}\n\n        vgg16_outputs = []\n\n        def hook(module: torch.nn.Module, x: Tensor, y: Tensor) -> None:\n            vgg16_outputs.append(y)\n\n        vgg16 = models.vgg16(pretrained=True).to(self.device)\n        vgg16.eval()\n        vgg16.features[20].register_forward_hook(hook)  # relu4_2\n        vgg16.features[25].register_forward_hook(hook)  # relu5_1\n\n        transform = T.Compose([\n            T.ToTensor(),\n            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n        ])\n\n        train_set, test_set = [], []\n        for i, name in enumerate(chain(train_split, test_split)):\n            filename = '_'.join(name.split('/')[1].split('_')[:-1])\n            file_idx = int(name.split('_')[-1].split('.')[0]) - 1\n\n            path = osp.join(info_path, f'{filename}.xml')\n            obj = minidom.parse(path).getElementsByTagName('object')[file_idx]\n\n            child = obj.getElementsByTagName('truncated')[0].firstChild\n            assert child is not None\n            trunc = child.data  # type: ignore\n\n            elements = obj.getElementsByTagName('occluded')\n            if len(elements) == 0:\n                occ = '0'\n            else:\n                child = elements[0].firstChild\n                assert child is not None\n                occ = child.data  # type: ignore\n\n            child = obj.getElementsByTagName('difficult')[0].firstChild\n            diff = child.data  # type: ignore\n\n            if bool(int(trunc)) or bool(int(occ)) or bool(int(diff)):\n                continue\n\n            if self.category == 'person' and int(filename[:4]) > 2008:\n                continue\n\n            child = obj.getElementsByTagName('xmin')[0].firstChild\n            assert child is not None\n            xmin = int(child.data)  # type: ignore\n\n            child = obj.getElementsByTagName('xmax')[0].firstChild\n            assert child is not None\n            xmax = int(child.data)  # type: ignore\n\n            child = obj.getElementsByTagName('ymin')[0].firstChild\n            assert child is not None\n            ymin = int(child.data)  # type: ignore\n\n            child = obj.getElementsByTagName('ymax')[0].firstChild\n            assert child is not None\n            ymax = int(child.data)  # type: ignore\n\n            box = (xmin, ymin, xmax, ymax)\n\n            dom = minidom.parse(osp.join(annotation_path, name))\n            keypoints = dom.getElementsByTagName('keypoint')\n            poss, ys = [], []\n            for keypoint in keypoints:\n                label = keypoint.attributes['name'].value\n                if label not in labels:\n                    labels[label] = len(labels)\n                ys.append(labels[label])\n                _x = float(keypoint.attributes['x'].value)\n                _y = float(keypoint.attributes['y'].value)\n                poss += [_x, _y]\n            y = torch.tensor(ys, dtype=torch.long)\n            pos = torch.tensor(poss, dtype=torch.float).view(-1, 2)\n\n            if pos.numel() == 0:\n                continue  # These examples do not make any sense anyway...\n\n            # Add a small offset to the bounding because some keypoints lay\n            # outside the bounding box intervals.\n            box = (\n                min(int(pos[:, 0].min().floor()), box[0]) - 16,\n                min(int(pos[:, 1].min().floor()), box[1]) - 16,\n                max(int(pos[:, 0].max().ceil()), box[2]) + 16,\n                max(int(pos[:, 1].max().ceil()), box[3]) + 16,\n            )\n\n            # Rescale keypoints.\n            pos[:, 0] = (pos[:, 0] - box[0]) * 256.0 / (box[2] - box[0])\n            pos[:, 1] = (pos[:, 1] - box[1]) * 256.0 / (box[3] - box[1])\n\n            path = osp.join(image_path, f'{filename}.jpg')\n            with open(path, 'rb') as f:\n                img = Image.open(f).convert('RGB').crop(box)\n                img = img.resize((256, 256), resample=Image.Resampling.BICUBIC)\n\n            img = transform(img)\n\n            data = Data(img=img, pos=pos, y=y, name=filename)\n\n            if i < len(train_split):\n                train_set.append(data)\n            else:\n                test_set.append(data)\n\n        data_list = list(chain(train_set, test_set))\n        imgs = [data.img for data in data_list]\n        loader: DataLoader = DataLoader(\n            dataset=imgs,  # type: ignore\n            batch_size=self.batch_size,\n            shuffle=False,\n        )\n        for i, batch_img in enumerate(loader):\n            vgg16_outputs.clear()\n\n            with torch.no_grad():\n                vgg16(batch_img.to(self.device))\n\n            out1 = F.interpolate(vgg16_outputs[0], (256, 256), mode='bilinear',\n                                 align_corners=False)\n            out2 = F.interpolate(vgg16_outputs[1], (256, 256), mode='bilinear',\n                                 align_corners=False)\n\n            for j in range(out1.size(0)):\n                data = data_list[i * self.batch_size + j]\n                assert data.pos is not None\n                idx = data.pos.round().long().clamp(0, 255)\n                x_1 = out1[j, :, idx[:, 1], idx[:, 0]].to('cpu')\n                x_2 = out2[j, :, idx[:, 1], idx[:, 0]].to('cpu')\n                data.img = None\n                data.x = torch.cat([x_1.t(), x_2.t()], dim=-1)\n            del out1\n            del out2\n\n        if self.pre_filter is not None:\n            train_set = [data for data in train_set if self.pre_filter(data)]\n            test_set = [data for data in test_set if self.pre_filter(data)]\n\n        if self.pre_transform is not None:\n            train_set = [self.pre_transform(data) for data in train_set]\n            test_set = [self.pre_transform(data) for data in test_set]\n\n        self.save(train_set, self.processed_paths[0])\n        self.save(test_set, self.processed_paths[1])\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'category={self.category})')\n"
  },
  {
    "path": "torch_geometric/datasets/pascal_pf.py",
    "content": "import glob\nimport os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass PascalPF(InMemoryDataset):\n    r\"\"\"The Pascal-PF dataset from the `\"Proposal Flow\"\n    <https://arxiv.org/abs/1511.05065>`_ paper, containing 4 to 16 keypoints\n    per example over 20 categories.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        category (str): The category of the images (one of\n            :obj:`\"Aeroplane\"`, :obj:`\"Bicycle\"`, :obj:`\"Bird\"`,\n            :obj:`\"Boat\"`, :obj:`\"Bottle\"`, :obj:`\"Bus\"`, :obj:`\"Car\"`,\n            :obj:`\"Cat\"`, :obj:`\"Chair\"`, :obj:`\"Diningtable\"`, :obj:`\"Dog\"`,\n            :obj:`\"Horse\"`, :obj:`\"Motorbike\"`, :obj:`\"Person\"`,\n            :obj:`\"Pottedplant\"`, :obj:`\"Sheep\"`, :obj:`\"Sofa\"`,\n            :obj:`\"Train\"`, :obj:`\"TVMonitor\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    url = ('https://www.di.ens.fr/willow/research/proposalflow/dataset/'\n           'PF-dataset-PASCAL.zip')\n\n    categories = [\n        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',\n        'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',\n        'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'\n    ]\n\n    def __init__(\n        self,\n        root: str,\n        category: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.category = category.lower()\n        assert self.category in self.categories\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n        self.pairs = fs.torch_load(self.processed_paths[1])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['Annotations', 'parsePascalVOC.mat']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return [f'{self.category}.pt', f'{self.category}_pairs.pt']\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        fs.rm(self.raw_dir)\n        os.rename(osp.join(self.root, 'PF-dataset-PASCAL'), self.raw_dir)\n\n    def process(self) -> None:\n        from scipy.io import loadmat\n\n        path = osp.join(self.raw_dir, 'Annotations', self.category, '*.mat')\n        filenames = glob.glob(path)\n\n        names = []\n        data_list = []\n        for filename in filenames:\n            name = osp.basename(filename).split('.')[0]\n\n            pos = torch.from_numpy(loadmat(filename)['kps']).to(torch.float)\n            mask = ~torch.isnan(pos[:, 0])\n            pos = pos[mask]\n\n            # Normalize points to unit sphere.\n            pos = pos - pos.mean(dim=0, keepdim=True)\n            pos = pos / pos.norm(dim=1).max()\n\n            y = mask.nonzero(as_tuple=False).flatten()\n\n            data = Data(pos=pos, y=y, name=name)\n\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n\n            names.append(name)\n            data_list.append(data)\n\n        pairs = loadmat(osp.join(self.raw_dir, 'parsePascalVOC.mat'))\n        pairs = pairs['PascalVOC']['pair'][0, 0][\n            0, self.categories.index(self.category)]\n\n        pairs = [(names.index(x[0][0]), names.index(x[1][0])) for x in pairs]\n\n        self.save(data_list, self.processed_paths[0])\n        torch.save(pairs, self.processed_paths[1])\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'category={self.category})')\n"
  },
  {
    "path": "torch_geometric/datasets/pcpnet_dataset.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import read_txt_array\n\n\nclass PCPNetDataset(InMemoryDataset):\n    r\"\"\"The PCPNet dataset from the `\"PCPNet: Learning Local Shape Properties\n    from Raw Point Clouds\" <https://arxiv.org/abs/1710.04954>`_ paper,\n    consisting of 30 shapes, each given as a point cloud, densely sampled with\n    100k points.\n    For each shape, surface normals and local curvatures are given as node\n    features.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        category (str): The training set category (one of :obj:`\"NoNoise\"`,\n            :obj:`\"Noisy\"`, :obj:`\"VarDensity\"`, :obj:`\"NoisyAndVarDensity\"`\n            for :obj:`split=\"train\"` or :obj:`split=\"val\"`,\n            or one of :obj:`\"All\"`, :obj:`\"LowNoise\"`, :obj:`\"MedNoise\"`,\n            :obj:`\"HighNoise\", :obj:`\"VarDensityStriped\",\n            :obj:`\"VarDensityGradient\"` for :obj:`split=\"test\"`).\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset. (default: :obj:`\"train\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'http://geometry.cs.ucl.ac.uk/projects/2018/pcpnet/pclouds.zip'\n\n    category_files_train = {\n        'NoNoise': 'trainingset_no_noise.txt',\n        'Noisy': 'trainingset_whitenoise.txt',\n        'VarDensity': 'trainingset_vardensity.txt',\n        'NoisyAndVarDensity': 'trainingset_vardensity_whitenoise.txt'\n    }\n\n    category_files_val = {\n        'NoNoise': 'validationset_no_noise.txt',\n        'Noisy': 'validationset_whitenoise.txt',\n        'VarDensity': 'validationset_vardensity.txt',\n        'NoisyAndVarDensity': 'validationset_vardensity_whitenoise.txt'\n    }\n\n    category_files_test = {\n        'All': 'testset_all.txt',\n        'NoNoise': 'testset_no_noise.txt',\n        'LowNoise': 'testset_low_noise.txt',\n        'MedNoise': 'testset_med_noise.txt',\n        'HighNoise': 'testset_high_noise.txt',\n        'VarDensityStriped': 'testset_vardensity_striped.txt',\n        'VarDensityGradient': 'testset_vardensity_gradient.txt'\n    }\n\n    def __init__(\n        self,\n        root: str,\n        category: str,\n        split: str = 'train',\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n\n        assert split in ['train', 'val', 'test']\n\n        if split == 'train':\n            assert category in self.category_files_train.keys()\n        elif split == 'val':\n            assert category in self.category_files_val.keys()\n        else:\n            assert category in self.category_files_test.keys()\n\n        self.category = category\n        self.split = split\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> str:\n        if self.split == 'train':\n            return self.category_files_train[self.category]\n        elif self.split == 'val':\n            return self.category_files_val[self.category]\n        else:\n            return self.category_files_test[self.category]\n\n    @property\n    def processed_file_names(self) -> str:\n        return self.split + '_' + self.category + '.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        path_file = self.raw_paths\n        with open(path_file[0]) as f:\n            filenames = f.read().split('\\n')[:-1]\n        data_list = []\n        for filename in filenames:\n            pos_path = osp.join(self.raw_dir, filename + '.xyz')\n            normal_path = osp.join(self.raw_dir, filename + '.normals')\n            curv_path = osp.join(self.raw_dir, filename + '.curv')\n            idx_path = osp.join(self.raw_dir, filename + '.pidx')\n            pos = read_txt_array(pos_path)\n            normals = read_txt_array(normal_path)\n            curv = read_txt_array(curv_path)\n            normals_and_curv = torch.cat([normals, curv], dim=1)\n            test_idx = read_txt_array(idx_path, dtype=torch.long)\n            data = Data(pos=pos, x=normals_and_curv)\n            data.test_idx = test_idx\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n            data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'category={self.category})')\n"
  },
  {
    "path": "torch_geometric/datasets/pcqm4m.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Any, Callable, Dict, List, Optional\n\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import Data, OnDiskDataset, download_url, extract_zip\nfrom torch_geometric.data.data import BaseData\nfrom torch_geometric.io import fs\nfrom torch_geometric.utils import from_smiles as _from_smiles\n\n\nclass PCQM4Mv2(OnDiskDataset):\n    r\"\"\"The PCQM4Mv2 dataset from the `\"OGB-LSC: A Large-Scale Challenge for\n    Machine Learning on Graphs\" <https://arxiv.org/abs/2103.09430>`_ paper.\n    :class:`PCQM4Mv2` is a quantum chemistry dataset originally curated under\n    the `PubChemQC project\n    <https://pubs.acs.org/doi/10.1021/acs.jcim.7b00083>`_.\n    The task is to predict the DFT-calculated HOMO-LUMO energy gap of molecules\n    given their 2D molecular graphs.\n\n    .. note::\n        This dataset uses the :class:`OnDiskDataset` base class to load data\n        dynamically from disk.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset.\n            If :obj:`\"holdout\"`, loads the holdout dataset.\n            (default: :obj:`\"train\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        backend (str): The :class:`Database` backend to use.\n            (default: :obj:`\"sqlite\"`)\n        from_smiles (callable, optional): A custom function that takes a SMILES\n            string and outputs a :obj:`~torch_geometric.data.Data` object.\n            If not set, defaults to :meth:`~torch_geometric.utils.from_smiles`.\n            (default: :obj:`None`)\n    \"\"\"\n    url = ('https://dgl-data.s3-accelerate.amazonaws.com/dataset/OGB-LSC/'\n           'pcqm4m-v2.zip')\n\n    split_mapping = {\n        'train': 'train',\n        'val': 'valid',\n        'test': 'test-dev',\n        'holdout': 'test-challenge',\n    }\n\n    def __init__(\n        self,\n        root: str,\n        split: str = 'train',\n        transform: Optional[Callable] = None,\n        backend: str = 'sqlite',\n        from_smiles: Optional[Callable] = None,\n    ) -> None:\n        assert split in ['train', 'val', 'test', 'holdout']\n\n        schema = {\n            'x': dict(dtype=torch.int64, size=(-1, 9)),\n            'edge_index': dict(dtype=torch.int64, size=(2, -1)),\n            'edge_attr': dict(dtype=torch.int64, size=(-1, 3)),\n            'smiles': str,\n            'y': float,\n        }\n\n        self.from_smiles = from_smiles or _from_smiles\n        super().__init__(root, transform, backend=backend, schema=schema)\n\n        split_idx = fs.torch_load(self.raw_paths[1])\n        self._indices = split_idx[self.split_mapping[split]].tolist()\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            osp.join('pcqm4m-v2', 'raw', 'data.csv.gz'),\n            osp.join('pcqm4m-v2', 'split_dict.pt'),\n        ]\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        df = pd.read_csv(self.raw_paths[0])\n\n        data_list: List[Data] = []\n        iterator = enumerate(zip(df['smiles'], df['homolumogap']))\n        for i, (smiles, y) in tqdm(iterator, total=len(df)):\n            data = self.from_smiles(smiles)\n            data.y = y\n\n            data_list.append(data)\n            if i + 1 == len(df) or (i + 1) % 1000 == 0:  # Write batch-wise:\n                self.extend(data_list)\n                data_list = []\n\n    def serialize(self, data: BaseData) -> Dict[str, Any]:\n        assert isinstance(data, Data)\n        return dict(\n            x=data.x,\n            edge_index=data.edge_index,\n            edge_attr=data.edge_attr,\n            y=data.y,\n            smiles=data.smiles,\n        )\n\n    def deserialize(self, data: Dict[str, Any]) -> Data:\n        return Data.from_dict(data)\n"
  },
  {
    "path": "torch_geometric/datasets/planetoid.py",
    "content": "import os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import InMemoryDataset\nfrom torch_geometric.io import fs, read_planetoid_data\n\n\nclass Planetoid(InMemoryDataset):\n    r\"\"\"The citation network datasets :obj:`\"Cora\"`, :obj:`\"CiteSeer\"` and\n    :obj:`\"PubMed\"` from the `\"Revisiting Semi-Supervised Learning with Graph\n    Embeddings\" <https://arxiv.org/abs/1603.08861>`_ paper.\n    Nodes represent documents and edges represent citation links.\n    Training, validation and test splits are given by binary masks.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"Cora\"`, :obj:`\"CiteSeer\"`,\n            :obj:`\"PubMed\"`).\n        split (str, optional): The type of dataset split (:obj:`\"public\"`,\n            :obj:`\"full\"`, :obj:`\"geom-gcn\"`, :obj:`\"random\"`).\n            If set to :obj:`\"public\"`, the split will be the public fixed split\n            from the `\"Revisiting Semi-Supervised Learning with Graph\n            Embeddings\" <https://arxiv.org/abs/1603.08861>`_ paper.\n            If set to :obj:`\"full\"`, all nodes except those in the validation\n            and test sets will be used for training (as in the\n            `\"FastGCN: Fast Learning with Graph Convolutional Networks via\n            Importance Sampling\" <https://arxiv.org/abs/1801.10247>`_ paper).\n            If set to :obj:`\"geom-gcn\"`, the 10 public fixed splits from the\n            `\"Geom-GCN: Geometric Graph Convolutional Networks\"\n            <https://openreview.net/forum?id=S1e2agrFvS>`_ paper are given.\n            If set to :obj:`\"random\"`, train, validation, and test sets will be\n            randomly generated, according to :obj:`num_train_per_class`,\n            :obj:`num_val` and :obj:`num_test`. (default: :obj:`\"public\"`)\n        num_train_per_class (int, optional): The number of training samples\n            per class in case of :obj:`\"random\"` split. (default: :obj:`20`)\n        num_val (int, optional): The number of validation samples in case of\n            :obj:`\"random\"` split. (default: :obj:`500`)\n        num_test (int, optional): The number of test samples in case of\n            :obj:`\"random\"` split. (default: :obj:`1000`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - Cora\n          - 2,708\n          - 10,556\n          - 1,433\n          - 7\n        * - CiteSeer\n          - 3,327\n          - 9,104\n          - 3,703\n          - 6\n        * - PubMed\n          - 19,717\n          - 88,648\n          - 500\n          - 3\n    \"\"\"\n    url = 'https://github.com/kimiyoung/planetoid/raw/master/data'\n    geom_gcn_url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/'\n                    'geom-gcn/master')\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        split: str = \"public\",\n        num_train_per_class: int = 20,\n        num_val: int = 500,\n        num_test: int = 1000,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name\n\n        self.split = split.lower()\n        assert self.split in ['public', 'full', 'geom-gcn', 'random']\n\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n        if split == 'full':\n            data = self.get(0)\n            data.train_mask.fill_(True)\n            data.train_mask[data.val_mask | data.test_mask] = False\n            self.data, self.slices = self.collate([data])\n\n        elif split == 'random':\n            data = self.get(0)\n            data.train_mask.fill_(False)\n            for c in range(self.num_classes):\n                idx = (data.y == c).nonzero(as_tuple=False).view(-1)\n                idx = idx[torch.randperm(idx.size(0))[:num_train_per_class]]\n                data.train_mask[idx] = True\n\n            remaining = (~data.train_mask).nonzero(as_tuple=False).view(-1)\n            remaining = remaining[torch.randperm(remaining.size(0))]\n\n            data.val_mask.fill_(False)\n            data.val_mask[remaining[:num_val]] = True\n\n            data.test_mask.fill_(False)\n            data.test_mask[remaining[num_val:num_val + num_test]] = True\n\n            self.data, self.slices = self.collate([data])\n\n    @property\n    def raw_dir(self) -> str:\n        if self.split == 'geom-gcn':\n            return osp.join(self.root, self.name, 'geom-gcn', 'raw')\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        if self.split == 'geom-gcn':\n            return osp.join(self.root, self.name, 'geom-gcn', 'processed')\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']\n        return [f'ind.{self.name.lower()}.{name}' for name in names]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for name in self.raw_file_names:\n            fs.cp(f'{self.url}/{name}', self.raw_dir)\n        if self.split == 'geom-gcn':\n            for i in range(10):\n                url = f'{self.geom_gcn_url}/splits/{self.name.lower()}'\n                fs.cp(f'{url}_split_0.6_0.2_{i}.npz', self.raw_dir)\n\n    def process(self) -> None:\n        data = read_planetoid_data(self.raw_dir, self.name)\n\n        if self.split == 'geom-gcn':\n            train_masks, val_masks, test_masks = [], [], []\n            for i in range(10):\n                name = f'{self.name.lower()}_split_0.6_0.2_{i}.npz'\n                splits = np.load(osp.join(self.raw_dir, name))\n                train_masks.append(torch.from_numpy(splits['train_mask']))\n                val_masks.append(torch.from_numpy(splits['val_mask']))\n                test_masks.append(torch.from_numpy(splits['test_mask']))\n            data.train_mask = torch.stack(train_masks, dim=1)\n            data.val_mask = torch.stack(val_masks, dim=1)\n            data.test_mask = torch.stack(test_masks, dim=1)\n\n        data = data if self.pre_transform is None else self.pre_transform(data)\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.name}()'\n"
  },
  {
    "path": "torch_geometric/datasets/polblogs.py",
    "content": "import os\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_tar,\n)\n\n\nclass PolBlogs(InMemoryDataset):\n    r\"\"\"The Political Blogs dataset from the `\"The Political Blogosphere and\n    the 2004 US Election: Divided they Blog\"\n    <https://dl.acm.org/doi/10.1145/1134271.1134277>`_ paper.\n\n    :class:`Polblogs` is a graph with 1,490 vertices (representing political\n    blogs) and 19,025 edges (links between blogs).\n    The links are automatically extracted from a crawl of the front page of the\n    blog.\n    Each vertex receives a label indicating the political leaning of the blog:\n    liberal or conservative.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 1,490\n          - 19,025\n          - 0\n          - 2\n    \"\"\"\n\n    url = 'https://netset.telecom-paris.fr/datasets/polblogs.tar.gz'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['adjacency.tsv', 'labels.tsv']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_tar(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        edge_index = pd.read_csv(self.raw_paths[0], header=None, sep='\\t',\n                                 usecols=[0, 1])\n        edge_index = torch.from_numpy(edge_index.values).t().contiguous()\n\n        y = pd.read_csv(self.raw_paths[1], header=None, sep='\\t')\n        y = torch.from_numpy(y.values).view(-1)\n\n        data = Data(edge_index=edge_index, y=y, num_nodes=y.size(0))\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/ppi.py",
    "content": "import json\nimport os\nimport os.path as osp\nfrom itertools import product\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.utils import remove_self_loops\n\n\nclass PPI(InMemoryDataset):\n    r\"\"\"The protein-protein interaction networks from the `\"Predicting\n    Multicellular Function through Multi-layer Tissue Networks\"\n    <https://arxiv.org/abs/1707.04638>`_ paper, containing positional gene\n    sets, motif gene sets and immunological signatures as features (50 in\n    total) and gene ontology sets as labels (121 in total).\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset. (default: :obj:`\"train\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #tasks\n        * - 20\n          - ~2,245.3\n          - ~61,318.4\n          - 50\n          - 121\n    \"\"\"\n\n    url = 'https://data.dgl.ai/dataset/ppi.zip'\n\n    def __init__(\n        self,\n        root: str,\n        split: str = 'train',\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n\n        assert split in ['train', 'val', 'test']\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n\n        if split == 'train':\n            self.load(self.processed_paths[0])\n        elif split == 'val':\n            self.load(self.processed_paths[1])\n        elif split == 'test':\n            self.load(self.processed_paths[2])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        splits = ['train', 'valid', 'test']\n        files = ['feats.npy', 'graph_id.npy', 'graph.json', 'labels.npy']\n        return [f'{split}_{name}' for split, name in product(splits, files)]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train.pt', 'val.pt', 'test.pt']\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        import networkx as nx\n        from networkx.readwrite import json_graph\n\n        for s, split in enumerate(['train', 'valid', 'test']):\n            path = osp.join(self.raw_dir, f'{split}_graph.json')\n            with open(path) as f:\n                G = nx.DiGraph(\n                    json_graph.node_link_graph(json.load(f), edges=\"links\"))\n\n            x = np.load(osp.join(self.raw_dir, f'{split}_feats.npy'))\n            x = torch.from_numpy(x).to(torch.float)\n\n            y = np.load(osp.join(self.raw_dir, f'{split}_labels.npy'))\n            y = torch.from_numpy(y).to(torch.float)\n\n            data_list = []\n            path = osp.join(self.raw_dir, f'{split}_graph_id.npy')\n            idx = torch.from_numpy(np.load(path)).to(torch.long)\n            idx = idx - idx.min()\n\n            for i in range(int(idx.max()) + 1):\n                mask = idx == i\n\n                G_s = G.subgraph(\n                    mask.nonzero(as_tuple=False).view(-1).tolist())\n                edge_index = torch.tensor(list(G_s.edges)).t().contiguous()\n                edge_index = edge_index - edge_index.min()\n                edge_index, _ = remove_self_loops(edge_index)\n\n                data = Data(edge_index=edge_index, x=x[mask], y=y[mask])\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                data_list.append(data)\n            self.save(data_list, self.processed_paths[s])\n"
  },
  {
    "path": "torch_geometric/datasets/protein_mpnn_dataset.py",
    "content": "import os\nimport pickle\nimport random\nfrom collections import defaultdict\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_tar,\n)\n\n\nclass ProteinMPNNDataset(InMemoryDataset):\n    r\"\"\"The ProteinMPNN dataset from the `\"Robust deep learning based protein\n    sequence design using ProteinMPNN\"\n    <https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1>`_ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        size (str): Size of the PDB information to train the model.\n            If :obj:`\"small\"`, loads the small dataset (229.4 MB).\n            If :obj:`\"large\"`, loads the large dataset (64.1 GB).\n            (default: :obj:`\"small\"`)\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"valid\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset.\n            (default: :obj:`\"train\"`)\n        datacut (str, optional): Date cutoff to filter the dataset.\n            (default: :obj:`\"2030-01-01\"`)\n        rescut (float, optional): PDB resolution cutoff.\n            (default: :obj:`3.5`)\n        homo (float, optional): Homology cutoff.\n            (default: :obj:`0.70`)\n        max_length (int, optional): Maximum length of the protein complex.\n            (default: :obj:`10000`)\n        num_units (int, optional): Number of units of the protein complex.\n            (default: :obj:`150`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    raw_url = {\n        'small':\n        'https://files.ipd.uw.edu/pub/training_sets/'\n        'pdb_2021aug02_sample.tar.gz',\n        'large':\n        'https://files.ipd.uw.edu/pub/training_sets/'\n        'pdb_2021aug02.tar.gz',\n    }\n\n    splits = {\n        'train': 1,\n        'valid': 2,\n        'test': 3,\n    }\n\n    def __init__(\n        self,\n        root: str,\n        size: str = 'small',\n        split: str = 'train',\n        datacut: str = '2030-01-01',\n        rescut: float = 3.5,\n        homo: float = 0.70,\n        max_length: int = 10000,\n        num_units: int = 150,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.size = size\n        self.split = split\n        self.datacut = datacut\n        self.rescut = rescut\n        self.homo = homo\n        self.max_length = max_length\n        self.num_units = num_units\n\n        self.sub_folder = self.raw_url[self.size].split('/')[-1].split('.')[0]\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[self.splits[self.split]])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            f'{self.sub_folder}/{f}'\n            for f in ['list.csv', 'valid_clusters.txt', 'test_clusters.txt']\n        ]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['splits.pkl', 'train.pt', 'valid.pt', 'test.pt']\n\n    def download(self) -> None:\n        file_path = download_url(self.raw_url[self.size], self.raw_dir)\n        extract_tar(file_path, self.raw_dir)\n        os.unlink(file_path)\n\n    def process(self) -> None:\n        alphabet_set = set(list('ACDEFGHIKLMNPQRSTVWYX'))\n        cluster_ids = self._process_split()\n        total_items = sum(len(items) for items in cluster_ids.values())\n        data_list = []\n\n        with tqdm(total=total_items, desc=\"Processing\") as pbar:\n            for _, items in cluster_ids.items():\n                for chain_id, _ in items:\n                    item = self._process_pdb1(chain_id)\n\n                    if 'label' not in item:\n                        pbar.update(1)\n                        continue\n                    if len(list(np.unique(item['idx']))) >= 352:\n                        pbar.update(1)\n                        continue\n\n                    my_dict = self._process_pdb2(item)\n\n                    if len(my_dict['seq']) > self.max_length:\n                        pbar.update(1)\n                        continue\n                    bad_chars = set(list(\n                        my_dict['seq'])).difference(alphabet_set)\n                    if len(bad_chars) > 0:\n                        pbar.update(1)\n                        continue\n\n                    x_chain_all, chain_seq_label_all, mask, chain_mask_all, residue_idx, chain_encoding_all = self._process_pdb3(  # noqa: E501\n                        my_dict)\n\n                    data = Data(\n                        x=x_chain_all,  # [seq_len, 4, 3]\n                        chain_seq_label=chain_seq_label_all,  # [seq_len]\n                        mask=mask,  # [seq_len]\n                        chain_mask_all=chain_mask_all,  # [seq_len]\n                        residue_idx=residue_idx,  # [seq_len]\n                        chain_encoding_all=chain_encoding_all,  # [seq_len]\n                    )\n\n                    if self.pre_filter is not None and not self.pre_filter(\n                            data):\n                        continue\n                    if self.pre_transform is not None:\n                        data = self.pre_transform(data)\n\n                    data_list.append(data)\n\n                    if len(data_list) >= self.num_units:\n                        pbar.update(total_items - pbar.n)\n                        break\n                    pbar.update(1)\n                else:\n                    continue\n                break\n            self.save(data_list, self.processed_paths[self.splits[self.split]])\n\n    def _process_split(self) -> Dict[int, List[Tuple[str, int]]]:\n        import pandas as pd\n        save_path = self.processed_paths[0]\n\n        if os.path.exists(save_path):\n            print('Load split')\n            with open(save_path, 'rb') as f:\n                data = pickle.load(f)\n        else:\n            # CHAINID, DEPOSITION, RESOLUTION, HASH, CLUSTER, SEQUENCE\n            df = pd.read_csv(self.raw_paths[0])\n            df = df[(df['RESOLUTION'] <= self.rescut)\n                    & (df['DEPOSITION'] <= self.datacut)]\n\n            val_ids = pd.read_csv(self.raw_paths[1], header=None)[0].tolist()\n            test_ids = pd.read_csv(self.raw_paths[2], header=None)[0].tolist()\n\n            # compile training and validation sets\n            data = {\n                'train': defaultdict(list),\n                'valid': defaultdict(list),\n                'test': defaultdict(list),\n            }\n\n            for _, r in tqdm(df.iterrows(), desc='Processing split',\n                             total=len(df)):\n                cluster_id = r['CLUSTER']\n                hash_id = r['HASH']\n                chain_id = r['CHAINID']\n                if cluster_id in val_ids:\n                    data['valid'][cluster_id].append((chain_id, hash_id))\n                elif cluster_id in test_ids:\n                    data['test'][cluster_id].append((chain_id, hash_id))\n                else:\n                    data['train'][cluster_id].append((chain_id, hash_id))\n\n            with open(save_path, 'wb') as f:\n                pickle.dump(data, f)\n\n        return data[self.split]\n\n    def _process_pdb1(self, chain_id: str) -> Dict[str, Any]:\n        pdbid, chid = chain_id.split('_')\n        prefix = f'{self.raw_dir}/{self.sub_folder}/pdb/{pdbid[1:3]}/{pdbid}'\n        # load metadata\n        if not os.path.isfile(f'{prefix}.pt'):\n            return {'seq': np.zeros(5)}\n        meta = torch.load(f'{prefix}.pt')\n        asmb_ids = meta['asmb_ids']\n        asmb_chains = meta['asmb_chains']\n        chids = np.array(meta['chains'])\n\n        # find candidate assemblies which contain chid chain\n        asmb_candidates = {\n            a\n            for a, b in zip(asmb_ids, asmb_chains) if chid in b.split(',')\n        }\n\n        # if the chains is missing is missing from all the assemblies\n        # then return this chain alone\n        if len(asmb_candidates) < 1:\n            chain = torch.load(f'{prefix}_{chid}.pt')\n            L = len(chain['seq'])\n            return {\n                'seq': chain['seq'],\n                'xyz': chain['xyz'],\n                'idx': torch.zeros(L).int(),\n                'masked': torch.Tensor([0]).int(),\n                'label': chain_id,\n            }\n\n        # randomly pick one assembly from candidates\n        asmb_i = random.sample(list(asmb_candidates), 1)\n\n        # indices of selected transforms\n        idx = np.where(np.array(asmb_ids) == asmb_i)[0]\n\n        # load relevant chains\n        chains = {\n            c: torch.load(f'{prefix}_{c}.pt')\n            for i in idx\n            for c in asmb_chains[i] if c in meta['chains']\n        }\n\n        # generate assembly\n        asmb = {}\n        for k in idx:\n\n            # pick k-th xform\n            xform = meta[f'asmb_xform{k}']\n            u = xform[:, :3, :3]\n            r = xform[:, :3, 3]\n\n            # select chains which k-th xform should be applied to\n            s1 = set(meta['chains'])\n            s2 = set(asmb_chains[k].split(','))\n            chains_k = s1 & s2\n\n            # transform selected chains\n            for c in chains_k:\n                try:\n                    xyz = chains[c]['xyz']\n                    xyz_ru = torch.einsum('bij,raj->brai', u, xyz) + r[:, None,\n                                                                       None, :]\n                    asmb.update({\n                        (c, k, i): xyz_i\n                        for i, xyz_i in enumerate(xyz_ru)\n                    })\n                except KeyError:\n                    return {'seq': np.zeros(5)}\n\n        # select chains which share considerable similarity to chid\n        seqid = meta['tm'][chids == chid][0, :, 1]\n        homo = {\n            ch_j\n            for seqid_j, ch_j in zip(seqid, chids) if seqid_j > self.homo\n        }\n        # stack all chains in the assembly together\n        seq: str = ''\n        xyz_all: List[torch.Tensor] = []\n        idx_all: List[torch.Tensor] = []\n        masked: List[int] = []\n        seq_list = []\n        for counter, (k, v) in enumerate(asmb.items()):\n            seq += chains[k[0]]['seq']\n            seq_list.append(chains[k[0]]['seq'])\n            xyz_all.append(v)\n            idx_all.append(torch.full((v.shape[0], ), counter))\n            if k[0] in homo:\n                masked.append(counter)\n\n        return {\n            'seq': seq,\n            'xyz': torch.cat(xyz_all, dim=0),\n            'idx': torch.cat(idx_all, dim=0),\n            'masked': torch.Tensor(masked).int(),\n            'label': chain_id,\n        }\n\n    def _process_pdb2(self, t: Dict[str, Any]) -> Dict[str, Any]:\n        init_alphabet = list(\n            'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz')\n        extra_alphabet = [str(item) for item in list(np.arange(300))]\n        chain_alphabet = init_alphabet + extra_alphabet\n        my_dict: Dict[str, Union[str, int, Dict[str, Any], List[Any]]] = {}\n        concat_seq = ''\n        mask_list = []\n        visible_list = []\n        for idx in list(np.unique(t['idx'])):\n            letter = chain_alphabet[idx]\n            res = np.argwhere(t['idx'] == idx)\n            initial_sequence = \"\".join(list(\n                np.array(list(t['seq']))[res][\n                    0,\n                ]))\n            if initial_sequence[-6:] == \"HHHHHH\":\n                res = res[:, :-6]\n            if initial_sequence[0:6] == \"HHHHHH\":\n                res = res[:, 6:]\n            if initial_sequence[-7:-1] == \"HHHHHH\":\n                res = res[:, :-7]\n            if initial_sequence[-8:-2] == \"HHHHHH\":\n                res = res[:, :-8]\n            if initial_sequence[-9:-3] == \"HHHHHH\":\n                res = res[:, :-9]\n            if initial_sequence[-10:-4] == \"HHHHHH\":\n                res = res[:, :-10]\n            if initial_sequence[1:7] == \"HHHHHH\":\n                res = res[:, 7:]\n            if initial_sequence[2:8] == \"HHHHHH\":\n                res = res[:, 8:]\n            if initial_sequence[3:9] == \"HHHHHH\":\n                res = res[:, 9:]\n            if initial_sequence[4:10] == \"HHHHHH\":\n                res = res[:, 10:]\n            if res.shape[1] >= 4:\n                chain_seq = \"\".join(list(np.array(list(t['seq']))[res][0]))\n                my_dict[f'seq_chain_{letter}'] = chain_seq\n                concat_seq += chain_seq\n                if idx in t['masked']:\n                    mask_list.append(letter)\n                else:\n                    visible_list.append(letter)\n                coords_dict_chain = {}\n                all_atoms = np.array(t['xyz'][res])[0]  # [L, 14, 3]\n                for i, c in enumerate(['N', 'CA', 'C', 'O']):\n                    coords_dict_chain[\n                        f'{c}_chain_{letter}'] = all_atoms[:, i, :].tolist()\n                my_dict[f'coords_chain_{letter}'] = coords_dict_chain\n        my_dict['name'] = t['label']\n        my_dict['masked_list'] = mask_list\n        my_dict['visible_list'] = visible_list\n        my_dict['num_of_chains'] = len(mask_list) + len(visible_list)\n        my_dict['seq'] = concat_seq\n        return my_dict\n\n    def _process_pdb3(\n        self, b: Dict[str, Any]\n    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,\n               torch.Tensor, torch.Tensor]:\n        L = len(b['seq'])\n        # residue idx with jumps across chains\n        residue_idx = -100 * np.ones([L], dtype=np.int32)\n        # get the list of masked / visible chains\n        masked_chains, visible_chains = b['masked_list'], b['visible_list']\n        visible_temp_dict, masked_temp_dict = {}, {}\n        for letter in masked_chains + visible_chains:\n            chain_seq = b[f'seq_chain_{letter}']\n            if letter in visible_chains:\n                visible_temp_dict[letter] = chain_seq\n            elif letter in masked_chains:\n                masked_temp_dict[letter] = chain_seq\n        # check for duplicate chains (same sequence but different identity)\n        for _, vm in masked_temp_dict.items():\n            for kv, vv in visible_temp_dict.items():\n                if vm == vv:\n                    if kv not in masked_chains:\n                        masked_chains.append(kv)\n                    if kv in visible_chains:\n                        visible_chains.remove(kv)\n        # build protein data structures\n        all_chains = masked_chains + visible_chains\n        np.random.shuffle(all_chains)\n        x_chain_list = []\n        chain_mask_list = []\n        chain_seq_list = []\n        chain_encoding_list = []\n        c, l0, l1 = 1, 0, 0\n        for letter in all_chains:\n            chain_seq = b[f'seq_chain_{letter}']\n            chain_length = len(chain_seq)\n            chain_coords = b[f'coords_chain_{letter}']\n            x_chain = np.stack([\n                chain_coords[c] for c in [\n                    f'N_chain_{letter}', f'CA_chain_{letter}',\n                    f'C_chain_{letter}', f'O_chain_{letter}'\n                ]\n            ], 1)  # [chain_length, 4, 3]\n            x_chain_list.append(x_chain)\n            chain_seq_list.append(chain_seq)\n            if letter in visible_chains:\n                chain_mask = np.zeros(chain_length)  # 0 for visible chains\n            elif letter in masked_chains:\n                chain_mask = np.ones(chain_length)  # 1 for masked chains\n            chain_mask_list.append(chain_mask)\n            chain_encoding_list.append(c * np.ones(chain_length))\n            l1 += chain_length\n            residue_idx[l0:l1] = 100 * (c - 1) + np.arange(l0, l1)\n            l0 += chain_length\n            c += 1\n        x_chain_all = np.concatenate(x_chain_list, 0)  # [L, 4, 3]\n        chain_seq_all = \"\".join(chain_seq_list)\n        # [L,] 1.0 for places that need to be predicted\n        chain_mask_all = np.concatenate(chain_mask_list, 0)\n        chain_encoding_all = np.concatenate(chain_encoding_list, 0)\n\n        # Convert to labels\n        alphabet = 'ACDEFGHIKLMNPQRSTVWYX'\n        chain_seq_label_all = np.asarray(\n            [alphabet.index(a) for a in chain_seq_all], dtype=np.int32)\n\n        isnan = np.isnan(x_chain_all)\n        mask = np.isfinite(np.sum(x_chain_all, (1, 2))).astype(np.float32)\n        x_chain_all[isnan] = 0.\n\n        # Conversion\n        return (\n            torch.from_numpy(x_chain_all).to(dtype=torch.float32),\n            torch.from_numpy(chain_seq_label_all).to(dtype=torch.long),\n            torch.from_numpy(mask).to(dtype=torch.float32),\n            torch.from_numpy(chain_mask_all).to(dtype=torch.float32),\n            torch.from_numpy(residue_idx).to(dtype=torch.long),\n            torch.from_numpy(chain_encoding_all).to(dtype=torch.long),\n        )\n"
  },
  {
    "path": "torch_geometric/datasets/qm7.py",
    "content": "from typing import Callable, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass QM7b(InMemoryDataset):\n    r\"\"\"The QM7b dataset from the `\"MoleculeNet: A Benchmark for Molecular\n    Machine Learning\" <https://arxiv.org/abs/1703.00564>`_ paper, consisting of\n    7,211 molecules with 14 regression targets.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #tasks\n        * - 7,211\n          - ~15.4\n          - ~245.0\n          - 0\n          - 14\n    \"\"\"\n\n    url = 'https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/qm7b.mat'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'qm7b.mat'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(self.url, self.raw_dir)\n\n    def process(self) -> None:\n        from scipy.io import loadmat\n\n        data = loadmat(self.raw_paths[0])\n        coulomb_matrix = torch.from_numpy(data['X'])\n        target = torch.from_numpy(data['T']).to(torch.float)\n\n        data_list = []\n        for i in range(target.shape[0]):\n            edge_index = coulomb_matrix[i].nonzero(\n                as_tuple=False).t().contiguous()\n            edge_attr = coulomb_matrix[i, edge_index[0], edge_index[1]]\n            y = target[i].view(1, -1)\n            data = Data(edge_index=edge_index, edge_attr=edge_attr, y=y)\n            data.num_nodes = int(edge_index.max()) + 1\n            data_list.append(data)\n\n        if self.pre_filter is not None:\n            data_list = [d for d in data_list if self.pre_filter(d)]\n\n        if self.pre_transform is not None:\n            data_list = [self.pre_transform(d) for d in data_list]\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/qm9.py",
    "content": "import os\nimport os.path as osp\nimport sys\nfrom typing import Callable, List, Optional\n\nimport torch\nfrom torch import Tensor\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\nfrom torch_geometric.utils import one_hot, scatter\n\nHAR2EV = 27.211386246\nKCALMOL2EV = 0.04336414\n\nconversion = torch.tensor([\n    1., 1., HAR2EV, HAR2EV, HAR2EV, 1., HAR2EV, HAR2EV, HAR2EV, HAR2EV, HAR2EV,\n    1., KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, KCALMOL2EV, 1., 1., 1.\n])\n\natomrefs = {\n    6: [0., 0., 0., 0., 0.],\n    7: [\n        -13.61312172, -1029.86312267, -1485.30251237, -2042.61123593,\n        -2713.48485589\n    ],\n    8: [\n        -13.5745904, -1029.82456413, -1485.26398105, -2042.5727046,\n        -2713.44632457\n    ],\n    9: [\n        -13.54887564, -1029.79887659, -1485.2382935, -2042.54701705,\n        -2713.42063702\n    ],\n    10: [\n        -13.90303183, -1030.25891228, -1485.71166277, -2043.01812778,\n        -2713.88796536\n    ],\n    11: [0., 0., 0., 0., 0.],\n}\n\n\nclass QM9(InMemoryDataset):\n    r\"\"\"The QM9 dataset from the `\"MoleculeNet: A Benchmark for Molecular\n    Machine Learning\" <https://arxiv.org/abs/1703.00564>`_ paper, consisting of\n    about 130,000 molecules with 19 regression targets.\n    Each molecule includes complete spatial information for the single low\n    energy conformation of the atoms in the molecule.\n    In addition, we provide the atom features from the `\"Neural Message\n    Passing for Quantum Chemistry\" <https://arxiv.org/abs/1704.01212>`_ paper.\n\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | Target | Property                         | Description                                                                       | Unit                                        |\n    +========+==================================+===================================================================================+=============================================+\n    | 0      | :math:`\\mu`                      | Dipole moment                                                                     | :math:`\\textrm{D}`                          |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 1      | :math:`\\alpha`                   | Isotropic polarizability                                                          | :math:`{a_0}^3`                             |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 2      | :math:`\\epsilon_{\\textrm{HOMO}}` | Highest occupied molecular orbital energy                                         | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 3      | :math:`\\epsilon_{\\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy                                        | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 4      | :math:`\\Delta \\epsilon`          | Gap between :math:`\\epsilon_{\\textrm{HOMO}}` and :math:`\\epsilon_{\\textrm{LUMO}}` | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 5      | :math:`\\langle R^2 \\rangle`      | Electronic spatial extent                                                         | :math:`{a_0}^2`                             |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 6      | :math:`\\textrm{ZPVE}`            | Zero point vibrational energy                                                     | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 7      | :math:`U_0`                      | Internal energy at 0K                                                             | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 8      | :math:`U`                        | Internal energy at 298.15K                                                        | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 9      | :math:`H`                        | Enthalpy at 298.15K                                                               | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 10     | :math:`G`                        | Free energy at 298.15K                                                            | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 11     | :math:`c_{\\textrm{v}}`           | Heat capavity at 298.15K                                                          | :math:`\\frac{\\textrm{cal}}{\\textrm{mol K}}` |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 12     | :math:`U_0^{\\textrm{ATOM}}`      | Atomization energy at 0K                                                          | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 13     | :math:`U^{\\textrm{ATOM}}`        | Atomization energy at 298.15K                                                     | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 14     | :math:`H^{\\textrm{ATOM}}`        | Atomization enthalpy at 298.15K                                                   | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 15     | :math:`G^{\\textrm{ATOM}}`        | Atomization free energy at 298.15K                                                | :math:`\\textrm{eV}`                         |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 16     | :math:`A`                        | Rotational constant                                                               | :math:`\\textrm{GHz}`                        |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 17     | :math:`B`                        | Rotational constant                                                               | :math:`\\textrm{GHz}`                        |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n    | 18     | :math:`C`                        | Rotational constant                                                               | :math:`\\textrm{GHz}`                        |\n    +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+\n\n    .. note::\n\n        We also provide a pre-processed version of the dataset in case\n        :class:`rdkit` is not installed. The pre-processed version matches with\n        the manually processed version as outlined in :meth:`process`.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #tasks\n        * - 130,831\n          - ~18.0\n          - ~37.3\n          - 11\n          - 19\n    \"\"\"  # noqa: E501\n\n    raw_url = ('https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/'\n               'molnet_publish/qm9.zip')\n    raw_url2 = 'https://ndownloader.figshare.com/files/3195404'\n    processed_url = 'https://data.pyg.org/datasets/qm9_v3.zip'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    def mean(self, target: int) -> float:\n        y = torch.cat([self.get(i).y for i in range(len(self))], dim=0)\n        return float(y[:, target].mean())\n\n    def std(self, target: int) -> float:\n        y = torch.cat([self.get(i).y for i in range(len(self))], dim=0)\n        return float(y[:, target].std())\n\n    def atomref(self, target: int) -> Optional[Tensor]:\n        if target in atomrefs:\n            out = torch.zeros(100)\n            out[torch.tensor([1, 6, 7, 8, 9])] = torch.tensor(atomrefs[target])\n            return out.view(-1, 1)\n        return None\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        try:\n            import rdkit  # noqa\n            return ['gdb9.sdf', 'gdb9.sdf.csv', 'uncharacterized.txt']\n        except ImportError:\n            return ['qm9_v3.pt']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data_v3.pt'\n\n    def download(self) -> None:\n        try:\n            import rdkit  # noqa\n            file_path = download_url(self.raw_url, self.raw_dir)\n            extract_zip(file_path, self.raw_dir)\n            os.unlink(file_path)\n\n            file_path = download_url(self.raw_url2, self.raw_dir)\n            os.rename(osp.join(self.raw_dir, '3195404'),\n                      osp.join(self.raw_dir, 'uncharacterized.txt'))\n        except ImportError:\n            path = download_url(self.processed_url, self.raw_dir)\n            extract_zip(path, self.raw_dir)\n            os.unlink(path)\n\n    def process(self) -> None:\n        try:\n            from rdkit import Chem, RDLogger\n            from rdkit.Chem.rdchem import BondType as BT\n            from rdkit.Chem.rdchem import HybridizationType\n            RDLogger.DisableLog('rdApp.*')  # type: ignore[attr-defined]\n            WITH_RDKIT = True\n\n        except ImportError:\n            WITH_RDKIT = False\n\n        if not WITH_RDKIT:\n            print((\"Using a pre-processed version of the dataset. Please \"\n                   \"install 'rdkit' to alternatively process the raw data.\"),\n                  file=sys.stderr)\n\n            data_list = fs.torch_load(self.raw_paths[0])\n            data_list = [Data(**data_dict) for data_dict in data_list]\n\n            if self.pre_filter is not None:\n                data_list = [d for d in data_list if self.pre_filter(d)]\n\n            if self.pre_transform is not None:\n                data_list = [self.pre_transform(d) for d in data_list]\n\n            self.save(data_list, self.processed_paths[0])\n            return\n\n        types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}\n        bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3}\n\n        with open(self.raw_paths[1]) as f:\n            target = [[float(x) for x in line.split(',')[1:20]]\n                      for line in f.read().split('\\n')[1:-1]]\n            y = torch.tensor(target, dtype=torch.float)\n            y = torch.cat([y[:, 3:], y[:, :3]], dim=-1)\n            y = y * conversion.view(1, -1)\n\n        with open(self.raw_paths[2]) as f:\n            skip = [int(x.split()[0]) - 1 for x in f.read().split('\\n')[9:-2]]\n\n        suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False,\n                                   sanitize=False)\n\n        data_list = []\n        for i, mol in enumerate(tqdm(suppl)):\n            if i in skip:\n                continue\n\n            N = mol.GetNumAtoms()\n\n            conf = mol.GetConformer()\n            pos = conf.GetPositions()\n            pos = torch.tensor(pos, dtype=torch.float)\n\n            type_idx = []\n            atomic_number = []\n            aromatic = []\n            sp = []\n            sp2 = []\n            sp3 = []\n            num_hs = []\n            for atom in mol.GetAtoms():\n                type_idx.append(types[atom.GetSymbol()])\n                atomic_number.append(atom.GetAtomicNum())\n                aromatic.append(1 if atom.GetIsAromatic() else 0)\n                hybridization = atom.GetHybridization()\n                sp.append(1 if hybridization == HybridizationType.SP else 0)\n                sp2.append(1 if hybridization == HybridizationType.SP2 else 0)\n                sp3.append(1 if hybridization == HybridizationType.SP3 else 0)\n\n            z = torch.tensor(atomic_number, dtype=torch.long)\n\n            rows, cols, edge_types = [], [], []\n            for bond in mol.GetBonds():\n                start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()\n                rows += [start, end]\n                cols += [end, start]\n                edge_types += 2 * [bonds[bond.GetBondType()]]\n\n            edge_index = torch.tensor([rows, cols], dtype=torch.long)\n            edge_type = torch.tensor(edge_types, dtype=torch.long)\n            edge_attr = one_hot(edge_type, num_classes=len(bonds))\n\n            perm = (edge_index[0] * N + edge_index[1]).argsort()\n            edge_index = edge_index[:, perm]\n            edge_type = edge_type[perm]\n            edge_attr = edge_attr[perm]\n\n            row, col = edge_index\n            hs = (z == 1).to(torch.float)\n            num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist()\n\n            x1 = one_hot(torch.tensor(type_idx), num_classes=len(types))\n            x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs],\n                              dtype=torch.float).t().contiguous()\n            x = torch.cat([x1, x2], dim=-1)\n\n            name = mol.GetProp('_Name')\n            smiles = Chem.MolToSmiles(mol, isomericSmiles=True)\n\n            data = Data(\n                x=x,\n                z=z,\n                pos=pos,\n                edge_index=edge_index,\n                smiles=smiles,\n                edge_attr=edge_attr,\n                y=y[i].unsqueeze(0),\n                name=name,\n                idx=i,\n            )\n\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n\n            data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/rcdd.py",
    "content": "import os\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.utils import index_to_mask\n\n\nclass RCDD(InMemoryDataset):\n    r\"\"\"The risk commodity detection dataset (RCDD) from the\n    `\"Datasets and Interfaces for Benchmarking Heterogeneous Graph\n    Neural Networks\" <https://dl.acm.org/doi/10.1145/3583780.3615117>`_ paper.\n    RCDD is an industrial-scale heterogeneous graph dataset based on a\n    real risk detection scenario from Alibaba's e-commerce platform.\n    It consists of 13,806,619 nodes and 157,814,864 edges across 7 node types\n    and 7 edge types, respectively.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    url = ('https://s3.cn-north-1.amazonaws.com.cn/dgl-data/dataset/'\n           'openhgnn/AliRCD_ICDM.zip')\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'AliRCD_ICDM_nodes.csv',\n            'AliRCD_ICDM_edges.csv',\n            'AliRCD_ICDM_train_labels.csv',\n            'AliRCD_ICDM_test_labels.csv',\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    @property\n    def num_classes(self) -> int:\n        return 2\n\n    def process(self) -> None:\n        import pandas as pd\n\n        data = HeteroData()\n\n        node_df = pd.read_csv(  # AliRCD_ICDM_nodes.csv:\n            self.raw_paths[0],\n            header=None,\n            names=['node_id', 'node_type', 'node_feat'],\n        )\n        # Map global node IDs to local ones for each node type:\n        mapping = torch.empty(len(node_df), dtype=torch.long)\n        for node_type in node_df['node_type'].unique():\n            mask = node_df['node_type'] == node_type\n            node_id = torch.from_numpy(node_df['node_id'][mask].values)\n            num_nodes = mask.sum()\n            mapping[node_id] = torch.arange(num_nodes)\n            data[node_type].num_nodes = num_nodes\n            x = np.vstack([\n                np.asarray(f.split(':'), dtype=np.float32)\n                for f in node_df['node_feat'][mask]\n            ])\n            data[node_type].x = torch.from_numpy(x)\n\n        edge_df = pd.read_csv(  # AliRCD_ICDM_edges.csv:\n            self.raw_paths[1],\n            header=None,\n            names=['src_id', 'dst_id', 'src_type', 'dst_type', 'edge_type'],\n        )\n        for edge_type in edge_df['edge_type'].unique():\n            edge_type_df = edge_df[edge_df['edge_type'] == edge_type]\n            src_type = edge_type_df['src_type'].iloc[0]\n            dst_type = edge_type_df['dst_type'].iloc[0]\n            src = mapping[torch.from_numpy(edge_type_df['src_id'].values)]\n            dst = mapping[torch.from_numpy(edge_type_df['dst_id'].values)]\n            edge_index = torch.stack([src, dst], dim=0)\n            data[src_type, edge_type, dst_type].edge_index = edge_index\n\n        train_df = pd.read_csv(  # AliRCD_ICDM_train_labels.csv:\n            self.raw_paths[2],\n            header=None,\n            names=['node_id', 'label'],\n            dtype=int,\n        )\n        test_df = pd.read_csv(  # AliRCD_ICDM_test_labels.csv:\n            self.raw_paths[3],\n            header=None,\n            sep='\\t',\n            names=['node_id', 'label'],\n            dtype=int,\n        )\n\n        train_idx = mapping[torch.from_numpy(train_df['node_id'].values)]\n        test_idx = mapping[torch.from_numpy(test_df['node_id'].values)]\n\n        y = torch.full((data['item'].num_nodes, ), -1, dtype=torch.long)\n        y[train_idx] = torch.from_numpy(train_df['label'].values)\n        y[test_idx] = torch.from_numpy(test_df['label'].values)\n\n        train_mask = index_to_mask(train_idx, data['item'].num_nodes)\n        test_mask = index_to_mask(test_idx, data['item'].num_nodes)\n\n        data['item'].y = y\n        data['item'].train_mask = train_mask\n        data['item'].test_mask = test_mask\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/reddit.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.utils import coalesce\n\n\nclass Reddit(InMemoryDataset):\n    r\"\"\"The Reddit dataset from the `\"Inductive Representation Learning on\n    Large Graphs\" <https://arxiv.org/abs/1706.02216>`_ paper, containing\n    Reddit posts belonging to different communities.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 232,965\n          - 114,615,892\n          - 602\n          - 41\n    \"\"\"\n\n    url = 'https://data.dgl.ai/dataset/reddit.zip'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['reddit_data.npz', 'reddit_graph.npz']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        import scipy.sparse as sp\n\n        data = np.load(osp.join(self.raw_dir, 'reddit_data.npz'))\n        x = torch.from_numpy(data['feature']).to(torch.float)\n        y = torch.from_numpy(data['label']).to(torch.long)\n        split = torch.from_numpy(data['node_types'])\n\n        adj = sp.load_npz(osp.join(self.raw_dir, 'reddit_graph.npz'))\n        row = torch.from_numpy(adj.row).to(torch.long)\n        col = torch.from_numpy(adj.col).to(torch.long)\n        edge_index = torch.stack([row, col], dim=0)\n        edge_index = coalesce(edge_index, num_nodes=x.size(0))\n\n        data = Data(x=x, edge_index=edge_index, y=y)\n        data.train_mask = split == 1\n        data.val_mask = split == 2\n        data.test_mask = split == 3\n\n        data = data if self.pre_transform is None else self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/reddit2.py",
    "content": "import json\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_google_url\n\n\nclass Reddit2(InMemoryDataset):\n    r\"\"\"The Reddit dataset from the `\"GraphSAINT: Graph Sampling Based\n    Inductive Learning Method\" <https://arxiv.org/abs/1907.04931>`_ paper,\n    containing Reddit posts belonging to different communities.\n\n    .. note::\n\n        This is a sparser version of the original\n        :obj:`~torch_geometric.datasets.Reddit` dataset (~23M edges instead of\n        ~114M edges), and is used in papers such as\n        `SGC <https://arxiv.org/abs/1902.07153>`_ and\n        `GraphSAINT <https://arxiv.org/abs/1907.04931>`_.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 232,965\n          - 23,213,838\n          - 602\n          - 41\n    \"\"\"\n    adj_full_id = '1sncK996BM5lpuDf75lDFqCiDZyErc1c2'\n    feats_id = '1ZsHaJ0ussP1W722krmEIp_8pwKAoi5b3'\n    class_map_id = '1JF3Pjv9OboMNYs2aXRQGbJbc4t_nDd5u'\n    role_id = '1nJIKd77lcAGU4j-kVNx_AIGEkveIKz3A'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz')\n        download_google_url(self.feats_id, self.raw_dir, 'feats.npy')\n        download_google_url(self.class_map_id, self.raw_dir, 'class_map.json')\n        download_google_url(self.role_id, self.raw_dir, 'role.json')\n\n    def process(self) -> None:\n        import scipy.sparse as sp\n\n        f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))\n        adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])\n        adj = adj.tocoo()\n        row = torch.from_numpy(adj.row).to(torch.long)\n        col = torch.from_numpy(adj.col).to(torch.long)\n        edge_index = torch.stack([row, col], dim=0)\n\n        x = np.load(osp.join(self.raw_dir, 'feats.npy'))\n        x = torch.from_numpy(x).to(torch.float)\n\n        ys = [-1] * x.size(0)\n        with open(osp.join(self.raw_dir, 'class_map.json')) as f:\n            class_map = json.load(f)\n            for key, item in class_map.items():\n                ys[int(key)] = item\n        y = torch.tensor(ys)\n\n        with open(osp.join(self.raw_dir, 'role.json')) as f:\n            role = json.load(f)\n\n        train_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        train_mask[torch.tensor(role['tr'])] = True\n\n        val_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        val_mask[torch.tensor(role['va'])] = True\n\n        test_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        test_mask[torch.tensor(role['te'])] = True\n\n        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,\n                    val_mask=val_mask, test_mask=test_mask)\n\n        data = data if self.pre_transform is None else self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/rel_link_pred_dataset.py",
    "content": "import os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass RelLinkPredDataset(InMemoryDataset):\n    r\"\"\"The relational link prediction datasets from the\n    `\"Modeling Relational Data with Graph Convolutional Networks\"\n    <https://arxiv.org/abs/1703.06103>`_ paper.\n    Training and test splits are given by sets of triplets.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"FB15k-237\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 14,541\n          - 544,230\n          - 0\n          - 0\n    \"\"\"\n\n    urls = {\n        'FB15k-237': ('https://raw.githubusercontent.com/MichSchli/'\n                      'RelationPrediction/master/data/FB-Toutanova')\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name\n        assert name in ['FB15k-237']\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def num_relations(self) -> int:\n        return int(self._data.edge_type.max()) + 1  # type: ignore\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'entities.dict', 'relations.dict', 'test.txt', 'train.txt',\n            'valid.txt'\n        ]\n\n    def download(self) -> None:\n        for file_name in self.raw_file_names:\n            download_url(f'{self.urls[self.name]}/{file_name}', self.raw_dir)\n\n    def process(self) -> None:\n        with open(osp.join(self.raw_dir, 'entities.dict')) as f:\n            lines = [row.split('\\t') for row in f.read().split('\\n')[:-1]]\n            entities_dict = {key: int(value) for value, key in lines}\n\n        with open(osp.join(self.raw_dir, 'relations.dict')) as f:\n            lines = [row.split('\\t') for row in f.read().split('\\n')[:-1]]\n            relations_dict = {key: int(value) for value, key in lines}\n\n        kwargs = {}\n        for split in ['train', 'valid', 'test']:\n            with open(osp.join(self.raw_dir, f'{split}.txt')) as f:\n                lines = [row.split('\\t') for row in f.read().split('\\n')[:-1]]\n                src = [entities_dict[row[0]] for row in lines]\n                rel = [relations_dict[row[1]] for row in lines]\n                dst = [entities_dict[row[2]] for row in lines]\n                kwargs[f'{split}_edge_index'] = torch.tensor([src, dst])\n                kwargs[f'{split}_edge_type'] = torch.tensor(rel)\n\n        # For message passing, we add reverse edges and types to the graph:\n        row, col = kwargs['train_edge_index']\n        edge_type = kwargs['train_edge_type']\n        row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)\n        edge_index = torch.stack([row, col], dim=0)\n        edge_type = torch.cat([edge_type, edge_type + len(relations_dict)])\n\n        data = Data(num_nodes=len(entities_dict), edge_index=edge_index,\n                    edge_type=edge_type, **kwargs)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.name}()'\n"
  },
  {
    "path": "torch_geometric/datasets/s3dis.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass S3DIS(InMemoryDataset):\n    r\"\"\"The (pre-processed) Stanford Large-Scale 3D Indoor Spaces dataset from\n    the `\"3D Semantic Parsing of Large-Scale Indoor Spaces\"\n    <https://openaccess.thecvf.com/content_cvpr_2016/papers/Armeni_3D_Semantic_Parsing_CVPR_2016_paper.pdf>`_\n    paper, containing point clouds of six large-scale indoor parts in three\n    buildings with 12 semantic elements (and one clutter class).\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        test_area (int, optional): Which area to use for testing (1-6).\n            (default: :obj:`6`)\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = ('https://shapenet.cs.stanford.edu/media/'\n           'indoor3d_sem_seg_hdf5_data.zip')\n\n    # In case `shapenet.cs.stanford.edu` is offline, try to download the data\n    # from here:\n    # https://cvg-data.inf.ethz.ch/s3dis/\n\n    def __init__(\n        self,\n        root: str,\n        test_area: int = 6,\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        assert test_area >= 1 and test_area <= 6\n        self.test_area = test_area\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = self.processed_paths[0] if train else self.processed_paths[1]\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['all_files.txt', 'room_filelist.txt']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return [f'{split}_{self.test_area}.pt' for split in ['train', 'test']]\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        os.unlink(path)\n        fs.rm(self.raw_dir)\n        name = self.url.split('/')[-1].split('.')[0]\n        os.rename(osp.join(self.root, name), self.raw_dir)\n\n    def process(self) -> None:\n        import h5py\n\n        with open(self.raw_paths[0]) as f:\n            filenames = [x.split('/')[-1] for x in f.read().split('\\n')[:-1]]\n\n        with open(self.raw_paths[1]) as f:\n            rooms = f.read().split('\\n')[:-1]\n\n        xs: List[Tensor] = []\n        ys: List[Tensor] = []\n        for filename in filenames:\n            h5 = h5py.File(osp.join(self.raw_dir, filename))\n            xs += torch.from_numpy(h5['data'][:]).unbind(0)\n            ys += torch.from_numpy(h5['label'][:]).to(torch.long).unbind(0)\n\n        test_area = f'Area_{self.test_area}'\n        train_data_list, test_data_list = [], []\n        for i, (x, y) in enumerate(zip(xs, ys)):\n            data = Data(pos=x[:, :3], x=x[:, 3:], y=y)\n\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n\n            if test_area not in rooms[i]:\n                train_data_list.append(data)\n            else:\n                test_data_list.append(data)\n\n        self.save(train_data_list, self.processed_paths[0])\n        self.save(test_data_list, self.processed_paths[1])\n"
  },
  {
    "path": "torch_geometric/datasets/sbm_dataset.py",
    "content": "import os.path as osp\nfrom typing import Any, Callable, List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.utils import stochastic_blockmodel_graph\n\n\nclass StochasticBlockModelDataset(InMemoryDataset):\n    r\"\"\"A synthetic graph dataset generated by the stochastic block model.\n    The node features of each block are sampled from normal distributions where\n    the centers of clusters are vertices of a hypercube, as computed by the\n    :meth:`sklearn.datasets.make_classification` method.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        block_sizes ([int] or LongTensor): The sizes of blocks.\n        edge_probs ([[float]] or FloatTensor): The density of edges going from\n            each block to each other block. Must be symmetric if the graph is\n            undirected.\n        num_graphs (int, optional): The number of graphs. (default: :obj:`1`)\n        num_channels (int, optional): The number of node features. If given\n            as :obj:`None`, node features are not generated.\n            (default: :obj:`None`)\n        is_undirected (bool, optional): Whether the graph to generate is\n            undirected. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes\n            in an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed\n            before being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        **kwargs (optional): The keyword arguments that are passed down to the\n            :meth:`sklearn.datasets.make_classification` method for drawing\n            node features.\n    \"\"\"\n    def __init__(\n        self,\n        root: str,\n        block_sizes: Union[List[int], Tensor],\n        edge_probs: Union[List[List[float]], Tensor],\n        num_graphs: int = 1,\n        num_channels: Optional[int] = None,\n        is_undirected: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n        **kwargs: Any,\n    ) -> None:\n        if not isinstance(block_sizes, torch.Tensor):\n            block_sizes = torch.tensor(block_sizes, dtype=torch.long)\n        if not isinstance(edge_probs, torch.Tensor):\n            edge_probs = torch.tensor(edge_probs, dtype=torch.float)\n\n        assert num_graphs > 0\n\n        self.block_sizes = block_sizes\n        self.edge_probs = edge_probs\n        self.num_graphs = num_graphs\n        self.num_channels = num_channels\n        self.is_undirected = is_undirected\n\n        self.kwargs = {\n            'n_informative': num_channels,\n            'n_redundant': 0,\n            'flip_y': 0.0,\n            'shuffle': False,\n        }\n        self.kwargs.update(kwargs)\n\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.__class__.__name__, 'processed')\n\n    @property\n    def processed_file_names(self) -> str:\n        block_sizes = self.block_sizes.view(-1).tolist()\n        hash1 = '-'.join([f'{x:.1f}' for x in block_sizes])\n\n        edge_probs = self.edge_probs.view(-1).tolist()\n        hash2 = '-'.join([f'{x:.1f}' for x in edge_probs])\n\n        return f'data_{self.num_channels}_{hash1}_{hash2}_{self.num_graphs}.pt'\n\n    def process(self) -> None:\n        from sklearn.datasets import make_classification\n\n        edge_index = stochastic_blockmodel_graph(\n            self.block_sizes, self.edge_probs, directed=not self.is_undirected)\n\n        num_samples = int(self.block_sizes.sum())\n        num_classes = self.block_sizes.size(0)\n\n        data_list = []\n        for _ in range(self.num_graphs):\n            x = None\n            if self.num_channels is not None:\n                x, y_not_sorted = make_classification(\n                    n_samples=num_samples,\n                    n_features=self.num_channels,\n                    n_classes=num_classes,\n                    weights=self.block_sizes / num_samples,\n                    **self.kwargs,\n                )\n                x = x[np.argsort(y_not_sorted)]\n                x = torch.from_numpy(x).to(torch.float)\n\n            y = torch.arange(num_classes).repeat_interleave(self.block_sizes)\n\n            data = Data(x=x, edge_index=edge_index, y=y)\n\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n\n            data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n\n\nclass RandomPartitionGraphDataset(StochasticBlockModelDataset):\n    r\"\"\"The random partition graph dataset from the `\"How to Find Your\n    Friendly Neighborhood: Graph Attention Design with Self-Supervision\"\n    <https://openreview.net/forum?id=Wi5KUNlqWty>`_ paper.\n    This is a synthetic graph of communities controlled by the node homophily\n    and the average degree, and each community is considered as a class.\n    The node features are sampled from normal distributions where the centers\n    of clusters are vertices of a hypercube, as computed by the\n    :meth:`sklearn.datasets.make_classification` method.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        num_classes (int): The number of classes.\n        num_nodes_per_class (int): The number of nodes per class.\n        node_homophily_ratio (float): The degree of node homophily.\n        average_degree (float): The average degree of the graph.\n        num_graphs (int, optional): The number of graphs. (default: :obj:`1`)\n        num_channels (int, optional): The number of node features. If given\n            as :obj:`None`, node features are not generated.\n            (default: :obj:`None`)\n        is_undirected (bool, optional): Whether the graph to generate is\n            undirected. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes\n            in an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        **kwargs (optional): The keyword arguments that are passed down\n            to :meth:`sklearn.datasets.make_classification` method in\n            drawing node features.\n    \"\"\"\n    def __init__(\n        self,\n        root: str,\n        num_classes: int,\n        num_nodes_per_class: int,\n        node_homophily_ratio: float,\n        average_degree: float,\n        num_graphs: int = 1,\n        num_channels: Optional[int] = None,\n        is_undirected: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        **kwargs: Any,\n    ) -> None:\n\n        self._num_classes = num_classes\n        self.num_nodes_per_class = num_nodes_per_class\n        self.node_homophily_ratio = node_homophily_ratio\n        self.average_degree = average_degree\n\n        # (p_in + (C - 1) * p_out) / C = |E|/|V|^2\n        # i.e., p_in + (C - 1) * p_out = average_degree / num_nodes_per_class\n        ec_over_v2 = average_degree / num_nodes_per_class\n        p_in = node_homophily_ratio * ec_over_v2\n        p_out = (ec_over_v2 - p_in) / (num_classes - 1)\n\n        block_sizes = [num_nodes_per_class for _ in range(num_classes)]\n        edge_probs = [[p_out for _ in range(num_classes)]\n                      for _ in range(num_classes)]\n        for r in range(num_classes):\n            edge_probs[r][r] = p_in\n\n        super().__init__(root, block_sizes, edge_probs, num_graphs,\n                         num_channels, is_undirected, transform, pre_transform,\n                         **kwargs)\n\n    @property\n    def processed_file_names(self) -> str:\n        return (f'data_{self.num_channels}_{self._num_classes}_'\n                f'{self.num_nodes_per_class}_{self.node_homophily_ratio:.1f}_'\n                f'{self.average_degree:.1f}_{self.num_graphs}.pt')\n\n    def process(self) -> None:\n        return super().process()\n"
  },
  {
    "path": "torch_geometric/datasets/shapenet.py",
    "content": "import json\nimport os\nimport os.path as osp\nfrom typing import Callable, List, Optional, Union\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs, read_txt_array\n\n\nclass ShapeNet(InMemoryDataset):\n    r\"\"\"The ShapeNet part level segmentation dataset from the `\"A Scalable\n    Active Framework for Region Annotation in 3D Shape Collections\"\n    <http://web.stanford.edu/~ericyi/papers/part_annotation_16_small.pdf>`_\n    paper, containing about 17,000 3D shape point clouds from 16 shape\n    categories.\n    Each category is annotated with 2 to 6 parts.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        categories (str or [str], optional): The category of the CAD models\n            (one or a combination of :obj:`\"Airplane\"`, :obj:`\"Bag\"`,\n            :obj:`\"Cap\"`, :obj:`\"Car\"`, :obj:`\"Chair\"`, :obj:`\"Earphone\"`,\n            :obj:`\"Guitar\"`, :obj:`\"Knife\"`, :obj:`\"Lamp\"`, :obj:`\"Laptop\"`,\n            :obj:`\"Motorbike\"`, :obj:`\"Mug\"`, :obj:`\"Pistol\"`, :obj:`\"Rocket\"`,\n            :obj:`\"Skateboard\"`, :obj:`\"Table\"`).\n            Can be explicitly set to :obj:`None` to load all categories.\n            (default: :obj:`None`)\n        include_normals (bool, optional): If set to :obj:`False`, will not\n            include normal vectors as input features to :obj:`data.x`.\n            As a result, :obj:`data.x` will be :obj:`None`.\n            (default: :obj:`True`)\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"trainval\"`, loads the training and validation dataset.\n            If :obj:`\"test\"`, loads the test dataset.\n            (default: :obj:`\"trainval\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - 16,881\n          - ~2,616.2\n          - 0\n          - 3\n          - 50\n    \"\"\"\n\n    url = ('https://shapenet.cs.stanford.edu/media/'\n           'shapenetcore_partanno_segmentation_benchmark_v0_normal.zip')\n\n    # In case `shapenet.cs.stanford.edu` is offline, try to download the data\n    # from Kaggle instead (requires login):\n    # https://www.kaggle.com/datasets/mitkir/shapenet/download?datasetVersionNumber=1\n\n    category_ids = {\n        'Airplane': '02691156',\n        'Bag': '02773838',\n        'Cap': '02954340',\n        'Car': '02958343',\n        'Chair': '03001627',\n        'Earphone': '03261776',\n        'Guitar': '03467517',\n        'Knife': '03624134',\n        'Lamp': '03636649',\n        'Laptop': '03642806',\n        'Motorbike': '03790512',\n        'Mug': '03797390',\n        'Pistol': '03948459',\n        'Rocket': '04099429',\n        'Skateboard': '04225987',\n        'Table': '04379243',\n    }\n\n    seg_classes = {\n        'Airplane': [0, 1, 2, 3],\n        'Bag': [4, 5],\n        'Cap': [6, 7],\n        'Car': [8, 9, 10, 11],\n        'Chair': [12, 13, 14, 15],\n        'Earphone': [16, 17, 18],\n        'Guitar': [19, 20, 21],\n        'Knife': [22, 23],\n        'Lamp': [24, 25, 26, 27],\n        'Laptop': [28, 29],\n        'Motorbike': [30, 31, 32, 33, 34, 35],\n        'Mug': [36, 37],\n        'Pistol': [38, 39, 40],\n        'Rocket': [41, 42, 43],\n        'Skateboard': [44, 45, 46],\n        'Table': [47, 48, 49],\n    }\n\n    def __init__(\n        self,\n        root: str,\n        categories: Optional[Union[str, List[str]]] = None,\n        include_normals: bool = True,\n        split: str = 'trainval',\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        if categories is None:\n            categories = list(self.category_ids.keys())\n        if isinstance(categories, str):\n            categories = [categories]\n        assert all(category in self.category_ids for category in categories)\n        self.categories = categories\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n\n        if split == 'train':\n            path = self.processed_paths[0]\n        elif split == 'val':\n            path = self.processed_paths[1]\n        elif split == 'test':\n            path = self.processed_paths[2]\n        elif split == 'trainval':\n            path = self.processed_paths[3]\n        else:\n            raise ValueError(f'Split {split} found, but expected either '\n                             'train, val, trainval or test')\n\n        self.load(path)\n\n        assert isinstance(self._data, Data)\n        self._data.x = self._data.x if include_normals else None\n\n        self.y_mask = torch.zeros((len(self.seg_classes.keys()), 50),\n                                  dtype=torch.bool)\n        for i, labels in enumerate(self.seg_classes.values()):\n            self.y_mask[i, labels] = 1\n\n    @property\n    def num_classes(self) -> int:\n        return self.y_mask.size(-1)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return list(self.category_ids.values()) + ['train_test_split']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        cats = '_'.join([cat[:3].lower() for cat in self.categories])\n        return [\n            osp.join(f'{cats}_{split}.pt')\n            for split in ['train', 'val', 'test', 'trainval']\n        ]\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        os.unlink(path)\n        fs.rm(self.raw_dir)\n        name = self.url.split('/')[-1].split('.')[0]\n        os.rename(osp.join(self.root, name), self.raw_dir)\n\n    def process_filenames(self, filenames: List[str]) -> List[Data]:\n        data_list = []\n        categories_ids = [self.category_ids[cat] for cat in self.categories]\n        cat_idx = {categories_ids[i]: i for i in range(len(categories_ids))}\n\n        for name in filenames:\n            cat = name.split(osp.sep)[0]\n            if cat not in categories_ids:\n                continue\n\n            tensor = read_txt_array(osp.join(self.raw_dir, name))\n            pos = tensor[:, :3]\n            x = tensor[:, 3:6]\n            y = tensor[:, -1].type(torch.long)\n            data = Data(pos=pos, x=x, y=y, category=cat_idx[cat])\n            if self.pre_filter is not None and not self.pre_filter(data):\n                continue\n            if self.pre_transform is not None:\n                data = self.pre_transform(data)\n            data_list.append(data)\n\n        return data_list\n\n    def process(self) -> None:\n        trainval = []\n        for i, split in enumerate(['train', 'val', 'test']):\n            path = osp.join(self.raw_dir, 'train_test_split',\n                            f'shuffled_{split}_file_list.json')\n            with open(path) as f:\n                filenames = [\n                    osp.sep.join(name.split('/')[1:]) + '.txt'\n                    for name in json.load(f)\n                ]  # Removing first directory.\n            data_list = self.process_filenames(filenames)\n            if split == 'train' or split == 'val':\n                trainval += data_list\n            self.save(data_list, self.processed_paths[i])\n        self.save(trainval, self.processed_paths[3])\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'categories={self.categories})')\n"
  },
  {
    "path": "torch_geometric/datasets/shrec2016.py",
    "content": "import glob\nimport os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import InMemoryDataset, download_url, extract_zip\nfrom torch_geometric.io import fs, read_off, read_txt_array\n\n\nclass SHREC2016(InMemoryDataset):\n    r\"\"\"The SHREC 2016 partial matching dataset from the `\"SHREC'16: Partial\n    Matching of Deformable Shapes\"\n    <http://www.dais.unive.it/~shrec2016/shrec16-partial.pdf>`_ paper.\n    The reference shape can be referenced via :obj:`dataset.ref`.\n\n    .. note::\n\n        Data objects hold mesh faces instead of edge indices.\n        To convert the mesh to a graph, use the\n        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.\n        To convert the mesh to a point cloud, use the\n        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to\n        sample a fixed number of points on the mesh faces according to their\n        face area.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        partiality (str): The partiality of the dataset (one of :obj:`\"Holes\"`,\n            :obj:`\"Cuts\"`).\n        category (str): The category of the dataset (one of\n            :obj:`\"Cat\"`, :obj:`\"Centaur\"`, :obj:`\"David\"`, :obj:`\"Dog\"`,\n            :obj:`\"Horse\"`, :obj:`\"Michael\"`, :obj:`\"Victoria\"`,\n            :obj:`\"Wolf\"`).\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    train_url = ('http://www.dais.unive.it/~shrec2016/data/'\n                 'shrec2016_PartialDeformableShapes.zip')\n    test_url = ('http://www.dais.unive.it/~shrec2016/data/'\n                'shrec2016_PartialDeformableShapes_TestSet.zip')\n\n    categories = [\n        'cat', 'centaur', 'david', 'dog', 'horse', 'michael', 'victoria',\n        'wolf'\n    ]\n    partialities = ['holes', 'cuts']\n\n    def __init__(\n        self,\n        root: str,\n        partiality: str,\n        category: str,\n        train: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        assert partiality.lower() in self.partialities\n        self.part = partiality.lower()\n        assert category.lower() in self.categories\n        self.cat = category.lower()\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.__ref__ = fs.torch_load(self.processed_paths[0])\n        path = self.processed_paths[1] if train else self.processed_paths[2]\n        self.load(path)\n\n    @property\n    def ref(self) -> str:\n        ref = self.__ref__\n        if self.transform is not None:\n            ref = self.transform(ref)\n        return ref\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['training', 'test']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        name = f'{self.part}_{self.cat}.pt'\n        return [f'{i}_{name}' for i in ['ref', 'training', 'test']]\n\n    def download(self) -> None:\n        path = download_url(self.train_url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n        path = osp.join(self.raw_dir, 'shrec2016_PartialDeformableShapes')\n        os.rename(path, osp.join(self.raw_dir, 'training'))\n\n        path = download_url(self.test_url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n        path = osp.join(self.raw_dir,\n                        'shrec2016_PartialDeformableShapes_TestSet')\n        os.rename(path, osp.join(self.raw_dir, 'test'))\n\n    def process(self) -> None:\n        ref_data = read_off(\n            osp.join(self.raw_paths[0], 'null', f'{self.cat}.off'))\n\n        train_list = []\n        name = f'{self.part}_{self.cat}_*.off'\n        paths = glob.glob(osp.join(self.raw_paths[0], self.part, name))\n        paths = [path[:-4] for path in paths]\n        paths = sorted(paths, key=lambda e: (len(e), e))\n\n        for path in paths:\n            data = read_off(f'{path}.off')\n            y = read_txt_array(f'{path}.baryc_gt')\n            data.y = y[:, 0].to(torch.long) - 1\n            data.y_baryc = y[:, 1:]\n            train_list.append(data)\n\n        test_list = []\n        name = f'{self.part}_{self.cat}_*.off'\n        paths = glob.glob(osp.join(self.raw_paths[1], self.part, name))\n        paths = [path[:-4] for path in paths]\n        paths = sorted(paths, key=lambda e: (len(e), e))\n\n        for path in paths:\n            test_list.append(read_off(f'{path}.off'))\n\n        if self.pre_filter is not None:\n            train_list = [d for d in train_list if self.pre_filter(d)]\n            test_list = [d for d in test_list if self.pre_filter(d)]\n\n        if self.pre_transform is not None:\n            ref_data = self.pre_transform(ref_data)\n            train_list = [self.pre_transform(d) for d in train_list]\n            test_list = [self.pre_transform(d) for d in test_list]\n\n        torch.save(ref_data, self.processed_paths[0])\n        self.save(train_list, self.processed_paths[1])\n        self.save(test_list, self.processed_paths[2])\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'partiality={self.part}, category={self.cat})')\n"
  },
  {
    "path": "torch_geometric/datasets/snap_dataset.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport fsspec\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.io import fs\nfrom torch_geometric.utils import coalesce\n\n\nclass EgoData(Data):\n    def __inc__(self, key: str, value: Any, *args: Any, **kwargs: Any) -> Any:\n        if key == 'circle':\n            return self.num_nodes\n        elif key == 'circle_batch':\n            return int(value.max()) + 1 if value.numel() > 0 else 0\n        return super().__inc__(key, value, *args, **kwargs)\n\n\ndef read_ego(files: List[str], name: str) -> List[EgoData]:\n    import pandas as pd\n    import tqdm\n\n    files = sorted(files)\n\n    all_featnames = []\n    files = [\n        x for x in files if x.split('.')[-1] in\n        ['circles', 'edges', 'egofeat', 'feat', 'featnames']\n    ]\n    for i in range(4, len(files), 5):\n        featnames_file = files[i]\n        with fsspec.open(featnames_file, 'r') as f:\n            featnames = f.read().split('\\n')[:-1]\n            featnames = [' '.join(x.split(' ')[1:]) for x in featnames]\n            all_featnames += featnames\n    all_featnames = sorted(list(set(all_featnames)))\n    all_featnames_dict = {key: i for i, key in enumerate(all_featnames)}\n\n    data_list = []\n    for i in tqdm.tqdm(range(0, len(files), 5)):\n        circles_file = files[i]\n        edges_file = files[i + 1]\n        egofeat_file = files[i + 2]\n        feat_file = files[i + 3]\n        featnames_file = files[i + 4]\n\n        x = None\n        if name != 'gplus':  # Don't read node features on g-plus:\n            x_ego = pd.read_csv(egofeat_file, sep=' ', header=None,\n                                dtype=np.float32)\n            x_ego = torch.from_numpy(x_ego.values)\n\n            x = pd.read_csv(feat_file, sep=' ', header=None, dtype=np.float32)\n            x = torch.from_numpy(x.values)[:, 1:]\n\n            x_all = torch.cat([x, x_ego], dim=0)\n\n            # Reorder `x` according to `featnames` ordering.\n            x_all = torch.zeros(x.size(0), len(all_featnames))\n            with fsspec.open(featnames_file, 'r') as f:\n                featnames = f.read().split('\\n')[:-1]\n                featnames = [' '.join(x.split(' ')[1:]) for x in featnames]\n            indices = [all_featnames_dict[featname] for featname in featnames]\n            x_all[:, torch.tensor(indices)] = x\n            x = x_all\n\n            if x.size(1) > 100_000:\n                x = x.to_sparse_csr()\n\n        idx = pd.read_csv(feat_file, sep=' ', header=None, dtype=str,\n                          usecols=[0]).squeeze()\n\n        idx_assoc: Dict[str, int] = {}\n        for i, j in enumerate(idx):\n            idx_assoc[j] = i\n\n        circles: List[int] = []\n        circles_batch: List[int] = []\n        with fsspec.open(circles_file, 'r') as f:\n            for i, line in enumerate(f.read().split('\\n')[:-1]):\n                circle_indices = [idx_assoc[c] for c in line.split()[1:]]\n                circles += circle_indices\n                circles_batch += [i] * len(circle_indices)\n        circle = torch.tensor(circles)\n        circle_batch = torch.tensor(circles_batch)\n\n        try:\n            row = pd.read_csv(edges_file, sep=' ', header=None, dtype=str,\n                              usecols=[0]).squeeze()\n            col = pd.read_csv(edges_file, sep=' ', header=None, dtype=str,\n                              usecols=[1]).squeeze()\n        except Exception:\n            continue\n\n        row = torch.tensor([idx_assoc[i] for i in row])\n        col = torch.tensor([idx_assoc[i] for i in col])\n\n        N = max(int(row.max()), int(col.max())) + 2\n        N = x.size(0) if x is not None else N\n\n        row_ego = torch.full((N - 1, ), N - 1, dtype=torch.long)\n        col_ego = torch.arange(N - 1)\n\n        # Ego node should be connected to every other node.\n        row = torch.cat([row, row_ego, col_ego], dim=0)\n        col = torch.cat([col, col_ego, row_ego], dim=0)\n        edge_index = torch.stack([row, col], dim=0)\n        edge_index = coalesce(edge_index, num_nodes=int(N))\n\n        data = EgoData(x=x, edge_index=edge_index, circle=circle,\n                       circle_batch=circle_batch)\n\n        data_list.append(data)\n\n    return data_list\n\n\ndef read_soc(files: List[str], name: str) -> List[Data]:\n    import pandas as pd\n\n    skiprows = 4\n    if name == 'pokec':\n        skiprows = 0\n\n    edge_index = pd.read_csv(files[0], sep='\\t', header=None,\n                             skiprows=skiprows, dtype=np.int64)\n    edge_index = torch.from_numpy(edge_index.values).t()\n    num_nodes = int(edge_index.max()) + 1\n    edge_index = coalesce(edge_index, num_nodes=num_nodes)\n\n    return [Data(edge_index=edge_index, num_nodes=num_nodes)]\n\n\ndef read_wiki(files: List[str], name: str) -> List[Data]:\n    import pandas as pd\n\n    edge_index = pd.read_csv(files[0], sep='\\t', header=None, skiprows=4,\n                             dtype=np.int64)\n    edge_index = torch.from_numpy(edge_index.values).t()\n\n    idx = torch.unique(edge_index.flatten())\n    idx_assoc = torch.full(\n        (edge_index.max() + 1, ),  # type: ignore\n        -1,\n        dtype=torch.long,\n    )\n    idx_assoc[idx] = torch.arange(idx.size(0))\n\n    edge_index = idx_assoc[edge_index]\n    num_nodes = int(edge_index.max()) + 1\n    edge_index = coalesce(edge_index, num_nodes=num_nodes)\n\n    return [Data(edge_index=edge_index, num_nodes=num_nodes)]\n\n\nclass SNAPDataset(InMemoryDataset):\n    r\"\"\"A variety of graph datasets collected from `SNAP at Stanford University\n    <https://snap.stanford.edu/data>`_.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'https://snap.stanford.edu/data'\n\n    available_datasets = {\n        'ego-facebook': ['facebook.tar.gz'],\n        'ego-gplus': ['gplus.tar.gz'],\n        'ego-twitter': ['twitter.tar.gz'],\n        'soc-ca-astroph': ['ca-AstroPh.txt.gz'],\n        'soc-ca-grqc': ['ca-GrQc.txt.gz'],\n        'soc-epinions1': ['soc-Epinions1.txt.gz'],\n        'soc-livejournal1': ['soc-LiveJournal1.txt.gz'],\n        'soc-pokec': ['soc-pokec-relationships.txt.gz'],\n        'soc-slashdot0811': ['soc-Slashdot0811.txt.gz'],\n        'soc-slashdot0922': ['soc-Slashdot0902.txt.gz'],\n        'wiki-vote': ['wiki-Vote.txt.gz'],\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in self.available_datasets.keys()\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def _download(self) -> None:\n        if osp.isdir(self.raw_dir) and len(os.listdir(self.raw_dir)) > 0:\n            return\n\n        fs.makedirs(self.raw_dir, exist_ok=True)\n        self.download()\n\n    def download(self) -> None:\n        for name in self.available_datasets[self.name]:\n            fs.cp(f'{self.url}/{name}', self.raw_dir, extract=True)\n\n    def process(self) -> None:\n        raw_dir = self.raw_dir\n        filenames = fs.ls(self.raw_dir)\n        if len(filenames) == 1 and fs.isdir(filenames[0]):\n            raw_dir = filenames[0]\n\n        raw_files = fs.ls(raw_dir)\n\n        data_list: Union[List[Data], List[EgoData]]\n        if self.name[:4] == 'ego-':\n            data_list = read_ego(raw_files, self.name[4:])\n        elif self.name[:4] == 'soc-':\n            data_list = read_soc(raw_files, self.name[:4])\n        elif self.name[:5] == 'wiki-':\n            data_list = read_wiki(raw_files, self.name[5:])\n        else:\n            raise NotImplementedError\n\n        if len(data_list) > 1 and self.pre_filter is not None:\n            data_list = [data for data in data_list if self.pre_filter(data)]\n\n        if self.pre_transform is not None:\n            data_list = [self.pre_transform(data) for data in data_list]\n\n        self.save(data_list, self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'SNAP-{self.name}({len(self)})'\n"
  },
  {
    "path": "torch_geometric/datasets/suite_sparse.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nimport fsspec\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.io import fs\n\n\nclass SuiteSparseMatrixCollection(InMemoryDataset):\n    r\"\"\"A suite of sparse matrix benchmarks known as the `Suite Sparse Matrix\n    Collection <https://sparse.tamu.edu>`_ collected from a wide range of\n    applications.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        group (str): The group of the sparse matrix.\n        name (str): The name of the sparse matrix.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'https://sparse.tamu.edu/mat/{}/{}.mat'\n\n    def __init__(\n        self,\n        root: str,\n        group: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.group = group\n        self.name = name\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.group, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.group, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> str:\n        return f'{self.name}.mat'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        fs.cp(self.url.format(self.group, self.name), self.raw_dir)\n\n    def process(self) -> None:\n        from scipy.io import loadmat\n\n        with fsspec.open(self.raw_paths[0], 'rb') as f:\n            mat = loadmat(f)['Problem'][0][0][2].tocsr().tocoo()\n\n        row = torch.from_numpy(mat.row).to(torch.long)\n        col = torch.from_numpy(mat.col).to(torch.long)\n        edge_index = torch.stack([row, col], dim=0)\n\n        value = torch.from_numpy(mat.data).to(torch.float)\n        edge_attr = None if torch.all(value == 1.0) else value\n\n        size: Optional[torch.Size] = torch.Size(mat.shape)\n        if mat.shape[0] == mat.shape[1]:\n            size = None\n\n        num_nodes = mat.shape[0]\n\n        data = Data(edge_index=edge_index, edge_attr=edge_attr, size=size,\n                    num_nodes=num_nodes)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(group={self.group}, '\n                f'name={self.name})')\n"
  },
  {
    "path": "torch_geometric/datasets/tag_dataset.py",
    "content": "import csv\nimport os\nimport os.path as osp\nfrom collections.abc import Sequence\nfrom typing import Dict, List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import InMemoryDataset, download_google_url\nfrom torch_geometric.data.data import BaseData\nfrom torch_geometric.io import fs\n\ntry:\n    from pandas import DataFrame, read_csv\n    WITH_PANDAS = True\nexcept ImportError:\n    WITH_PANDAS = False\n\nIndexType = Union[slice, Tensor, np.ndarray, Sequence]\n\n\nclass TAGDataset(InMemoryDataset):\n    r\"\"\"The Text Attributed Graph datasets from the\n    `\"Learning on Large-scale Text-attributed Graphs via Variational Inference\"\n    <https://arxiv.org/abs/2210.14709>`_ paper and `\"Harnessing Explanations:\n    LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation\n    Learning\" <https://arxiv.org/abs/2305.19523>`_ paper.\n    This dataset is aiming on transform `ogbn products`, `ogbn arxiv`\n    into Text Attributed Graph that each node in graph is associate with a\n    raw text, LLM prediction and explanation, that dataset can be adapt to\n    DataLoader (for LM training) and NeighborLoader(for GNN training).\n    In addition, this class can be use as a wrapper class by convert a\n    InMemoryDataset with Tokenizer and text into Text Attributed Graph.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        dataset (InMemoryDataset): The name of the dataset\n            (:obj:`\"ogbn-products\"`, :obj:`\"ogbn-arxiv\"`).\n        tokenizer_name (str): The tokenizer name for language model,\n            Be sure to use same tokenizer name as your `model id` of model repo\n            on huggingface.co.\n        text (List[str]): list of raw text associate with node, the order of\n            list should be align with node list\n        split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary,\n            for saving split index, it is required that if your dataset doesn't\n            have get_split_idx function\n        tokenize_batch_size (int): batch size of tokenizing text, the\n            tokenizing process will run on cpu, default: 256\n        token_on_disk (bool): save token as .pt file on disk or not,\n            default: False\n        text_on_disk (bool): save given text(list of str) as dataframe on disk\n            or not, default: False\n        force_reload (bool): default: False\n    .. note::\n        See `example/llm/glem.py` for example usage\n    \"\"\"\n    raw_text_id = {\n        'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',\n        'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'\n    }\n\n    llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds'\n\n    llm_explanation_id = {\n        'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ',\n    }\n\n    def __init__(\n        self,\n        root: str,\n        dataset: InMemoryDataset,\n        tokenizer_name: str,\n        text: Optional[List[str]] = None,\n        split_idx: Optional[Dict[str, Tensor]] = None,\n        tokenize_batch_size: int = 256,\n        token_on_disk: bool = False,\n        text_on_disk: bool = False,\n        force_reload: bool = False,\n    ) -> None:\n        # list the vars you want to pass in before run download & process\n        self.name = dataset.name\n        self.text = text\n        self.llm_prediction_topk = 5\n        self.tokenizer_name = tokenizer_name\n        from transformers import AutoTokenizer\n        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)\n        if self.tokenizer.pad_token_id is None:\n            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id\n        if self.tokenizer.pad_token is None:\n            self.tokenizer.pad_token = self.tokenizer.eos_token\n\n        self.dir_name = '_'.join(dataset.name.split('-'))\n        self.root = osp.join(root, self.dir_name)\n        missing_str_list = []\n        if not WITH_PANDAS:\n            missing_str_list.append('pandas')\n        if len(missing_str_list) > 0:\n            missing_str = ' '.join(missing_str_list)\n            error_out = f\"`pip install {missing_str}` to use this dataset.\"\n            raise ImportError(error_out)\n        if hasattr(dataset, 'get_idx_split'):\n            self.split_idx = dataset.get_idx_split()\n        elif split_idx is not None:\n            self.split_idx = split_idx\n        else:\n            raise ValueError(\"TAGDataset need split idx for generating \"\n                             \"is_gold mask, please pass splited index \"\n                             \"in format of dictionaty with 'train', 'valid' \"\n                             \"'test' index tensor to 'split_idx'\")\n        if text_on_disk:\n            if text is not None:\n                self.save_node_text(text)\n        self.text_on_disk = text_on_disk\n        # init will call download and process\n        super().__init__(self.root, transform=None, pre_transform=None,\n                         pre_filter=None, force_reload=force_reload)\n        # after processing and download\n        # Dataset has to have BaseData as _data\n        assert dataset._data is not None\n        self._data = dataset._data  # reassign reference\n        assert self._data is not None\n        assert dataset._data.y is not None\n        assert isinstance(self._data, BaseData)\n        assert self._data.num_nodes is not None\n        assert isinstance(dataset._data.num_nodes, int)\n        assert isinstance(self._data.num_nodes, int)\n        self._n_id = torch.arange(self._data.num_nodes)\n        is_good_tensor = self.load_gold_mask()\n        self._is_gold = is_good_tensor.squeeze()\n        self._data['is_gold'] = is_good_tensor\n        if self.text is not None and len(self.text) != self._data.num_nodes:\n            raise ValueError(\"The number of text sequence in 'text' should be \"\n                             \"equal to number of nodes!\")\n        self.token_on_disk = token_on_disk\n        self.tokenize_batch_size = tokenize_batch_size\n        self._token = self.tokenize_graph(self.tokenize_batch_size)\n        self._llm_explanation_token: Dict[str, Tensor] = {}\n        self._all_token: Dict[str, Tensor] = {}\n        if self.name in self.llm_explanation_id:\n            self._llm_explanation_token = self.tokenize_graph(\n                self.tokenize_batch_size, text_type='llm_explanation')\n            self._all_token = self.tokenize_graph(self.tokenize_batch_size,\n                                                  text_type='all')\n        self.__num_classes__ = dataset.num_classes\n\n    @property\n    def num_classes(self) -> int:\n        return self.__num_classes__\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        file_names = []\n        for _, _, files in os.walk(osp.join(self.root, 'raw')):\n            for file in files:\n                file_names.append(file)\n        return file_names\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return [\n            'geometric_data_processed.pt', 'pre_filter.pt',\n            'pre_transformed.pt'\n        ]\n\n    @property\n    def token(self) -> Dict[str, Tensor]:\n        if self._token is None:  # lazy load\n            self._token = self.tokenize_graph()\n        return self._token\n\n    @property\n    def llm_explanation_token(self) -> Dict[str, Tensor]:\n        if self._llm_explanation_token is None and \\\n                self.name in self.llm_explanation_id:\n            self._llm_explanation_token = self.tokenize_graph(\n                text_type='llm_explanation')\n        return self._llm_explanation_token\n\n    @property\n    def all_token(self) -> Dict[str, Tensor]:\n        if self._all_token is None and \\\n                self.name in self.llm_explanation_id:\n            self._all_token = self.tokenize_graph(text_type='all')\n        return self._all_token\n\n    # load is_gold after init\n    @property\n    def is_gold(self) -> Tensor:\n        if self._is_gold is None:\n            print('lazy load is_gold!!')\n            self._is_gold = self.load_gold_mask()\n        return self._is_gold\n\n    def get_n_id(self, node_idx: IndexType) -> Tensor:\n        if self._n_id is None:\n            assert self._data is not None\n            assert self._data.num_nodes is not None\n            assert isinstance(self._data.num_nodes, int)\n            self._n_id = torch.arange(self._data.num_nodes)\n        return self._n_id[node_idx]\n\n    def load_gold_mask(self) -> Tensor:\n        r\"\"\"Use original train split as gold split, generating is_gold mask\n        for picking ground truth labels and pseudo labels.\n        \"\"\"\n        train_split_idx = self.get_idx_split()['train']\n        assert self._data is not None\n        assert self._data.num_nodes is not None\n        assert isinstance(self._data.num_nodes, int)\n        is_good_tensor = torch.zeros(self._data.num_nodes,\n                                     dtype=torch.bool).view(-1, 1)\n        is_good_tensor[train_split_idx] = True\n        return is_good_tensor\n\n    def get_gold(self, node_idx: IndexType) -> Tensor:\n        r\"\"\"Get gold mask for given node_idx.\n\n        Args:\n            node_idx (torch.tensor): a tensor contain node idx\n        \"\"\"\n        if self._is_gold is None:\n            self._is_gold = self.is_gold\n        return self._is_gold[node_idx]\n\n    def get_idx_split(self) -> Dict[str, Tensor]:\n        return self.split_idx\n\n    def download(self) -> None:\n        print('downloading raw text')\n        raw_text_path = download_google_url(id=self.raw_text_id[self.name],\n                                            folder=f'{self.root}/raw',\n                                            filename='node-text.csv.gz',\n                                            log=True)\n        self.text = list(read_csv(raw_text_path)['text'])\n        if self.name in self.llm_explanation_id:\n            print('downloading llm explanations')\n            llm_explanation_path = download_google_url(\n                id=self.llm_explanation_id[self.name],\n                folder=f'{self.root}/raw', filename='node-gpt-response.csv.gz',\n                log=True)\n            self.llm_explanation = list(read_csv(llm_explanation_path)['text'])\n            print('downloading llm predictions')\n            fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir)\n\n    def process(self) -> None:\n        # process Title and Abstraction\n        if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):\n            text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))\n            self.text = list(text_df['text'])\n        elif self.name in self.raw_text_id:\n            self.download()\n        else:\n            print('The dataset is not ogbn-products nor ogbn-arxiv,'\n                  'please pass in your raw text string list to `text`')\n        if self.text is None:\n            raise ValueError(\"The TAGDataset only have ogbn-products and \"\n                             \"ogbn-arxiv raw text in default \"\n                             \"The raw text of each node is not specified\"\n                             \"Please pass in 'text' when convert your dataset \"\n                             \"to Text Attribute Graph Dataset\")\n        # process LLM explanation and prediction\n        llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz'\n        llm_prediction_path = f'{self.raw_dir}/{self.name}.csv'\n        if osp.exists(llm_explanation_path) and osp.exists(\n                llm_prediction_path):\n            # load LLM explanation\n            self.llm_explanation = list(read_csv(llm_explanation_path)['text'])\n            # load LLM prediction\n            preds = []\n            with open(llm_prediction_path) as file:\n                reader = csv.reader(file)\n                for row in reader:\n                    inner_list = []\n                    for value in row:\n                        inner_list.append(int(value))\n                    preds.append(inner_list)\n\n            pl = torch.zeros(len(preds), self.llm_prediction_topk,\n                             dtype=torch.long)\n            for i, pred in enumerate(preds):\n                pl[i][:len(pred)] = torch.tensor(\n                    pred[:self.llm_prediction_topk], dtype=torch.long) + 1\n\n            if self.llm_explanation is None or pl is None:\n                raise ValueError(\n                    \"The TAGDataset only have ogbn-arxiv LLM explanations\"\n                    \"and predictions in default. The llm explanation and\"\n                    \"prediction of each node is not specified.Please pass in\"\n                    \"'llm_explanation' and 'llm_prediction' when\"\n                    \"convert your dataset to Text Attribute Graph Dataset\")\n        elif self.name in self.llm_explanation_id:\n            self.download()\n        else:\n            print(\n                'The dataset is not ogbn-arxiv,'\n                'please pass in your llm explanation list to `llm_explanation`'\n                'and llm prediction list to `llm_prediction`')\n\n    def save_node_text(self, text: List[str]) -> None:\n        node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')\n        if osp.exists(node_text_path):\n            print(f'The raw text is existed at {node_text_path}')\n        else:\n            print(f'Saving raw text file at {node_text_path}')\n            os.makedirs(f'{self.root}/raw', exist_ok=True)\n            text_df = DataFrame(text, columns=['text'])\n            text_df.to_csv(osp.join(node_text_path), compression='gzip',\n                           index=False)\n\n    def tokenize_graph(self, batch_size: int = 256,\n                       text_type: str = 'raw_text') -> Dict[str, Tensor]:\n        r\"\"\"Tokenizing the text associate with each node, running in cpu.\n\n        Args:\n            batch_size (Optional[int]): batch size of list of text for\n                generating emebdding\n            text_type (Optional[str]): type of text\n        Returns:\n            Dict[str, torch.Tensor]: tokenized graph\n        \"\"\"\n        assert text_type in ['raw_text', 'llm_explanation', 'all']\n        if text_type == 'raw_text':\n            _text = self.text\n        elif text_type == 'llm_explanation':\n            _text = self.llm_explanation\n        elif text_type == 'all':\n            if self.text is None or self.llm_explanation is None:\n                raise ValueError(\"The TAGDataset need text and llm explanation\"\n                                 \"for tokenizing all text\")\n            _text = [\n                f'{raw_txt} Explanation: {exp_txt}'\n                for raw_txt, exp_txt in zip(self.text, self.llm_explanation)\n            ]\n\n        data_len = 0\n        if _text is not None:\n            data_len = len(_text)\n        else:\n            raise ValueError(\"The TAGDataset need text for tokenization\")\n        token_keys = ['input_ids', 'token_type_ids', 'attention_mask']\n        path = os.path.join(self.processed_dir, 'token', text_type,\n                            self.tokenizer_name)\n        # Check if the .pt files already exist\n        token_files_exist = any(\n            os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)\n\n        if token_files_exist and self.token_on_disk:\n            print('Found tokenized file, loading may take several minutes...')\n            all_encoded_token = {\n                k: torch.load(os.path.join(path, f'{k}.pt'), weights_only=True)\n                for k in token_keys\n                if os.path.exists(os.path.join(path, f'{k}.pt'))\n            }\n            return all_encoded_token\n\n        all_encoded_token = {k: [] for k in token_keys}\n        pbar = tqdm(total=data_len)\n\n        pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}')\n        for i in range(0, data_len, batch_size):\n            end_index = min(data_len, i + batch_size)\n            token = self.tokenizer(_text[i:end_index], padding='max_length',\n                                   truncation=True, max_length=512,\n                                   return_tensors=\"pt\")\n            for k in token.keys():\n                all_encoded_token[k].append(token[k])\n            pbar.update(end_index - i)\n        pbar.close()\n\n        all_encoded_token = {\n            k: torch.cat(v)\n            for k, v in all_encoded_token.items() if len(v) > 0\n        }\n        if self.token_on_disk:\n            os.makedirs(path, exist_ok=True)\n            print('Saving tokens on Disk')\n            for k, tensor in all_encoded_token.items():\n                torch.save(tensor, os.path.join(path, f'{k}.pt'))\n                print('Token saved:', os.path.join(path, f'{k}.pt'))\n        os.environ[\"TOKENIZERS_PARALLELISM\"] = 'true'  # suppressing warning\n        return all_encoded_token\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n\n    class TextDataset(torch.utils.data.Dataset):\n        r\"\"\"This nested dataset provides textual data for each node in\n        the graph. Factory method to create TextDataset from TAGDataset.\n\n        Args:\n            tag_dataset (TAGDataset): the parent dataset\n            text_type (str): type of text\n        \"\"\"\n        def __init__(self, tag_dataset: 'TAGDataset',\n                     text_type: str = 'raw_text') -> None:\n            assert text_type in ['raw_text', 'llm_explanation', 'all']\n            self.tag_dataset = tag_dataset\n            if text_type == 'raw_text':\n                self.token = tag_dataset.token\n            elif text_type == 'llm_explanation':\n                self.token = tag_dataset.llm_explanation_token\n            elif text_type == 'all':\n                self.token = tag_dataset.all_token\n            assert tag_dataset._data is not None\n            self._data = tag_dataset._data\n\n            assert tag_dataset._data.y is not None\n            self.labels = tag_dataset._data.y\n\n        def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]:\n            r\"\"\"This function will be called in __getitem__().\n\n            Args:\n                node_idx (IndexType): selected node idx in each batch\n            Returns:\n                items (Dict[str, Tensor]): input for LM\n            \"\"\"\n            items = {k: v[node_idx] for k, v in self.token.items()}\n            return items\n\n        # for LM training\n        def __getitem__(\n            self,\n            node_id: IndexType,\n        ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:\n            r\"\"\"This function will override the function in\n            torch.utils.data.Dataset, and will be called when you\n            iterate batch in the dataloader, make sure all following\n            key value pairs are present in the return dict.\n\n            Args:\n                node_id (List[int]): list of node idx for selecting tokens,\n                    labels etc. when iterating data loader for LM\n            Returns:\n                items (dict): input k,v pairs for Language model training and\n                    inference\n            \"\"\"\n            item: Dict[str, Union[Tensor, Dict[str, Tensor]]] = {}\n            item['input'] = self.get_token(node_id)\n            item['labels'] = self.labels[node_id]\n            item['is_gold'] = self.tag_dataset.get_gold(node_id)\n            item['n_id'] = self.tag_dataset.get_n_id(node_id)\n            return item\n\n        def __len__(self) -> int:\n            assert self._data.num_nodes is not None\n            return self._data.num_nodes\n\n        def get(self, idx: int) -> BaseData:\n            return self._data\n\n        def __repr__(self) -> str:\n            return f'{self.__class__.__name__}()'\n\n    def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset:\n        r\"\"\"Factory Build text dataset from Text Attributed Graph Dataset\n        each data point is node's associated text token.\n        \"\"\"\n        return TAGDataset.TextDataset(self, text_type)\n"
  },
  {
    "path": "torch_geometric/datasets/taobao.py",
    "content": "import os\nfrom typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    HeteroData,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\n\n\nclass Taobao(InMemoryDataset):\n    r\"\"\"Taobao is a dataset of user behaviors from Taobao offered by Alibaba,\n    provided by the `Tianchi Alicloud platform\n    <https://tianchi.aliyun.com/dataset/649>`_.\n\n    Taobao is a heterogeneous graph for recommendation.\n    Nodes represent users with user IDs, items with item IDs, and categories\n    with category ID.\n    Edges between users and items represent different types of user behaviors\n    towards items (alongside with timestamps).\n    Edges between items and categories assign each item to its set of\n    categories.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            every access. (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.HeteroData` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    \"\"\"\n    url = ('https://alicloud-dev.oss-cn-hangzhou.aliyuncs.com/'\n           'UserBehavior.csv.zip')\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0], data_cls=HeteroData)\n\n    @property\n    def raw_file_names(self) -> str:\n        return 'UserBehavior.csv'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.remove(path)\n\n    def process(self) -> None:\n        import pandas as pd\n\n        cols = ['userId', 'itemId', 'categoryId', 'behaviorType', 'timestamp']\n        df = pd.read_csv(self.raw_paths[0], names=cols)\n\n        # Time representation (YYYY.MM.DD-HH:MM:SS -> Integer)\n        # start: 1511539200 = 2017.11.25-00:00:00\n        # end:   1512316799 = 2017.12.03-23:59:59\n        start = 1511539200\n        end = 1512316799\n        df = df[(df[\"timestamp\"] >= start) & (df[\"timestamp\"] <= end)]\n\n        df = df.drop_duplicates()\n\n        behavior_dict = {'pv': 0, 'cart': 1, 'buy': 2, 'fav': 3}\n        df['behaviorType'] = df['behaviorType'].map(behavior_dict)\n\n        num_entries = {}\n        for name in ['userId', 'itemId', 'categoryId']:\n            # Map IDs to consecutive integers:\n            value, df[name] = np.unique(df[[name]].values, return_inverse=True)\n            num_entries[name] = value.shape[0]\n\n        data = HeteroData()\n\n        data['user'].num_nodes = num_entries['userId']\n        data['item'].num_nodes = num_entries['itemId']\n        data['category'].num_nodes = num_entries['categoryId']\n\n        row = torch.from_numpy(df['userId'].values)\n        col = torch.from_numpy(df['itemId'].values)\n        data['user', 'item'].edge_index = torch.stack([row, col], dim=0)\n        data['user', 'item'].time = torch.from_numpy(df['timestamp'].values)\n        behavior = torch.from_numpy(df['behaviorType'].values)\n        data['user', 'item'].behavior = behavior\n\n        df = df[['itemId', 'categoryId']].drop_duplicates()\n        row = torch.from_numpy(df['itemId'].values)\n        col = torch.from_numpy(df['categoryId'].values)\n        data['item', 'category'].edge_index = torch.stack([row, col], dim=0)\n\n        data = data if self.pre_transform is None else self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/teeth3ds.py",
    "content": "import json\nimport os\nimport os.path as osp\nfrom glob import glob\nfrom typing import Callable, Dict, List, Optional\n\nimport numpy as np\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\n\n\nclass Teeth3DS(InMemoryDataset):\n    r\"\"\"The Teeth3DS+ dataset from the `\"An Extended Benchmark for Intra-oral\n    3D Scans Analysis\" <https://crns-smartvision.github.io/teeth3ds/>`_ paper.\n\n    This dataset is the first comprehensive public benchmark designed to\n    advance the field of intra-oral 3D scan analysis developed as part of the\n    3DTeethSeg 2022 and 3DTeethLand 2024 MICCAI challenges, aiming to drive\n    research in teeth identification, segmentation, labeling, 3D modeling,\n    and dental landmark identification.\n    The dataset includes at least 1,800 intra-oral scans (containing 23,999\n    annotated teeth) collected from 900 patients, covering both upper and lower\n    jaws separately.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        split (str): The split name (one of :obj:`\"Teeth3DS\"`,\n            :obj:`\"3DTeethSeg22_challenge\"` or :obj:`\"3DTeethLand_challenge\"`).\n        train (bool, optional): If :obj:`True`, loads the training dataset,\n            otherwise the test dataset. (default: :obj:`True`)\n        num_samples (int, optional): Number of points to sample from each mesh.\n            (default: :obj:`30000`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    urls = {\n        'data_part_1.zip':\n        'https://osf.io/download/qhprs/',\n        'data_part_2.zip':\n        'https://osf.io/download/4pwnr/',\n        'data_part_3.zip':\n        'https://osf.io/download/frwdp/',\n        'data_part_4.zip':\n        'https://osf.io/download/2arn4/',\n        'data_part_5.zip':\n        'https://osf.io/download/xrz5f/',\n        'data_part_6.zip':\n        'https://osf.io/download/23hgq/',\n        'data_part_7.zip':\n        'https://osf.io/download/u83ad/',\n        'train_test_split':\n        'https://files.de-1.osf.io/v1/'\n        'resources/xctdy/providers/osfstorage/?zip='\n    }\n\n    sample_url = {\n        'teeth3ds_sample': 'https://osf.io/download/vr38s/',\n    }\n\n    landmarks_urls = {\n        '3DTeethLand_landmarks_train.zip': 'https://osf.io/download/k5hbj/',\n        '3DTeethLand_landmarks_test.zip': 'https://osf.io/download/sqw5e/',\n    }\n\n    def __init__(\n        self,\n        root: str,\n        split:\n        str = 'Teeth3DS',  # [3DTeethSeg22_challenge, 3DTeethLand_challenge]\n        train: bool = True,\n        num_samples: int = 30000,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n\n        self.mode = 'training' if train else 'testing'\n        self.split = split\n        self.num_samples = num_samples\n\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n\n    @property\n    def processed_dir(self) -> str:\n        return os.path.join(self.root, f'processed_{self.split}_{self.mode}')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['license.txt']\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        # Directory containing train/test split files:\n        split_subdir = 'teeth3ds_sample' if self.split == 'sample' else ''\n        split_dir = osp.join(\n            self.raw_dir,\n            split_subdir,\n            f'{self.split}_train_test_split',\n        )\n\n        split_files = glob(osp.join(split_dir, f'{self.mode}*.txt'))\n\n        # Collect all file names from the split files:\n        combined_list = []\n        for file_path in split_files:\n            with open(file_path) as file:\n                combined_list.extend(file.read().splitlines())\n\n        # Generate the list of processed file paths:\n        return [f'{file_name}.pt' for file_name in combined_list]\n\n    def download(self) -> None:\n        if self.split == 'sample':\n            for key, url in self.sample_url.items():\n                path = download_url(url, self.root, filename=key)\n                extract_zip(path, self.raw_dir)\n                os.unlink(path)\n        else:\n            for key, url in self.urls.items():\n                path = download_url(url, self.root, filename=key)\n                extract_zip(path, self.raw_dir)\n                os.unlink(path)\n            for key, url in self.landmarks_urls.items():\n                path = download_url(url, self.root, filename=key)\n                extract_zip(path, self.raw_dir)  # Extract each downloaded part\n                os.unlink(path)\n\n    def process_file(self, file_path: str) -> Optional[Data]:\n        \"\"\"Processes the input file path to load mesh data, annotations,\n        and prepare the input features for a graph-based deep learning model.\n        \"\"\"\n        import trimesh\n        from fpsample import bucket_fps_kdline_sampling\n\n        mesh = trimesh.load_mesh(file_path)\n\n        if isinstance(mesh, list):\n            # Handle the case where a list of Geometry objects is returned\n            mesh = mesh[0]\n\n        vertices = mesh.vertices\n        vertex_normals = mesh.vertex_normals\n\n        # Perform sampling on mesh vertices:\n        if len(vertices) < self.num_samples:\n            sampled_indices = np.random.choice(\n                len(vertices),\n                self.num_samples,\n                replace=True,\n            )\n        else:\n            sampled_indices = bucket_fps_kdline_sampling(\n                vertices,\n                self.num_samples,\n                h=5,\n                start_idx=0,\n            )\n\n        if len(sampled_indices) != self.num_samples:\n            raise RuntimeError(f\"Sampled points mismatch, expected \"\n                               f\"{self.num_samples} points, but got \"\n                               f\"{len(sampled_indices)} for '{file_path}'\")\n\n        # Extract features and annotations for the sampled points:\n        pos = torch.tensor(vertices[sampled_indices], dtype=torch.float)\n        x = torch.tensor(vertex_normals[sampled_indices], dtype=torch.float)\n\n        # Load segmentation annotations:\n        seg_annotation_path = file_path.replace('.obj', '.json')\n        if osp.exists(seg_annotation_path):\n            with open(seg_annotation_path) as f:\n                seg_annotations = json.load(f)\n            y = torch.tensor(\n                np.asarray(seg_annotations['labels'])[sampled_indices],\n                dtype=torch.float)\n            instances = torch.tensor(\n                np.asarray(seg_annotations['instances'])[sampled_indices],\n                dtype=torch.float)\n        else:\n            y = torch.empty(0, 3)\n            instances = torch.empty(0, 3)\n\n        # Load landmarks annotations:\n        landmarks_annotation_path = file_path.replace('.obj', '__kpt.json')\n\n        # Parse keypoint annotations into structured tensors:\n        keypoints_dict: Dict[str, List] = {\n            key: []\n            for key in [\n                'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',\n                'FacialPoint'\n            ]\n        }\n        keypoint_tensors: Dict[str, torch.Tensor] = {\n            key: torch.empty(0, 3)\n            for key in [\n                'Mesial', 'Distal', 'Cusp', 'InnerPoint', 'OuterPoint',\n                'FacialPoint'\n            ]\n        }\n        if osp.exists(landmarks_annotation_path):\n            with open(landmarks_annotation_path) as f:\n                landmarks_annotations = json.load(f)\n\n            for keypoint in landmarks_annotations['objects']:\n                keypoints_dict[keypoint['class']].extend(keypoint['coord'])\n\n            keypoint_tensors = {\n                k: torch.tensor(np.asarray(v),\n                                dtype=torch.float).reshape(-1, 3)\n                for k, v in keypoints_dict.items()\n            }\n\n        data = Data(\n            pos=pos,\n            x=x,\n            y=y,\n            instances=instances,\n            jaw=file_path.split('.obj')[0].split('_')[1],\n            mesial=keypoint_tensors['Mesial'],\n            distal=keypoint_tensors['Distal'],\n            cusp=keypoint_tensors['Cusp'],\n            inner_point=keypoint_tensors['InnerPoint'],\n            outer_point=keypoint_tensors['OuterPoint'],\n            facial_point=keypoint_tensors['FacialPoint'],\n        )\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        return data\n\n    def process(self) -> None:\n        for file in tqdm(self.processed_file_names):\n            name = file.split('.')[0]\n            path = osp.join(self.raw_dir, '**', '*', name + '.obj')\n            paths = glob(path)\n            if len(paths) == 1:\n                data = self.process_file(paths[0])\n                torch.save(data, osp.join(self.processed_dir, file))\n\n    def len(self) -> int:\n        return len(self.processed_file_names)\n\n    def get(self, idx: int) -> Data:\n        return torch.load(\n            osp.join(self.processed_dir, self.processed_file_names[idx]),\n            weights_only=False,\n        )\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'mode={self.mode}, split={self.split})')\n"
  },
  {
    "path": "torch_geometric/datasets/tosca.py",
    "content": "import glob\nimport os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import read_txt_array\n\n\nclass TOSCA(InMemoryDataset):\n    r\"\"\"The TOSCA dataset from the `\"Numerical Geometry of Non-Ridig Shapes\"\n    <https://www.amazon.com/Numerical-Geometry-Non-Rigid-Monographs-Computer/\n    dp/0387733000>`_ book, containing 80 meshes.\n    Meshes within the same category have the same triangulation and an equal\n    number of vertices numbered in a compatible way.\n\n    .. note::\n\n        Data objects hold mesh faces instead of edge indices.\n        To convert the mesh to a graph, use the\n        :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.\n        To convert the mesh to a point cloud, use the\n        :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to\n        sample a fixed number of points on the mesh faces according to their\n        face area.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        categories (list, optional): List of categories to include in the\n            dataset. Can include the categories :obj:`\"Cat\"`, :obj:`\"Centaur\"`,\n            :obj:`\"David\"`, :obj:`\"Dog\"`, :obj:`\"Gorilla\"`, :obj:`\"Horse\"`,\n            :obj:`\"Michael\"`, :obj:`\"Victoria\"`, :obj:`\"Wolf\"`. If set to\n            :obj:`None`, the dataset will contain all categories. (default:\n            :obj:`None`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'http://tosca.cs.technion.ac.il/data/toscahires-asci.zip'\n\n    categories = [\n        'cat', 'centaur', 'david', 'dog', 'gorilla', 'horse', 'michael',\n        'victoria', 'wolf'\n    ]\n\n    def __init__(\n        self,\n        root: str,\n        categories: Optional[List[str]] = None,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        categories = self.categories if categories is None else categories\n        categories = [cat.lower() for cat in categories]\n        for cat in categories:\n            assert cat in self.categories\n        self.categories = categories\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['cat0.vert', 'cat0.tri']\n\n    @property\n    def processed_file_names(self) -> str:\n        name = '_'.join([cat[:2] for cat in self.categories])\n        return f'{name}.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.raw_dir)\n        extract_zip(path, self.raw_dir)\n        os.unlink(path)\n\n    def process(self) -> None:\n        data_list = []\n        for cat in self.categories:\n            paths = glob.glob(osp.join(self.raw_dir, f'{cat}*.tri'))\n            paths = [path[:-4] for path in paths]\n            paths = sorted(paths, key=lambda e: (len(e), e))\n\n            for path in paths:\n                pos = read_txt_array(f'{path}.vert')\n                face = read_txt_array(f'{path}.tri', dtype=torch.long)\n                face = face - face.min()  # Ensure zero-based index.\n                data = Data(pos=pos, face=face.t().contiguous())\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n                data_list.append(data)\n\n        self.save(data_list, self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/tu_dataset.py",
    "content": "import os.path as osp\nfrom typing import Callable, List, Optional\n\nfrom torch_geometric.data import Data, InMemoryDataset\nfrom torch_geometric.io import fs, read_tu_data\n\n\nclass TUDataset(InMemoryDataset):\n    r\"\"\"A variety of graph kernel benchmark datasets, *.e.g.*,\n    :obj:`\"IMDB-BINARY\"`, :obj:`\"REDDIT-BINARY\"` or :obj:`\"PROTEINS\"`,\n    collected from the `TU Dortmund University\n    <https://chrsmrrs.github.io/datasets>`_.\n    In addition, this dataset wrapper provides `cleaned dataset versions\n    <https://github.com/nd7141/graph_datasets>`_ as motivated by the\n    `\"Understanding Isomorphism Bias in Graph Data Sets\"\n    <https://arxiv.org/abs/1910.12091>`_ paper, containing only non-isomorphic\n    graphs.\n\n    .. note::\n        Some datasets may not come with any node labels.\n        You can then either make use of the argument :obj:`use_node_attr`\n        to load additional continuous node attributes (if present) or provide\n        synthetic node features using transforms such as\n        :class:`torch_geometric.transforms.Constant` or\n        :class:`torch_geometric.transforms.OneHotDegree`.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The `name\n            <https://chrsmrrs.github.io/datasets/docs/datasets/>`_ of the\n            dataset.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        use_node_attr (bool, optional): If :obj:`True`, the dataset will\n            contain additional continuous node attributes (if present).\n            (default: :obj:`False`)\n        use_edge_attr (bool, optional): If :obj:`True`, the dataset will\n            contain additional continuous edge attributes (if present).\n            (default: :obj:`False`)\n        cleaned (bool, optional): If :obj:`True`, the dataset will\n            contain only non-isomorphic graphs. (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 20 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - MUTAG\n          - 188\n          - ~17.9\n          - ~39.6\n          - 7\n          - 2\n        * - ENZYMES\n          - 600\n          - ~32.6\n          - ~124.3\n          - 3\n          - 6\n        * - PROTEINS\n          - 1,113\n          - ~39.1\n          - ~145.6\n          - 3\n          - 2\n        * - COLLAB\n          - 5,000\n          - ~74.5\n          - ~4914.4\n          - 0\n          - 3\n        * - IMDB-BINARY\n          - 1,000\n          - ~19.8\n          - ~193.1\n          - 0\n          - 2\n        * - REDDIT-BINARY\n          - 2,000\n          - ~429.6\n          - ~995.5\n          - 0\n          - 2\n        * - ...\n          -\n          -\n          -\n          -\n          -\n    \"\"\"\n\n    url = 'https://www.chrsmrrs.com/graphkerneldatasets'\n    cleaned_url = ('https://raw.githubusercontent.com/nd7141/'\n                   'graph_datasets/master/datasets')\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n        use_node_attr: bool = False,\n        use_edge_attr: bool = False,\n        cleaned: bool = False,\n    ) -> None:\n        self.name = name\n        self.cleaned = cleaned\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n\n        out = fs.torch_load(self.processed_paths[0])\n        if not isinstance(out, tuple) or len(out) < 3:\n            raise RuntimeError(\n                \"The 'data' object was created by an older version of PyG. \"\n                \"If this error occurred while loading an already existing \"\n                \"dataset, remove the 'processed/' directory in the dataset's \"\n                \"root folder and try again.\")\n        assert len(out) == 3 or len(out) == 4\n\n        if len(out) == 3:  # Backward compatibility.\n            data, self.slices, self.sizes = out\n            data_cls = Data\n        else:\n            data, self.slices, self.sizes, data_cls = out\n\n        if not isinstance(data, dict):  # Backward compatibility.\n            self.data = data\n        else:\n            self.data = data_cls.from_dict(data)\n\n        assert isinstance(self._data, Data)\n        if self._data.x is not None and not use_node_attr:\n            num_node_attributes = self.num_node_attributes\n            self._data.x = self._data.x[:, num_node_attributes:]\n        if self._data.edge_attr is not None and not use_edge_attr:\n            num_edge_attrs = self.num_edge_attributes\n            self._data.edge_attr = self._data.edge_attr[:, num_edge_attrs:]\n\n    @property\n    def raw_dir(self) -> str:\n        name = f'raw{\"_cleaned\" if self.cleaned else \"\"}'\n        return osp.join(self.root, self.name, name)\n\n    @property\n    def processed_dir(self) -> str:\n        name = f'processed{\"_cleaned\" if self.cleaned else \"\"}'\n        return osp.join(self.root, self.name, name)\n\n    @property\n    def num_node_labels(self) -> int:\n        return self.sizes['num_node_labels']\n\n    @property\n    def num_node_attributes(self) -> int:\n        return self.sizes['num_node_attributes']\n\n    @property\n    def num_edge_labels(self) -> int:\n        return self.sizes['num_edge_labels']\n\n    @property\n    def num_edge_attributes(self) -> int:\n        return self.sizes['num_edge_attributes']\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        names = ['A', 'graph_indicator']\n        return [f'{self.name}_{name}.txt' for name in names]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        url = self.cleaned_url if self.cleaned else self.url\n        fs.cp(f'{url}/{self.name}.zip', self.raw_dir, extract=True)\n        for filename in fs.ls(osp.join(self.raw_dir, self.name)):\n            fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename)))\n        fs.rm(osp.join(self.raw_dir, self.name))\n\n    def process(self) -> None:\n        self.data, self.slices, sizes = read_tu_data(self.raw_dir, self.name)\n\n        if self.pre_filter is not None or self.pre_transform is not None:\n            data_list = [self.get(idx) for idx in range(len(self))]\n\n            if self.pre_filter is not None:\n                data_list = [d for d in data_list if self.pre_filter(d)]\n\n            if self.pre_transform is not None:\n                data_list = [self.pre_transform(d) for d in data_list]\n\n            self.data, self.slices = self.collate(data_list)\n            self._data_list = None  # Reset cache.\n\n        assert isinstance(self._data, Data)\n        fs.torch_save(\n            (self._data.to_dict(), self.slices, sizes, self._data.__class__),\n            self.processed_paths[0],\n        )\n\n    def __repr__(self) -> str:\n        return f'{self.name}({len(self)})'\n"
  },
  {
    "path": "torch_geometric/datasets/twitch.py",
    "content": "import os.path as osp\nfrom typing import Callable, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\n\n\nclass Twitch(InMemoryDataset):\n    r\"\"\"The Twitch Gamer networks introduced in the\n    `\"Multi-scale Attributed Node Embedding\"\n    <https://arxiv.org/abs/1909.13021>`_ paper.\n    Nodes represent gamers on Twitch and edges are followerships between them.\n    Node features represent embeddings of games played by the Twitch users.\n    The task is to predict whether a user streams mature content.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"DE\"`, :obj:`\"EN\"`,\n            :obj:`\"ES\"`, :obj:`\"FR\"`, :obj:`\"PT\"`, :obj:`\"RU\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - DE\n          - 9,498\n          - 315,774\n          - 128\n          - 2\n        * - EN\n          - 7,126\n          - 77,774\n          - 128\n          - 2\n        * - ES\n          - 4,648\n          - 123,412\n          - 128\n          - 2\n        * - FR\n          - 6,551\n          - 231,883\n          - 128\n          - 2\n        * - PT\n          - 1,912\n          - 64,510\n          - 128\n          - 2\n        * - RU\n          - 4,385\n          - 78,993\n          - 128\n          - 2\n    \"\"\"\n\n    url = 'https://graphmining.ai/datasets/ptg/twitch'\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name\n        assert self.name in ['DE', 'EN', 'ES', 'FR', 'PT', 'RU']\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> str:\n        return f'{self.name}.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_url(f'{self.url}/{self.name}.npz', self.raw_dir)\n\n    def process(self) -> None:\n        data = np.load(self.raw_paths[0], 'r', allow_pickle=True)\n        x = torch.from_numpy(data['features']).to(torch.float)\n        y = torch.from_numpy(data['target']).to(torch.long)\n\n        edge_index = torch.from_numpy(data['edges']).to(torch.long)\n        edge_index = edge_index.t().contiguous()\n\n        data = Data(x=x, y=y, edge_index=edge_index)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/upfd.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_google_url,\n    extract_zip,\n)\nfrom torch_geometric.io import read_txt_array\nfrom torch_geometric.utils import coalesce, cumsum\n\n\nclass UPFD(InMemoryDataset):\n    r\"\"\"The tree-structured fake news propagation graph classification dataset\n    from the `\"User Preference-aware Fake News Detection\"\n    <https://arxiv.org/abs/2104.12259>`_ paper.\n    It includes two sets of tree-structured fake & real news propagation graphs\n    extracted from Twitter.\n    For a single graph, the root node represents the source news, and leaf\n    nodes represent Twitter users who retweeted the same root news.\n    A user node has an edge to the news node if and only if the user retweeted\n    the root news directly.\n    Two user nodes have an edge if and only if one user retweeted the root news\n    from the other user.\n    Four different node features are encoded using different encoders.\n    Please refer to `GNN-FakeNews\n    <https://github.com/safe-graph/GNN-FakeNews>`_ repo for more details.\n\n    .. note::\n\n        For an example of using UPFD, see `examples/upfd.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        upfd.py>`_.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the graph set (:obj:`\"politifact\"`,\n            :obj:`\"gossipcop\"`).\n        feature (str): The node feature type (:obj:`\"profile\"`, :obj:`\"spacy\"`,\n            :obj:`\"bert\"`, :obj:`\"content\"`).\n            If set to :obj:`\"profile\"`, the 10-dimensional node feature\n            is composed of ten Twitter user profile attributes.\n            If set to :obj:`\"spacy\"`, the 300-dimensional node feature is\n            composed of Twitter user historical tweets encoded by\n            the `spaCy word2vec encoder\n            <https://spacy.io/models/en#en_core_web_lg>`_.\n            If set to :obj:`\"bert\"`, the 768-dimensional node feature is\n            composed of Twitter user historical tweets encoded by the\n            `bert-as-service <https://github.com/hanxiao/bert-as-service>`_.\n            If set to :obj:`\"content\"`, the 310-dimensional node feature is\n            composed of a 300-dimensional \"spacy\" vector plus a\n            10-dimensional \"profile\" vector.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset.\n            (default: :obj:`\"train\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    file_ids = {\n        'politifact': '1toou2GO0agoY_OS54LaCWEECQfe93nuq',\n        'gossipcop': '1DkMAzC7XUUciAxsSujRJt3sq1MqaVI3g',\n    }\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        feature: str,\n        split: str = \"train\",\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        assert name in ['politifact', 'gossipcop']\n        assert split in ['train', 'val', 'test']\n\n        self.root = root\n        self.name = name\n        self.feature = feature\n\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n\n        path = self.processed_paths[['train', 'val', 'test'].index(split)]\n        self.load(path)\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed', self.feature)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'node_graph_id.npy', 'graph_labels.npy', 'A.txt', 'train_idx.npy',\n            'val_idx.npy', 'test_idx.npy', f'new_{self.feature}_feature.npz'\n        ]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train.pt', 'val.pt', 'test.pt']\n\n    def download(self) -> None:\n        id = self.file_ids[self.name]\n        path = download_google_url(id, self.raw_dir, 'data.zip')\n        extract_zip(path, self.raw_dir)\n        os.remove(path)\n\n    def process(self) -> None:\n        import scipy.sparse as sp\n\n        x = sp.load_npz(\n            osp.join(self.raw_dir, f'new_{self.feature}_feature.npz'))\n        x = torch.from_numpy(x.todense()).to(torch.float)\n\n        edge_index = read_txt_array(osp.join(self.raw_dir, 'A.txt'), sep=',',\n                                    dtype=torch.long).t()\n        edge_index = coalesce(edge_index, num_nodes=x.size(0))\n\n        y = np.load(osp.join(self.raw_dir, 'graph_labels.npy'))\n        y = torch.from_numpy(y).to(torch.long)\n        _, y = y.unique(sorted=True, return_inverse=True)\n\n        batch = np.load(osp.join(self.raw_dir, 'node_graph_id.npy'))\n        batch = torch.from_numpy(batch).to(torch.long)\n\n        node_slice = cumsum(batch.bincount())\n        edge_slice = cumsum(batch[edge_index[0]].bincount())\n        graph_slice = torch.arange(y.size(0) + 1)\n        self.slices = {\n            'x': node_slice,\n            'edge_index': edge_slice,\n            'y': graph_slice\n        }\n\n        edge_index -= node_slice[batch[edge_index[0]]].view(1, -1)\n        self.data = Data(x=x, edge_index=edge_index, y=y)\n\n        for path, split in zip(self.processed_paths, ['train', 'val', 'test']):\n            idx = np.load(osp.join(self.raw_dir, f'{split}_idx.npy')).tolist()\n            data_list = [self.get(i) for i in idx]\n            if self.pre_filter is not None:\n                data_list = [d for d in data_list if self.pre_filter(d)]\n            if self.pre_transform is not None:\n                data_list = [self.pre_transform(d) for d in data_list]\n            self.save(data_list, path)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, name={self.name}, '\n                f'feature={self.feature})')\n"
  },
  {
    "path": "torch_geometric/datasets/utils/__init__.py",
    "content": "from .cheatsheet import paper_link, has_stats, get_stat, get_children, get_type\n\n__all__ = [\n    'paper_link',\n    'has_stats',\n    'get_stat',\n    'get_children',\n    'get_type',\n]\n"
  },
  {
    "path": "torch_geometric/datasets/utils/cheatsheet.py",
    "content": "import importlib\nimport inspect\nimport re\nfrom typing import Any, List, Optional\n\n\ndef paper_link(cls: str) -> Optional[str]:\n    cls = importlib.import_module('torch_geometric.datasets').__dict__[cls]\n    doc = inspect.getdoc(cls)\n    assert doc is not None\n    match = re.search('<.+?>', doc, flags=re.DOTALL)\n    return None if match is None else match.group().replace('\\n', ' ')[1:-1]\n\n\ndef get_stats_table(cls: str) -> str:\n    cls = importlib.import_module('torch_geometric.datasets').__dict__[cls]\n    doc = inspect.getdoc(cls)\n    assert doc is not None\n    match = re.search(r'\\*\\*STATS:\\*\\*\\n.*$', doc, flags=re.DOTALL)\n    return '' if match is None else match.group()\n\n\ndef has_stats(cls: str) -> bool:\n    return len(get_stats_table(cls)) > 0\n\n\ndef get_type(cls: str) -> str:\n    return 'Edge' if '-' in cls else 'Node'\n\n\ndef get_stat(cls: str, name: str, child: Optional[str] = None,\n             default: Any = None) -> str:\n    if child is None and len(get_children(cls)) > 0:\n        return ''\n\n    stats_table = get_stats_table(cls)\n\n    if len(stats_table) > 0:\n        stats_table = '\\n'.join(stats_table.split('\\n')[2:])\n\n    match = re.search(f'^.*- {name}', stats_table, flags=re.DOTALL)\n    if match is None:\n        return default\n\n    column = match.group().count(' -')\n\n    if child is not None:\n        child = child.replace('(', r'\\(').replace(')', r'\\)')\n        match = re.search(f'[*] - {child}\\n.*$', stats_table, flags=re.DOTALL)\n        assert match is not None\n        stats_row = match.group()\n    else:\n        stats_row = '*' + stats_table.split('*')[2]\n\n    return stats_row.split(' -')[column].split('\\n')[0].strip()\n\n\ndef get_children(cls: str) -> List[str]:\n    matches = re.findall('[*] -.*', get_stats_table(cls))\n    return [match[4:] for match in matches[1:]] if len(matches) > 2 else []\n"
  },
  {
    "path": "torch_geometric/datasets/web_qsp_dataset.py",
    "content": "# Code adapted from the G-Retriever paper: https://arxiv.org/abs/2402.07630\nimport gc\nimport os\nfrom itertools import chain\nfrom typing import Any, Dict, Iterator, List, Optional\n\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import InMemoryDataset\nfrom torch_geometric.llm.large_graph_indexer import (\n    EDGE_RELATION,\n    LargeGraphIndexer,\n    TripletLike,\n    get_features_for_triplets_groups,\n)\nfrom torch_geometric.llm.models import SentenceTransformer\nfrom torch_geometric.llm.utils.backend_utils import (\n    preprocess_triplet,\n    retrieval_via_pcst,\n)\n\n\nclass KGQABaseDataset(InMemoryDataset):\n    r\"\"\"Base class for the 2 KGQA datasets used in `\"Reasoning on Graphs:\n    Faithful and Interpretable Large Language Model Reasoning\"\n    <https://arxiv.org/pdf/2310.01061>`_ paper.\n\n    Args:\n        dataset_name (str): HuggingFace `dataset` name.\n        root (str): Root directory where the dataset should be saved.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset. (default: :obj:`\"train\"`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        verbose (bool, optional): Whether to print output. Defaults to False.\n        use_pcst (bool, optional): Whether to preprocess the dataset's graph\n            with PCST or return the full graphs. (default: :obj:`True`)\n        load_dataset_kwargs (dict, optional):\n            Keyword arguments for the `datasets.load_dataset` function.\n            (default: :obj:`{}`)\n        retrieval_kwargs (dict, optional):\n            Keyword arguments for the\n            `get_features_for_triplets_groups` function.\n            (default: :obj:`{}`)\n    \"\"\"\n    def __init__(\n        self,\n        dataset_name: str,\n        root: str,\n        split: str = \"train\",\n        force_reload: bool = False,\n        verbose: bool = False,\n        use_pcst: bool = True,\n        load_dataset_kwargs: Optional[Dict[str, Any]] = None,\n        retrieval_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> None:\n        self.split = split\n        self.dataset_name = dataset_name\n        self.use_pcst = use_pcst\n        self.load_dataset_kwargs = load_dataset_kwargs or {}\n        \"\"\"\n        NOTE: If running into memory issues,\n        try reducing this batch size for the LargeGraphIndexer\n        used to build our KG.\n        Example: self.retrieval_kwargs = {\"batch_size\": 64}\n        \"\"\"\n        self.retrieval_kwargs = retrieval_kwargs or {}\n\n        # Caching custom subsets of the dataset results in unsupported behavior\n        if 'split' in self.load_dataset_kwargs:\n            print(\"WARNING: Caching custom subsets of the dataset \\\n                results in unsupported behavior.\\\n                Please specify a separate root directory for each split,\\\n                or set force_reload=True on subsequent instantiations\\\n                of the dataset.\")\n\n        self.required_splits = ['train', 'validation', 'test']\n\n        self.verbose = verbose\n        self.force_reload = force_reload\n        super().__init__(root, force_reload=force_reload)\n        \"\"\"\n        NOTE: Current behavior is to process the entire dataset,\n        and only return the split specified by the user.\n        \"\"\"\n        if f'{split}_data.pt' not in set(self.processed_file_names):\n            raise ValueError(f\"Invalid 'split' argument (got {split})\")\n        if split == 'val':\n            split = 'validation'\n\n        self.load(self.processed_paths[self.required_splits.index(split)])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\"raw.pt\"]\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return [\"train_data.pt\", \"val_data.pt\", \"test_data.pt\"]\n\n    def download(self) -> None:\n        import datasets\n\n        # HF Load Dataset by dataset name if no path is specified\n        self.load_dataset_kwargs['path'] = self.load_dataset_kwargs.get(\n            'path', self.dataset_name)\n        raw_dataset = datasets.load_dataset(**self.load_dataset_kwargs)\n\n        # Assert that the dataset contains the required splits\n        assert all(split in raw_dataset for split in self.required_splits), \\\n            f\"Dataset '{self.dataset_name}' is missing required splits: \\\n            {self.required_splits}\"\n\n        raw_dataset.save_to_disk(self.raw_paths[0])\n\n    def _get_trips(self) -> Iterator[TripletLike]:\n        # Iterate over each element's graph in each split of the dataset\n        # Using chain to lazily iterate without storing all trips in memory\n        split_iterators = []\n\n        for split in self.required_splits:\n            # Create an iterator for each element's graph in the current split\n            split_graphs = (element['graph']\n                            for element in self.raw_dataset[split])\n            split_iterators.append(chain.from_iterable(split_graphs))\n\n        # Chain all split iterators together\n        return chain.from_iterable(split_iterators)\n\n    def _build_graph(self) -> None:\n        print(\"Encoding graph...\")\n        trips = self._get_trips()\n        self.indexer: LargeGraphIndexer = LargeGraphIndexer.from_triplets(\n            trips, pre_transform=preprocess_triplet)\n\n        # Nodes:\n        print(\"\\tEncoding nodes...\")\n        nodes = self.indexer.get_unique_node_features()\n        x = self.model.encode(nodes, batch_size=256, output_device='cpu')\n        self.indexer.add_node_feature(new_feature_name=\"x\", new_feature_vals=x)\n\n        # Edges:\n        print(\"\\tEncoding edges...\")\n        edges = self.indexer.get_unique_edge_features(\n            feature_name=EDGE_RELATION)\n        edge_attr = self.model.encode(edges, batch_size=256,\n                                      output_device='cpu')\n        self.indexer.add_edge_feature(\n            new_feature_name=\"edge_attr\",\n            new_feature_vals=edge_attr,\n            map_from_feature=EDGE_RELATION,\n        )\n\n        print(\"\\tSaving graph...\")\n        self.indexer.save(self.indexer_path)\n\n    def _retrieve_subgraphs(self) -> None:\n        raw_splits = [\n            self.raw_dataset[split] for split in self.required_splits\n        ]\n        zipped = zip(\n            self.required_splits,\n            raw_splits,  # noqa\n            self.processed_paths,\n        )\n        for split_name, dataset, path in zipped:\n            print(f\"Processing {split_name} split...\")\n\n            print(\"\\tEncoding questions...\")\n            split_questions = [str(element['question']) for element in dataset]\n            split_q_embs = self.model.encode(split_questions, batch_size=256,\n                                             output_device='cpu')\n\n            print(\"\\tRetrieving subgraphs...\")\n            results_graphs = []\n            retrieval_kwargs = {\n                **self.retrieval_kwargs,\n                **{\n                    'pre_transform': preprocess_triplet,\n                    'verbose': self.verbose,\n                }\n            }\n            graph_gen = get_features_for_triplets_groups(\n                self.indexer, (element['graph'] for element in dataset),\n                **retrieval_kwargs)\n\n            for index in tqdm(range(len(dataset)), disable=not self.verbose):\n                data_i = dataset[index]\n                graph = next(graph_gen)\n                textual_nodes = self.textual_nodes.iloc[\n                    graph[\"node_idx\"]].reset_index()\n                textual_edges = self.textual_edges.iloc[\n                    graph[\"edge_idx\"]].reset_index()\n                if self.use_pcst and len(textual_nodes) > 0 and len(\n                        textual_edges) > 0:\n                    subgraph, desc = retrieval_via_pcst(\n                        graph,\n                        split_q_embs[index],\n                        textual_nodes,\n                        textual_edges,\n                    )\n                else:\n                    desc = textual_nodes.to_csv(\n                        index=False) + \"\\n\" + textual_edges.to_csv(\n                            index=False,\n                            columns=[\"src\", \"edge_attr\", \"dst\"],\n                        )\n                    subgraph = graph\n                question = f\"Question: {data_i['question']}\\nAnswer: \"\n                label = (\"|\").join(data_i[\"answer\"]).lower()\n\n                subgraph[\"question\"] = question\n                subgraph[\"label\"] = label\n                subgraph[\"desc\"] = desc\n                results_graphs.append(subgraph.to(\"cpu\"))\n            print(\"\\tSaving subgraphs...\")\n            self.save(results_graphs, path)\n\n    def process(self) -> None:\n        import datasets\n        from pandas import DataFrame\n        self.raw_dataset = datasets.load_from_disk(self.raw_paths[0])\n\n        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n        model_name = 'sentence-transformers/all-roberta-large-v1'\n        self.model: SentenceTransformer = SentenceTransformer(model_name).to(\n            device)\n        self.model.eval()\n        self.indexer_path = os.path.join(self.processed_dir,\n                                         \"large_graph_indexer\")\n        if self.force_reload or not os.path.exists(self.indexer_path):\n            self._build_graph()\n        else:\n            print(\"Loading graph...\")\n            self.indexer = LargeGraphIndexer.from_disk(self.indexer_path)\n        self.textual_nodes = DataFrame.from_dict(\n            {\"node_attr\": self.indexer.get_node_features()})\n        self.textual_nodes[\"node_id\"] = self.textual_nodes.index\n        self.textual_nodes = self.textual_nodes[[\"node_id\", \"node_attr\"]]\n        self.textual_edges = DataFrame(self.indexer.get_edge_features(),\n                                       columns=[\"src\", \"edge_attr\", \"dst\"])\n        self.textual_edges[\"src\"] = [\n            self.indexer._nodes[h] for h in self.textual_edges[\"src\"]\n        ]\n        self.textual_edges[\"dst\"] = [\n            self.indexer._nodes[h] for h in self.textual_edges[\"dst\"]\n        ]\n        self._retrieve_subgraphs()\n\n        gc.collect()\n        torch.cuda.empty_cache()\n\n\nclass WebQSPDataset(KGQABaseDataset):\n    r\"\"\"The WebQuestionsSP dataset of the `\"The Value of Semantic Parse\n    Labeling for Knowledge Base Question Answering\"\n    <https://aclanthology.org/P16-2033/>`_ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset. (default: :obj:`\"train\"`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        verbose (bool, optional): Whether to print output. Defaults to False.\n        use_pcst (bool, optional): Whether to preprocess the dataset's graph\n            with PCST or return the full graphs. (default: :obj:`True`)\n        load_dataset_kwargs (dict, optional):\n            Keyword arguments for the `datasets.load_dataset` function.\n            (default: :obj:`{}`)\n        retrieval_kwargs (dict, optional):\n            Keyword arguments for the\n            `get_features_for_triplets_groups` function.\n            (default: :obj:`{}`)\n    \"\"\"\n    def __init__(\n        self,\n        root: str,\n        split: str = \"train\",\n        force_reload: bool = False,\n        verbose: bool = False,\n        use_pcst: bool = True,\n        load_dataset_kwargs: Optional[Dict[str, Any]] = None,\n        retrieval_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> None:\n        load_dataset_kwargs = load_dataset_kwargs or {}\n        retrieval_kwargs = retrieval_kwargs or {}\n        # Modify these paramters if running into memory/compute issues\n        default_retrieval_kwargs = {\n            'max_batch_size': 250,  # Lower batch size to reduce memory usage\n            'num_workers':\n            None,  # Use all available workers, or set to number of threads\n        }\n        retrieval_kwargs = {**default_retrieval_kwargs, **retrieval_kwargs}\n        dataset_name = 'rmanluo/RoG-webqsp'\n        super().__init__(dataset_name, root, split, force_reload, verbose,\n                         use_pcst, load_dataset_kwargs=load_dataset_kwargs,\n                         retrieval_kwargs=retrieval_kwargs)\n\n\nclass CWQDataset(KGQABaseDataset):\n    r\"\"\"The ComplexWebQuestions (CWQ) dataset of the `\"The Web as a\n    Knowledge-base forAnswering Complex Questions\"\n    <https://arxiv.org/pdf/1803.06643>`_ paper.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset. (default: :obj:`\"train\"`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        verbose (bool, optional): Whether to print output. Defaults to False.\n        use_pcst (bool, optional): Whether to preprocess the dataset's graph\n            with PCST or return the full graphs. (default: :obj:`True`)\n        load_dataset_kwargs (dict, optional):\n            Keyword arguments for the `datasets.load_dataset` function.\n            (default: :obj:`{}`)\n        retrieval_kwargs (dict, optional):\n            Keyword arguments for the\n            `get_features_for_triplets_groups` function.\n            (default: :obj:`{}`)\n    \"\"\"\n    def __init__(\n        self,\n        root: str,\n        split: str = \"train\",\n        force_reload: bool = False,\n        verbose: bool = False,\n        use_pcst: bool = True,\n        load_dataset_kwargs: Optional[Dict[str, Any]] = None,\n        retrieval_kwargs: Optional[Dict[str, Any]] = None,\n    ) -> None:\n        load_dataset_kwargs = load_dataset_kwargs or {}\n        retrieval_kwargs = retrieval_kwargs or {}\n        dataset_name = 'rmanluo/RoG-cwq'\n        super().__init__(dataset_name, root, split, force_reload, verbose,\n                         use_pcst, load_dataset_kwargs=load_dataset_kwargs,\n                         retrieval_kwargs=retrieval_kwargs)\n"
  },
  {
    "path": "torch_geometric/datasets/webkb.py",
    "content": "import os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.utils import coalesce\n\n\nclass WebKB(InMemoryDataset):\n    r\"\"\"The WebKB datasets used in the\n    `\"Geom-GCN: Geometric Graph Convolutional Networks\"\n    <https://openreview.net/forum?id=S1e2agrFvS>`_ paper.\n    Nodes represent web pages and edges represent hyperlinks between them.\n    Node features are the bag-of-words representation of web pages.\n    The task is to classify the nodes into one of the five categories, student,\n    project, course, staff, and faculty.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"Cornell\"`, :obj:`\"Texas\"`,\n            :obj:`\"Wisconsin\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - Cornell\n          - 183\n          - 298\n          - 1,703\n          - 5\n        * - Texas\n          - 183\n          - 325\n          - 1,703\n          - 5\n        * - Wisconsin\n          - 251\n          - 515\n          - 1,703\n          - 5\n    \"\"\"\n\n    url = 'https://raw.githubusercontent.com/graphdml-uiuc-jlu/geom-gcn/master'\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        assert self.name in ['cornell', 'texas', 'wisconsin']\n\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        out = ['out1_node_feature_label.txt', 'out1_graph_edges.txt']\n        out += [f'{self.name}_split_0.6_0.2_{i}.npz' for i in range(10)]\n        return out\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for f in self.raw_file_names[:2]:\n            download_url(f'{self.url}/new_data/{self.name}/{f}', self.raw_dir)\n        for f in self.raw_file_names[2:]:\n            download_url(f'{self.url}/splits/{f}', self.raw_dir)\n\n    def process(self) -> None:\n        with open(self.raw_paths[0]) as f:\n            lines = f.read().split('\\n')[1:-1]\n            xs = [[float(value) for value in line.split('\\t')[1].split(',')]\n                  for line in lines]\n            x = torch.tensor(xs, dtype=torch.float)\n\n            ys = [int(line.split('\\t')[2]) for line in lines]\n            y = torch.tensor(ys, dtype=torch.long)\n\n        with open(self.raw_paths[1]) as f:\n            lines = f.read().split('\\n')[1:-1]\n            edge_indices = [[int(value) for value in line.split('\\t')]\n                            for line in lines]\n            edge_index = torch.tensor(edge_indices).t().contiguous()\n            edge_index = coalesce(edge_index, num_nodes=x.size(0))\n\n        train_masks, val_masks, test_masks = [], [], []\n        for path in self.raw_paths[2:]:\n            tmp = np.load(path)\n            train_masks += [torch.from_numpy(tmp['train_mask']).to(torch.bool)]\n            val_masks += [torch.from_numpy(tmp['val_mask']).to(torch.bool)]\n            test_masks += [torch.from_numpy(tmp['test_mask']).to(torch.bool)]\n        train_mask = torch.stack(train_masks, dim=1)\n        val_mask = torch.stack(val_masks, dim=1)\n        test_mask = torch.stack(test_masks, dim=1)\n\n        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,\n                    val_mask=val_mask, test_mask=test_mask)\n        data = data if self.pre_transform is None else self.pre_transform(data)\n        self.save([data], self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return f'{self.name}()'\n"
  },
  {
    "path": "torch_geometric/datasets/wikics.py",
    "content": "import json\nimport warnings\nfrom itertools import chain\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.utils import to_undirected\n\n\nclass WikiCS(InMemoryDataset):\n    r\"\"\"The semi-supervised Wikipedia-based dataset from the\n    `\"Wiki-CS: A Wikipedia-Based Benchmark for Graph Neural Networks\"\n    <https://arxiv.org/abs/2007.02901>`_ paper, containing 11,701 nodes,\n    216,123 edges, 10 classes and 20 different training splits.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        is_undirected (bool, optional): Whether the graph is undirected.\n            (default: :obj:`True`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = 'https://github.com/pmernyei/wiki-cs-dataset/raw/master/dataset'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        is_undirected: Optional[bool] = None,\n        force_reload: bool = False,\n    ) -> None:\n        if is_undirected is None:\n            warnings.warn(\n                f\"The {self.__class__.__name__} dataset now returns an \"\n                f\"undirected graph by default. Please explicitly specify \"\n                f\"'is_undirected=False' to restore the old behavior.\",\n                stacklevel=2)\n            is_undirected = True\n        self.is_undirected = is_undirected\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['data.json']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data_undirected.pt' if self.is_undirected else 'data.pt'\n\n    def download(self) -> None:\n        for name in self.raw_file_names:\n            download_url(f'{self.url}/{name}', self.raw_dir)\n\n    def process(self) -> None:\n        with open(self.raw_paths[0]) as f:\n            data = json.load(f)\n\n        x = torch.tensor(data['features'], dtype=torch.float)\n        y = torch.tensor(data['labels'], dtype=torch.long)\n\n        edges = [[(i, j) for j in js] for i, js in enumerate(data['links'])]\n        edges = list(chain(*edges))  # type: ignore\n        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()\n        if self.is_undirected:\n            edge_index = to_undirected(edge_index, num_nodes=x.size(0))\n\n        train_mask = torch.tensor(data['train_masks'], dtype=torch.bool)\n        train_mask = train_mask.t().contiguous()\n\n        val_mask = torch.tensor(data['val_masks'], dtype=torch.bool)\n        val_mask = val_mask.t().contiguous()\n\n        test_mask = torch.tensor(data['test_mask'], dtype=torch.bool)\n\n        stopping_mask = torch.tensor(data['stopping_masks'], dtype=torch.bool)\n        stopping_mask = stopping_mask.t().contiguous()\n\n        data = Data(x=x, y=y, edge_index=edge_index, train_mask=train_mask,\n                    val_mask=val_mask, test_mask=test_mask,\n                    stopping_mask=stopping_mask)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/wikidata.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Callable, Dict, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_tar,\n)\nfrom torch_geometric.io import fs\n\n\nclass Wikidata5M(InMemoryDataset):\n    r\"\"\"The Wikidata-5M dataset from the `\"KEPLER: A Unified Model for\n    Knowledge Embedding and Pre-trained Language Representation\"\n    <https://arxiv.org/abs/1911.06136>`_ paper,\n    containing 4,594,485 entities, 822 relations,\n    20,614,279 train triples, 5,163 validation triples, and 5,133 test triples.\n\n    `Wikidata-5M <https://deepgraphlearning.github.io/project/wikidata5m>`_\n    is a large-scale knowledge graph dataset with aligned corpus\n    extracted form Wikidata.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        setting (str, optional):\n            If :obj:`\"transductive\"`, loads the transductive dataset.\n            If :obj:`\"inductive\"`, loads the inductive dataset.\n            (default: :obj:`\"transductive\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        root: str,\n        setting: str = 'transductive',\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        if setting not in {'transductive', 'inductive'}:\n            raise ValueError(f\"Invalid 'setting' argument (got '{setting}')\")\n\n        self.setting = setting\n\n        self.urls = [\n            ('https://www.dropbox.com/s/7jp4ib8zo3i6m10/'\n             'wikidata5m_text.txt.gz?dl=1'),\n            'https://uni-bielefeld.sciebo.de/s/yuBKzBxsEc9j3hy/download',\n        ]\n        if self.setting == 'inductive':\n            self.urls.append('https://www.dropbox.com/s/csed3cgal3m7rzo/'\n                             'wikidata5m_inductive.tar.gz?dl=1')\n        else:\n            self.urls.append('https://www.dropbox.com/s/6sbhm0rwo4l73jq/'\n                             'wikidata5m_transductive.tar.gz?dl=1')\n\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'wikidata5m_text.txt.gz',\n            'download',\n            f'wikidata5m_{self.setting}_train.txt',\n            f'wikidata5m_{self.setting}_valid.txt',\n            f'wikidata5m_{self.setting}_test.txt',\n        ]\n\n    @property\n    def processed_file_names(self) -> str:\n        return f'{self.setting}_data.pt'\n\n    def download(self) -> None:\n        for url in self.urls:\n            download_url(url, self.raw_dir)\n        path = osp.join(self.raw_dir, f'wikidata5m_{self.setting}.tar.gz')\n        extract_tar(path, self.raw_dir)\n        os.remove(path)\n\n    def process(self) -> None:\n        import gzip\n\n        entity_to_id: Dict[str, int] = {}\n        with gzip.open(self.raw_paths[0], 'rt') as f:\n            for i, line in enumerate(f):\n                values = line.strip().split('\\t')\n                entity_to_id[values[0]] = i\n\n        x = fs.torch_load(self.raw_paths[1])\n\n        edge_indices = []\n        edge_types = []\n        split_indices = []\n\n        rel_to_id: Dict[str, int] = {}\n        for split, path in enumerate(self.raw_paths[2:]):\n            with open(path) as f:\n                for line in f:\n                    head, rel, tail = line[:-1].split('\\t')\n                    src = entity_to_id[head]\n                    dst = entity_to_id[tail]\n                    edge_indices.append([src, dst])\n                    if rel not in rel_to_id:\n                        rel_to_id[rel] = len(rel_to_id)\n                    edge_types.append(rel_to_id[rel])\n                    split_indices.append(split)\n\n        edge_index = torch.tensor(edge_indices).t().contiguous()\n        edge_type = torch.tensor(edge_types)\n        split_index = torch.tensor(split_indices)\n\n        data = Data(\n            x=x,\n            edge_index=edge_index,\n            edge_type=edge_type,\n            train_mask=split_index == 0,\n            val_mask=split_index == 1,\n            test_mask=split_index == 2,\n        )\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/wikipedia_network.py",
    "content": "import os.path as osp\nfrom typing import Callable, List, Optional, Union\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.utils import coalesce\n\n\nclass WikipediaNetwork(InMemoryDataset):\n    r\"\"\"The Wikipedia networks introduced in the\n    `\"Multi-scale Attributed Node Embedding\"\n    <https://arxiv.org/abs/1909.13021>`_ paper.\n    Nodes represent web pages and edges represent hyperlinks between them.\n    Node features represent several informative nouns in the Wikipedia pages.\n    The task is to predict the average daily traffic of the web page.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        name (str): The name of the dataset (:obj:`\"chameleon\"`,\n            :obj:`\"crocodile\"`, :obj:`\"squirrel\"`).\n        geom_gcn_preprocess (bool): If set to :obj:`True`, will load the\n            pre-processed data as introduced in the `\"Geom-GCN: Geometric\n            Graph Convolutional Networks\" <https://arxiv.org/abs/2002.05287>_`,\n            in which the average monthly traffic of the web page is converted\n            into five categories to predict.\n            If set to :obj:`True`, the dataset :obj:`\"crocodile\"` is not\n            available.\n            If set to :obj:`True`, train/validation/test splits will be\n            available as masks for multiple splits with shape\n            :obj:`[num_nodes, num_splits]`. (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    \"\"\"\n\n    raw_url = 'https://graphmining.ai/datasets/ptg/wiki'\n    processed_url = ('https://raw.githubusercontent.com/graphdml-uiuc-jlu/'\n                     'geom-gcn/f1fc0d14b3b019c562737240d06ec83b07d16a8f')\n\n    def __init__(\n        self,\n        root: str,\n        name: str,\n        geom_gcn_preprocess: bool = True,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.name = name.lower()\n        self.geom_gcn_preprocess = geom_gcn_preprocess\n        assert self.name in ['chameleon', 'crocodile', 'squirrel']\n        if geom_gcn_preprocess and self.name == 'crocodile':\n            raise AttributeError(\"The dataset 'crocodile' is not available in \"\n                                 \"case 'geom_gcn_preprocess=True'\")\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        if self.geom_gcn_preprocess:\n            return osp.join(self.root, self.name, 'geom_gcn', 'raw')\n        else:\n            return osp.join(self.root, self.name, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        if self.geom_gcn_preprocess:\n            return osp.join(self.root, self.name, 'geom_gcn', 'processed')\n        else:\n            return osp.join(self.root, self.name, 'processed')\n\n    @property\n    def raw_file_names(self) -> Union[List[str], str]:\n        if self.geom_gcn_preprocess:\n            return (['out1_node_feature_label.txt', 'out1_graph_edges.txt'] +\n                    [f'{self.name}_split_0.6_0.2_{i}.npz' for i in range(10)])\n        else:\n            return f'{self.name}.npz'\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        if self.geom_gcn_preprocess:\n            for filename in self.raw_file_names[:2]:\n                url = f'{self.processed_url}/new_data/{self.name}/{filename}'\n                download_url(url, self.raw_dir)\n            for filename in self.raw_file_names[2:]:\n                url = f'{self.processed_url}/splits/{filename}'\n                download_url(url, self.raw_dir)\n        else:\n            download_url(f'{self.raw_url}/{self.name}.npz', self.raw_dir)\n\n    def process(self) -> None:\n        if self.geom_gcn_preprocess:\n            with open(self.raw_paths[0]) as f:\n                lines = f.read().split('\\n')[1:-1]\n            xs = [[float(value) for value in line.split('\\t')[1].split(',')]\n                  for line in lines]\n            x = torch.tensor(xs, dtype=torch.float)\n            ys = [int(line.split('\\t')[2]) for line in lines]\n            y = torch.tensor(ys, dtype=torch.long)\n\n            with open(self.raw_paths[1]) as f:\n                lines = f.read().split('\\n')[1:-1]\n                edge_indices = [[int(value) for value in line.split('\\t')]\n                                for line in lines]\n            edge_index = torch.tensor(edge_indices).t().contiguous()\n            edge_index = coalesce(edge_index, num_nodes=x.size(0))\n\n            train_masks, val_masks, test_masks = [], [], []\n            for filepath in self.raw_paths[2:]:\n                masks = np.load(filepath)\n                train_masks += [torch.from_numpy(masks['train_mask'])]\n                val_masks += [torch.from_numpy(masks['val_mask'])]\n                test_masks += [torch.from_numpy(masks['test_mask'])]\n            train_mask = torch.stack(train_masks, dim=1).to(torch.bool)\n            val_mask = torch.stack(val_masks, dim=1).to(torch.bool)\n            test_mask = torch.stack(test_masks, dim=1).to(torch.bool)\n\n            data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,\n                        val_mask=val_mask, test_mask=test_mask)\n\n        else:\n            raw_data = np.load(self.raw_paths[0], 'r', allow_pickle=True)\n            x = torch.from_numpy(raw_data['features']).to(torch.float)\n            edge_index = torch.from_numpy(raw_data['edges']).to(torch.long)\n            edge_index = edge_index.t().contiguous()\n            edge_index = coalesce(edge_index, num_nodes=x.size(0))\n            y = torch.from_numpy(raw_data['target']).to(torch.float)\n\n            data = Data(x=x, edge_index=edge_index, y=y)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/willow_object_class.py",
    "content": "import glob\nimport os\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.utils.data import DataLoader\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass WILLOWObjectClass(InMemoryDataset):\n    r\"\"\"The WILLOW-ObjectClass dataset from the `\"Learning Graphs to Match\"\n    <https://www.di.ens.fr/willow/pdfscurrent/cho2013.pdf>`_ paper,\n    containing 10 equal keypoints of at least 40 images in each category.\n    The keypoints contain interpolated features from a pre-trained VGG16 model\n    on ImageNet (:obj:`relu4_2` and :obj:`relu5_1`).\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        category (str): The category of the images (one of :obj:`\"Car\"`,\n            :obj:`\"Duck\"`, :obj:`\"Face\"`, :obj:`\"Motorbike\"`,\n            :obj:`\"Winebottle\"`).\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n        device (str or torch.device, optional): The device to use for\n            processing the raw data. If set to :obj:`None`, will utilize\n            GPU-processing if available. (default: :obj:`None`)\n    \"\"\"\n    url = ('http://www.di.ens.fr/willow/research/graphlearning/'\n           'WILLOW-ObjectClass_dataset.zip')\n\n    categories = ['face', 'motorbike', 'car', 'duck', 'winebottle']\n\n    batch_size = 32\n\n    def __init__(\n        self,\n        root: str,\n        category: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n        device: Optional[str] = None,\n    ) -> None:\n        if device is None:\n            device = 'cuda' if torch.cuda.is_available() else 'cpu'\n\n        assert category.lower() in self.categories\n        self.category = category\n        self.device = device\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_dir(self) -> str:\n        return osp.join(self.root, 'raw')\n\n    @property\n    def processed_dir(self) -> str:\n        return osp.join(self.root, self.category.capitalize(), 'processed')\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [category.capitalize() for category in self.categories]\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        os.unlink(path)\n        os.unlink(osp.join(self.root, 'README'))\n        os.unlink(osp.join(self.root, 'demo_showAnno.m'))\n        fs.rm(self.raw_dir)\n        os.rename(osp.join(self.root, 'WILLOW-ObjectClass'), self.raw_dir)\n\n    def process(self) -> None:\n        import torchvision.models as models\n        import torchvision.transforms as T\n        from PIL import Image\n        from scipy.io import loadmat\n\n        category = self.category.capitalize()\n        names = glob.glob(osp.join(self.raw_dir, category, '*.png'))\n        names = sorted([name[:-4] for name in names])\n\n        vgg16_outputs = []\n\n        def hook(module: torch.nn.Module, x: Tensor, y: Tensor) -> None:\n            vgg16_outputs.append(y.to('cpu'))\n\n        vgg16 = models.vgg16(pretrained=True).to(self.device)\n        vgg16.eval()\n        vgg16.features[20].register_forward_hook(hook)  # relu4_2\n        vgg16.features[25].register_forward_hook(hook)  # relu5_1\n\n        transform = T.Compose([\n            T.ToTensor(),\n            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n        ])\n\n        data_list = []\n        for name in names:\n            pos = loadmat(f'{name}.mat')['pts_coord']\n            x, y = torch.from_numpy(pos).to(torch.float)\n            pos = torch.stack([x, y], dim=1)\n\n            # The \"face\" category contains a single image with less than 10\n            # keypoints, so we need to skip it.\n            if pos.size(0) != 10:\n                continue\n\n            with open(f'{name}.png', 'rb') as f:\n                img = Image.open(f).convert('RGB')\n\n            # Rescale keypoints.\n            pos[:, 0] = pos[:, 0] * 256.0 / (img.size[0])\n            pos[:, 1] = pos[:, 1] * 256.0 / (img.size[1])\n\n            img = img.resize((256, 256), resample=Image.Resampling.BICUBIC)\n            img = transform(img)\n\n            data = Data(img=img, pos=pos, name=name)\n            data_list.append(data)\n\n        imgs = [data.img for data in data_list]\n        loader = DataLoader(\n            dataset=imgs,  # type: ignore\n            batch_size=self.batch_size,\n            shuffle=False,\n        )\n        for i, batch_img in enumerate(loader):\n            vgg16_outputs.clear()\n\n            with torch.no_grad():\n                vgg16(batch_img.to(self.device))\n\n            out1 = F.interpolate(vgg16_outputs[0], (256, 256), mode='bilinear',\n                                 align_corners=False)\n            out2 = F.interpolate(vgg16_outputs[1], (256, 256), mode='bilinear',\n                                 align_corners=False)\n\n            for j in range(out1.size(0)):\n                data = data_list[i * self.batch_size + j]\n                assert data.pos is not None\n                idx = data.pos.round().long().clamp(0, 255)\n                x_1 = out1[j, :, idx[:, 1], idx[:, 0]].to('cpu')\n                x_2 = out2[j, :, idx[:, 1], idx[:, 0]].to('cpu')\n                data.img = None\n                data.x = torch.cat([x_1.t(), x_2.t()], dim=-1)\n            del out1\n            del out2\n\n        if self.pre_filter is not None:\n            data_list = [data for data in data_list if self.pre_filter(data)]\n\n        if self.pre_transform is not None:\n            data_list = [self.pre_transform(data) for data in data_list]\n\n        self.save(data_list, self.processed_paths[0])\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({len(self)}, '\n                f'category={self.category})')\n"
  },
  {
    "path": "torch_geometric/datasets/word_net.py",
    "content": "from itertools import chain\nfrom typing import Callable, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_url\nfrom torch_geometric.utils import index_sort\n\n\nclass WordNet18(InMemoryDataset):\n    r\"\"\"The WordNet18 dataset from the `\"Translating Embeddings for Modeling\n    Multi-Relational Data\"\n    <https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling\n    -multi-relational-data>`_ paper,\n    containing 40,943 entities, 18 relations and 151,442 fact triplets,\n    *e.g.*, furniture includes bed.\n\n    .. note::\n\n        The original :obj:`WordNet18` dataset suffers from test leakage, *i.e.*\n        more than 80% of test triplets can be found in the training set with\n        another relation type.\n        Therefore, it should not be used for research evaluation anymore.\n        We recommend to use its cleaned version\n        :class:`~torch_geometric.datasets.WordNet18RR` instead.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = ('https://raw.githubusercontent.com/villmow/'\n           'datasets_knowledge_embedding/master/WN18/original')\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['train.txt', 'valid.txt', 'test.txt']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for filename in self.raw_file_names:\n            download_url(f'{self.url}/{filename}', self.raw_dir)\n\n    def process(self) -> None:\n        srcs, dsts, edge_types = [], [], []\n        for path in self.raw_paths:\n            with open(path) as f:\n                edges = [int(x) for x in f.read().split()[1:]]\n                edge = torch.tensor(edges, dtype=torch.long)\n                srcs.append(edge[::3])\n                dsts.append(edge[1::3])\n                edge_types.append(edge[2::3])\n\n        src = torch.cat(srcs, dim=0)\n        dst = torch.cat(dsts, dim=0)\n        edge_type = torch.cat(edge_types, dim=0)\n\n        train_mask = torch.zeros(src.size(0), dtype=torch.bool)\n        train_mask[:srcs[0].size(0)] = True\n        val_mask = torch.zeros(src.size(0), dtype=torch.bool)\n        val_mask[srcs[0].size(0):srcs[0].size(0) + srcs[1].size(0)] = True\n        test_mask = torch.zeros(src.size(0), dtype=torch.bool)\n        test_mask[srcs[0].size(0) + srcs[1].size(0):] = True\n\n        num_nodes = max(int(src.max()), int(dst.max())) + 1\n        _, perm = index_sort(num_nodes * src + dst)\n\n        edge_index = torch.stack([src[perm], dst[perm]], dim=0)\n        edge_type = edge_type[perm]\n        train_mask = train_mask[perm]\n        val_mask = val_mask[perm]\n        test_mask = test_mask[perm]\n\n        data = Data(\n            edge_index=edge_index,\n            edge_type=edge_type,\n            train_mask=train_mask,\n            val_mask=val_mask,\n            test_mask=test_mask,\n            num_nodes=num_nodes,\n        )\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n\n\nclass WordNet18RR(InMemoryDataset):\n    r\"\"\"The WordNet18RR dataset from the `\"Convolutional 2D Knowledge Graph\n    Embeddings\" <https://arxiv.org/abs/1707.01476>`_ paper, containing 40,943\n    entities, 11 relations and 93,003 fact triplets.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n    \"\"\"\n\n    url = ('https://raw.githubusercontent.com/villmow/'\n           'datasets_knowledge_embedding/master/WN18RR/original')\n\n    edge2id = {\n        '_also_see': 0,\n        '_derivationally_related_form': 1,\n        '_has_part': 2,\n        '_hypernym': 3,\n        '_instance_hypernym': 4,\n        '_member_meronym': 5,\n        '_member_of_domain_region': 6,\n        '_member_of_domain_usage': 7,\n        '_similar_to': 8,\n        '_synset_domain_topic_of': 9,\n        '_verb_group': 10,\n    }\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['train.txt', 'valid.txt', 'test.txt']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        for filename in self.raw_file_names:\n            download_url(f'{self.url}/{filename}', self.raw_dir)\n\n    def process(self) -> None:\n        node2id, idx = {}, 0\n\n        srcs, dsts, edge_types = [], [], []\n        for path in self.raw_paths:\n            with open(path) as f:\n                edges = f.read().split()\n\n                _src = edges[::3]\n                _dst = edges[2::3]\n                _edge_type = edges[1::3]\n\n                for i in chain(_src, _dst):\n                    if i not in node2id:\n                        node2id[i] = idx\n                        idx += 1\n\n                srcs.append(torch.tensor([node2id[i] for i in _src]))\n                dsts.append(torch.tensor([node2id[i] for i in _dst]))\n                edge_types.append(\n                    torch.tensor([self.edge2id[i] for i in _edge_type]))\n\n        src = torch.cat(srcs, dim=0)\n        dst = torch.cat(dsts, dim=0)\n        edge_type = torch.cat(edge_types, dim=0)\n\n        train_mask = torch.zeros(src.size(0), dtype=torch.bool)\n        train_mask[:srcs[0].size(0)] = True\n        val_mask = torch.zeros(src.size(0), dtype=torch.bool)\n        val_mask[srcs[0].size(0):srcs[0].size(0) + srcs[1].size(0)] = True\n        test_mask = torch.zeros(src.size(0), dtype=torch.bool)\n        test_mask[srcs[0].size(0) + srcs[1].size(0):] = True\n\n        num_nodes = max(int(src.max()), int(dst.max())) + 1\n        _, perm = index_sort(num_nodes * src + dst)\n\n        edge_index = torch.stack([src[perm], dst[perm]], dim=0)\n        edge_type = edge_type[perm]\n        train_mask = train_mask[perm]\n        val_mask = val_mask[perm]\n        test_mask = test_mask[perm]\n\n        data = Data(edge_index=edge_index, edge_type=edge_type,\n                    train_mask=train_mask, val_mask=val_mask,\n                    test_mask=test_mask, num_nodes=num_nodes)\n\n        if self.pre_transform is not None:\n            data = self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/yelp.py",
    "content": "import json\nimport os.path as osp\nfrom typing import Callable, List, Optional\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data, InMemoryDataset, download_google_url\n\n\nclass Yelp(InMemoryDataset):\n    r\"\"\"The Yelp dataset from the `\"GraphSAINT: Graph Sampling Based\n    Inductive Learning Method\" <https://arxiv.org/abs/1907.04931>`_ paper,\n    containing customer reviewers and their friendship.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 10 10 10 10\n        :header-rows: 1\n\n        * - #nodes\n          - #edges\n          - #features\n          - #tasks\n        * - 716,847\n          - 13,954,819\n          - 300\n          - 100\n    \"\"\"\n    adj_full_id = '1Juwx8HtDwSzmVIJ31ooVa1WljI4U5JnA'\n    feats_id = '1Zy6BZH_zLEjKlEFSduKE5tV9qqA_8VtM'\n    class_map_id = '1VUcBGr0T0-klqerjAjxRmAqFuld_SMWU'\n    role_id = '1NI5pa5Chpd-52eSmLW60OnB3WS5ikxq_'\n\n    def __init__(\n        self,\n        root: str,\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        super().__init__(root, transform, pre_transform,\n                         force_reload=force_reload)\n        self.load(self.processed_paths[0])\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return ['adj_full.npz', 'feats.npy', 'class_map.json', 'role.json']\n\n    @property\n    def processed_file_names(self) -> str:\n        return 'data.pt'\n\n    def download(self) -> None:\n        download_google_url(self.adj_full_id, self.raw_dir, 'adj_full.npz')\n        download_google_url(self.feats_id, self.raw_dir, 'feats.npy')\n        download_google_url(self.class_map_id, self.raw_dir, 'class_map.json')\n        download_google_url(self.role_id, self.raw_dir, 'role.json')\n\n    def process(self) -> None:\n        import scipy.sparse as sp\n\n        f = np.load(osp.join(self.raw_dir, 'adj_full.npz'))\n        adj = sp.csr_matrix((f['data'], f['indices'], f['indptr']), f['shape'])\n        adj = adj.tocoo()\n        row = torch.from_numpy(adj.row).to(torch.long)\n        col = torch.from_numpy(adj.col).to(torch.long)\n        edge_index = torch.stack([row, col], dim=0)\n\n        x = np.load(osp.join(self.raw_dir, 'feats.npy'))\n        x = torch.from_numpy(x).to(torch.float)\n\n        ys = [-1] * x.size(0)\n        with open(osp.join(self.raw_dir, 'class_map.json')) as f:\n            class_map = json.load(f)\n            for key, item in class_map.items():\n                ys[int(key)] = item\n        y = torch.tensor(ys)\n\n        with open(osp.join(self.raw_dir, 'role.json')) as f:\n            role = json.load(f)\n\n        train_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        train_mask[torch.tensor(role['tr'])] = True\n\n        val_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        val_mask[torch.tensor(role['va'])] = True\n\n        test_mask = torch.zeros(x.size(0), dtype=torch.bool)\n        test_mask[torch.tensor(role['te'])] = True\n\n        data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask,\n                    val_mask=val_mask, test_mask=test_mask)\n\n        data = data if self.pre_transform is None else self.pre_transform(data)\n\n        self.save([data], self.processed_paths[0])\n"
  },
  {
    "path": "torch_geometric/datasets/zinc.py",
    "content": "import os\nimport os.path as osp\nimport pickle\nfrom typing import Callable, List, Optional\n\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import (\n    Data,\n    InMemoryDataset,\n    download_url,\n    extract_zip,\n)\nfrom torch_geometric.io import fs\n\n\nclass ZINC(InMemoryDataset):\n    r\"\"\"The ZINC dataset from the `ZINC database\n    <https://pubs.acs.org/doi/abs/10.1021/acs.jcim.5b00559>`_ and the\n    `\"Automatic Chemical Design Using a Data-Driven Continuous Representation\n    of Molecules\" <https://arxiv.org/abs/1610.02415>`_ paper, containing about\n    250,000 molecular graphs with up to 38 heavy atoms.\n    The task is to regress the penalized :obj:`logP` (also called constrained\n    solubility in some works), given by :obj:`y = logP - SAS - cycles`, where\n    :obj:`logP` is the water-octanol partition coefficient, :obj:`SAS` is the\n    synthetic accessibility score, and :obj:`cycles` denotes the number of\n    cycles with more than six atoms.\n    Penalized :obj:`logP` is a score commonly used for training molecular\n    generation models, see, *e.g.*, the\n    `\"Junction Tree Variational Autoencoder for Molecular Graph Generation\"\n    <https://proceedings.mlr.press/v80/jin18a.html>`_ and\n    `\"Grammar Variational Autoencoder\"\n    <https://proceedings.mlr.press/v70/kusner17a.html>`_ papers.\n\n    Args:\n        root (str): Root directory where the dataset should be saved.\n        subset (bool, optional): If set to :obj:`True`, will only load a\n            subset of the dataset (12,000 molecular graphs), following the\n            `\"Benchmarking Graph Neural Networks\"\n            <https://arxiv.org/abs/2003.00982>`_ paper. (default: :obj:`False`)\n        split (str, optional): If :obj:`\"train\"`, loads the training dataset.\n            If :obj:`\"val\"`, loads the validation dataset.\n            If :obj:`\"test\"`, loads the test dataset.\n            (default: :obj:`\"train\"`)\n        transform (callable, optional): A function/transform that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a transformed\n            version. The data object will be transformed before every access.\n            (default: :obj:`None`)\n        pre_transform (callable, optional): A function/transform that takes in\n            an :obj:`torch_geometric.data.Data` object and returns a\n            transformed version. The data object will be transformed before\n            being saved to disk. (default: :obj:`None`)\n        pre_filter (callable, optional): A function that takes in an\n            :obj:`torch_geometric.data.Data` object and returns a boolean\n            value, indicating whether the data object should be included in the\n            final dataset. (default: :obj:`None`)\n        force_reload (bool, optional): Whether to re-process the dataset.\n            (default: :obj:`False`)\n\n    **STATS:**\n\n    .. list-table::\n        :widths: 20 10 10 10 10 10\n        :header-rows: 1\n\n        * - Name\n          - #graphs\n          - #nodes\n          - #edges\n          - #features\n          - #classes\n        * - ZINC Full\n          - 249,456\n          - ~23.2\n          - ~49.8\n          - 1\n          - 1\n        * - ZINC Subset\n          - 12,000\n          - ~23.2\n          - ~49.8\n          - 1\n          - 1\n    \"\"\"\n\n    url = 'https://www.dropbox.com/s/feo9qle74kg48gy/molecules.zip?dl=1'\n    split_url = ('https://raw.githubusercontent.com/graphdeeplearning/'\n                 'benchmarking-gnns/master/data/molecules/{}.index')\n\n    def __init__(\n        self,\n        root: str,\n        subset: bool = False,\n        split: str = 'train',\n        transform: Optional[Callable] = None,\n        pre_transform: Optional[Callable] = None,\n        pre_filter: Optional[Callable] = None,\n        force_reload: bool = False,\n    ) -> None:\n        self.subset = subset\n        assert split in ['train', 'val', 'test']\n        super().__init__(root, transform, pre_transform, pre_filter,\n                         force_reload=force_reload)\n        path = osp.join(self.processed_dir, f'{split}.pt')\n        self.load(path)\n\n    @property\n    def raw_file_names(self) -> List[str]:\n        return [\n            'train.pickle', 'val.pickle', 'test.pickle', 'train.index',\n            'val.index', 'test.index'\n        ]\n\n    @property\n    def processed_dir(self) -> str:\n        name = 'subset' if self.subset else 'full'\n        return osp.join(self.root, name, 'processed')\n\n    @property\n    def processed_file_names(self) -> List[str]:\n        return ['train.pt', 'val.pt', 'test.pt']\n\n    def download(self) -> None:\n        fs.rm(self.raw_dir)\n        path = download_url(self.url, self.root)\n        extract_zip(path, self.root)\n        os.rename(osp.join(self.root, 'molecules'), self.raw_dir)\n        os.unlink(path)\n\n        for split in ['train', 'val', 'test']:\n            download_url(self.split_url.format(split), self.raw_dir)\n\n    def process(self) -> None:\n        for split in ['train', 'val', 'test']:\n            with open(osp.join(self.raw_dir, f'{split}.pickle'), 'rb') as f:\n                mols = pickle.load(f)\n\n            indices = list(range(len(mols)))\n\n            if self.subset:\n                with open(osp.join(self.raw_dir, f'{split}.index')) as f:\n                    indices = [int(x) for x in f.read()[:-1].split(',')]\n\n            pbar = tqdm(total=len(indices))\n            pbar.set_description(f'Processing {split} dataset')\n\n            data_list = []\n            for idx in indices:\n                mol = mols[idx]\n\n                x = mol['atom_type'].to(torch.long).view(-1, 1)\n                y = mol['logP_SA_cycle_normalized'].to(torch.float)\n\n                adj = mol['bond_type']\n                edge_index = adj.nonzero(as_tuple=False).t().contiguous()\n                edge_attr = adj[edge_index[0], edge_index[1]].to(torch.long)\n\n                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,\n                            y=y)\n\n                if self.pre_filter is not None and not self.pre_filter(data):\n                    continue\n\n                if self.pre_transform is not None:\n                    data = self.pre_transform(data)\n\n                data_list.append(data)\n                pbar.update(1)\n\n            pbar.close()\n\n            self.save(data_list, osp.join(self.processed_dir, f'{split}.pt'))\n"
  },
  {
    "path": "torch_geometric/debug.py",
    "content": "from typing import Any\n\n__debug_flag__ = {'enabled': False}\n\n\ndef is_debug_enabled() -> bool:\n    r\"\"\"Returns :obj:`True` if the debug mode is enabled.\"\"\"\n    return __debug_flag__['enabled']\n\n\ndef set_debug_enabled(mode: bool) -> None:\n    __debug_flag__['enabled'] = mode\n\n\nclass debug:\n    r\"\"\"Context-manager that enables the debug mode to help track down errors\n    and separate usage errors from real bugs.\n\n    .. code-block:: python\n\n        with torch_geometric.debug():\n            out = model(data.x, data.edge_index)\n    \"\"\"\n    def __init__(self) -> None:\n        self.prev = is_debug_enabled()\n\n    def __enter__(self) -> None:\n        set_debug_enabled(True)\n\n    def __exit__(self, *args: Any) -> None:\n        set_debug_enabled(self.prev)\n\n\nclass set_debug:\n    r\"\"\"Context-manager that sets the debug mode on or off.\n\n    :class:`set_debug` will enable or disable the debug mode based on its\n    argument :attr:`mode`.\n    It can be used as a context-manager or as a function.\n\n    See :class:`debug` above for more details.\n    \"\"\"\n    def __init__(self, mode: bool) -> None:\n        self.prev = is_debug_enabled()\n        set_debug_enabled(mode)\n\n    def __enter__(self) -> None:\n        pass\n\n    def __exit__(self, *args: Any) -> None:\n        set_debug_enabled(self.prev)\n"
  },
  {
    "path": "torch_geometric/deprecation.py",
    "content": "import functools\nimport inspect\nimport warnings\nfrom typing import Any, Callable, Optional\n\n\ndef deprecated(\n    details: Optional[str] = None,\n    func_name: Optional[str] = None,\n) -> Callable:\n    def decorator(func: Callable) -> Callable:\n        name = func_name or func.__name__\n\n        if inspect.isclass(func):\n            cls = type(func.__name__, (func, ), {})\n            cls.__init__ = deprecated(details, name)(  # type: ignore\n                func.__init__)\n            cls.__doc__ = func.__doc__\n            return cls\n\n        @functools.wraps(func)\n        def wrapper(*args: Any, **kwargs: Any) -> Any:\n            out = f\"'{name}' is deprecated\"\n            if details is not None:\n                out += f\", {details}\"\n            warnings.warn(out, stacklevel=2)\n            return func(*args, **kwargs)\n\n        return wrapper\n\n    return decorator\n"
  },
  {
    "path": "torch_geometric/device.py",
    "content": "from typing import Any\n\nimport torch\n\n\ndef is_mps_available() -> bool:\n    r\"\"\"Returns a bool indicating if MPS is currently available.\"\"\"\n    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():\n        try:  # Github CI may not have access to MPS hardware. Confirm:\n            torch.empty(1, device='mps')\n            return True\n        except Exception:\n            return False\n    return False\n\n\ndef is_xpu_available() -> bool:\n    r\"\"\"Returns a bool indicating if XPU is currently available.\"\"\"\n    if hasattr(torch, 'xpu') and torch.xpu.is_available():\n        return True\n    try:\n        import intel_extension_for_pytorch as ipex\n        return ipex.xpu.is_available()\n    except ImportError:\n        return False\n\n\ndef device(device: Any) -> torch.device:\n    r\"\"\"Returns a :class:`torch.device`.\n\n    If :obj:`\"auto\"` is specified, returns the optimal device depending on\n    available hardware.\n    \"\"\"\n    if device != 'auto':\n        return torch.device(device)\n    if torch.cuda.is_available():\n        return torch.device('cuda')\n    if is_mps_available():\n        return torch.device('mps')\n    if is_xpu_available():\n        return torch.device('xpu')\n    return torch.device('cpu')\n"
  },
  {
    "path": "torch_geometric/distributed/__init__.py",
    "content": "from .dist_context import DistContext\nfrom .local_feature_store import LocalFeatureStore\nfrom .local_graph_store import LocalGraphStore\nfrom .partition import Partitioner\nfrom .dist_neighbor_sampler import DistNeighborSampler\nfrom .dist_loader import DistLoader\nfrom .dist_neighbor_loader import DistNeighborLoader\nfrom .dist_link_neighbor_loader import DistLinkNeighborLoader\n\n__all__ = classes = [\n    'DistContext',\n    'LocalFeatureStore',\n    'LocalGraphStore',\n    'Partitioner',\n    'DistNeighborSampler',\n    'DistLoader',\n    'DistNeighborLoader',\n    'DistLinkNeighborLoader',\n]\n"
  },
  {
    "path": "torch_geometric/distributed/dist_context.py",
    "content": "from dataclasses import dataclass\nfrom enum import Enum\n\n\nclass DistRole(Enum):\n    WORKER = 1\n\n\n@dataclass\nclass DistContext:\n    r\"\"\"Context information of the current process.\"\"\"\n    rank: int\n    global_rank: int\n    world_size: int\n    global_world_size: int\n    group_name: str\n    role: DistRole = DistRole.WORKER\n\n    @property\n    def worker_name(self) -> str:\n        return f'{self.group_name}-{self.rank}'\n"
  },
  {
    "path": "torch_geometric/distributed/dist_link_neighbor_loader.py",
    "content": "from typing import Callable, Dict, List, Optional, Tuple, Union\nfrom warnings import warn\n\nimport torch\n\nfrom torch_geometric.distributed import (\n    DistContext,\n    DistLoader,\n    DistNeighborSampler,\n    LocalFeatureStore,\n    LocalGraphStore,\n)\nfrom torch_geometric.loader import LinkLoader\nfrom torch_geometric.sampler.base import NegativeSampling, SubgraphType\nfrom torch_geometric.typing import EdgeType, InputEdges, OptTensor\n\n\nclass DistLinkNeighborLoader(LinkLoader, DistLoader):\n    r\"\"\"A distributed loader that performs sampling from edges.\n\n    Args:\n        data (tuple): A (:class:`~torch_geometric.data.FeatureStore`,\n            :class:`~torch_geometric.data.GraphStore`) data object.\n        num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]):\n            The number of neighbors to sample for each node in each iteration.\n            If an entry is set to :obj:`-1`, all neighbors will be included.\n            In heterogeneous graphs, may also take in a dictionary denoting\n            the amount of neighbors to sample for each individual edge type.\n        master_addr (str): RPC address for distributed loader communication,\n            *i.e.* the IP address of the master node.\n        master_port (Union[int, str]): Open port for RPC communication with\n            the master node.\n        current_ctx (DistContext): Distributed context information of the\n            current process.\n        concurrency (int, optional): RPC concurrency used for defining the\n            maximum size of the asynchronous processing queue.\n            (default: :obj:`1`)\n\n    All other arguments follow the interface of\n    :class:`torch_geometric.loader.LinkNeighborLoader`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Tuple[LocalFeatureStore, LocalGraphStore],\n        num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],\n        master_addr: str,\n        master_port: Union[int, str],\n        current_ctx: DistContext,\n        edge_label_index: InputEdges = None,\n        edge_label: OptTensor = None,\n        edge_label_time: OptTensor = None,\n        dist_sampler: Optional[DistNeighborSampler] = None,\n        replace: bool = False,\n        subgraph_type: Union[SubgraphType, str] = \"directional\",\n        disjoint: bool = False,\n        temporal_strategy: str = \"uniform\",\n        neg_sampling: Optional[NegativeSampling] = None,\n        neg_sampling_ratio: Optional[Union[int, float]] = None,\n        time_attr: Optional[str] = None,\n        transform: Optional[Callable] = None,\n        concurrency: int = 1,\n        num_rpc_threads: int = 16,\n        filter_per_worker: Optional[bool] = False,\n        async_sampling: bool = True,\n        device: Optional[torch.device] = None,\n        **kwargs,\n    ):\n        assert isinstance(data[0], LocalFeatureStore)\n        assert isinstance(data[1], LocalGraphStore)\n        assert concurrency >= 1, \"RPC concurrency must be greater than 1\"\n\n        if (edge_label_time is not None) != (time_attr is not None):\n            raise ValueError(\n                f\"Received conflicting 'edge_label_time' and 'time_attr' \"\n                f\"arguments: 'edge_label_time' is \"\n                f\"{'set' if edge_label_time is not None else 'not set'} \"\n                f\"while 'time_attr' is \"\n                f\"{'set' if time_attr is not None else 'not set'}. \"\n                f\"Both arguments must be provided for temporal sampling.\")\n\n        channel = torch.multiprocessing.Queue() if async_sampling else None\n\n        if dist_sampler is None:\n            dist_sampler = DistNeighborSampler(\n                data=data,\n                current_ctx=current_ctx,\n                num_neighbors=num_neighbors,\n                replace=replace,\n                subgraph_type=subgraph_type,\n                disjoint=disjoint,\n                temporal_strategy=temporal_strategy,\n                time_attr=time_attr,\n                device=device,\n                channel=channel,\n                concurrency=concurrency,\n            )\n        else:\n            warn(  # noqa: B028\n                \"`torch_geometric.distributed` has been deprecated since 2.7.0 and will \"  # noqa: E501\n                \"no longer be maintained. For distributed training, refer to our \"  # noqa: E501\n                \"tutorials on distributed training at \"\n                \"https://pytorch-geometric.readthedocs.io/en/latest/tutorial/distributed.html \"  # noqa: E501\n                \"or cuGraph examples at \"\n                \"https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples\",  # noqa: E501\n                stack_level=2)\n\n        DistLoader.__init__(\n            self,\n            channel=channel,\n            master_addr=master_addr,\n            master_port=master_port,\n            current_ctx=current_ctx,\n            dist_sampler=dist_sampler,\n            num_rpc_threads=num_rpc_threads,\n            **kwargs,\n        )\n        LinkLoader.__init__(\n            self,\n            data=data,\n            link_sampler=dist_sampler,\n            edge_label_index=edge_label_index,\n            edge_label=edge_label,\n            edge_label_time=edge_label_time,\n            neg_sampling=neg_sampling,\n            neg_sampling_ratio=neg_sampling_ratio,\n            transform=transform,\n            filter_per_worker=filter_per_worker,\n            worker_init_fn=self.worker_init_fn,\n            transform_sampler_output=self.channel_get if channel else None,\n            **kwargs,\n        )\n\n    def __repr__(self) -> str:\n        return DistLoader.__repr__(self)\n"
  },
  {
    "path": "torch_geometric/distributed/dist_loader.py",
    "content": "import atexit\nimport logging\nimport os\nfrom typing import Any, Optional, Union\n\nimport torch.distributed\nimport torch.multiprocessing as mp\n\nfrom torch_geometric.distributed import DistNeighborSampler\nfrom torch_geometric.distributed.dist_context import DistContext\nfrom torch_geometric.distributed.rpc import (\n    global_barrier,\n    init_rpc,\n    shutdown_rpc,\n)\nfrom torch_geometric.loader.base import DataLoaderIterator\n\n\nclass DistLoader:\n    r\"\"\"A base class for creating distributed data loading routines.\n\n    Args:\n        current_ctx (DistContext): Distributed context info of the current\n            process.\n        master_addr (str, optional): RPC address for distributed loader\n            communication.\n            Refers to the IP address of the master node. (default: :obj:`None`)\n        master_port (int or str, optional): The open port for RPC communication\n            with the master node. (default: :obj:`None`)\n        channel (mp.Queue, optional): A communication channel for messages.\n            (default: :obj:`None`)\n        num_rpc_threads (int, optional): The number of threads in the\n            thread-pool used by\n            :class:`~torch.distributed.rpc.TensorPipeAgent` to execute\n            requests. (default: :obj:`16`)\n        rpc_timeout (int, optional): The default timeout in seconds for RPC\n            requests.\n            If the RPC has not completed in this timeframe, an exception will\n            be raised.\n            Callers can override this timeout for\n            individual RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and\n            :meth:`~torch.distributed.rpc.rpc_async` if necessary.\n            (default: :obj:`180`)\n    \"\"\"\n    def __init__(\n        self,\n        current_ctx: DistContext,\n        master_addr: Optional[str] = None,\n        master_port: Optional[Union[int, str]] = None,\n        channel: Optional[mp.Queue] = None,\n        num_rpc_threads: int = 16,\n        rpc_timeout: int = 180,\n        dist_sampler: DistNeighborSampler = None,\n        **kwargs,\n    ):\n        if master_addr is None and os.environ.get('MASTER_ADDR') is not None:\n            master_addr = os.environ['MASTER_ADDR']\n        if master_addr is None:\n            raise ValueError(f\"Missing master address for RPC communication \"\n                             f\"in '{self.__class__.__name__}'. Try to provide \"\n                             f\"it or set it via the 'MASTER_ADDR' environment \"\n                             f\"variable.\")\n\n        if master_port is None and os.environ.get('MASTER_PORT') is not None:\n            # Select next port to MASTER_PORT used for DDP.\n            # If multiple loaders are launched in the same script, please\n            # provide distinct ports for each.\n            master_port = int(os.environ['MASTER_PORT']) + 1\n        if master_port is None:\n            raise ValueError(f\"Missing master port for RPC communication in \"\n                             f\"'{self.__class__.__name__}'. Try to provide it \"\n                             f\"or set it via the 'MASTER_ADDR' environment \"\n                             f\"variable.\")\n\n        assert num_rpc_threads > 0\n        assert rpc_timeout > 0\n\n        self.dist_sampler = dist_sampler\n        self.current_ctx = current_ctx\n        self.master_addr = master_addr\n        self.master_port = master_port\n        self.channel = channel\n        self.pid = mp.current_process().pid\n        self.num_rpc_threads = num_rpc_threads\n        self.rpc_timeout = rpc_timeout\n        self.num_workers = kwargs.get('num_workers', 0)\n\n        logging.info(f\"[{self}] MASTER_ADDR={master_addr}, \"\n                     f\"MASTER_PORT={master_port}\")\n\n        if self.num_workers == 0:  # Initialize RPC in main process:\n            self.worker_init_fn(0)\n\n    def channel_get(self, out: Any) -> Any:\n        if self.channel:\n            out = self.channel.get()\n            logging.debug(f\"[{self}] Retrieved message\")\n        return out\n\n    def reset_channel(self, channel=None):\n        # clean remaining queue items and restart new queue\n        logging.debug(f'{self} Resetting msg channel')\n        while not self.channel.empty():\n            self.channel.get_nowait()\n\n        torch.distributed.barrier()\n\n        self.channel = channel or mp.Queue()\n        self.dist_sampler.channel = self.channel\n\n    def worker_init_fn(self, worker_id: int):\n        try:\n            num_sampler_proc = self.num_workers if self.num_workers > 0 else 1\n            self.current_ctx_worker = DistContext(\n                world_size=self.current_ctx.world_size * num_sampler_proc,\n                rank=self.current_ctx.rank * num_sampler_proc + worker_id,\n                global_world_size=self.current_ctx.world_size *\n                num_sampler_proc,\n                global_rank=self.current_ctx.rank * num_sampler_proc +\n                worker_id,\n                group_name='mp_sampling_worker',\n            )\n\n            init_rpc(\n                current_ctx=self.current_ctx_worker,\n                master_addr=self.master_addr,\n                master_port=self.master_port,\n                num_rpc_threads=self.num_rpc_threads,\n                rpc_timeout=self.rpc_timeout,\n            )\n            logging.info(\n                f\"RPC initiated in worker-{worker_id} \"\n                f\"(current_ctx_worker={self.current_ctx_worker.worker_name})\")\n            self.dist_sampler.init_sampler_instance()\n            self.dist_sampler.register_sampler_rpc()\n            global_barrier(timeout=10)  # Wait for all workers to initialize.\n\n            # close RPC & worker group at exit:\n            atexit.register(shutdown_rpc, self.current_ctx_worker.worker_name)\n\n        except RuntimeError as e:\n            raise RuntimeError(f\"`{self}.init_fn()` could not initialize the \"\n                               f\"worker loop of the neighbor sampler\") from e\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(pid={self.pid})'\n\n    def __enter__(self) -> DataLoaderIterator:\n        # fetch a single batch for init\n        self._prefetch_old = self.prefetch_factor\n        self.prefetch_factor = 1\n        self._iterator = self._get_iterator()\n        return self._iterator\n\n    def __exit__(self, *args) -> None:\n        if self.channel:\n            self.reset_channel()\n        if self._iterator:\n            del self._iterator\n            torch.distributed.barrier()\n            self._iterator = None\n            self.prefetch_factor = self._prefetch_old\n"
  },
  {
    "path": "torch_geometric/distributed/dist_neighbor_loader.py",
    "content": "from typing import Callable, Dict, List, Optional, Tuple, Union\nfrom warnings import warn\n\nimport torch\n\nfrom torch_geometric.distributed import (\n    DistContext,\n    DistLoader,\n    DistNeighborSampler,\n    LocalFeatureStore,\n    LocalGraphStore,\n)\nfrom torch_geometric.loader import NodeLoader\nfrom torch_geometric.sampler.base import SubgraphType\nfrom torch_geometric.typing import EdgeType, InputNodes, OptTensor\n\n\nclass DistNeighborLoader(NodeLoader, DistLoader):\n    r\"\"\"A distributed loader that performs sampling from nodes.\n\n    Args:\n        data (tuple): A (:class:`~torch_geometric.data.FeatureStore`,\n            :class:`~torch_geometric.data.GraphStore`) data object.\n        num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]):\n            The number of neighbors to sample for each node in each iteration.\n            If an entry is set to :obj:`-1`, all neighbors will be included.\n            In heterogeneous graphs, may also take in a dictionary denoting\n            the amount of neighbors to sample for each individual edge type.\n        master_addr (str): RPC address for distributed loader communication,\n            *i.e.* the IP address of the master node.\n        master_port (Union[int, str]): Open port for RPC communication with\n            the master node.\n        current_ctx (DistContext): Distributed context information of the\n            current process.\n        concurrency (int, optional): RPC concurrency used for defining the\n            maximum size of the asynchronous processing queue.\n            (default: :obj:`1`)\n\n    All other arguments follow the interface of\n    :class:`torch_geometric.loader.NeighborLoader`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Tuple[LocalFeatureStore, LocalGraphStore],\n        num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],\n        master_addr: str,\n        master_port: Union[int, str],\n        current_ctx: DistContext,\n        input_nodes: InputNodes = None,\n        input_time: OptTensor = None,\n        dist_sampler: Optional[DistNeighborSampler] = None,\n        replace: bool = False,\n        subgraph_type: Union[SubgraphType, str] = \"directional\",\n        disjoint: bool = False,\n        temporal_strategy: str = \"uniform\",\n        time_attr: Optional[str] = None,\n        transform: Optional[Callable] = None,\n        concurrency: int = 1,\n        num_rpc_threads: int = 16,\n        filter_per_worker: Optional[bool] = False,\n        async_sampling: bool = True,\n        device: Optional[torch.device] = None,\n        **kwargs,\n    ):\n        assert isinstance(data[0], LocalFeatureStore)\n        assert isinstance(data[1], LocalGraphStore)\n        assert concurrency >= 1, \"RPC concurrency must be greater than 1\"\n\n        if input_time is not None and time_attr is None:\n            raise ValueError(\"Received conflicting 'input_time' and \"\n                             \"'time_attr' arguments: 'input_time' is set \"\n                             \"while 'time_attr' is not set.\")\n\n        channel = torch.multiprocessing.Queue() if async_sampling else None\n\n        if dist_sampler is None:\n            dist_sampler = DistNeighborSampler(\n                data=data,\n                current_ctx=current_ctx,\n                num_neighbors=num_neighbors,\n                replace=replace,\n                subgraph_type=subgraph_type,\n                disjoint=disjoint,\n                temporal_strategy=temporal_strategy,\n                time_attr=time_attr,\n                device=device,\n                channel=channel,\n                concurrency=concurrency,\n            )\n        else:\n            warn(  # noqa: B028\n                \"`torch_geometric.distributed` has been deprecated since 2.7.0 and will \"  # noqa: E501\n                \"no longer be maintained. For distributed training, refer to our \"  # noqa: E501\n                \"tutorials on distributed training at \"\n                \"https://pytorch-geometric.readthedocs.io/en/latest/tutorial/distributed.html \"  # noqa: E501\n                \"or cuGraph examples at \"\n                \"https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples\",  # noqa: E501\n                stack_level=2)\n\n        DistLoader.__init__(\n            self,\n            channel=channel,\n            master_addr=master_addr,\n            master_port=master_port,\n            current_ctx=current_ctx,\n            dist_sampler=dist_sampler,\n            num_rpc_threads=num_rpc_threads,\n            **kwargs,\n        )\n        NodeLoader.__init__(\n            self,\n            data=data,\n            node_sampler=dist_sampler,\n            input_nodes=input_nodes,\n            input_time=input_time,\n            transform=transform,\n            filter_per_worker=filter_per_worker,\n            transform_sampler_output=self.channel_get if channel else None,\n            worker_init_fn=self.worker_init_fn,\n            **kwargs,\n        )\n\n    def __repr__(self) -> str:\n        return DistLoader.__repr__(self)\n"
  },
  {
    "path": "torch_geometric/distributed/dist_neighbor_sampler.py",
    "content": "import itertools\nimport logging\nimport math\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\nfrom warnings import warn\n\nimport numpy as np\nimport torch\nimport torch.multiprocessing as mp\nfrom torch import Tensor\n\nfrom torch_geometric.distributed import (\n    DistContext,\n    LocalFeatureStore,\n    LocalGraphStore,\n)\nfrom torch_geometric.distributed.event_loop import (\n    ConcurrentEventLoop,\n    to_asyncio_future,\n)\nfrom torch_geometric.distributed.rpc import (\n    RPCCallBase,\n    RPCRouter,\n    rpc_async,\n    rpc_partition_to_workers,\n    rpc_register,\n)\nfrom torch_geometric.distributed.utils import (\n    BatchDict,\n    DistEdgeHeteroSamplerInput,\n    NodeDict,\n    remove_duplicates,\n)\nfrom torch_geometric.sampler import (\n    EdgeSamplerInput,\n    HeteroSamplerOutput,\n    NegativeSampling,\n    NeighborSampler,\n    NodeSamplerInput,\n    SamplerOutput,\n)\nfrom torch_geometric.sampler.base import NumNeighbors, SubgraphType\nfrom torch_geometric.sampler.neighbor_sampler import neg_sample\nfrom torch_geometric.sampler.utils import remap_keys\nfrom torch_geometric.typing import EdgeType, NodeType\n\nNumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]\n\n\nclass RPCSamplingCallee(RPCCallBase):\n    r\"\"\"A wrapper for RPC callee that will perform RPC sampling from remote\n    processes.\n    \"\"\"\n    def __init__(self, sampler: NeighborSampler):\n        super().__init__()\n        self.sampler = sampler\n\n    def rpc_async(self, *args, **kwargs) -> Any:\n        return self.sampler._sample_one_hop(*args, **kwargs)\n\n    def rpc_sync(self, *args, **kwargs) -> Any:\n        pass\n\n\nclass DistNeighborSampler:\n    r\"\"\"An implementation of a distributed and asynchronised neighbor sampler\n    used by :class:`~torch_geometric.distributed.DistNeighborLoader` and\n    :class:`~torch_geometric.distributed.DistLinkNeighborLoader`.\n    \"\"\"\n    def __init__(\n        self,\n        current_ctx: DistContext,\n        data: Tuple[LocalFeatureStore, LocalGraphStore],\n        num_neighbors: NumNeighborsType,\n        channel: Optional[mp.Queue] = None,\n        replace: bool = False,\n        subgraph_type: Union[SubgraphType, str] = 'directional',\n        disjoint: bool = False,\n        temporal_strategy: str = 'uniform',\n        time_attr: Optional[str] = None,\n        concurrency: int = 1,\n        device: Optional[torch.device] = None,\n        **kwargs,\n    ):\n        warn(  # noqa: B028\n            \"`torch_geometric.distributed` has been deprecated since 2.7.0 and will \"  # noqa: E501\n            \"no longer be maintained. For distributed training, refer to our \"  # noqa: E501\n            \"tutorials on distributed training at \"\n            \"https://pytorch-geometric.readthedocs.io/en/latest/tutorial/distributed.html \"  # noqa: E501\n            \"or cuGraph examples at \"\n            \"https://github.com/rapidsai/cugraph-gnn/tree/main/python/cugraph-pyg/cugraph_pyg/examples\",  # noqa: E501\n            stack_level=2)\n        self.current_ctx = current_ctx\n\n        self.feature_store, self.graph_store = data\n        assert isinstance(self.graph_store, LocalGraphStore)\n        assert isinstance(self.feature_store, LocalFeatureStore)\n        self.is_hetero = self.graph_store.meta['is_hetero']\n\n        self.num_neighbors = num_neighbors\n        self.channel = channel\n        self.concurrency = concurrency\n        self.device = device\n        self.event_loop = None\n        self.replace = replace\n        self.subgraph_type = SubgraphType(subgraph_type)\n        self.disjoint = disjoint\n        self.temporal_strategy = temporal_strategy\n        self.time_attr = time_attr\n        self.temporal = time_attr is not None\n        self.with_edge_attr = self.feature_store.has_edge_attr()\n        self.csc = True\n\n    def init_sampler_instance(self):\n        self._sampler = NeighborSampler(\n            data=(self.feature_store, self.graph_store),\n            num_neighbors=self.num_neighbors,\n            subgraph_type=self.subgraph_type,\n            replace=self.replace,\n            disjoint=self.disjoint,\n            temporal_strategy=self.temporal_strategy,\n            time_attr=self.time_attr,\n        )\n\n        self.num_hops = self._sampler.num_neighbors.num_hops\n        self.node_types = self._sampler.node_types\n        self.edge_types = self._sampler.edge_types\n        self.node_time = self._sampler.node_time\n        self.edge_time = self._sampler.edge_time\n\n    def register_sampler_rpc(self) -> None:\n        partition2workers = rpc_partition_to_workers(\n            current_ctx=self.current_ctx,\n            num_partitions=self.graph_store.num_partitions,\n            current_partition_idx=self.graph_store.partition_idx,\n        )\n        self.rpc_router = RPCRouter(partition2workers)\n        self.feature_store.set_rpc_router(self.rpc_router)\n\n        rpc_sample_callee = RPCSamplingCallee(self)\n        self.rpc_sample_callee_id = rpc_register(rpc_sample_callee)\n\n    def init_event_loop(self) -> None:\n        if self.event_loop is None:\n            self.event_loop = ConcurrentEventLoop(self.concurrency)\n            self.event_loop.start_loop()\n            logging.info(f'{self} uses {self.event_loop}')\n\n    # Node-based distributed sampling #########################################\n\n    def sample_from_nodes(\n        self,\n        inputs: NodeSamplerInput,\n        **kwargs,\n    ) -> Optional[Union[SamplerOutput, HeteroSamplerOutput]]:\n        self.init_event_loop()\n\n        inputs = NodeSamplerInput.cast(inputs)\n        if self.channel is None:\n            # synchronous sampling\n            return self.event_loop.run_task(\n                coro=self._sample_from(self.node_sample, inputs))\n\n        # asynchronous sampling\n        cb = kwargs.get(\"callback\", None)\n        self.event_loop.add_task(\n            coro=self._sample_from(self.node_sample, inputs), callback=cb)\n        return None\n\n    # Edge-based distributed sampling #########################################\n\n    def sample_from_edges(\n        self,\n        inputs: EdgeSamplerInput,\n        neg_sampling: Optional[NegativeSampling] = None,\n        **kwargs,\n    ) -> Optional[Union[SamplerOutput, HeteroSamplerOutput]]:\n        self.init_event_loop()\n\n        if self.channel is None:\n            # synchronous sampling\n            return self.event_loop.run_task(coro=self._sample_from(\n                self.edge_sample, inputs, self.node_sample, self._sampler.\n                num_nodes, self.disjoint, self.node_time, neg_sampling))\n\n        # asynchronous sampling\n        cb = kwargs.get(\"callback\", None)\n        self.event_loop.add_task(\n            coro=self._sample_from(self.edge_sample, inputs, self.node_sample,\n                                   self._sampler.num_nodes, self.disjoint,\n                                   self.node_time, neg_sampling), callback=cb)\n        return None\n\n    async def _sample_from(\n        self,\n        async_func,\n        *args,\n        **kwargs,\n    ) -> Optional[Union[SamplerOutput, HeteroSamplerOutput]]:\n        sampler_output = await async_func(*args, **kwargs)\n\n        if self.subgraph_type == SubgraphType.bidirectional:\n            sampler_output = sampler_output.to_bidirectional()\n\n        res = await self._collate_fn(sampler_output)\n\n        if self.channel is None:\n            return res\n        self.channel.put(res)\n        return None\n\n    async def node_sample(\n        self,\n        inputs: Union[NodeSamplerInput, DistEdgeHeteroSamplerInput],\n    ) -> Union[SamplerOutput, HeteroSamplerOutput]:\n        r\"\"\"Performs layer-by-layer distributed sampling from a\n        :class:`NodeSamplerInput` or :class:`DistEdgeHeteroSamplerInput` and\n        returns the output of the sampling procedure.\n\n        .. note::\n            In case of distributed training it is required to synchronize the\n            results between machines after each layer.\n        \"\"\"\n        input_type = inputs.input_type\n        self.input_type = input_type\n\n        if isinstance(inputs, NodeSamplerInput):\n            seed = inputs.node.to(self.device)\n            batch_size = len(inputs.node)\n            seed_batch = torch.arange(batch_size) if self.disjoint else None\n\n            metadata = (inputs.input_id, inputs.time, batch_size)\n\n            seed_time: Optional[Tensor] = None\n            if self.temporal:\n                if inputs.time is not None:\n                    seed_time = inputs.time.to(self.device)\n                elif self.node_time is not None:\n                    if not self.is_hetero:\n                        seed_time = self.node_time[seed]\n                    else:\n                        seed_time = self.node_time[input_type][seed]\n                else:\n                    raise ValueError(\"Seed time needs to be specified\")\n        else:  # `DistEdgeHeteroSamplerInput`:\n            metadata = None  # Metadata is added during `edge_sample`.\n\n        # Heterogeneous Neighborhood Sampling #################################\n\n        if self.is_hetero:\n            if input_type is None:\n                raise ValueError(\"Input type should be defined\")\n\n            node_dict = NodeDict(self.node_types, self.num_hops)\n            batch_dict = BatchDict(self.node_types, self.num_hops)\n\n            if isinstance(inputs, NodeSamplerInput):\n                seed_dict: Dict[NodeType, Tensor] = {input_type: seed}\n                if self.temporal:\n                    node_dict.seed_time[input_type][0] = seed_time.clone()\n\n            else:  # `DistEdgeHeteroSamplerInput`:\n                seed_dict = inputs.node_dict\n                if self.temporal:\n                    for k, v in inputs.node_dict.items():\n                        if inputs.time_dict is not None:\n                            node_dict.seed_time[k][0] = inputs.time_dict[k]\n                        elif self.node_time is not None:\n                            node_dict.seed_time[k][0] = self.node_time[k][v]\n                        else:\n                            raise ValueError(\"Seed time needs to be specified\")\n\n            edge_dict: Dict[EdgeType, Tensor] = {\n                k: torch.empty(0, dtype=torch.int64)\n                for k in self.edge_types\n            }\n            sampled_nbrs_per_node_dict: Dict[EdgeType, List[List]] = {\n                k: [[] for _ in range(self.num_hops)]\n                for k in self.edge_types\n            }\n            num_sampled_edges_dict: Dict[EdgeType, List[int]] = {\n                k: []\n                for k in self.edge_types\n            }\n            num_sampled_nodes_dict: Dict[NodeType, List[int]] = {\n                k: [0]\n                for k in self.node_types\n            }\n\n            # Fill in node_dict and batch_dict with input data:\n            batch_len = 0\n            for k, v in seed_dict.items():\n                node_dict.src[k][0] = v\n                node_dict.out[k] = v\n                num_sampled_nodes_dict[k][0] = len(v)\n\n                if self.disjoint:\n                    src_batch = torch.arange(batch_len, batch_len + len(v))\n                    batch_dict.src[k][0] = src_batch\n                    batch_dict.out[k] = src_batch\n\n                    batch_len = len(src_batch)\n\n            # Loop over the layers:\n            for i in range(self.num_hops):\n                # Sample neighbors per edge type:\n                for edge_type in self.edge_types:\n                    # `src` is a destination node type of a given edge.\n                    src = edge_type[0] if not self.csc else edge_type[2]\n\n                    if node_dict.src[src][i].numel() == 0:\n                        # No source nodes of this type in the current layer.\n                        num_sampled_edges_dict[edge_type].append(0)\n                        continue\n\n                    if isinstance(self.num_neighbors, list):\n                        one_hop_num = self.num_neighbors[i]\n                    else:\n                        one_hop_num = self.num_neighbors[edge_type][i]\n\n                    # Sample neighbors:\n                    out = await self.sample_one_hop(\n                        node_dict.src[src][i],\n                        one_hop_num,\n                        node_dict.seed_time[src][i],\n                        batch_dict.src[src][i],\n                        edge_type,\n                    )\n\n                    if out.node.numel() == 0:  # No neighbors were sampled.\n                        num_sampled_edges_dict[edge_type].append(0)\n                        continue\n\n                    # `dst` is a destination node type of a given edge.\n                    dst = edge_type[2] if not self.csc else edge_type[0]\n\n                    # Remove duplicates:\n                    (\n                        src_node,\n                        node_dict.out[dst],\n                        src_batch,\n                        batch_dict.out[dst],\n                    ) = remove_duplicates(\n                        out,\n                        node_dict.out[dst],\n                        batch_dict.out[dst],\n                        self.disjoint,\n                    )\n\n                    # Create src nodes for the next layer:\n                    node_dict.src[dst][i + 1] = torch.cat(\n                        [node_dict.src[dst][i + 1], src_node])\n                    if self.disjoint:\n                        batch_dict.src[dst][i + 1] = torch.cat(\n                            [batch_dict.src[dst][i + 1], src_batch])\n\n                    # Save sampled nodes with duplicates to be able to create\n                    # local edge indices:\n                    node_dict.with_dupl[dst] = torch.cat(\n                        [node_dict.with_dupl[dst], out.node])\n\n                    edge_dict[edge_type] = torch.cat(\n                        [edge_dict[edge_type], out.edge])\n\n                    if self.disjoint:\n                        batch_dict.with_dupl[dst] = torch.cat(\n                            [batch_dict.with_dupl[dst], out.batch])\n\n                    if self.temporal and i < self.num_hops - 1:\n                        # Assign seed time based on source node subgraph ID:\n                        if isinstance(inputs, NodeSamplerInput):\n                            src_seed_time = [\n                                seed_time[(seed_batch == batch_idx).nonzero()]\n                                for batch_idx in src_batch\n                            ]\n                            src_seed_time = torch.as_tensor(\n                                src_seed_time, dtype=torch.int64)\n\n                        else:  # `DistEdgeHeteroSamplerInput`:\n                            src_seed_time = torch.empty(0, dtype=torch.int64)\n                            for k, v in batch_dict.src.items():\n                                time = [\n                                    node_dict.seed_time[k][0][(\n                                        v[0] == batch_idx).nonzero()]\n                                    for batch_idx in src_batch\n                                ]\n                                try:\n                                    time = torch.as_tensor(\n                                        time, dtype=torch.int64)\n                                    src_seed_time = torch.cat(\n                                        [src_seed_time, time])\n                                except Exception:\n                                    # `time`  may be an empty tensors, because\n                                    # no nodes of this type were sampled.\n                                    pass\n\n                        node_dict.seed_time[dst][i + 1] = torch.cat(\n                            [node_dict.seed_time[dst][i + 1], src_seed_time])\n\n                    # Collect sampled neighbors per node for each layer:\n                    sampled_nbrs_per_node_dict[edge_type][i] += out.metadata[0]\n\n                    num_sampled_edges_dict[edge_type].append(len(out.node))\n\n                for node_type in self.node_types:\n                    num_sampled_nodes_dict[node_type].append(\n                        len(node_dict.src[node_type][i + 1]))\n\n            sampled_nbrs_per_node_dict = remap_keys(sampled_nbrs_per_node_dict,\n                                                    self._sampler.to_rel_type)\n\n            # Create local edge indices for a batch:\n            row_dict, col_dict = torch.ops.pyg.hetero_relabel_neighborhood(\n                self.node_types,\n                self.edge_types,\n                seed_dict,\n                node_dict.with_dupl,\n                sampled_nbrs_per_node_dict,\n                self._sampler.num_nodes,\n                batch_dict.with_dupl,\n                self.csc,\n                self.disjoint,\n            )\n\n            sampler_output = HeteroSamplerOutput(\n                node=node_dict.out,\n                row=remap_keys(row_dict, self._sampler.to_edge_type),\n                col=remap_keys(col_dict, self._sampler.to_edge_type),\n                edge=edge_dict,\n                batch=batch_dict.out if self.disjoint else None,\n                num_sampled_nodes=num_sampled_nodes_dict,\n                num_sampled_edges=num_sampled_edges_dict,\n                metadata=metadata,\n            )\n\n        # Homogeneous Neighborhood Sampling ###################################\n\n        else:\n            src = seed\n            node = src.clone()\n\n            src_batch = seed_batch.clone() if self.disjoint else None\n            batch = seed_batch.clone() if self.disjoint else None\n\n            src_seed_time = seed_time.clone() if self.temporal else None\n\n            node_with_dupl = [torch.empty(0, dtype=torch.int64)]\n            batch_with_dupl = [torch.empty(0, dtype=torch.int64)]\n            edge = [torch.empty(0, dtype=torch.int64)]\n\n            sampled_nbrs_per_node = []\n            num_sampled_nodes = [seed.numel()]\n            num_sampled_edges = []\n\n            # Loop over the layers:\n            for i, one_hop_num in enumerate(self.num_neighbors):\n                out = await self.sample_one_hop(src, one_hop_num,\n                                                src_seed_time, src_batch)\n                if out.node.numel() == 0:\n                    # No neighbors were sampled:\n                    num_zero_layers = self.num_hops - i\n                    num_sampled_nodes += num_zero_layers * [0]\n                    num_sampled_edges += num_zero_layers * [0]\n                    break\n\n                # Remove duplicates:\n                src, node, src_batch, batch = remove_duplicates(\n                    out, node, batch, self.disjoint)\n\n                node_with_dupl.append(out.node)\n                edge.append(out.edge)\n\n                if self.disjoint:\n                    batch_with_dupl.append(out.batch)\n\n                if self.temporal and i < self.num_hops - 1:\n                    # Assign seed time based on src nodes subgraph IDs.\n                    src_seed_time = [\n                        seed_time[(seed_batch == batch_idx).nonzero()]\n                        for batch_idx in src_batch\n                    ]\n                    src_seed_time = torch.as_tensor(src_seed_time,\n                                                    dtype=torch.int64)\n\n                num_sampled_nodes.append(len(src))\n                num_sampled_edges.append(len(out.node))\n                sampled_nbrs_per_node += out.metadata[0]\n\n            row, col = torch.ops.pyg.relabel_neighborhood(\n                seed,\n                torch.cat(node_with_dupl),\n                sampled_nbrs_per_node,\n                self._sampler.num_nodes,\n                torch.cat(batch_with_dupl) if self.disjoint else None,\n                self.csc,\n                self.disjoint,\n            )\n\n            sampler_output = SamplerOutput(\n                node=node,\n                row=row,\n                col=col,\n                edge=torch.cat(edge),\n                batch=batch if self.disjoint else None,\n                num_sampled_nodes=num_sampled_nodes,\n                num_sampled_edges=num_sampled_edges,\n                metadata=metadata,\n            )\n\n        return sampler_output\n\n    async def edge_sample(\n        self,\n        inputs: EdgeSamplerInput,\n        sample_fn: Callable,\n        num_nodes: Union[int, Dict[NodeType, int]],\n        disjoint: bool,\n        node_time: Optional[Union[Tensor, Dict[str, Tensor]]] = None,\n        neg_sampling: Optional[NegativeSampling] = None,\n    ) -> Union[SamplerOutput, HeteroSamplerOutput]:\n        r\"\"\"Performs layer-by-layer distributed sampling from an\n        :class:`EdgeSamplerInput` and returns the output of the sampling\n        procedure.\n\n        .. note::\n            In case of distributed training it is required to synchronize the\n            results between machines after each layer.\n        \"\"\"\n        input_id = inputs.input_id\n        src = inputs.row\n        dst = inputs.col\n        edge_label = inputs.label\n        edge_label_time = inputs.time\n        input_type = inputs.input_type\n\n        src_time = dst_time = edge_label_time\n        assert edge_label_time is None or disjoint\n\n        assert isinstance(num_nodes, (dict, int))\n        if not isinstance(num_nodes, dict):\n            num_src_nodes = num_dst_nodes = num_nodes\n        else:\n            num_src_nodes = num_nodes[input_type[0]]\n            num_dst_nodes = num_nodes[input_type[-1]]\n\n        num_pos = src.numel()\n        num_neg = 0\n\n        # Negative Sampling ###################################################\n\n        if neg_sampling is not None:\n            # When we are doing negative sampling, we append negative\n            # information of nodes/edges to `src`, `dst`, `src_time`,\n            # `dst_time`. Later on, we can easily reconstruct what belongs to\n            # positive and negative examples by slicing via `num_pos`.\n            num_neg = math.ceil(num_pos * neg_sampling.amount)\n\n            if neg_sampling.is_binary():\n                # In the \"binary\" case, we randomly sample negative pairs of\n                # nodes.\n                if isinstance(node_time, dict):\n                    src_node_time = node_time.get(input_type[0])\n                else:\n                    src_node_time = node_time\n\n                src_neg = neg_sample(src, neg_sampling, num_src_nodes,\n                                     src_time, src_node_time)\n                src = torch.cat([src, src_neg], dim=0)\n\n                if isinstance(node_time, dict):\n                    dst_node_time = node_time.get(input_type[-1])\n                else:\n                    dst_node_time = node_time\n\n                dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes,\n                                     dst_time, dst_node_time)\n                dst = torch.cat([dst, dst_neg], dim=0)\n\n                if edge_label is None:\n                    edge_label = torch.ones(num_pos)\n                size = (num_neg, ) + edge_label.size()[1:]\n                edge_neg_label = edge_label.new_zeros(size)\n                edge_label = torch.cat([edge_label, edge_neg_label])\n\n                if edge_label_time is not None:\n                    src_time = dst_time = edge_label_time.repeat(\n                        1 + math.ceil(neg_sampling.amount))[:num_pos + num_neg]\n\n            elif neg_sampling.is_triplet():\n                # In the \"triplet\" case, we randomly sample negative\n                # destinations.\n                if isinstance(node_time, dict):\n                    dst_node_time = node_time.get(input_type[-1])\n                else:\n                    dst_node_time = node_time\n\n                dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes,\n                                     dst_time, dst_node_time)\n                dst = torch.cat([dst, dst_neg], dim=0)\n\n                assert edge_label is None\n\n                if edge_label_time is not None:\n                    dst_time = edge_label_time.repeat(1 + neg_sampling.amount)\n\n        # Heterogeneus Neighborhood Sampling ##################################\n\n        if input_type is not None:\n            if input_type[0] != input_type[-1]:  # Two distinct node types:\n\n                if not disjoint:\n                    src, inverse_src = src.unique(return_inverse=True)\n                    dst, inverse_dst = dst.unique(return_inverse=True)\n\n                seed_dict = {input_type[0]: src, input_type[-1]: dst}\n\n                seed_time_dict = None\n                if edge_label_time is not None:  # Always disjoint.\n                    seed_time_dict = {\n                        input_type[0]: src_time,\n                        input_type[-1]: dst_time,\n                    }\n\n                out = await sample_fn(\n                    DistEdgeHeteroSamplerInput(\n                        input_id=inputs.input_id,\n                        node_dict=seed_dict,\n                        time_dict=seed_time_dict,\n                        input_type=input_type,\n                    ))\n\n            else:\n                # Only a single node type: Merge both source and destination.\n                seed = torch.cat([src, dst], dim=0)\n\n                if not disjoint:\n                    seed, inverse_seed = seed.unique(return_inverse=True)\n\n                seed_dict = {input_type[0]: seed}\n\n                seed_time = None\n                if edge_label_time is not None:  # Always disjoint.\n                    seed_time = torch.cat([src_time, dst_time], dim=0)\n\n                out = await sample_fn(\n                    NodeSamplerInput(\n                        input_id=inputs.input_id,\n                        node=seed,\n                        time=seed_time,\n                        input_type=input_type[0],\n                    ))\n\n            # Enhance `out` by label information ##############################\n            if disjoint:\n                for key, batch in out.batch.items():\n                    out.batch[key] = batch % num_pos\n\n            if neg_sampling is None or neg_sampling.is_binary():\n                if disjoint:\n                    if input_type[0] != input_type[-1]:\n                        edge_label_index = torch.arange(num_pos + num_neg)\n                        edge_label_index = edge_label_index.repeat(2)\n                        edge_label_index = edge_label_index.view(2, -1)\n                    else:\n                        num_labels = num_pos + num_neg\n                        edge_label_index = torch.arange(2 * (num_labels))\n                        edge_label_index = edge_label_index.view(2, -1)\n                else:\n                    if input_type[0] != input_type[-1]:\n                        edge_label_index = torch.stack([\n                            inverse_src,\n                            inverse_dst,\n                        ], dim=0)\n                    else:\n                        edge_label_index = inverse_seed.view(2, -1)\n\n                out.metadata = (input_id, edge_label_index, edge_label,\n                                src_time)\n\n            elif neg_sampling.is_triplet():\n                if disjoint:\n                    src_index = torch.arange(num_pos)\n                    if input_type[0] != input_type[-1]:\n                        dst_pos_index = torch.arange(num_pos)\n                        # `dst_neg_index` needs to be offset such that indices\n                        # with offset `num_pos` belong to the same triplet:\n                        dst_neg_index = torch.arange(\n                            num_pos, seed_dict[input_type[-1]].numel())\n                        dst_neg_index = dst_neg_index.view(-1, num_pos).t()\n                    else:\n                        dst_pos_index = torch.arange(num_pos, 2 * num_pos)\n                        dst_neg_index = torch.arange(\n                            2 * num_pos, seed_dict[input_type[-1]].numel())\n                        dst_neg_index = dst_neg_index.view(-1, num_pos).t()\n                else:\n                    if input_type[0] != input_type[-1]:\n                        src_index = inverse_src\n                        dst_pos_index = inverse_dst[:num_pos]\n                        dst_neg_index = inverse_dst[num_pos:]\n                    else:\n                        src_index = inverse_seed[:num_pos]\n                        dst_pos_index = inverse_seed[num_pos:2 * num_pos]\n                        dst_neg_index = inverse_seed[2 * num_pos:]\n\n                dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)\n\n                out.metadata = (\n                    input_id,\n                    src_index,\n                    dst_pos_index,\n                    dst_neg_index,\n                    src_time,\n                )\n\n        # Homogeneous Neighborhood Sampling ###################################\n\n        else:\n\n            seed = torch.cat([src, dst], dim=0)\n            seed_time = None\n\n            if not disjoint:\n                seed, inverse_seed = seed.unique(return_inverse=True)\n\n            if edge_label_time is not None:  # Always disjoint.\n                seed_time = torch.cat([src_time, dst_time])\n\n            out = await sample_fn(\n                NodeSamplerInput(\n                    input_id=inputs.input_id,\n                    node=seed,\n                    time=seed_time,\n                    input_type=None,\n                ))\n\n            # Enhance `out` by label information ##############################\n            if neg_sampling is None or neg_sampling.is_binary():\n                if disjoint:\n                    out.batch = out.batch % num_pos\n                    edge_label_index = torch.arange(seed.numel()).view(2, -1)\n                else:\n                    edge_label_index = inverse_seed.view(2, -1)\n\n                out.metadata = (input_id, edge_label_index, edge_label,\n                                src_time)\n\n            elif neg_sampling.is_triplet():\n                if disjoint:\n                    out.batch = out.batch % num_pos\n                    src_index = torch.arange(num_pos)\n                    dst_pos_index = torch.arange(num_pos, 2 * num_pos)\n                    # `dst_neg_index` needs to be offset such that indices with\n                    # offset `num_pos` belong to the same triplet:\n                    dst_neg_index = torch.arange(2 * num_pos, seed.numel())\n                    dst_neg_index = dst_neg_index.view(-1, num_pos).t()\n                else:\n                    src_index = inverse_seed[:num_pos]\n                    dst_pos_index = inverse_seed[num_pos:2 * num_pos]\n                    dst_neg_index = inverse_seed[2 * num_pos:]\n                dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)\n\n                out.metadata = (\n                    input_id,\n                    src_index,\n                    dst_pos_index,\n                    dst_neg_index,\n                    src_time,\n                )\n\n        return out\n\n    def _get_sampler_output(\n        self,\n        outputs: List[SamplerOutput],\n        seed_size: int,\n        p_id: int,\n        src_batch: Optional[Tensor] = None,\n    ) -> SamplerOutput:\n        r\"\"\"Used when seed nodes belongs to one partition. It's purpose is to\n        remove seed nodes from sampled nodes and calculates how many neighbors\n        were sampled by each src node based on the\n        :obj:`cumsum_neighbors_per_node`. Returns updated sampler output.\n        \"\"\"\n        cumsum_neighbors_per_node = outputs[p_id].metadata[0]\n\n        # do not include seed\n        outputs[p_id].node = outputs[p_id].node[seed_size:]\n\n        begin = np.array(cumsum_neighbors_per_node[1:])\n        end = np.array(cumsum_neighbors_per_node[:-1])\n\n        sampled_nbrs_per_node = list(np.subtract(begin, end))\n\n        outputs[p_id].metadata = (sampled_nbrs_per_node, )\n\n        if self.disjoint:\n            batch = [[src_batch[i]] * nbrs_per_node\n                     for i, nbrs_per_node in enumerate(sampled_nbrs_per_node)]\n            outputs[p_id].batch = Tensor(\n                list(itertools.chain.from_iterable(batch))).type(torch.int64)\n\n        return outputs[p_id]\n\n    def _merge_sampler_outputs(\n        self,\n        partition_ids: Tensor,\n        partition_orders: Tensor,\n        outputs: List[SamplerOutput],\n        one_hop_num: int,\n        src_batch: Optional[Tensor] = None,\n    ) -> SamplerOutput:\n        r\"\"\"Merges samplers outputs from different partitions, so that they\n        are sorted according to the sampling order. Removes seed nodes from\n        sampled nodes and calculates how many neighbors were sampled by each\n        src node based on the :obj:`cumsum_neighbors_per_node`. Leverages the\n        :obj:`pyg-lib` :meth:`merge_sampler_outputs` function.\n\n        Args:\n            partition_ids (torch.Tensor): Contains information on which\n                partition seeds nodes are located on.\n            partition_orders (torch.Tensor): Contains information about the\n                order of seed nodes in each partition.\n            outputs (List[SamplerOutput]): List of all samplers outputs.\n            one_hop_num (int): Max number of neighbors sampled in the current\n                layer.\n            src_batch (torch.Tensor, optional): The batch assignment of seed\n                nodes. (default: :obj:`None`)\n\n        Returns :obj:`SamplerOutput` containing all merged outputs.\n        \"\"\"\n        sampled_nodes_with_dupl = [\n            o.node if o is not None else torch.empty(0, dtype=torch.int64)\n            for o in outputs\n        ]\n        edge_ids = [\n            o.edge if o is not None else torch.empty(0, dtype=torch.int64)\n            for o in outputs\n        ]\n        cumm_sampled_nbrs_per_node = [\n            o.metadata[0] if o is not None else [] for o in outputs\n        ]\n\n        partition_ids = partition_ids.tolist()\n        partition_orders = partition_orders.tolist()\n\n        partitions_num = self.graph_store.meta[\"num_parts\"]\n\n        out = torch.ops.pyg.merge_sampler_outputs(\n            sampled_nodes_with_dupl,\n            edge_ids,\n            cumm_sampled_nbrs_per_node,\n            partition_ids,\n            partition_orders,\n            partitions_num,\n            one_hop_num,\n            src_batch,\n            self.disjoint,\n        )\n        (\n            out_node_with_dupl,\n            out_edge,\n            out_batch,\n            out_sampled_nbrs_per_node,\n        ) = out\n\n        return SamplerOutput(\n            out_node_with_dupl,\n            None,\n            None,\n            out_edge,\n            out_batch if self.disjoint else None,\n            metadata=(out_sampled_nbrs_per_node, ),\n        )\n\n    async def sample_one_hop(\n        self,\n        srcs: Tensor,\n        one_hop_num: int,\n        seed_time: Optional[Tensor] = None,\n        src_batch: Optional[Tensor] = None,\n        edge_type: Optional[EdgeType] = None,\n    ) -> SamplerOutput:\n        r\"\"\"Samples one-hop neighbors for a set of seed nodes in :obj:`srcs`.\n        If seed nodes are located on a local partition, evaluates the sampling\n        function on the current machine. If seed nodes are from a remote\n        partition, sends a request to a remote machine that contains this\n        partition.\n        \"\"\"\n        src_node_type = None if not self.is_hetero else edge_type[2]\n        partition_ids = self.graph_store.get_partition_ids_from_nids(\n            srcs, src_node_type)\n        partition_orders = torch.zeros(len(partition_ids), dtype=torch.long)\n\n        p_outputs: List[SamplerOutput] = [\n            None\n        ] * self.graph_store.meta[\"num_parts\"]\n        futs: List[torch.futures.Future] = []\n\n        local_only = True\n        single_partition = len(set(partition_ids.tolist())) == 1\n\n        for i in range(self.graph_store.num_partitions):\n            p_id = (self.graph_store.partition_idx +\n                    i) % self.graph_store.num_partitions\n            p_mask = partition_ids == p_id\n            p_srcs = torch.masked_select(srcs, p_mask)\n            p_seed_time = (torch.masked_select(seed_time, p_mask)\n                           if self.temporal else None)\n\n            p_indices = torch.arange(len(p_srcs), dtype=torch.long)\n            partition_orders[p_mask] = p_indices\n\n            if p_srcs.shape[0] > 0:\n                if p_id == self.graph_store.partition_idx:\n                    # Sample for one hop on a local machine:\n                    p_nbr_out = self._sample_one_hop(p_srcs, one_hop_num,\n                                                     p_seed_time, edge_type)\n                    p_outputs.pop(p_id)\n                    p_outputs.insert(p_id, p_nbr_out)\n\n                else:  # Sample on a remote machine:\n                    local_only = False\n                    to_worker = self.rpc_router.get_to_worker(p_id)\n                    futs.append(\n                        rpc_async(\n                            to_worker,\n                            self.rpc_sample_callee_id,\n                            args=(p_srcs, one_hop_num, p_seed_time, edge_type),\n                        ))\n\n        if not local_only:\n            # Src nodes are remote\n            res_fut_list = await to_asyncio_future(\n                torch.futures.collect_all(futs))\n            for i, res_fut in enumerate(res_fut_list):\n                p_id = (self.graph_store.partition_idx + i +\n                        1) % self.graph_store.num_partitions\n                p_outputs.pop(p_id)\n                p_outputs.insert(p_id, res_fut.wait())\n\n        # All src nodes are in the same partition\n        if single_partition:\n            return self._get_sampler_output(p_outputs, len(srcs),\n                                            partition_ids[0], src_batch)\n\n        return self._merge_sampler_outputs(partition_ids, partition_orders,\n                                           p_outputs, one_hop_num, src_batch)\n\n    def _sample_one_hop(\n        self,\n        input_nodes: Tensor,\n        num_neighbors: int,\n        seed_time: Optional[Tensor] = None,\n        edge_type: Optional[EdgeType] = None,\n    ) -> SamplerOutput:\n        r\"\"\"Implements one-hop neighbor sampling for a set of input nodes for a\n        specific edge type.\n        \"\"\"\n        if not self.is_hetero:\n            colptr = self._sampler.colptr\n            row = self._sampler.row\n            node_time = self.node_time\n            edge_time = self.edge_time\n        else:\n            # Given edge type, get input data and evaluate sample function:\n            rel_type = '__'.join(edge_type)\n            colptr = self._sampler.colptr_dict[rel_type]\n            row = self._sampler.row_dict[rel_type]\n            # `node_time` is a destination node time:\n            node_time = (self.node_time or {}).get(edge_type[0], None)\n            edge_time = (self.edge_time or {}).get(edge_type, None)\n\n        out = torch.ops.pyg.dist_neighbor_sample(\n            colptr,\n            row,\n            input_nodes.to(colptr.dtype),\n            num_neighbors,\n            node_time,\n            edge_time,\n            seed_time,\n            None,  # TODO: edge_weight\n            True,  # csc\n            self.replace,\n            self.subgraph_type != SubgraphType.induced,\n            self.disjoint and self.temporal,\n            self.temporal_strategy,\n        )\n        node, edge, cumsum_neighbors_per_node = out\n\n        if self.disjoint and self.temporal:\n            # We create a batch during the step of merging sampler outputs.\n            _, node = node.t().contiguous()\n\n        return SamplerOutput(\n            node=node,\n            row=None,\n            col=None,\n            edge=edge,\n            batch=None,\n            metadata=(cumsum_neighbors_per_node, ),\n        )\n\n    async def _collate_fn(\n        self, output: Union[SamplerOutput, HeteroSamplerOutput]\n    ) -> Union[SamplerOutput, HeteroSamplerOutput]:\n        r\"\"\"Collect labels and features for the sampled subgrarph if necessary,\n        and put them into a sample message.\n        \"\"\"\n        if self.is_hetero:\n            labels = {}\n            nfeats = {}\n            efeats = {}\n            labels = self.feature_store.labels\n            if labels is not None:\n                if isinstance(self.input_type, tuple):  # Edge labels.\n                    labels = {\n                        self.input_type: labels[output.edge[self.input_type]]\n                    }\n                else:  # Node labels.\n                    labels = {\n                        self.input_type: labels[output.node[self.input_type]]\n                    }\n            # Collect node features.\n            if output.node is not None:\n                for ntype in output.node.keys():\n                    if output.node[ntype].numel() > 0:\n                        fut = self.feature_store.lookup_features(\n                            is_node_feat=True,\n                            index=output.node[ntype],\n                            input_type=ntype,\n                        )\n                        nfeat = await to_asyncio_future(fut)\n                        nfeat = nfeat.to(torch.device(\"cpu\"))\n                        nfeats[ntype] = nfeat\n                    else:\n                        nfeats[ntype] = None\n            # Collect edge features\n            if output.edge is not None and self.with_edge_attr:\n                for edge_type in output.edge.keys():\n                    if output.edge[edge_type].numel() > 0:\n                        fut = self.feature_store.lookup_features(\n                            is_node_feat=False,\n                            index=output.edge[edge_type],\n                            input_type=edge_type,\n                        )\n                        efeat = await to_asyncio_future(fut)\n                        efeat = efeat.to(torch.device(\"cpu\"))\n                        efeats[edge_type] = efeat\n                    else:\n                        efeats[edge_type] = None\n\n        else:  # Homogeneous:\n            # Collect node labels.\n            if self.feature_store.labels is not None:\n                labels = self.feature_store.labels[output.node]\n            else:\n                labels = None\n            # Collect node features.\n            if output.node is not None:\n                fut = self.feature_store.lookup_features(\n                    is_node_feat=True, index=output.node)\n                nfeats = await to_asyncio_future(fut)\n                nfeats = nfeats.to(torch.device(\"cpu\"))\n            else:\n                nfeats = None\n            # Collect edge features.\n            if output.edge is not None and self.with_edge_attr:\n                fut = self.feature_store.lookup_features(\n                    is_node_feat=False, index=output.edge)\n                efeats = await to_asyncio_future(fut)\n                efeats = efeats.to(torch.device(\"cpu\"))\n            else:\n                efeats = None\n\n        output.metadata = (*output.metadata, nfeats, labels, efeats)\n        return output\n\n    @property\n    def edge_permutation(self) -> None:\n        return None\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(pid={mp.current_process().pid})'\n"
  },
  {
    "path": "torch_geometric/distributed/event_loop.py",
    "content": "import asyncio\nimport atexit\nimport logging\nfrom threading import BoundedSemaphore, Thread\nfrom typing import Callable, Optional\n\nimport torch\n\n# Based on graphlearn-for-pytorch repository python/distributed/event_loop.py\n# https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/\n# LICENSE: Apache v2\n\n\ndef to_asyncio_future(future: torch.futures.Future) -> asyncio.futures.Future:\n    r\"\"\"Convert a :class:`torch.futures.Future` to a :obj:`asyncio` future.\"\"\"\n    loop = asyncio.get_event_loop()\n    asyncio_future = loop.create_future()\n\n    def on_done(*_):\n        try:\n            result = future.wait()\n        except Exception as e:\n            loop.call_soon_threadsafe(asyncio_future.set_exception, e)\n        else:\n            loop.call_soon_threadsafe(asyncio_future.set_result, result)\n\n    future.add_done_callback(on_done)\n\n    return asyncio_future\n\n\nclass ConcurrentEventLoop:\n    r\"\"\"Concurrent event loop context.\n\n    Args:\n        concurrency: max processing concurrency.\n    \"\"\"\n    def __init__(self, concurrency: int):\n        self._concurrency = concurrency\n        self._sem = BoundedSemaphore(concurrency)\n        self._loop = asyncio.new_event_loop()\n        self._runner_t = Thread(target=self._run_loop)\n        self._runner_t.daemon = True\n\n        def cleanup():\n            for _ in range(self._concurrency):\n                self._sem.acquire()\n            for _ in range(self._concurrency):\n                self._sem.release()\n            if self._runner_t.is_alive():\n                self._loop.stop()\n                self._runner_t.join(timeout=1)\n                logging.debug(f'{self}: Closed `ConcurrentEventLoop`')\n\n        atexit.register(cleanup)\n\n    def start_loop(self):\n        if not self._runner_t.is_alive():\n            self._runner_t.start()\n\n    def wait_all(self):\n        r\"\"\"Wait for all pending tasks to be finished.\"\"\"\n        for _ in range(self._concurrency):\n            self._sem.acquire()\n        for _ in range(self._concurrency):\n            self._sem.release()\n\n    def add_task(self, coro, callback: Optional[Callable] = None):\n        r\"\"\"Adds an asynchronized coroutine task to run.\n\n        Args:\n            coro: The asynchronous coroutine function.\n            callback (callable, optional): The callback function applied on the\n                returned results after the coroutine task is finished.\n                (default: :obj:`None`)\n\n        Note that any result returned by :obj:`callback` will be ignored.\n        \"\"\"\n        def on_done(f: asyncio.futures.Future):\n            try:\n                res = f.result()\n                if callback is not None:\n                    callback(res)\n            except Exception as e:\n                logging.error(f\"Coroutine task failed with error: {e}\")\n            self._sem.release()\n\n        self._sem.acquire()\n        fut = asyncio.run_coroutine_threadsafe(coro, self._loop)\n        fut.add_done_callback(on_done)\n\n    def run_task(self, coro):\n        r\"\"\"Runs a coroutine task synchronously.\n\n        Args:\n            coro: The synchronous coroutine function.\n        \"\"\"\n        with self._sem:\n            fut = asyncio.run_coroutine_threadsafe(coro, self._loop)\n            return fut.result()\n\n    def _run_loop(self):\n        self._loop.run_forever()\n"
  },
  {
    "path": "torch_geometric/distributed/local_feature_store.py",
    "content": "import copy\nimport os.path as osp\nfrom dataclasses import dataclass\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import FeatureStore, TensorAttr\nfrom torch_geometric.data.feature_store import _FieldStatus\nfrom torch_geometric.distributed.partition import load_partition_info\nfrom torch_geometric.distributed.rpc import (\n    RPCCallBase,\n    RPCRouter,\n    rpc_async,\n    rpc_register,\n)\nfrom torch_geometric.io import fs\nfrom torch_geometric.typing import EdgeType, NodeOrEdgeType, NodeType\n\n\nclass RPCCallFeatureLookup(RPCCallBase):\n    r\"\"\"A wrapper for RPC calls to the feature store.\"\"\"\n    def __init__(self, dist_feature: FeatureStore):\n        super().__init__()\n        self.dist_feature = dist_feature\n\n    def rpc_async(self, *args, **kwargs):\n        return self.dist_feature._rpc_local_feature_get(*args, **kwargs)\n\n    def rpc_sync(self, *args, **kwargs):\n        raise NotImplementedError\n\n\n@dataclass\nclass LocalTensorAttr(TensorAttr):\n    r\"\"\"Tensor attribute for storing features without :obj:`index`.\"\"\"\n    def __init__(\n        self,\n        group_name: Optional[Union[NodeType, EdgeType]] = _FieldStatus.UNSET,\n        attr_name: Optional[str] = _FieldStatus.UNSET,\n        index=None,\n    ):\n        super().__init__(group_name, attr_name, index)\n\n\nclass LocalFeatureStore(FeatureStore):\n    r\"\"\"Implements the :class:`~torch_geometric.data.FeatureStore` interface to\n    act as a local feature store for distributed training.\n    \"\"\"\n    def __init__(self):\n        super().__init__(tensor_attr_cls=LocalTensorAttr)\n        self._feat: Dict[Tuple[Union[NodeType, EdgeType], str], Tensor] = {}\n        # Save the global node/edge IDs:\n        self._global_id: Dict[Union[NodeType, EdgeType], Tensor] = {}\n        # Save the mapping from global node/edge IDs to indices in `_feat`:\n        self._global_id_to_index: Dict[Union[NodeType, EdgeType], Tensor] = {}\n        # For partition/RPC info related to distributed features:\n        self.num_partitions: int = 1\n        self.partition_idx: int = 0\n        # Mapping between node ID and partition ID:\n        self.node_feat_pb: Union[Tensor, Dict[NodeType, Tensor]]\n        # Mapping between edge ID and partition ID:\n        self.edge_feat_pb: Union[Tensor, Dict[EdgeType, Tensor]]\n        # Node labels:\n        self.labels: Optional[Tensor] = None\n\n        self.local_only: bool = False\n        self.rpc_router: Optional[RPCRouter] = None\n        self.meta: Optional[Dict] = None\n        self.rpc_call_id: Optional[int] = None\n\n    @staticmethod\n    def key(attr: TensorAttr) -> Tuple[str, str]:\n        return (attr.group_name, attr.attr_name)\n\n    def put_global_id(\n        self,\n        global_id: Tensor,\n        group_name: Union[NodeType, EdgeType],\n    ) -> bool:\n        self._global_id[group_name] = global_id\n        self._set_global_id_to_index(group_name)\n        return True\n\n    def get_global_id(\n        self,\n        group_name: Union[NodeType, EdgeType],\n    ) -> Optional[Tensor]:\n        return self._global_id.get(group_name)\n\n    def remove_global_id(self, group_name: Union[NodeType, EdgeType]) -> bool:\n        return self._global_id.pop(group_name) is not None\n\n    def _set_global_id_to_index(self, group_name: Union[NodeType, EdgeType]):\n        global_id = self.get_global_id(group_name)\n\n        if global_id is None:\n            return\n\n        # TODO Compute this mapping without materializing a full-sized tensor:\n        global_id_to_index = global_id.new_full((int(global_id.max()) + 1, ),\n                                                fill_value=-1)\n        global_id_to_index[global_id] = torch.arange(global_id.numel())\n        self._global_id_to_index[group_name] = global_id_to_index\n\n    def _put_tensor(self, tensor: Tensor, attr: TensorAttr) -> bool:\n        assert attr.index is None\n        self._feat[self.key(attr)] = tensor\n        return True\n\n    def _get_tensor(self, attr: TensorAttr) -> Optional[Tensor]:\n        tensor = self._feat.get(self.key(attr))\n\n        if tensor is None:\n            return None\n\n        if attr.index is None:  # Empty indices return the full tensor:\n            return tensor\n\n        return tensor[attr.index]\n\n    def _remove_tensor(self, attr: TensorAttr) -> bool:\n        assert attr.index is None\n        return self._feat.pop(self.key(attr), None) is not None\n\n    def get_tensor_from_global_id(self, *args, **kwargs) -> Optional[Tensor]:\n        attr = self._tensor_attr_cls.cast(*args, **kwargs)\n        assert attr.index is not None\n\n        attr = copy.copy(attr)\n        attr.index = self._global_id_to_index[attr.group_name][attr.index]\n\n        return self.get_tensor(attr)\n\n    def _get_tensor_size(self, attr: TensorAttr) -> Tuple[int, ...]:\n        return self._get_tensor(attr).size()\n\n    def get_all_tensor_attrs(self) -> List[LocalTensorAttr]:\n        return [self._tensor_attr_cls.cast(*key) for key in self._feat.keys()]\n\n    def set_rpc_router(self, rpc_router: RPCRouter):\n        self.rpc_router = rpc_router\n\n        if not self.local_only:\n            if self.rpc_router is None:\n                raise ValueError(\"An RPC router must be provided\")\n            rpc_call = RPCCallFeatureLookup(self)\n            self.rpc_call_id = rpc_register(rpc_call)\n        else:\n            self.rpc_call_id = None\n\n    def has_edge_attr(self) -> bool:\n        has_edge_attr = False\n        for k in [key for key in self._feat.keys() if 'edge_attr' in key]:\n            try:\n                self.get_tensor(k[0], 'edge_attr')\n                has_edge_attr = True\n            except KeyError:\n                pass\n        return has_edge_attr\n\n    def lookup_features(\n        self,\n        index: Tensor,\n        is_node_feat: bool = True,\n        input_type: Optional[NodeOrEdgeType] = None,\n    ) -> torch.futures.Future:\n        r\"\"\"Lookup of local/remote features.\"\"\"\n        remote_fut = self._remote_lookup_features(index, is_node_feat,\n                                                  input_type)\n        local_feature = self._local_lookup_features(index, is_node_feat,\n                                                    input_type)\n        res_fut = torch.futures.Future()\n\n        def when_finish(*_):\n            try:\n                remote_feature_list = remote_fut.wait()\n                # combine the feature from remote and local\n                result = torch.zeros(\n                    index.size(0),\n                    local_feature[0].size(1),\n                    dtype=local_feature[0].dtype,\n                )\n                result[local_feature[1]] = local_feature[0]\n                for remote in remote_feature_list:\n                    result[remote[1]] = remote[0]\n            except Exception as e:\n                res_fut.set_exception(e)\n            else:\n                res_fut.set_result(result)\n\n        remote_fut.add_done_callback(when_finish)\n        return res_fut\n\n    def _local_lookup_features(\n        self,\n        index: Tensor,\n        is_node_feat: bool = True,\n        input_type: Optional[Union[NodeType, EdgeType]] = None,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Lookup the features in local nodes based on node/edge IDs.\"\"\"\n        pb = self.node_feat_pb if is_node_feat else self.edge_feat_pb\n\n        input_order = torch.arange(index.size(0), dtype=torch.long)\n        if self.meta['is_hetero']:\n            partition_ids = pb[input_type][index]\n        else:\n            partition_ids = pb[index]\n\n        local_mask = partition_ids == self.partition_idx\n        local_ids = torch.masked_select(index, local_mask)\n        local_index = torch.masked_select(input_order, local_mask)\n\n        if self.meta['is_hetero']:\n            if is_node_feat:\n                kwargs = dict(group_name=input_type, attr_name='x')\n                ret_feat = self.get_tensor_from_global_id(\n                    index=local_ids, **kwargs)\n            else:\n                kwargs = dict(group_name=input_type, attr_name='edge_attr')\n                ret_feat = self.get_tensor_from_global_id(\n                    index=local_ids, **kwargs)\n        else:\n            if is_node_feat:\n                kwargs = dict(group_name=None, attr_name='x')\n                ret_feat = self.get_tensor_from_global_id(\n                    index=local_ids, **kwargs)\n            else:\n                kwargs = dict(group_name=(None, None), attr_name='edge_attr')\n                ret_feat = self.get_tensor_from_global_id(\n                    index=local_ids, **kwargs)\n\n        return ret_feat, local_index\n\n    def _remote_lookup_features(\n        self,\n        index: Tensor,\n        is_node_feat: bool = True,\n        input_type: Optional[Union[NodeType, EdgeType]] = None,\n    ) -> torch.futures.Future:\n        r\"\"\"Fetch the remote features with the remote node/edge IDs.\"\"\"\n        pb = self.node_feat_pb if is_node_feat else self.edge_feat_pb\n\n        input_order = torch.arange(index.size(0), dtype=torch.long)\n        if self.meta['is_hetero']:\n            partition_ids = pb[input_type][index]\n        else:\n            partition_ids = pb[index]\n\n        futs, indexes = [], []\n        for pidx in range(0, self.num_partitions):\n            if pidx == self.partition_idx:\n                continue\n            remote_mask = partition_ids == pidx\n            remote_ids = index[remote_mask]\n            if remote_ids.shape[0] > 0:\n                to_worker = self.rpc_router.get_to_worker(pidx)\n                futs.append(\n                    rpc_async(\n                        to_worker,\n                        self.rpc_call_id,\n                        args=(remote_ids.cpu(), is_node_feat, input_type),\n                    ))\n                indexes.append(torch.masked_select(input_order, remote_mask))\n        collect_fut = torch.futures.collect_all(futs)\n        res_fut = torch.futures.Future()\n\n        def when_finish(*_):\n            try:\n                fut_list = collect_fut.wait()\n                result = []\n                for i, fut in enumerate(fut_list):\n                    result.append((fut.wait(), indexes[i]))\n            except Exception as e:\n                res_fut.set_exception(e)\n            else:\n                res_fut.set_result(result)\n\n        collect_fut.add_done_callback(when_finish)\n        return res_fut\n\n    def _rpc_local_feature_get(\n        self,\n        index: Tensor,\n        is_node_feat: bool = True,\n        input_type: Optional[Union[NodeType, EdgeType]] = None,\n    ) -> Tensor:\n        r\"\"\"Lookup of features in remote nodes.\"\"\"\n        if self.meta['is_hetero']:\n            feat = self\n            if is_node_feat:\n                kwargs = dict(group_name=input_type, attr_name='x')\n                ret_feat = feat.get_tensor_from_global_id(\n                    index=index, **kwargs)\n            else:\n                kwargs = dict(group_name=input_type, attr_name='edge_attr')\n                ret_feat = feat.get_tensor_from_global_id(\n                    index=index, **kwargs)\n        else:\n            feat = self\n            if is_node_feat:\n                kwargs = dict(group_name=None, attr_name='x')\n                ret_feat = feat.get_tensor_from_global_id(\n                    index=index, **kwargs)\n            else:\n                kwargs = dict(group_name=(None, None), attr_name='edge_attr')\n                ret_feat = feat.get_tensor_from_global_id(\n                    index=index, **kwargs)\n\n        return ret_feat\n\n    # Initialization ##########################################################\n\n    @classmethod\n    def from_data(\n        cls,\n        node_id: Tensor,\n        x: Optional[Tensor] = None,\n        y: Optional[Tensor] = None,\n        edge_id: Optional[Tensor] = None,\n        edge_attr: Optional[Tensor] = None,\n    ) -> 'LocalFeatureStore':\n        r\"\"\"Creates a local feature store from homogeneous :pyg:`PyG` tensors.\n\n        Args:\n            node_id (torch.Tensor): The global identifier for every local node.\n            x (torch.Tensor, optional): The node features.\n                (default: :obj:`None`)\n            y (torch.Tensor, optional): The node labels. (default: :obj:`None`)\n            edge_id (torch.Tensor, optional): The global identifier for every\n                local edge. (default: :obj:`None`)\n            edge_attr (torch.Tensor, optional): The edge features.\n                (default: :obj:`None`)\n        \"\"\"\n        feat_store = cls()\n        feat_store.put_global_id(node_id, group_name=None)\n        if x is not None:\n            feat_store.put_tensor(x, group_name=None, attr_name='x')\n        if y is not None:\n            feat_store.put_tensor(y, group_name=None, attr_name='y')\n        if edge_id is not None:\n            feat_store.put_global_id(edge_id, group_name=(None, None))\n        if edge_attr is not None:\n            if edge_id is None:\n                raise ValueError(\"'edge_id' needs to be present in case \"\n                                 \"'edge_attr' is passed\")\n            feat_store.put_tensor(edge_attr, group_name=(None, None),\n                                  attr_name='edge_attr')\n        return feat_store\n\n    @classmethod\n    def from_hetero_data(\n        cls,\n        node_id_dict: Dict[NodeType, Tensor],\n        x_dict: Optional[Dict[NodeType, Tensor]] = None,\n        y_dict: Optional[Dict[NodeType, Tensor]] = None,\n        edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None,\n        edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,\n    ) -> 'LocalFeatureStore':\n        r\"\"\"Creates a local graph store from heterogeneous :pyg:`PyG` tensors.\n\n        Args:\n            node_id_dict (Dict[NodeType, torch.Tensor]): The global identifier\n                for every local node of every node type.\n            x_dict (Dict[NodeType, torch.Tensor], optional): The node features\n                of every node type. (default: :obj:`None`)\n            y_dict (Dict[NodeType, torch.Tensor], optional): The node labels of\n                every node type. (default: :obj:`None`)\n            edge_id_dict (Dict[EdgeType, torch.Tensor], optional): The global\n                identifier for every local edge of every edge types.\n                (default: :obj:`None`)\n            edge_attr_dict (Dict[EdgeType, torch.Tensor], optional): The edge\n                features of every edge type. (default: :obj:`None`)\n        \"\"\"\n        feat_store = cls()\n\n        for node_type, node_id in node_id_dict.items():\n            feat_store.put_global_id(node_id, group_name=node_type)\n        if x_dict is not None:\n            for node_type, x in x_dict.items():\n                feat_store.put_tensor(x, group_name=node_type, attr_name='x')\n        if y_dict is not None:\n            for node_type, y in y_dict.items():\n                feat_store.put_tensor(y, group_name=node_type, attr_name='y')\n        if edge_id_dict is not None:\n            for edge_type, edge_id in edge_id_dict.items():\n                feat_store.put_global_id(edge_id, group_name=edge_type)\n        if edge_attr_dict is not None:\n            for edge_type, edge_attr in edge_attr_dict.items():\n                if edge_id_dict is None or edge_type not in edge_id_dict:\n                    raise ValueError(\"'edge_id' needs to be present in case \"\n                                     \"'edge_attr' is passed\")\n                feat_store.put_tensor(edge_attr, group_name=edge_type,\n                                      attr_name='edge_attr')\n\n        return feat_store\n\n    @classmethod\n    def from_partition(cls, root: str, pid: int) -> 'LocalFeatureStore':\n        part_dir = osp.join(root, f'part_{pid}')\n        assert osp.exists(part_dir)\n        feat_store = cls()\n        (\n            meta,\n            num_partitions,\n            partition_idx,\n            node_pb,\n            edge_pb,\n        ) = load_partition_info(root, pid)\n        feat_store.num_partitions = num_partitions\n        feat_store.partition_idx = partition_idx\n        feat_store.node_feat_pb = node_pb\n        feat_store.edge_feat_pb = edge_pb\n        feat_store.meta = meta\n\n        node_feats: Optional[Dict[str, Any]] = None\n        if osp.exists(osp.join(part_dir, 'node_feats.pt')):\n            node_feats = fs.torch_load(osp.join(part_dir, 'node_feats.pt'))\n\n        edge_feats: Optional[Dict[str, Any]] = None\n        if osp.exists(osp.join(part_dir, 'edge_feats.pt')):\n            edge_feats = fs.torch_load(osp.join(part_dir, 'edge_feats.pt'))\n\n        if not meta['is_hetero'] and node_feats is not None:\n            feat_store.put_global_id(node_feats['global_id'], group_name=None)\n            for key, value in node_feats['feats'].items():\n                feat_store.put_tensor(value, group_name=None, attr_name=key)\n            if 'time' in node_feats:\n                feat_store.put_tensor(node_feats['time'], group_name=None,\n                                      attr_name='time')\n\n        if not meta['is_hetero'] and edge_feats is not None:\n            if 'global_id' in edge_feats:\n                feat_store.put_global_id(edge_feats['global_id'],\n                                         group_name=(None, None))\n            if 'feats' in edge_feats:\n                for key, value in edge_feats['feats'].items():\n                    feat_store.put_tensor(value, group_name=(None, None),\n                                          attr_name=key)\n            if 'edge_time' in edge_feats:\n                feat_store.put_tensor(edge_feats['edge_time'],\n                                      group_name=(None, None),\n                                      attr_name='edge_time')\n\n        if meta['is_hetero'] and node_feats is not None:\n            for node_type, node_feat in node_feats.items():\n                feat_store.put_global_id(node_feat['global_id'],\n                                         group_name=node_type)\n                for key, value in node_feat['feats'].items():\n                    feat_store.put_tensor(value, group_name=node_type,\n                                          attr_name=key)\n                if 'time' in node_feat:\n                    feat_store.put_tensor(node_feat['time'],\n                                          group_name=node_type,\n                                          attr_name='time')\n\n        if meta['is_hetero'] and edge_feats is not None:\n            for edge_type, edge_feat in edge_feats.items():\n                if 'global_id' in edge_feat:\n                    feat_store.put_global_id(edge_feat['global_id'],\n                                             group_name=edge_type)\n                if 'feats' in edge_feat:\n                    for key, value in edge_feat['feats'].items():\n                        feat_store.put_tensor(value, group_name=edge_type,\n                                              attr_name=key)\n                if 'edge_time' in edge_feat:\n                    feat_store.put_tensor(edge_feat['edge_time'],\n                                          group_name=edge_type,\n                                          attr_name='edge_time')\n\n        return feat_store\n"
  },
  {
    "path": "torch_geometric/distributed/local_graph_store.py",
    "content": "import os.path as osp\nfrom typing import Any, Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import EdgeAttr, GraphStore\nfrom torch_geometric.distributed.partition import load_partition_info\nfrom torch_geometric.io import fs\nfrom torch_geometric.typing import EdgeTensorType, EdgeType, NodeType\nfrom torch_geometric.utils import sort_edge_index\n\n\nclass LocalGraphStore(GraphStore):\n    r\"\"\"Implements the :class:`~torch_geometric.data.GraphStore` interface to\n    act as a local graph store for distributed training.\n    \"\"\"\n    def __init__(self):\n        super().__init__()\n        self._edge_index: Dict[Tuple, EdgeTensorType] = {}\n        self._edge_attr: Dict[Tuple, EdgeAttr] = {}\n        self._edge_id: Dict[Tuple, Tensor] = {}\n\n        self.num_partitions = 1\n        self.partition_idx = 0\n        # Mapping between node ID and partition ID\n        self.node_pb: Union[Tensor, Dict[NodeType, Tensor]] = None\n        # Mapping between edge ID and partition ID\n        self.edge_pb: Union[Tensor, Dict[EdgeType, Tensor]] = None\n        # Meta information related to partition and graph store info\n        self.meta: Optional[Dict[Any, Any]] = None\n        # If data is sorted based on destination nodes (CSC format):\n        self.is_sorted: Optional[bool] = None\n\n    @staticmethod\n    def key(attr: EdgeAttr) -> Tuple:\n        return (attr.edge_type, attr.layout.value)\n\n    def get_partition_ids_from_nids(\n        self,\n        ids: torch.Tensor,\n        node_type: Optional[NodeType] = None,\n    ) -> Tensor:\n        r\"\"\"Returns the partition IDs of node IDs for a specific node type.\"\"\"\n        if self.meta['is_hetero']:\n            return self.node_pb[node_type][ids]\n        else:\n            return self.node_pb[ids]\n\n    def get_partition_ids_from_eids(self, eids: torch.Tensor,\n                                    edge_type: Optional[EdgeType] = None):\n        r\"\"\"Returns the partition IDs of edge IDs for a specific edge type.\"\"\"\n        if self.meta['is_hetero']:\n            return self.edge_pb[edge_type][eids]\n        else:\n            return self.edge_pb[eids]\n\n    def put_edge_id(self, edge_id: Tensor, *args, **kwargs) -> bool:\n        edge_attr = self._edge_attr_cls.cast(*args, **kwargs)\n        self._edge_id[self.key(edge_attr)] = edge_id\n        return True\n\n    def get_edge_id(self, *args, **kwargs) -> Optional[EdgeTensorType]:\n        edge_attr = self._edge_attr_cls.cast(*args, **kwargs)\n        return self._edge_id.get(self.key(edge_attr))\n\n    def remove_edge_id(self, *args, **kwargs) -> bool:\n        edge_attr = self._edge_attr_cls.cast(*args, **kwargs)\n        return self._edge_id.pop(self.key(edge_attr), None) is not None\n\n    def _put_edge_index(self, edge_index: EdgeTensorType,\n                        edge_attr: EdgeAttr) -> bool:\n        self._edge_index[self.key(edge_attr)] = edge_index\n        self._edge_attr[self.key(edge_attr)] = edge_attr\n        return True\n\n    def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:\n        return self._edge_index.get(self.key(edge_attr), None)\n\n    def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool:\n        self._edge_attr.pop(self.key(edge_attr), None)\n        return self._edge_index.pop(self.key(edge_attr), None) is not None\n\n    def get_all_edge_attrs(self) -> List[EdgeAttr]:\n        return [self._edge_attr[key] for key in self._edge_index.keys()]\n\n    # Initialization ##########################################################\n\n    @classmethod\n    def from_data(\n        cls,\n        edge_id: Tensor,\n        edge_index: Tensor,\n        num_nodes: int,\n        is_sorted: bool = False,\n    ) -> 'LocalGraphStore':\n        r\"\"\"Creates a local graph store from a homogeneous or heterogenous\n        :pyg:`PyG` graph.\n\n        Args:\n            edge_id (torch.Tensor): The global identifier for every local edge.\n            edge_index (torch.Tensor): The local edge indices.\n            num_nodes (int): The number of nodes in the local graph.\n            is_sorted (bool): Whether edges are sorted by column/destination\n                nodes (CSC format). (default: :obj:`False`)\n        \"\"\"\n        graph_store = cls()\n        graph_store.meta = {'is_hetero': False}\n\n        if not is_sorted:\n            edge_index, edge_id = sort_edge_index(\n                edge_index,\n                edge_id,\n                sort_by_row=False,\n            )\n\n        attr = dict(\n            edge_type=None,\n            layout='coo',\n            size=(num_nodes, num_nodes),\n            is_sorted=True,\n        )\n\n        graph_store.put_edge_index(edge_index, **attr)\n        graph_store.put_edge_id(edge_id, **attr)\n\n        return graph_store\n\n    @classmethod\n    def from_hetero_data(\n        cls,\n        edge_id_dict: Dict[EdgeType, Tensor],\n        edge_index_dict: Dict[EdgeType, Tensor],\n        num_nodes_dict: Dict[NodeType, int],\n        is_sorted: bool = False,\n    ) -> \"LocalGraphStore\":\n        r\"\"\"Creates a local graph store from a heterogeneous :pyg:`PyG` graph.\n\n        Args:\n            edge_id_dict (Dict[EdgeType, torch.Tensor]): The global identifier\n                for every local edge of every edge type.\n            edge_index_dict (Dict[EdgeType, torch.Tensor]): The local edge\n                indices of every edge type.\n            num_nodes_dict: (Dict[str, int]): The number of nodes for every\n                node type.\n            is_sorted (bool): Whether edges are sorted by column/destination\n                nodes (CSC format). (default: :obj:`False`)\n        \"\"\"\n        graph_store = cls()\n        graph_store.meta = {'is_hetero': True}\n\n        for edge_type, edge_index in edge_index_dict.items():\n            src, _, dst = edge_type\n            attr = dict(\n                edge_type=edge_type,\n                layout='coo',\n                size=(num_nodes_dict[src], num_nodes_dict[dst]),\n                is_sorted=True,\n            )\n            edge_id = edge_id_dict[edge_type]\n            if not is_sorted:\n                edge_index, edge_id = sort_edge_index(\n                    edge_index,\n                    edge_id,\n                    sort_by_row=False,\n                )\n            graph_store.put_edge_index(edge_index, **attr)\n            graph_store.put_edge_id(edge_id, **attr)\n        return graph_store\n\n    @classmethod\n    def from_partition(cls, root: str, pid: int) -> 'LocalGraphStore':\n        part_dir = osp.join(root, f'part_{pid}')\n        assert osp.exists(part_dir)\n        graph_store = cls()\n        (\n            meta,\n            num_partitions,\n            partition_idx,\n            node_pb,\n            edge_pb,\n        ) = load_partition_info(root, pid)\n        graph_store.num_partitions = num_partitions\n        graph_store.partition_idx = partition_idx\n        graph_store.node_pb = node_pb\n        graph_store.edge_pb = edge_pb\n        graph_store.meta = meta\n\n        graph_data = fs.torch_load(osp.join(part_dir, 'graph.pt'))\n        graph_store.is_sorted = meta['is_sorted']\n\n        if not meta['is_hetero']:\n            edge_index = torch.stack((graph_data['row'], graph_data['col']),\n                                     dim=0)\n            edge_id = graph_data['edge_id']\n            if not graph_store.is_sorted:\n                edge_index, edge_id = sort_edge_index(edge_index, edge_id,\n                                                      sort_by_row=False)\n\n            attr = dict(\n                edge_type=None,\n                layout='coo',\n                size=graph_data['size'],\n                is_sorted=True,\n            )\n            graph_store.put_edge_index(edge_index, **attr)\n            graph_store.put_edge_id(edge_id, **attr)\n\n        if meta['is_hetero']:\n            for edge_type, data in graph_data.items():\n                attr = dict(\n                    edge_type=edge_type,\n                    layout='coo',\n                    size=data['size'],\n                    is_sorted=True,\n                )\n                edge_index = torch.stack((data['row'], data['col']), dim=0)\n                edge_id = data['edge_id']\n\n                if not graph_store.is_sorted:\n                    edge_index, edge_id = sort_edge_index(\n                        edge_index, edge_id, sort_by_row=False)\n                graph_store.put_edge_index(edge_index, **attr)\n                graph_store.put_edge_id(edge_id, **attr)\n\n        return graph_store\n"
  },
  {
    "path": "torch_geometric/distributed/partition.py",
    "content": "import json\nimport logging\nimport os\nimport os.path as osp\nfrom collections import defaultdict\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\n\nimport torch_geometric.distributed as pyg_dist\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.io import fs\nfrom torch_geometric.loader.cluster import ClusterData\nfrom torch_geometric.sampler.utils import sort_csc\nfrom torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType\n\n\nclass Partitioner:\n    r\"\"\"Partitions the graph and its features of a\n    :class:`~torch_geometric.data.Data` or\n    :class:`~torch_geometric.data.HeteroData` object.\n\n    Partitioned data output will be structured as shown below.\n\n    **Homogeneous graphs:**\n\n    .. code-block:: none\n\n        root/\n        |-- META.json\n        |-- node_map.pt\n        |-- edge_map.pt\n        |-- part0/\n            |-- graph.pt\n            |-- node_feats.pt\n            |-- edge_feats.pt\n        |-- part1/\n            |-- graph.pt\n            |-- node_feats.pt\n            |-- edge_feats.pt\n\n    **Heterogeneous graphs:**\n\n    .. code-block:: none\n\n        root/\n        |-- META.json\n        |-- node_map/\n            |-- ntype1.pt\n            |-- ntype2.pt\n        |-- edge_map/\n            |-- etype1.pt\n            |-- etype2.pt\n        |-- part0/\n            |-- graph.pt\n            |-- node_feats.pt\n            |-- edge_feats.pt\n        |-- part1/\n            |-- graph.pt\n            |-- node_feats.pt\n            |-- edge_feats.pt\n\n    Args:\n        data (Data or HeteroData): The data object.\n        num_parts (int): The number of partitions.\n        recursive (bool, optional): If set to :obj:`True`, will use multilevel\n            recursive bisection instead of multilevel k-way partitioning.\n            (default: :obj:`False`)\n        root (str): Root directory where the partitioned dataset should be\n            saved.\n    \"\"\"\n    def __init__(\n        self,\n        data: Union[Data, HeteroData],\n        num_parts: int,\n        root: str,\n        recursive: bool = False,\n    ):\n        assert num_parts > 1\n\n        self.data = data\n        self.num_parts = num_parts\n        self.root = root\n        self.recursive = recursive\n\n    @property\n    def is_hetero(self) -> bool:\n        return isinstance(self.data, HeteroData)\n\n    @property\n    def is_node_level_time(self) -> bool:\n        if 'time' not in self.data:\n            return False\n\n        if self.is_hetero:\n            return any(['time' in store for store in self.data.node_stores])\n\n        return self.data.is_node_attr('time')\n\n    @property\n    def is_edge_level_time(self) -> bool:\n        if 'edge_time' in self.data:\n            return True\n\n        if 'time' not in self.data:\n            return False\n\n        if self.is_hetero:\n            return any(['time' in store for store in self.data.edge_stores])\n\n        return self.data.is_edge_attr('time')\n\n    @property\n    def node_types(self) -> Optional[List[NodeType]]:\n        return self.data.node_types if self.is_hetero else None\n\n    @property\n    def edge_types(self) -> Optional[List[EdgeType]]:\n        return self.data.edge_types if self.is_hetero else None\n\n    def generate_partition(self):\n        r\"\"\"Generates the partitions.\"\"\"\n        os.makedirs(self.root, exist_ok=True)\n\n        if self.is_hetero and self.is_node_level_time:\n            time_data = {  # Get temporal information before converting data:\n                node_type: self.data[node_type].time\n                for node_type in self.data.node_types\n            }\n\n        data = self.data.to_homogeneous() if self.is_hetero else self.data\n        cluster_data = ClusterData(\n            data,\n            num_parts=self.num_parts,\n            recursive=self.recursive,\n            log=True,\n            keep_inter_cluster_edges=True,\n            sparse_format='csc',\n        )\n\n        node_perm = cluster_data.partition.node_perm\n        partptr = cluster_data.partition.partptr\n        edge_perm = cluster_data.partition.edge_perm\n\n        node_map = torch.empty(data.num_nodes, dtype=torch.int64)\n        edge_map = torch.empty(data.num_edges, dtype=torch.int64)\n        node_offset, edge_offset = {}, {}\n\n        if self.is_hetero:\n            offset = 0\n            for node_type in self.node_types:\n                node_offset[node_type] = offset\n                offset += self.data[node_type].num_nodes\n\n            offset = 0\n            for edge_name in self.edge_types:\n                edge_offset[edge_name] = offset\n                offset += self.data.num_edges_dict[edge_name]\n\n            edge_start = 0\n            for pid in range(self.num_parts):\n                logging.info(f'Saving graph partition {pid}')\n                path = osp.join(self.root, f'part_{pid}')\n                os.makedirs(path, exist_ok=True)\n\n                part_data = cluster_data[pid]\n                start, end = int(partptr[pid]), int(partptr[pid + 1])\n\n                num_edges = part_data.num_edges\n                edge_id = edge_perm[edge_start:edge_start + num_edges]\n                edge_map[edge_id] = pid\n                edge_start += num_edges\n\n                node_id = node_perm[start:end]\n                node_map[node_id] = pid\n\n                graph = {}\n                efeat = defaultdict(dict)\n                for i, edge_type in enumerate(self.edge_types):\n                    # Row vector refers to source nodes.\n                    # Column vector refers to destination nodes.\n                    src, _, dst = edge_type\n                    size = (self.data[src].num_nodes, self.data[dst].num_nodes)\n\n                    mask = part_data.edge_type == i\n                    row = part_data.edge_index[0, mask]\n                    col = part_data.edge_index[1, mask]\n                    global_col = node_id[col]\n                    global_row = node_perm[row]\n\n                    edge_time = src_node_time = None\n                    if self.is_edge_level_time:\n                        if 'edge_time' in part_data:\n                            edge_time = part_data.edge_time[mask]\n                        elif 'time' in part_data:\n                            edge_time = part_data.time[mask]\n\n                    elif self.is_node_level_time:\n                        src_node_time = time_data[src]\n\n                    offsetted_row = global_row - node_offset[src]\n                    offsetted_col = global_col - node_offset[dst]\n                    # Sort by column to avoid keeping track of permutations in\n                    # `NeighborSampler` when converting to CSC format:\n                    offsetted_row, offsetted_col, perm = sort_csc(\n                        offsetted_row, offsetted_col, src_node_time, edge_time)\n\n                    global_eid = edge_id[mask][perm]\n                    assert torch.equal(\n                        data.edge_index[:, global_eid],\n                        torch.stack((offsetted_row + node_offset[src],\n                                     offsetted_col + node_offset[dst]), dim=0),\n                    )\n                    offsetted_eid = global_eid - edge_offset[edge_type]\n                    assert torch.equal(\n                        self.data[edge_type].edge_index[:, offsetted_eid],\n                        torch.stack((\n                            offsetted_row,\n                            offsetted_col,\n                        ), dim=0),\n                    )\n                    graph[edge_type] = {\n                        'edge_id': global_eid,\n                        'row': offsetted_row,\n                        'col': offsetted_col,\n                        'size': size,\n                    }\n\n                    if 'edge_attr' in part_data:\n                        edge_attr = part_data.edge_attr[mask][perm]\n                        efeat[edge_type].update({\n                            'global_id':\n                            offsetted_eid,\n                            'feats':\n                            dict(edge_attr=edge_attr),\n                        })\n                    if self.is_edge_level_time:\n                        efeat[edge_type].update({'edge_time': edge_time[perm]})\n\n                torch.save(efeat, osp.join(path, 'edge_feats.pt'))\n                torch.save(graph, osp.join(path, 'graph.pt'))\n\n                nfeat = {}\n                for i, node_type in enumerate(self.node_types):\n                    mask = part_data.node_type == i\n                    x = part_data.x[mask] if 'x' in part_data else None\n                    nfeat[node_type] = {\n                        'global_id': node_id[mask],\n                        'id': node_id[mask] - node_offset[node_type],\n                        'feats': dict(x=x),\n                    }\n                    if self.is_node_level_time:\n                        nfeat[node_type].update({'time': time_data[node_type]})\n\n                torch.save(nfeat, osp.join(path, 'node_feats.pt'))\n\n            logging.info('Saving partition mapping info')\n            path = osp.join(self.root, 'node_map')\n            os.makedirs(path, exist_ok=True)\n            for i, node_type in enumerate(self.node_types):\n                mask = data.node_type == i\n                torch.save(node_map[mask], osp.join(path, f'{node_type}.pt'))\n\n            path = osp.join(self.root, 'edge_map')\n            os.makedirs(path, exist_ok=True)\n            for i, edge_type in enumerate(self.edge_types):\n                mask = data.edge_type == i\n                torch.save(\n                    edge_map[mask],\n                    osp.join(path, f'{EdgeTypeStr(edge_type)}.pt'),\n                )\n\n        else:  # `if not self.is_hetero:`\n            edge_start = 0\n            for pid in range(self.num_parts):\n                logging.info(f'Saving graph partition {pid}')\n                path = osp.join(self.root, f'part_{pid}')\n                os.makedirs(path, exist_ok=True)\n\n                part_data = cluster_data[pid]\n                start, end = int(partptr[pid]), int(partptr[pid + 1])\n\n                num_edges = part_data.num_edges\n                edge_id = edge_perm[edge_start:edge_start + num_edges]\n                edge_map[edge_id] = pid\n                edge_start += num_edges\n\n                node_id = node_perm[start:end]  # global node_ids\n                node_map[node_id] = pid  # 0 or 1\n\n                row = part_data.edge_index[0]\n                col = part_data.edge_index[1]\n\n                global_col = node_id[col]  # part_ids -> global\n                global_row = node_perm[row]\n\n                edge_time = node_time = None\n                if self.is_edge_level_time:\n                    if 'edge_time' in part_data:\n                        edge_time = part_data.edge_time\n                    elif 'time' in part_data:\n                        edge_time = part_data.time\n\n                elif self.is_node_level_time:\n                    node_time = data.time\n\n                # Sort by column to avoid keeping track of permutations in\n                # `NeighborSampler` when converting to CSC format:\n                global_row, global_col, perm = sort_csc(\n                    global_row, global_col, node_time, edge_time)\n\n                edge_id = edge_id[perm]\n\n                assert torch.equal(\n                    self.data.edge_index[:, edge_id],\n                    torch.stack((global_row, global_col)),\n                )\n                if 'edge_attr' in part_data:\n                    edge_attr = part_data.edge_attr[perm]\n                    assert torch.equal(self.data.edge_attr[edge_id, :],\n                                       edge_attr)\n\n                torch.save(\n                    {\n                        'edge_id': edge_id,\n                        'row': global_row,\n                        'col': global_col,\n                        'size': (data.num_nodes, data.num_nodes),\n                    }, osp.join(path, 'graph.pt'))\n\n                nfeat = {\n                    'global_id': node_id,\n                    'feats': dict(x=part_data.x),\n                }\n                if self.is_node_level_time:\n                    nfeat.update({'time': data.time})\n\n                torch.save(nfeat, osp.join(path, 'node_feats.pt'))\n\n                efeat = defaultdict()\n                if 'edge_attr' in part_data:\n                    efeat.update({\n                        'global_id':\n                        edge_id,\n                        'feats':\n                        dict(edge_attr=part_data.edge_attr[perm]),\n                    })\n                if self.is_edge_level_time:\n                    efeat.update({'edge_time': edge_time[perm]})\n\n                torch.save(efeat, osp.join(path, 'edge_feats.pt'))\n\n            logging.info('Saving partition mapping info')\n            torch.save(node_map, osp.join(self.root, 'node_map.pt'))\n            torch.save(edge_map, osp.join(self.root, 'edge_map.pt'))\n\n        logging.info('Saving metadata')\n        meta = {\n            'num_parts': self.num_parts,\n            'node_types': self.node_types,\n            'edge_types': self.edge_types,\n            'node_offset': list(node_offset.values()) if node_offset else None,\n            'is_hetero': self.is_hetero,\n            'is_sorted': True,  # Based on column/destination.\n        }\n        with open(osp.join(self.root, 'META.json'), 'w') as f:\n            json.dump(meta, f)\n\n\ndef load_partition_info(\n    root_dir: str,\n    partition_idx: int,\n) -> Tuple[Dict, int, int, torch.Tensor, torch.Tensor]:\n    # load the partition with PyG format (graphstore/featurestore)\n    with open(osp.join(root_dir, 'META.json'), 'rb') as infile:\n        meta = json.load(infile)\n    num_partitions = meta['num_parts']\n    assert partition_idx >= 0\n    assert partition_idx < num_partitions\n    partition_dir = osp.join(root_dir, f'part_{partition_idx}')\n    assert osp.exists(partition_dir)\n\n    if meta['is_hetero'] is False:\n        node_pb = fs.torch_load(osp.join(root_dir, 'node_map.pt'))\n        edge_pb = fs.torch_load(osp.join(root_dir, 'edge_map.pt'))\n\n        return (meta, num_partitions, partition_idx, node_pb, edge_pb)\n    else:\n        node_pb_dict = {}\n        node_pb_dir = osp.join(root_dir, 'node_map')\n        for ntype in meta['node_types']:\n            node_pb_dict[ntype] = fs.torch_load(\n                osp.join(node_pb_dir, f'{pyg_dist.utils.as_str(ntype)}.pt'))\n\n        edge_pb_dict = {}\n        edge_pb_dir = osp.join(root_dir, 'edge_map')\n        for etype in meta['edge_types']:\n            edge_pb_dict[tuple(etype)] = fs.torch_load(\n                osp.join(edge_pb_dir, f'{pyg_dist.utils.as_str(etype)}.pt'))\n\n        return (meta, num_partitions, partition_idx, node_pb_dict,\n                edge_pb_dict)\n"
  },
  {
    "path": "torch_geometric/distributed/rpc.py",
    "content": "import logging\nimport threading\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Callable, Dict, List, Optional\n\nfrom torch.distributed import rpc\n\nfrom torch_geometric.distributed.dist_context import DistContext, DistRole\n\ntry:\n    from torch._C._distributed_rpc import _is_current_rpc_agent_set\nexcept Exception:\n\n    def _is_current_rpc_agent_set() -> bool:\n        return False\n\n\n_rpc_init_lock = threading.RLock()\n\n\ndef rpc_is_initialized() -> bool:\n    return _is_current_rpc_agent_set()\n\n\ndef rpc_require_initialized(func: Callable) -> Callable:\n    if hasattr(rpc, 'api'):\n        return rpc.api._require_initialized(func)\n    return func\n\n\n@rpc_require_initialized\ndef global_all_gather(obj, timeout: Optional[int] = None) -> Any:\n    r\"\"\"Gathers objects from all groups in a list.\"\"\"\n    if timeout is None:\n        return rpc.api._all_gather(obj)\n    return rpc.api._all_gather(obj, timeout=timeout)\n\n\n@rpc_require_initialized\ndef global_barrier(timeout: Optional[int] = None) -> None:\n    r\"\"\"Block until all local and remote RPC processes.\"\"\"\n    try:\n        global_all_gather(obj=None, timeout=timeout)\n    except RuntimeError:\n        logging.error('Failed to respond to global barrier')\n\n\ndef init_rpc(\n    current_ctx: DistContext,\n    master_addr: str,\n    master_port: int,\n    num_rpc_threads: int = 16,\n    rpc_timeout: float = 240.0,\n    rpc_worker_names: Optional[Dict[DistRole, List[str]]] = None,\n):\n    with _rpc_init_lock:\n        if rpc_is_initialized():\n            return\n\n        if current_ctx is None:\n            raise RuntimeError(\"'dist_context' has not been set in 'init_rpc'\")\n\n        options = rpc.TensorPipeRpcBackendOptions(\n            _transports=['ibv', 'uv'],\n            _channels=['mpt_uv', 'basic'],\n            num_worker_threads=num_rpc_threads,\n            rpc_timeout=rpc_timeout,\n            init_method=f'tcp://{master_addr}:{master_port}',\n        )\n\n        rpc.init_rpc(\n            name=current_ctx.worker_name,\n            rank=current_ctx.global_rank,\n            world_size=current_ctx.global_world_size,\n            rpc_backend_options=options,\n        )\n\n        global_barrier(timeout=rpc_timeout)\n\n\ndef shutdown_rpc(id: str = None, graceful: bool = True,\n                 timeout: float = 240.0):\n    with _rpc_init_lock:\n        if rpc_is_initialized():\n            logging.info(f\"Shutdown RPC in {id}\"\n                         f\"{' gracefully' if graceful else ''}\")\n            rpc.shutdown(graceful, timeout)\n        else:\n            logging.info(f'RPC in {id} not initialized.')\n\n\nclass RPCRouter:\n    r\"\"\"A router to get the worker based on the partition ID.\"\"\"\n    def __init__(self, partition_to_workers: List[List[str]]):\n        for rpc_worker_list in partition_to_workers:\n            if len(rpc_worker_list) == 0:\n                raise ValueError('No RPC worker is in worker list')\n        self.partition_to_workers = partition_to_workers\n        self.rpc_worker_indices = [0 for _ in range(len(partition_to_workers))]\n\n    def get_to_worker(self, partition_idx: int) -> str:\n        rpc_worker_list = self.partition_to_workers[partition_idx]\n        worker_idx = self.rpc_worker_indices[partition_idx]\n        router_worker = rpc_worker_list[worker_idx]\n        self.rpc_worker_indices[partition_idx] = ((worker_idx + 1) %\n                                                  len(rpc_worker_list))\n        return router_worker\n\n\n@rpc_require_initialized\ndef rpc_partition_to_workers(\n    current_ctx: DistContext,\n    num_partitions: int,\n    current_partition_idx: int,\n):\n    r\"\"\"Performs an :obj:`all_gather` to get the mapping between partition and\n    workers.\n    \"\"\"\n    ctx = current_ctx\n    partition_to_workers = [[] for _ in range(num_partitions)]\n    gathered_results = global_all_gather(\n        (ctx.role, num_partitions, current_partition_idx))\n    for worker_name, (_, _, idx) in gathered_results.items():\n        partition_to_workers[idx].append(worker_name)\n    return partition_to_workers\n\n\nclass RPCCallBase(ABC):\n    r\"\"\"A wrapper base class for RPC calls in remote processes.\"\"\"\n    @abstractmethod\n    def rpc_sync(self, *args, **kwargs):\n        pass\n\n    @abstractmethod\n    def rpc_async(self, *args, **kwargs):\n        pass\n\n\n_rpc_call_lock = threading.RLock()\n_rpc_call_id: int = 0\n_rpc_call_pool: Dict[int, RPCCallBase] = {}\n\n\n@rpc_require_initialized\ndef rpc_register(call: RPCCallBase) -> int:\n    r\"\"\"Registers a call for RPC requests.\"\"\"\n    global _rpc_call_id\n\n    with _rpc_call_lock:\n        call_id = _rpc_call_id\n        _rpc_call_id += 1\n        if call_id in _rpc_call_pool:\n            raise RuntimeError(\"Registered function twice in 'rpc_register'\")\n        _rpc_call_pool[call_id] = call\n\n    return call_id\n\n\ndef _rpc_async_call(call_id: int, *args, **kwargs):\n    r\"\"\"Entry point for RPC requests.\"\"\"\n    return _rpc_call_pool.get(call_id).rpc_async(*args, **kwargs)\n\n\n@rpc_require_initialized\ndef rpc_async(worker_name: str, call_id: int, args=None, kwargs=None):\n    r\"\"\"Performs an asynchronous RPC request and returns a future.\"\"\"\n    return rpc.rpc_async(\n        to=worker_name,\n        func=_rpc_async_call,\n        args=(call_id, *args),\n        kwargs=kwargs,\n    )\n\n\ndef _rpc_sync_call(call_id: int, *args, **kwargs):\n    r\"\"\"Entry point for synchronous RPC requests.\"\"\"\n    return _rpc_call_pool.get(call_id).rpc_sync(*args, **kwargs)\n\n\n@rpc_require_initialized\ndef rpc_sync(worker_name: str, call_id: int, args=None, kwargs=None):\n    r\"\"\"Performs a synchronous RPC request and returns a future.\"\"\"\n    future = rpc.rpc_async(\n        to=worker_name,\n        func=_rpc_sync_call,\n        args=(call_id, *args),\n        kwargs=kwargs,\n    )\n    return future.wait()\n"
  },
  {
    "path": "torch_geometric/distributed/utils.py",
    "content": "from dataclasses import dataclass\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.distributed.local_feature_store import LocalFeatureStore\nfrom torch_geometric.distributed.local_graph_store import LocalGraphStore\nfrom torch_geometric.sampler import SamplerOutput\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\n@dataclass\nclass DistEdgeHeteroSamplerInput:\n    r\"\"\"The sampling input of\n    :meth:`~torch_geometric.dstributed.DistNeighborSampler.node_sample` used\n    during distributed heterogeneous link sampling when source and target node\n    types of an input edge are different.\n\n    Args:\n        input_id (torch.Tensor, optional): The indices of the data loader input\n            of the current mini-batch.\n        node_dict (Dict[NodeType, torch.Tensor]): The indices of seed nodes of\n            a given node types to start sampling from.\n        time_dict (Dict[NodeType, torch.Tensor], optional): The timestamp for\n            the seed nodes of a given node types. (default: :obj:`None`)\n        input_type (str, optional): The input node type. (default: :obj:`None`)\n    \"\"\"\n    input_id: Optional[Tensor]\n    node_dict: Dict[NodeType, Tensor]\n    time_dict: Optional[Dict[NodeType, Tensor]] = None\n    input_type: Optional[EdgeType] = None\n\n\nclass NodeDict:\n    r\"\"\"Class used during heterogeneous sampling.\n    1) The nodes to serve as source nodes in the next layer.\n    2) The nodes with duplicates that are further needed to create COO output.\n    3) The output nodes without duplicates.\n    \"\"\"\n    def __init__(self, node_types, num_hops):\n        self.src: Dict[NodeType, List[Tensor]] = {\n            k: (num_hops + 1) * [torch.empty(0, dtype=torch.int64)]\n            for k in node_types\n        }\n        self.with_dupl: Dict[NodeType, Tensor] = {\n            k: torch.empty(0, dtype=torch.int64)\n            for k in node_types\n        }\n        self.out: Dict[NodeType, Tensor] = {\n            k: torch.empty(0, dtype=torch.int64)\n            for k in node_types\n        }\n        self.seed_time: Dict[NodeType, List[Tensor]] = {\n            k: num_hops * [torch.empty(0, dtype=torch.int64)]\n            for k in node_types\n        }\n\n\nclass BatchDict:\n    r\"\"\"Class used during disjoint heterogeneous sampling.\n    1) The batch to serve as initial subgraph IDs for source nodes in the next\n       layer.\n    2) The subgraph IDs with duplicates that are further needed to create COO\n       output.\n    3) The output subgraph IDs without duplicates.\n    \"\"\"\n    def __init__(self, node_types, num_hops):\n        self.src: Dict[NodeType, List[Tensor]] = {\n            k: (num_hops + 1) * [torch.empty(0, dtype=torch.int64)]\n            for k in node_types\n        }\n        self.with_dupl: Dict[NodeType, Tensor] = {\n            k: torch.empty(0, dtype=torch.int64)\n            for k in node_types\n        }\n        self.out: Dict[NodeType, Tensor] = {\n            k: torch.empty(0, dtype=torch.int64)\n            for k in node_types\n        }\n\n\ndef remove_duplicates(\n    out: SamplerOutput,\n    node: Tensor,\n    batch: Optional[Tensor] = None,\n    disjoint: bool = False,\n) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:\n    num_nodes = node.numel()\n    node_combined = torch.cat([node, out.node])\n\n    if not disjoint:\n        _, idx = np.unique(node_combined.cpu().numpy(), return_index=True)\n        idx = torch.from_numpy(idx).to(node.device).sort().values\n\n        node = node_combined[idx]\n        src = node[num_nodes:]\n\n        return (src, node, None, None)\n\n    else:\n        batch_combined = torch.cat([batch, out.batch])\n        node_batch = torch.stack([batch_combined, node_combined], dim=0)\n\n        _, idx = np.unique(node_batch.cpu().numpy(), axis=1, return_index=True)\n        idx = torch.from_numpy(idx).to(node.device).sort().values\n\n        batch = batch_combined[idx]\n        node = node_combined[idx]\n        src_batch = batch[num_nodes:]\n        src = node[num_nodes:]\n\n        return (src, node, src_batch, batch)\n\n\ndef filter_dist_store(\n    feature_store: LocalFeatureStore,\n    graph_store: LocalGraphStore,\n    node_dict: Dict[str, Tensor],\n    row_dict: Dict[str, Tensor],\n    col_dict: Dict[str, Tensor],\n    edge_dict: Dict[str, Optional[Tensor]],\n    custom_cls: Optional[HeteroData] = None,\n    meta: Optional[Dict[str, Tensor]] = None,\n    input_type: str = None,\n) -> HeteroData:\n    r\"\"\"Constructs a :class:`HeteroData` object from a feature store that only\n    holds nodes in `node` end edges in `edge` for each node and edge type,\n    respectively. Sorted attribute values are provided as metadata from\n    :class:`DistNeighborSampler`.\n    \"\"\"\n    # Construct a new `HeteroData` object:\n    data = custom_cls() if custom_cls is not None else HeteroData()\n    nfeats, labels, efeats = meta[-3:]\n\n    # Filter edge storage:\n    required_edge_attrs = []\n    for attr in graph_store.get_all_edge_attrs():\n        key = attr.edge_type\n        if key in row_dict and key in col_dict:\n            required_edge_attrs.append(attr)\n            edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)\n            data[attr.edge_type].edge_index = edge_index\n\n    # Filter node storage:\n    required_node_attrs = []\n    for attr in feature_store.get_all_tensor_attrs():\n        if attr.group_name in node_dict:\n            attr.index = node_dict[attr.group_name]\n            required_node_attrs.append(attr)\n            data[attr.group_name].num_nodes = attr.index.size(0)\n\n    if nfeats:\n        for attr in required_node_attrs:\n            if nfeats[attr.group_name] is not None:\n                data[attr.group_name][attr.attr_name] = nfeats[attr.group_name]\n\n    if efeats:\n        for attr in required_edge_attrs:\n            if efeats[attr.edge_type] is not None:\n                data[attr.edge_type].edge_attr = efeats[attr.edge_type]\n\n    if labels:\n        data[input_type].y = labels[input_type]\n\n    return data\n\n\ndef as_str(inputs: Union[NodeType, EdgeType]) -> str:\n    if isinstance(inputs, NodeType):\n        return inputs\n    elif isinstance(inputs, (list, tuple)) and len(inputs) == 3:\n        return '__'.join(inputs)\n    return ''\n\n\ndef reverse_edge_type(etype: EdgeType) -> EdgeType:\n    src, rel, dst = etype\n    if src != dst:\n        if rel.split('_', 1)[0] == 'rev':\n            # undirected edge with `rev_` prefix.\n            rel = rel.split('_', 1)[1]\n        else:\n            rel = 'rev_' + rel\n\n    return dst, rel, src\n"
  },
  {
    "path": "torch_geometric/edge_index.py",
    "content": "import functools\nfrom enum import Enum\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    List,\n    Literal,\n    NamedTuple,\n    Optional,\n    Sequence,\n    Tuple,\n    Type,\n    Union,\n    get_args,\n    overload,\n)\n\nimport numpy as np\nimport torch\nimport torch.utils._pytree as pytree\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import Index, is_compiling\nfrom torch_geometric.index import index2ptr, ptr2index\nfrom torch_geometric.typing import INDEX_DTYPES, SparseTensor\n\naten = torch.ops.aten\n\nHANDLED_FUNCTIONS: Dict[Callable, Callable] = {}\n\nReduceType = Literal['sum', 'mean', 'amin', 'amax', 'add', 'min', 'max']\nPYG_REDUCE: Dict[ReduceType, ReduceType] = {\n    'add': 'sum',\n    'amin': 'min',\n    'amax': 'max'\n}\nTORCH_REDUCE: Dict[ReduceType, ReduceType] = {\n    'add': 'sum',\n    'min': 'amin',\n    'max': 'amax'\n}\n\n\nclass SortOrder(Enum):\n    ROW = 'row'\n    COL = 'col'\n\n\nclass CatMetadata(NamedTuple):\n    nnz: List[int]\n    sparse_size: List[Tuple[Optional[int], Optional[int]]]\n    sort_order: List[Optional[SortOrder]]\n    is_undirected: List[bool]\n\n\ndef implements(torch_function: Callable) -> Callable:\n    r\"\"\"Registers a :pytorch:`PyTorch` function override.\"\"\"\n    @functools.wraps(torch_function)\n    def decorator(my_function: Callable) -> Callable:\n        HANDLED_FUNCTIONS[torch_function] = my_function\n        return my_function\n\n    return decorator\n\n\ndef set_tuple_item(\n    values: Tuple[Any, ...],\n    dim: int,\n    value: Any,\n) -> Tuple[Any, ...]:\n    if dim < -len(values) or dim >= len(values):\n        raise IndexError(\"tuple index out of range\")\n\n    dim = dim + len(values) if dim < 0 else dim\n    return values[:dim] + (value, ) + values[dim + 1:]\n\n\ndef maybe_add(\n    value: Sequence[Optional[int]],\n    other: Union[int, Sequence[Optional[int]]],\n    alpha: int = 1,\n) -> Tuple[Optional[int], ...]:\n\n    if isinstance(other, int):\n        return tuple(v + alpha * other if v is not None else None\n                     for v in value)\n\n    assert len(value) == len(other)\n    return tuple(v + alpha * o if v is not None and o is not None else None\n                 for v, o in zip(value, other))\n\n\ndef maybe_sub(\n    value: Sequence[Optional[int]],\n    other: Union[int, Sequence[Optional[int]]],\n    alpha: int = 1,\n) -> Tuple[Optional[int], ...]:\n\n    if isinstance(other, int):\n        return tuple(v - alpha * other if v is not None else None\n                     for v in value)\n\n    assert len(value) == len(other)\n    return tuple(v - alpha * o if v is not None and o is not None else None\n                 for v, o in zip(value, other))\n\n\ndef assert_valid_dtype(tensor: Tensor) -> None:\n    if tensor.dtype not in INDEX_DTYPES:\n        raise ValueError(f\"'EdgeIndex' holds an unsupported data type \"\n                         f\"(got '{tensor.dtype}', but expected one of \"\n                         f\"{INDEX_DTYPES})\")\n\n\ndef assert_two_dimensional(tensor: Tensor) -> None:\n    if tensor.dim() != 2:\n        raise ValueError(f\"'EdgeIndex' needs to be two-dimensional \"\n                         f\"(got {tensor.dim()} dimensions)\")\n    if not torch.jit.is_tracing() and tensor.size(0) != 2:\n        raise ValueError(f\"'EdgeIndex' needs to have a shape of \"\n                         f\"[2, *] (got {list(tensor.size())})\")\n\n\ndef assert_contiguous(tensor: Tensor) -> None:\n    if not tensor[0].is_contiguous() or not tensor[1].is_contiguous():\n        raise ValueError(\"'EdgeIndex' needs to be contiguous. Please call \"\n                         \"`edge_index.contiguous()` before proceeding.\")\n\n\ndef assert_symmetric(size: Tuple[Optional[int], Optional[int]]) -> None:\n    if (not torch.jit.is_tracing() and size[0] is not None\n            and size[1] is not None and size[0] != size[1]):\n        raise ValueError(f\"'EdgeIndex' is undirected but received a \"\n                         f\"non-symmetric size (got {list(size)})\")\n\n\ndef assert_sorted(func: Callable) -> Callable:\n    @functools.wraps(func)\n    def wrapper(self: 'EdgeIndex', *args: Any, **kwargs: Any) -> Any:\n        if not self.is_sorted:\n            cls_name = self.__class__.__name__\n            raise ValueError(\n                f\"Cannot call '{func.__name__}' since '{cls_name}' is not \"\n                f\"sorted. Please call `{cls_name}.sort_by(...)` first.\")\n        return func(self, *args, **kwargs)\n\n    return wrapper\n\n\nclass EdgeIndex(Tensor):\n    r\"\"\"A COO :obj:`edge_index` tensor with additional (meta)data attached.\n\n    :class:`EdgeIndex` is a :pytorch:`null` :class:`torch.Tensor`, that holds\n    an :obj:`edge_index` representation of shape :obj:`[2, num_edges]`.\n    Edges are given as pairwise source and destination node indices in sparse\n    COO format.\n\n    While :class:`EdgeIndex` sub-classes a general :pytorch:`null`\n    :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:\n\n    * :obj:`sparse_size`: The underlying sparse matrix size\n    * :obj:`sort_order`: The sort order (if present), either by row or column.\n    * :obj:`is_undirected`: Whether edges are bidirectional.\n\n    Additionally, :class:`EdgeIndex` caches data for fast CSR or CSC conversion\n    in case its representation is sorted, such as its :obj:`rowptr` or\n    :obj:`colptr`, or the permutation vector for going from CSR to CSC or vice\n    versa.\n    Caches are filled based on demand (*e.g.*, when calling\n    :meth:`EdgeIndex.sort_by`), or when explicitly requested via\n    :meth:`EdgeIndex.fill_cache_`, and are maintained and adjusted over its\n    lifespan (*e.g.*, when calling :meth:`EdgeIndex.flip`).\n\n    This representation ensures optimal computation in GNN message passing\n    schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`\n    workflows.\n\n    .. code-block:: python\n\n        from torch_geometric import EdgeIndex\n\n        edge_index = EdgeIndex(\n            [[0, 1, 1, 2],\n             [1, 0, 2, 1]],\n            sparse_size=(3, 3),\n            sort_order='row',\n            is_undirected=True,\n            device='cpu',\n        )\n        >>> EdgeIndex([[0, 1, 1, 2],\n        ...            [1, 0, 2, 1]])\n        assert edge_index.is_sorted_by_row\n        assert edge_index.is_undirected\n\n        # Flipping order:\n        edge_index = edge_index.flip(0)\n        >>> EdgeIndex([[1, 0, 2, 1],\n        ...            [0, 1, 1, 2]])\n        assert edge_index.is_sorted_by_col\n        assert edge_index.is_undirected\n\n        # Filtering:\n        mask = torch.tensor([True, True, True, False])\n        edge_index = edge_index[:, mask]\n        >>> EdgeIndex([[1, 0, 2],\n        ...            [0, 1, 1]])\n        assert edge_index.is_sorted_by_col\n        assert not edge_index.is_undirected\n\n        # Sparse-Dense Matrix Multiplication:\n        out = edge_index.flip(0) @ torch.randn(3, 16)\n        assert out.size() == (3, 16)\n    \"\"\"\n    # See \"https://pytorch.org/docs/stable/notes/extending.html\"\n    # for a basic tutorial on how to subclass `torch.Tensor`.\n\n    # The underlying tensor representation:\n    _data: Tensor\n\n    # The size of the underlying sparse matrix:\n    _sparse_size: Tuple[Optional[int], Optional[int]] = (None, None)\n\n    # Whether the `edge_index` representation is non-sorted (`None`), or sorted\n    # based on row or column values.\n    _sort_order: Optional[SortOrder] = None\n\n    # Whether the `edge_index` is undirected:\n    # NOTE `is_undirected` allows us to assume symmetric adjacency matrix size\n    # and to share compressed pointer representations, however, it does not\n    # allow us get rid of CSR/CSC permutation vectors since ordering within\n    # neighborhoods is not necessarily deterministic.\n    _is_undirected: bool = False\n\n    # A cache for its compressed representation:\n    _indptr: Optional[Tensor] = None\n\n    # A cache for its transposed representation:\n    _T_perm: Optional[Tensor] = None\n    _T_index: Tuple[Optional[Tensor], Optional[Tensor]] = (None, None)\n    _T_indptr: Optional[Tensor] = None\n\n    # A cached \"1\"-value vector for `torch.sparse` matrix multiplication:\n    _value: Optional[Tensor] = None\n\n    # Whenever we perform a concatenation of edge indices, we cache the\n    # original metadata to be able to reconstruct individual edge indices:\n    _cat_metadata: Optional[CatMetadata] = None\n\n    @staticmethod\n    def __new__(\n        cls: Type,\n        data: Any,\n        *args: Any,\n        sparse_size: Optional[Tuple[Optional[int], Optional[int]]] = None,\n        sort_order: Optional[Union[str, SortOrder]] = None,\n        is_undirected: bool = False,\n        **kwargs: Any,\n    ) -> 'EdgeIndex':\n        if not isinstance(data, Tensor):\n            data = torch.tensor(data, *args, **kwargs)\n        elif len(args) > 0:\n            raise TypeError(\n                f\"new() received an invalid combination of arguments - got \"\n                f\"(Tensor, {', '.join(str(type(arg)) for arg in args)})\")\n        elif len(kwargs) > 0:\n            raise TypeError(f\"new() received invalid keyword arguments - got \"\n                            f\"{set(kwargs.keys())})\")\n\n        assert isinstance(data, Tensor)\n\n        indptr: Optional[Tensor] = None\n\n        if isinstance(data, cls):  # If passed `EdgeIndex`, inherit metadata:\n            indptr = data._indptr\n            sparse_size = sparse_size or data.sparse_size()\n            sort_order = sort_order or data.sort_order\n            is_undirected = is_undirected or data.is_undirected\n\n        # Convert `torch.sparse` tensors to `EdgeIndex` representation:\n        if data.layout == torch.sparse_coo:\n            sort_order = SortOrder.ROW\n            sparse_size = sparse_size or (data.size(0), data.size(1))\n            data = data.indices()\n\n        if data.layout == torch.sparse_csr:\n            indptr = data.crow_indices()\n            col = data.col_indices()\n\n            assert isinstance(indptr, Tensor)\n            row = ptr2index(indptr, output_size=col.numel())\n\n            sort_order = SortOrder.ROW\n            sparse_size = sparse_size or (data.size(0), data.size(1))\n            if sparse_size[0] is not None and sparse_size[0] != data.size(0):\n                indptr = None\n            data = torch.stack([row, col], dim=0)\n\n        if data.layout == torch.sparse_csc:\n            row = data.row_indices()\n            indptr = data.ccol_indices()\n\n            assert isinstance(indptr, Tensor)\n            col = ptr2index(indptr, output_size=row.numel())\n\n            sort_order = SortOrder.COL\n            sparse_size = sparse_size or (data.size(0), data.size(1))\n            if sparse_size[1] is not None and sparse_size[1] != data.size(1):\n                indptr = None\n            data = torch.stack([row, col], dim=0)\n\n        assert_valid_dtype(data)\n        assert_two_dimensional(data)\n        assert_contiguous(data)\n\n        if sparse_size is None:\n            sparse_size = (None, None)\n\n        if is_undirected:\n            assert_symmetric(sparse_size)\n            if sparse_size[0] is not None and sparse_size[1] is None:\n                sparse_size = (sparse_size[0], sparse_size[0])\n            elif sparse_size[0] is None and sparse_size[1] is not None:\n                sparse_size = (sparse_size[1], sparse_size[1])\n\n        out = Tensor._make_wrapper_subclass(\n            cls,\n            size=data.size(),\n            strides=data.stride(),\n            dtype=data.dtype,\n            device=data.device,\n            layout=data.layout,\n            requires_grad=False,\n        )\n        assert isinstance(out, EdgeIndex)\n\n        # Attach metadata:\n        out._data = data\n        out._sparse_size = sparse_size\n        out._sort_order = None if sort_order is None else SortOrder(sort_order)\n        out._is_undirected = is_undirected\n        out._indptr = indptr\n\n        if isinstance(data, cls):  # If passed `EdgeIndex`, inherit metadata:\n            out._data = data._data\n            out._T_perm = data._T_perm\n            out._T_index = data._T_index\n            out._T_indptr = data._T_indptr\n            out._value = out._value\n\n            # Reset metadata if cache is invalidated:\n            num_rows = sparse_size[0]\n            if num_rows is not None and num_rows != data.sparse_size(0):\n                out._indptr = None\n\n            num_cols = sparse_size[1]\n            if num_cols is not None and num_cols != data.sparse_size(1):\n                out._T_indptr = None\n\n        return out\n\n    # Validation ##############################################################\n\n    def validate(self) -> 'EdgeIndex':\n        r\"\"\"Validates the :class:`EdgeIndex` representation.\n\n        In particular, it ensures that\n\n        * it only holds valid indices.\n        * the sort order is correctly set.\n        * indices are bidirectional in case it is specified as undirected.\n        \"\"\"\n        assert_valid_dtype(self._data)\n        assert_two_dimensional(self._data)\n        assert_contiguous(self._data)\n        if self.is_undirected:\n            assert_symmetric(self.sparse_size())\n\n        if self.numel() > 0 and self._data.min() < 0:\n            raise ValueError(f\"'{self.__class__.__name__}' contains negative \"\n                             f\"indices (got {int(self.min())})\")\n\n        if (self.numel() > 0 and self.num_rows is not None\n                and self._data[0].max() >= self.num_rows):\n            raise ValueError(f\"'{self.__class__.__name__}' contains larger \"\n                             f\"indices than its number of rows \"\n                             f\"(got {int(self._data[0].max())}, but expected \"\n                             f\"values smaller than {self.num_rows})\")\n\n        if (self.numel() > 0 and self.num_cols is not None\n                and self._data[1].max() >= self.num_cols):\n            raise ValueError(f\"'{self.__class__.__name__}' contains larger \"\n                             f\"indices than its number of columns \"\n                             f\"(got {int(self._data[1].max())}, but expected \"\n                             f\"values smaller than {self.num_cols})\")\n\n        if self.is_sorted_by_row and (self._data[0].diff() < 0).any():\n            raise ValueError(f\"'{self.__class__.__name__}' is not sorted by \"\n                             f\"row indices\")\n\n        if self.is_sorted_by_col and (self._data[1].diff() < 0).any():\n            raise ValueError(f\"'{self.__class__.__name__}' is not sorted by \"\n                             f\"column indices\")\n\n        if self.is_undirected:\n            flat_index1 = self._data[0] * self.get_num_rows() + self._data[1]\n            flat_index1 = flat_index1.sort()[0]\n            flat_index2 = self._data[1] * self.get_num_cols() + self._data[0]\n            flat_index2 = flat_index2.sort()[0]\n            if not torch.equal(flat_index1, flat_index2):\n                raise ValueError(f\"'{self.__class__.__name__}' is not \"\n                                 f\"undirected\")\n\n        return self\n\n    # Properties ##############################################################\n\n    @overload\n    def sparse_size(self) -> Tuple[Optional[int], Optional[int]]:\n        pass\n\n    @overload\n    def sparse_size(self, dim: int) -> Optional[int]:\n        pass\n\n    def sparse_size(\n        self,\n        dim: Optional[int] = None,\n    ) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]:\n        r\"\"\"The size of the underlying sparse matrix.\n        If :obj:`dim` is specified, returns an integer holding the size of that\n        sparse dimension.\n\n        Args:\n            dim (int, optional): The dimension for which to retrieve the size.\n                (default: :obj:`None`)\n        \"\"\"\n        if dim is not None:\n            return self._sparse_size[dim]\n        return self._sparse_size\n\n    @property\n    def num_rows(self) -> Optional[int]:\n        r\"\"\"The number of rows of the underlying sparse matrix.\"\"\"\n        return self._sparse_size[0]\n\n    @property\n    def num_cols(self) -> Optional[int]:\n        r\"\"\"The number of columns of the underlying sparse matrix.\"\"\"\n        return self._sparse_size[1]\n\n    @property\n    def sort_order(self) -> Optional[str]:\n        r\"\"\"The sort order of indices, either :obj:`\"row\"`, :obj:`\"col\"` or\n        :obj:`None`.\n        \"\"\"\n        return None if self._sort_order is None else self._sort_order.value\n\n    @property\n    def is_sorted(self) -> bool:\n        r\"\"\"Returns whether indices are either sorted by rows or columns.\"\"\"\n        return self._sort_order is not None\n\n    @property\n    def is_sorted_by_row(self) -> bool:\n        r\"\"\"Returns whether indices are sorted by rows.\"\"\"\n        return self._sort_order == SortOrder.ROW\n\n    @property\n    def is_sorted_by_col(self) -> bool:\n        r\"\"\"Returns whether indices are sorted by columns.\"\"\"\n        return self._sort_order == SortOrder.COL\n\n    @property\n    def is_undirected(self) -> bool:\n        r\"\"\"Returns whether indices are bidirectional.\"\"\"\n        return self._is_undirected\n\n    @property\n    def dtype(self) -> torch.dtype:  # type: ignore\n        # TODO Remove once PyTorch does not override `dtype` in `DataLoader`.\n        return self._data.dtype\n\n    # Cache Interface #########################################################\n\n    @overload\n    def get_sparse_size(self) -> torch.Size:\n        pass\n\n    @overload\n    def get_sparse_size(self, dim: int) -> int:\n        pass\n\n    def get_sparse_size(\n        self,\n        dim: Optional[int] = None,\n    ) -> Union[torch.Size, int]:\n        r\"\"\"The size of the underlying sparse matrix.\n        Automatically computed and cached when not explicitly set.\n        If :obj:`dim` is specified, returns an integer holding the size of that\n        sparse dimension.\n\n        Args:\n            dim (int, optional): The dimension for which to retrieve the size.\n                (default: :obj:`None`)\n        \"\"\"\n        if dim is not None:\n            size = self._sparse_size[dim]\n            if size is not None:\n                return size\n\n            if self.is_undirected:\n                size = int(self._data.max()) + 1 if self.numel() > 0 else 0\n                self._sparse_size = (size, size)\n                return size\n\n            size = int(self._data[dim].max()) + 1 if self.numel() > 0 else 0\n            self._sparse_size = set_tuple_item(self._sparse_size, dim, size)\n            return size\n\n        return torch.Size((self.get_sparse_size(0), self.get_sparse_size(1)))\n\n    def sparse_resize_(  # type: ignore\n        self,\n        num_rows: Optional[int],\n        num_cols: Optional[int],\n    ) -> 'EdgeIndex':\n        r\"\"\"Assigns or re-assigns the size of the underlying sparse matrix.\n\n        Args:\n            num_rows (int, optional): The number of rows.\n            num_cols (int, optional): The number of columns.\n        \"\"\"\n        if self.is_undirected:\n            if num_rows is not None and num_cols is None:\n                num_cols = num_rows\n            elif num_cols is not None and num_rows is None:\n                num_rows = num_cols\n\n            if num_rows is not None and num_rows != num_cols:\n                raise ValueError(f\"'EdgeIndex' is undirected but received a \"\n                                 f\"non-symmetric size \"\n                                 f\"(got [{num_rows}, {num_cols}])\")\n\n        def _modify_ptr(\n            ptr: Optional[Tensor],\n            size: Optional[int],\n        ) -> Optional[Tensor]:\n\n            if ptr is None or size is None:\n                return None\n\n            if ptr.numel() - 1 >= size:\n                return ptr[:size + 1]\n\n            fill_value = ptr.new_full(\n                (size - ptr.numel() + 1, ),\n                fill_value=ptr[-1],  # type: ignore\n            )\n            return torch.cat([ptr, fill_value], dim=0)\n\n        if self.is_sorted_by_row:\n            self._indptr = _modify_ptr(self._indptr, num_rows)\n            self._T_indptr = _modify_ptr(self._T_indptr, num_cols)\n\n        if self.is_sorted_by_col:\n            self._indptr = _modify_ptr(self._indptr, num_cols)\n            self._T_indptr = _modify_ptr(self._T_indptr, num_rows)\n\n        self._sparse_size = (num_rows, num_cols)\n\n        return self\n\n    def get_num_rows(self) -> int:\n        r\"\"\"The number of rows of the underlying sparse matrix.\n        Automatically computed and cached when not explicitly set.\n        \"\"\"\n        return self.get_sparse_size(0)\n\n    def get_num_cols(self) -> int:\n        r\"\"\"The number of columns of the underlying sparse matrix.\n        Automatically computed and cached when not explicitly set.\n        \"\"\"\n        return self.get_sparse_size(1)\n\n    @assert_sorted\n    def get_indptr(self) -> Tensor:\n        r\"\"\"Returns the compressed index representation in case\n        :class:`EdgeIndex` is sorted.\n        \"\"\"\n        if self._indptr is not None:\n            return self._indptr\n\n        if self.is_undirected and self._T_indptr is not None:\n            return self._T_indptr\n\n        dim = 0 if self.is_sorted_by_row else 1\n        self._indptr = index2ptr(self._data[dim], self.get_sparse_size(dim))\n\n        return self._indptr\n\n    @assert_sorted\n    def _sort_by_transpose(self) -> Tuple[Tuple[Tensor, Tensor], Tensor]:\n        from torch_geometric.utils import index_sort\n\n        dim = 1 if self.is_sorted_by_row else 0\n\n        if self._T_perm is None:\n            max_index = self.get_sparse_size(dim)\n            index, perm = index_sort(self._data[dim], max_index)\n            self._T_index = set_tuple_item(self._T_index, dim, index)\n            self._T_perm = perm.to(self.dtype)\n\n        if self._T_index[1 - dim] is None:\n            self._T_index = set_tuple_item(  #\n                self._T_index, 1 - dim, self._data[1 - dim][self._T_perm])\n\n        row, col = self._T_index\n        assert row is not None and col is not None\n\n        return (row, col), self._T_perm\n\n    @assert_sorted\n    def get_csr(self) -> Tuple[Tuple[Tensor, Tensor], Optional[Tensor]]:\n        r\"\"\"Returns the compressed CSR representation\n        :obj:`(rowptr, col), perm` in case :class:`EdgeIndex` is sorted.\n        \"\"\"\n        if self.is_sorted_by_row:\n            return (self.get_indptr(), self._data[1]), None\n\n        assert self.is_sorted_by_col\n        (row, col), perm = self._sort_by_transpose()\n\n        if self._T_indptr is not None:\n            rowptr = self._T_indptr\n        elif self.is_undirected and self._indptr is not None:\n            rowptr = self._indptr\n        else:\n            rowptr = self._T_indptr = index2ptr(row, self.get_num_rows())\n\n        return (rowptr, col), perm\n\n    @assert_sorted\n    def get_csc(self) -> Tuple[Tuple[Tensor, Tensor], Optional[Tensor]]:\n        r\"\"\"Returns the compressed CSC representation\n        :obj:`(colptr, row), perm` in case :class:`EdgeIndex` is sorted.\n        \"\"\"\n        if self.is_sorted_by_col:\n            return (self.get_indptr(), self._data[0]), None\n\n        assert self.is_sorted_by_row\n        (row, col), perm = self._sort_by_transpose()\n\n        if self._T_indptr is not None:\n            colptr = self._T_indptr\n        elif self.is_undirected and self._indptr is not None:\n            colptr = self._indptr\n        else:\n            colptr = self._T_indptr = index2ptr(col, self.get_num_cols())\n\n        return (colptr, row), perm\n\n    def _get_value(self, dtype: Optional[torch.dtype] = None) -> Tensor:\n        if self._value is not None:\n            if (dtype or torch.get_default_dtype()) == self._value.dtype:\n                return self._value\n\n        # Expanded tensors are not yet supported in all PyTorch code paths :(\n        # value = torch.ones(1, dtype=dtype, device=self.device)\n        # value = value.expand(self.size(1))\n        self._value = torch.ones(self.size(1), dtype=dtype, device=self.device)\n\n        return self._value\n\n    def fill_cache_(self, no_transpose: bool = False) -> 'EdgeIndex':\n        r\"\"\"Fills the cache with (meta)data information.\n\n        Args:\n            no_transpose (bool, optional): If set to :obj:`True`, will not fill\n                the cache with information about the transposed\n                :class:`EdgeIndex`. (default: :obj:`False`)\n        \"\"\"\n        self.get_sparse_size()\n\n        if self.is_sorted_by_row:\n            self.get_csr()\n            if not no_transpose:\n                self.get_csc()\n        elif self.is_sorted_by_col:\n            self.get_csc()\n            if not no_transpose:\n                self.get_csr()\n\n        return self\n\n    # Methods #################################################################\n\n    def share_memory_(self) -> 'EdgeIndex':\n        \"\"\"\"\"\"  # noqa: D419\n        self._data.share_memory_()\n        if self._indptr is not None:\n            self._indptr.share_memory_()\n        if self._T_perm is not None:\n            self._T_perm.share_memory_()\n        if self._T_index[0] is not None:\n            self._T_index[0].share_memory_()\n        if self._T_index[1] is not None:\n            self._T_index[1].share_memory_()\n        if self._T_indptr is not None:\n            self._T_indptr.share_memory_()\n        if self._value is not None:\n            self._value.share_memory_()\n        return self\n\n    def is_shared(self) -> bool:\n        \"\"\"\"\"\"  # noqa: D419\n        return self._data.is_shared()\n\n    def as_tensor(self) -> Tensor:\n        r\"\"\"Zero-copies the :class:`EdgeIndex` representation back to a\n        :class:`torch.Tensor` representation.\n        \"\"\"\n        return self._data\n\n    def sort_by(\n        self,\n        sort_order: Union[str, SortOrder],\n        stable: bool = False,\n    ) -> 'SortReturnType':\n        r\"\"\"Sorts the elements by row or column indices.\n\n        Args:\n            sort_order (str): The sort order, either :obj:`\"row\"` or\n                :obj:`\"col\"`.\n            stable (bool, optional): Makes the sorting routine stable, which\n                guarantees that the order of equivalent elements is preserved.\n                (default: :obj:`False`)\n        \"\"\"\n        from torch_geometric.utils import index_sort\n\n        sort_order = SortOrder(sort_order)\n\n        if self._sort_order == sort_order:  # Nothing to do.\n            return SortReturnType(self, None)\n\n        if self.is_sorted:\n            (row, col), perm = self._sort_by_transpose()\n            edge_index = torch.stack([row, col], dim=0)\n\n        # Otherwise, perform sorting:\n        elif sort_order == SortOrder.ROW:\n            row, perm = index_sort(self._data[0], self.get_num_rows(), stable)\n            edge_index = torch.stack([row, self._data[1][perm]], dim=0)\n\n        else:\n            col, perm = index_sort(self._data[1], self.get_num_cols(), stable)\n            edge_index = torch.stack([self._data[0][perm], col], dim=0)\n\n        out = self.__class__(edge_index)\n\n        # We can inherit metadata and (mostly) cache:\n        out._sparse_size = self.sparse_size()\n        out._sort_order = sort_order\n        out._is_undirected = self.is_undirected\n\n        out._indptr = self._indptr\n        out._T_indptr = self._T_indptr\n\n        # NOTE We cannot copy CSR<>CSC permutations since we don't require that\n        # local neighborhoods are sorted, and thus they may run out of sync.\n\n        out._value = self._value\n\n        return SortReturnType(out, perm)\n\n    def to_dense(  # type: ignore\n        self,\n        value: Optional[Tensor] = None,\n        fill_value: float = 0.0,\n        dtype: Optional[torch.dtype] = None,\n    ) -> Tensor:\n        r\"\"\"Converts :class:`EdgeIndex` into a dense :class:`torch.Tensor`.\n\n        .. warning::\n\n            In case of duplicated edges, the behavior is non-deterministic (one\n            of the values from :obj:`value` will be picked arbitrarily). For\n            deterministic behavior, consider calling\n            :meth:`~torch_geometric.utils.coalesce` beforehand.\n\n        Args:\n            value (torch.Tensor, optional): The values for non-zero elements.\n                If not specified, non-zero elements will be assigned a value of\n                :obj:`1.0`. (default: :obj:`None`)\n            fill_value (float, optional): The fill value for remaining elements\n                in the dense matrix. (default: :obj:`0.0`)\n            dtype (torch.dtype, optional): The data type of the returned\n                tensor. (default: :obj:`None`)\n        \"\"\"\n        dtype = value.dtype if value is not None else dtype\n\n        size = self.get_sparse_size()\n        if value is not None and value.dim() > 1:\n            size = size + value.size()[1:]\n\n        out = torch.full(size, fill_value, dtype=dtype, device=self.device)\n        out[self._data[0], self._data[1]] = value if value is not None else 1\n\n        return out\n\n    def to_sparse_coo(self, value: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Converts :class:`EdgeIndex` into a :pytorch:`null`\n        :class:`torch.sparse_coo_tensor`.\n\n        Args:\n            value (torch.Tensor, optional): The values for non-zero elements.\n                If not specified, non-zero elements will be assigned a value of\n                :obj:`1.0`. (default: :obj:`None`)\n        \"\"\"\n        value = self._get_value() if value is None else value\n\n        if not torch_geometric.typing.WITH_PT21:\n            out = torch.sparse_coo_tensor(\n                indices=self._data,\n                values=value,\n                size=self.get_sparse_size(),\n                device=self.device,\n                requires_grad=value.requires_grad,\n            )\n            if self.is_sorted_by_row:\n                out = out._coalesced_(True)\n            return out\n\n        return torch.sparse_coo_tensor(\n            indices=self._data,\n            values=value,\n            size=self.get_sparse_size(),\n            device=self.device,\n            requires_grad=value.requires_grad,\n            is_coalesced=True if self.is_sorted_by_row else None,\n        )\n\n    def to_sparse_csr(  # type: ignore\n            self,\n            value: Optional[Tensor] = None,\n    ) -> Tensor:\n        r\"\"\"Converts :class:`EdgeIndex` into a :pytorch:`null`\n        :class:`torch.sparse_csr_tensor`.\n\n        Args:\n            value (torch.Tensor, optional): The values for non-zero elements.\n                If not specified, non-zero elements will be assigned a value of\n                :obj:`1.0`. (default: :obj:`None`)\n        \"\"\"\n        (rowptr, col), perm = self.get_csr()\n        if value is not None and perm is not None:\n            value = value[perm]\n        elif value is None:\n            value = self._get_value()\n\n        return torch.sparse_csr_tensor(\n            crow_indices=rowptr,\n            col_indices=col,\n            values=value,\n            size=self.get_sparse_size(),\n            device=self.device,\n            requires_grad=value.requires_grad,\n        )\n\n    def to_sparse_csc(  # type: ignore\n            self,\n            value: Optional[Tensor] = None,\n    ) -> Tensor:\n        r\"\"\"Converts :class:`EdgeIndex` into a :pytorch:`null`\n        :class:`torch.sparse_csc_tensor`.\n\n        Args:\n            value (torch.Tensor, optional): The values for non-zero elements.\n                If not specified, non-zero elements will be assigned a value of\n                :obj:`1.0`. (default: :obj:`None`)\n        \"\"\"\n        (colptr, row), perm = self.get_csc()\n        if value is not None and perm is not None:\n            value = value[perm]\n        elif value is None:\n            value = self._get_value()\n\n        return torch.sparse_csc_tensor(\n            ccol_indices=colptr,\n            row_indices=row,\n            values=value,\n            size=self.get_sparse_size(),\n            device=self.device,\n            requires_grad=value.requires_grad,\n        )\n\n    def to_sparse(  # type: ignore\n        self,\n        *,\n        layout: torch.layout = torch.sparse_coo,\n        value: Optional[Tensor] = None,\n    ) -> Tensor:\n        r\"\"\"Converts :class:`EdgeIndex` into a\n        :pytorch:`null` :class:`torch.sparse` tensor.\n\n        Args:\n            layout (torch.layout, optional): The desired sparse layout. One of\n                :obj:`torch.sparse_coo`, :obj:`torch.sparse_csr`, or\n                :obj:`torch.sparse_csc`. (default: :obj:`torch.sparse_coo`)\n            value (torch.Tensor, optional): The values for non-zero elements.\n                If not specified, non-zero elements will be assigned a value of\n                :obj:`1.0`. (default: :obj:`None`)\n        \"\"\"\n        if layout is None or layout == torch.sparse_coo:\n            return self.to_sparse_coo(value)\n        if layout == torch.sparse_csr:\n            return self.to_sparse_csr(value)\n        if layout == torch.sparse_csc:\n            return self.to_sparse_csc(value)\n\n        raise ValueError(f\"Unexpected tensor layout (got '{layout}')\")\n\n    def to_sparse_tensor(\n        self,\n        value: Optional[Tensor] = None,\n    ) -> SparseTensor:\n        r\"\"\"Converts :class:`EdgeIndex` into a\n        :class:`torch_sparse.SparseTensor`.\n        Requires that :obj:`torch-sparse` is installed.\n\n        Args:\n            value (torch.Tensor, optional): The values for non-zero elements.\n                (default: :obj:`None`)\n        \"\"\"\n        return SparseTensor(\n            row=self._data[0],\n            col=self._data[1],\n            rowptr=self._indptr if self.is_sorted_by_row else None,\n            value=value,\n            sparse_sizes=self.get_sparse_size(),\n            is_sorted=self.is_sorted_by_row,\n            trust_data=True,\n        )\n\n    # TODO Investigate how to avoid overlapping return types here.\n    @overload\n    def matmul(  # type: ignore\n        self,\n        other: 'EdgeIndex',\n        input_value: Optional[Tensor] = None,\n        other_value: Optional[Tensor] = None,\n        reduce: ReduceType = 'sum',\n        transpose: bool = False,\n    ) -> Tuple['EdgeIndex', Tensor]:\n        pass\n\n    @overload\n    def matmul(\n        self,\n        other: Tensor,\n        input_value: Optional[Tensor] = None,\n        other_value: None = None,\n        reduce: ReduceType = 'sum',\n        transpose: bool = False,\n    ) -> Tensor:\n        pass\n\n    def matmul(\n        self,\n        other: Union[Tensor, 'EdgeIndex'],\n        input_value: Optional[Tensor] = None,\n        other_value: Optional[Tensor] = None,\n        reduce: ReduceType = 'sum',\n        transpose: bool = False,\n    ) -> Union[Tensor, Tuple['EdgeIndex', Tensor]]:\n        r\"\"\"Performs a matrix multiplication of the matrices :obj:`input` and\n        :obj:`other`.\n        If :obj:`input` is a :math:`(n \\times m)` matrix and :obj:`other` is a\n        :math:`(m \\times p)` tensor, then the output will be a\n        :math:`(n \\times p)` tensor.\n        See :meth:`torch.matmul` for more information.\n\n        :obj:`input` is a sparse matrix as denoted by the indices in\n        :class:`EdgeIndex`, and :obj:`input_value` corresponds to the values\n        of non-zero elements in :obj:`input`.\n        If not specified, non-zero elements will be assigned a value of\n        :obj:`1.0`.\n\n        :obj:`other` can either be a dense :class:`torch.Tensor` or a sparse\n        :class:`EdgeIndex`.\n        if :obj:`other` is a sparse :class:`EdgeIndex`, then :obj:`other_value`\n        corresponds to the values of its non-zero elements.\n\n        This function additionally accepts an optional :obj:`reduce` argument\n        that allows specification of an optional reduction operation.\n        See :meth:`torch.sparse.mm` for more information.\n\n        Lastly, the :obj:`transpose` option allows to perform matrix\n        multiplication where :obj:`input` will be first transposed, *i.e.*:\n\n        .. math::\n\n            \\textrm{input}^{\\top} \\cdot \\textrm{other}\n\n        Args:\n            other (torch.Tensor or EdgeIndex): The second matrix to be\n                multiplied, which can be sparse or dense.\n            input_value (torch.Tensor, optional): The values for non-zero\n                elements of :obj:`input`.\n                If not specified, non-zero elements will be assigned a value of\n                :obj:`1.0`. (default: :obj:`None`)\n            other_value (torch.Tensor, optional): The values for non-zero\n                elements of :obj:`other` in case it is sparse.\n                If not specified, non-zero elements will be assigned a value of\n                :obj:`1.0`. (default: :obj:`None`)\n            reduce (str, optional): The reduce operation, one of\n                :obj:`\"sum\"`/:obj:`\"add\"`, :obj:`\"mean\"`,\n                :obj:`\"min\"`/:obj:`amin` or :obj:`\"max\"`/:obj:`amax`.\n                (default: :obj:`\"sum\"`)\n            transpose (bool, optional): If set to :obj:`True`, will perform\n                matrix multiplication based on the transposed :obj:`input`.\n                (default: :obj:`False`)\n        \"\"\"\n        return matmul(self, other, input_value, other_value, reduce, transpose)\n\n    def sparse_narrow(\n        self,\n        dim: int,\n        start: Union[int, Tensor],\n        length: int,\n    ) -> 'EdgeIndex':\n        r\"\"\"Returns a new :class:`EdgeIndex` that is a narrowed version of\n        itself. Narrowing is performed by interpreting :class:`EdgeIndex` as a\n        sparse matrix of shape :obj:`(num_rows, num_cols)`.\n\n        In contrast to :meth:`torch.narrow`, the returned tensor does not share\n        the same underlying storage anymore.\n\n        Args:\n            dim (int): The dimension along which to narrow.\n            start (int or torch.Tensor): Index of the element to start the\n                narrowed dimension from.\n            length (int): Length of the narrowed dimension.\n        \"\"\"\n        dim = dim + 2 if dim < 0 else dim\n        if dim != 0 and dim != 1:\n            raise ValueError(f\"Expected dimension to be 0 or 1 (got {dim})\")\n\n        if start < 0:\n            raise ValueError(f\"Expected 'start' value to be positive \"\n                             f\"(got {start})\")\n\n        if dim == 0:\n            if self.is_sorted_by_row:\n                (rowptr, col), _ = self.get_csr()\n                rowptr = rowptr.narrow(0, start, length + 1)\n\n                if rowptr.numel() < 2:\n                    row, col = self._data[0, :0], self._data[1, :0]\n                    rowptr = None\n                    num_rows = 0\n                else:\n                    col = col[rowptr[0]:rowptr[-1]]\n                    rowptr = rowptr - rowptr[0]\n                    num_rows = rowptr.numel() - 1\n\n                    row = torch.arange(\n                        num_rows,\n                        dtype=col.dtype,\n                        device=col.device,\n                    ).repeat_interleave(\n                        rowptr.diff(),\n                        output_size=col.numel(),\n                    )\n\n                edge_index = EdgeIndex(\n                    torch.stack([row, col], dim=0),\n                    sparse_size=(num_rows, self.sparse_size(1)),\n                    sort_order='row',\n                )\n                edge_index._indptr = rowptr\n                return edge_index\n\n            else:\n                mask = self._data[0] >= start\n                mask &= self._data[0] < (start + length)\n                offset = torch.tensor([[start], [0]], device=self.device)\n                edge_index = self[:, mask].sub_(offset)  # type: ignore\n                edge_index._sparse_size = (length, edge_index._sparse_size[1])\n                return edge_index\n\n        else:\n            assert dim == 1\n\n            if self.is_sorted_by_col:\n                (colptr, row), _ = self.get_csc()\n                colptr = colptr.narrow(0, start, length + 1)\n\n                if colptr.numel() < 2:\n                    row, col = self._data[0, :0], self._data[1, :0]\n                    colptr = None\n                    num_cols = 0\n                else:\n                    row = row[colptr[0]:colptr[-1]]\n                    colptr = colptr - colptr[0]\n                    num_cols = colptr.numel() - 1\n\n                    col = torch.arange(\n                        num_cols,\n                        dtype=row.dtype,\n                        device=row.device,\n                    ).repeat_interleave(\n                        colptr.diff(),\n                        output_size=row.numel(),\n                    )\n\n                edge_index = EdgeIndex(\n                    torch.stack([row, col], dim=0),\n                    sparse_size=(self.sparse_size(0), num_cols),\n                    sort_order='col',\n                )\n                edge_index._indptr = colptr\n                return edge_index\n\n            else:\n                mask = self._data[1] >= start\n                mask &= self._data[1] < (start + length)\n                offset = torch.tensor([[0], [start]], device=self.device)\n                edge_index = self[:, mask].sub_(offset)  # type: ignore\n                edge_index._sparse_size = (edge_index._sparse_size[0], length)\n                return edge_index\n\n    def to_vector(self) -> Tensor:\n        r\"\"\"Converts :class:`EdgeIndex` into a one-dimensional index\n        vector representation.\n        \"\"\"\n        num_rows, num_cols = self.get_sparse_size()\n\n        if num_rows * num_cols > torch_geometric.typing.MAX_INT64:\n            raise ValueError(\"'to_vector()' will result in an overflow\")\n\n        return self._data[0] * num_rows + self._data[1]\n\n    # PyTorch/Python builtins #################################################\n\n    def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:\n        attrs = ['_data']\n        if self._indptr is not None:\n            attrs.append('_indptr')\n        if self._T_perm is not None:\n            attrs.append('_T_perm')\n        # TODO We cannot save `_T_index` for now since it is stored as tuple.\n        if self._T_indptr is not None:\n            attrs.append('_T_indptr')\n\n        ctx = (\n            self._sparse_size,\n            self._sort_order,\n            self._is_undirected,\n            self._cat_metadata,\n        )\n\n        return attrs, ctx\n\n    @staticmethod\n    def __tensor_unflatten__(\n        inner_tensors: Dict[str, Any],\n        ctx: Tuple[Any, ...],\n        outer_size: Tuple[int, ...],\n        outer_stride: Tuple[int, ...],\n    ) -> 'EdgeIndex':\n        edge_index = EdgeIndex(\n            inner_tensors['_data'],\n            sparse_size=ctx[0],\n            sort_order=ctx[1],\n            is_undirected=ctx[2],\n        )\n\n        edge_index._indptr = inner_tensors.get('_indptr', None)\n        edge_index._T_perm = inner_tensors.get('_T_perm', None)\n        edge_index._T_indptr = inner_tensors.get('_T_indptr', None)\n        edge_index._cat_metadata = ctx[3]\n\n        return edge_index\n\n    # Prevent auto-wrapping outputs back into the proper subclass type:\n    __torch_function__ = torch._C._disabled_torch_function_impl  # type: ignore\n\n    @classmethod\n    def __torch_dispatch__(  # type: ignore\n        cls: Type,\n        func: Callable[..., Any],\n        types: Iterable[Type[Any]],\n        args: Iterable[Tuple[Any, ...]] = (),\n        kwargs: Optional[Dict[Any, Any]] = None,\n    ) -> Any:\n        # `EdgeIndex` should be treated as a regular PyTorch tensor for all\n        # standard PyTorch functionalities. However,\n        # * some of its metadata can be transferred to new functions, e.g.,\n        #   `torch.cat(dim=1)` can inherit the sparse matrix size, or\n        #   `torch.narrow(dim=1)` can inherit cached pointers.\n        # * not all operations lead to valid `EdgeIndex` tensors again, e.g.,\n        #   `torch.sum()` does not yield a `EdgeIndex` as its output, or\n        #   `torch.cat(dim=0) violates the [2, *] shape assumption.\n\n        # To account for this, we hold a number of `HANDLED_FUNCTIONS` that\n        # implement specific functions for valid `EdgeIndex` routines.\n        if func in HANDLED_FUNCTIONS:\n            return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))\n\n        # For all other PyTorch functions, we treat them as vanilla tensors.\n        args = pytree.tree_map_only(EdgeIndex, lambda x: x._data, args)\n        if kwargs is not None:\n            kwargs = pytree.tree_map_only(EdgeIndex, lambda x: x._data, kwargs)\n        return func(*args, **(kwargs or {}))\n\n    def __repr__(self) -> str:  # type: ignore\n        prefix = f'{self.__class__.__name__}('\n        indent = len(prefix)\n        tensor_str = torch._tensor_str._tensor_str(self._data, indent)\n\n        suffixes = []\n        num_rows, num_cols = self.sparse_size()\n        if num_rows is not None or num_cols is not None:\n            size_repr = f\"({num_rows or '?'}, {num_cols or '?'})\"\n            suffixes.append(f'sparse_size={size_repr}')\n        suffixes.append(f'nnz={self._data.size(1)}')\n        if (self.device.type != torch._C._get_default_device()\n                or (self.device.type == 'cuda'\n                    and torch.cuda.current_device() != self.device.index)\n                or (self.device.type == 'mps')):\n            suffixes.append(f\"device='{self.device}'\")\n        if self.dtype != torch.int64:\n            suffixes.append(f'dtype={self.dtype}')\n        if self.is_sorted:\n            suffixes.append(f'sort_order={self.sort_order}')\n        if self.is_undirected:\n            suffixes.append('is_undirected=True')\n\n        return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,\n                                               indent, force_newline=False)\n\n    def tolist(self) -> List[Any]:\n        \"\"\"\"\"\"  # noqa: D419\n        return self._data.tolist()\n\n    def numpy(self, *, force: bool = False) -> np.ndarray:\n        \"\"\"\"\"\"  # noqa: D419\n        return self._data.numpy(force=force)\n\n    # Helpers #################################################################\n\n    def _shallow_copy(self) -> 'EdgeIndex':\n        out = EdgeIndex(self._data)\n        out._sparse_size = self._sparse_size\n        out._sort_order = self._sort_order\n        out._is_undirected = self._is_undirected\n        out._indptr = self._indptr\n        out._T_perm = self._T_perm\n        out._T_index = self._T_index\n        out._T_indptr = self._T_indptr\n        out._value = self._value\n        out._cat_metadata = self._cat_metadata\n        return out\n\n    def _clear_metadata(self) -> 'EdgeIndex':\n        self._sparse_size = (None, None)\n        self._sort_order = None\n        self._is_undirected = False\n        self._indptr = None\n        self._T_perm = None\n        self._T_index = (None, None)\n        self._T_indptr = None\n        self._value = None\n        self._cat_metadata = None\n        return self\n\n\nclass SortReturnType(NamedTuple):\n    values: EdgeIndex\n    indices: Optional[Tensor]\n\n\ndef apply_(\n    tensor: EdgeIndex,\n    fn: Callable,\n    *args: Any,\n    **kwargs: Any,\n) -> Union[EdgeIndex, Tensor]:\n\n    data = fn(tensor._data, *args, **kwargs)\n\n    if data.dtype not in INDEX_DTYPES:\n        return data\n\n    if tensor._data.data_ptr() != data.data_ptr():\n        out = EdgeIndex(data)\n    else:  # In-place:\n        tensor._data = data\n        out = tensor\n\n    # Copy metadata:\n    out._sparse_size = tensor._sparse_size\n    out._sort_order = tensor._sort_order\n    out._is_undirected = tensor._is_undirected\n    out._cat_metadata = tensor._cat_metadata\n\n    # Convert cache (but do not consider `_value`):\n    if tensor._indptr is not None:\n        out._indptr = fn(tensor._indptr, *args, **kwargs)\n\n    if tensor._T_perm is not None:\n        out._T_perm = fn(tensor._T_perm, *args, **kwargs)\n\n    _T_row, _T_col = tensor._T_index\n    if _T_row is not None:\n        _T_row = fn(_T_row, *args, **kwargs)\n    if _T_col is not None:\n        _T_col = fn(_T_col, *args, **kwargs)\n    out._T_index = (_T_row, _T_col)\n\n    if tensor._T_indptr is not None:\n        out._T_indptr = fn(tensor._T_indptr, *args, **kwargs)\n\n    return out\n\n\n@implements(aten.clone.default)\ndef _clone(\n    tensor: EdgeIndex,\n    *,\n    memory_format: torch.memory_format = torch.preserve_format,\n) -> EdgeIndex:\n    out = apply_(tensor, aten.clone.default, memory_format=memory_format)\n    assert isinstance(out, EdgeIndex)\n    return out\n\n\n@implements(aten._to_copy.default)\ndef _to_copy(\n    tensor: EdgeIndex,\n    *,\n    dtype: Optional[torch.dtype] = None,\n    layout: Optional[torch.layout] = None,\n    device: Optional[torch.device] = None,\n    pin_memory: bool = False,\n    non_blocking: bool = False,\n    memory_format: Optional[torch.memory_format] = None,\n) -> Union[EdgeIndex, Tensor]:\n    return apply_(\n        tensor,\n        aten._to_copy.default,\n        dtype=dtype,\n        layout=layout,\n        device=device,\n        pin_memory=pin_memory,\n        non_blocking=non_blocking,\n        memory_format=memory_format,\n    )\n\n\n@implements(aten.alias.default)\ndef _alias(tensor: EdgeIndex) -> EdgeIndex:\n    return tensor._shallow_copy()\n\n\n@implements(aten._pin_memory.default)\ndef _pin_memory(tensor: EdgeIndex) -> EdgeIndex:\n    out = apply_(tensor, aten._pin_memory.default)\n    assert isinstance(out, EdgeIndex)\n    return out\n\n\n@implements(aten.cat.default)\ndef _cat(\n    tensors: List[Union[EdgeIndex, Tensor]],\n    dim: int = 0,\n) -> Union[EdgeIndex, Tensor]:\n\n    data_list = pytree.tree_map_only(EdgeIndex, lambda x: x._data, tensors)\n    data = aten.cat.default(data_list, dim=dim)\n\n    if dim != 1 and dim != -1:  # No valid `EdgeIndex` anymore.\n        return data\n\n    if any([not isinstance(tensor, EdgeIndex) for tensor in tensors]):\n        return data\n\n    out = EdgeIndex(data)\n\n    nnz_list = [t.size(1) for t in tensors]\n    sparse_size_list = [t.sparse_size() for t in tensors]  # type: ignore\n    sort_order_list = [t._sort_order for t in tensors]  # type: ignore\n    is_undirected_list = [t.is_undirected for t in tensors]  # type: ignore\n\n    # Post-process `sparse_size`:\n    total_num_rows: Optional[int] = 0\n    for num_rows, _ in sparse_size_list:\n        if num_rows is None:\n            total_num_rows = None\n            break\n        assert isinstance(total_num_rows, int)\n        total_num_rows = max(num_rows, total_num_rows)\n\n    total_num_cols: Optional[int] = 0\n    for _, num_cols in sparse_size_list:\n        if num_cols is None:\n            total_num_cols = None\n            break\n        assert isinstance(total_num_cols, int)\n        total_num_cols = max(num_cols, total_num_cols)\n\n    out._sparse_size = (total_num_rows, total_num_cols)\n\n    # Post-process `is_undirected`:\n    out._is_undirected = all(is_undirected_list)\n\n    out._cat_metadata = CatMetadata(\n        nnz=nnz_list,\n        sparse_size=sparse_size_list,\n        sort_order=sort_order_list,\n        is_undirected=is_undirected_list,\n    )\n\n    return out\n\n\n@implements(aten.flip.default)\ndef _flip(\n    input: EdgeIndex,\n    dims: Union[List[int], Tuple[int, ...]],\n) -> EdgeIndex:\n\n    data = aten.flip.default(input._data, dims)\n    out = EdgeIndex(data)\n\n    out._value = input._value\n    out._is_undirected = input.is_undirected\n\n    # Flip metadata and cache:\n    if 0 in dims or -2 in dims:\n        out._sparse_size = input.sparse_size()[::-1]\n\n    if len(dims) == 1 and (dims[0] == 0 or dims[0] == -2):\n        if input.is_sorted_by_row:\n            out._sort_order = SortOrder.COL\n        elif input.is_sorted_by_col:\n            out._sort_order = SortOrder.ROW\n\n        out._indptr = input._T_indptr\n        out._T_perm = input._T_perm\n        out._T_index = input._T_index[::-1]\n        out._T_indptr = input._indptr\n\n    return out\n\n\n@implements(aten.index_select.default)\ndef _index_select(\n    input: EdgeIndex,\n    dim: int,\n    index: Tensor,\n) -> Union[EdgeIndex, Tensor]:\n\n    out = aten.index_select.default(input._data, dim, index)\n\n    if dim == 1 or dim == -1:\n        out = EdgeIndex(out)\n        out._sparse_size = input.sparse_size()\n\n    return out\n\n\n@implements(aten.slice.Tensor)\ndef _slice(\n    input: EdgeIndex,\n    dim: int,\n    start: Optional[int] = None,\n    end: Optional[int] = None,\n    step: int = 1,\n) -> Union[EdgeIndex, Tensor]:\n\n    if ((start is None or start == 0 or start <= -input.size(dim))\n            and (end is None or end > input.size(dim)) and step == 1):\n        return input._shallow_copy()  # No-op.\n\n    out = aten.slice.Tensor(input._data, dim, start, end, step)\n\n    if dim == 1 or dim == -1:\n        if step != 1:\n            out = out.contiguous()\n\n        out = EdgeIndex(out)\n        out._sparse_size = input.sparse_size()\n        # NOTE We could potentially maintain `rowptr`/`colptr` attributes here,\n        # but it is not really clear if this is worth it. The most important\n        # information, the sort order, needs to be maintained though:\n        if step >= 0:\n            out._sort_order = input._sort_order\n        else:\n            if input._sort_order == SortOrder.ROW:\n                out._sort_order = SortOrder.COL\n            elif input._sort_order == SortOrder.COL:\n                out._sort_order = SortOrder.ROW\n\n    return out\n\n\n@implements(aten.index.Tensor)\ndef _index(\n    input: Union[EdgeIndex, Tensor],\n    indices: List[Optional[Union[Tensor, EdgeIndex]]],\n) -> Union[EdgeIndex, Tensor]:\n\n    if not isinstance(input, EdgeIndex):\n        indices = pytree.tree_map_only(EdgeIndex, lambda x: x._data, indices)\n        return aten.index.Tensor(input, indices)\n\n    out = aten.index.Tensor(input._data, indices)\n\n    if len(indices) != 2 or indices[0] is not None:\n        return out\n\n    index = indices[1]\n    assert isinstance(index, Tensor)\n\n    out = EdgeIndex(out)\n\n    # 1. `edge_index[:, mask]` or `edge_index[..., mask]`.\n    if index.dtype in (torch.bool, torch.uint8):\n        out._sparse_size = input.sparse_size()\n        out._sort_order = input._sort_order\n\n    else:  # 2. `edge_index[:, index]` or `edge_index[..., index]`.\n        out._sparse_size = input.sparse_size()\n\n    return out\n\n\n@implements(aten.select.int)\ndef _select(input: EdgeIndex, dim: int, index: int) -> Union[Tensor, Index]:\n    out = aten.select.int(input._data, dim, index)\n\n    if dim == 0 or dim == -2:\n        out = Index(out)\n\n        if index == 0 or index == -2:  # Row-select:\n            out._dim_size = input.sparse_size(0)\n            out._is_sorted = input.is_sorted_by_row\n            if input.is_sorted_by_row:\n                out._indptr = input._indptr\n\n        else:  # Col-select:\n            assert index == 1 or index == -1\n            out._dim_size = input.sparse_size(1)\n            out._is_sorted = input.is_sorted_by_col\n            if input.is_sorted_by_col:\n                out._indptr = input._indptr\n\n    return out\n\n\n@implements(aten.unbind.int)\ndef _unbind(\n    input: EdgeIndex,\n    dim: int = 0,\n) -> Union[List[Index], List[Tensor]]:\n\n    if dim == 0 or dim == -2:\n        row = input[0]\n        assert isinstance(row, Index)\n        col = input[1]\n        assert isinstance(col, Index)\n        return [row, col]\n\n    return aten.unbind.int(input._data, dim)\n\n\n@implements(aten.add.Tensor)\ndef _add(\n    input: EdgeIndex,\n    other: Union[int, Tensor, EdgeIndex],\n    *,\n    alpha: int = 1,\n) -> Union[EdgeIndex, Tensor]:\n\n    out = aten.add.Tensor(\n        input._data,\n        other._data if isinstance(other, EdgeIndex) else other,\n        alpha=alpha,\n    )\n\n    if out.dtype not in INDEX_DTYPES:\n        return out\n    if out.dim() != 2 or out.size(0) != 2:\n        return out\n\n    out = EdgeIndex(out)\n\n    if isinstance(other, Tensor) and other.numel() <= 1:\n        other = int(other)\n\n    if isinstance(other, int):\n        size = maybe_add(input._sparse_size, other, alpha)\n        assert len(size) == 2\n        out._sparse_size = size\n        out._sort_order = input._sort_order\n        out._is_undirected = input.is_undirected\n        out._T_perm = input._T_perm\n\n    elif isinstance(other, Tensor) and other.size() == (2, 1):\n        size = maybe_add(input._sparse_size, other.view(-1).tolist(), alpha)\n        assert len(size) == 2\n        out._sparse_size = size\n        out._sort_order = input._sort_order\n        if torch.equal(other[0], other[1]):\n            out._is_undirected = input.is_undirected\n        out._T_perm = input._T_perm\n\n    elif isinstance(other, EdgeIndex):\n        size = maybe_add(input._sparse_size, other._sparse_size, alpha)\n        assert len(size) == 2\n        out._sparse_size = size\n\n    return out\n\n\n@implements(aten.add_.Tensor)\ndef add_(\n    input: EdgeIndex,\n    other: Union[int, Tensor, EdgeIndex],\n    *,\n    alpha: int = 1,\n) -> EdgeIndex:\n\n    sparse_size = input._sparse_size\n    sort_order = input._sort_order\n    is_undirected = input._is_undirected\n    T_perm = input._T_perm\n    input._clear_metadata()\n\n    aten.add_.Tensor(\n        input._data,\n        other._data if isinstance(other, EdgeIndex) else other,\n        alpha=alpha,\n    )\n\n    if isinstance(other, Tensor) and other.numel() <= 1:\n        other = int(other)\n\n    if isinstance(other, int):\n        size = maybe_add(sparse_size, other, alpha)\n        assert len(size) == 2\n        input._sparse_size = size\n        input._sort_order = sort_order\n        input._is_undirected = is_undirected\n        input._T_perm = T_perm\n\n    elif isinstance(other, Tensor) and other.size() == (2, 1):\n        size = maybe_add(sparse_size, other.view(-1).tolist(), alpha)\n        assert len(size) == 2\n        input._sparse_size = size\n        input._sort_order = sort_order\n        if torch.equal(other[0], other[1]):\n            input._is_undirected = is_undirected\n        input._T_perm = T_perm\n\n    elif isinstance(other, EdgeIndex):\n        size = maybe_add(sparse_size, other._sparse_size, alpha)\n        assert len(size) == 2\n        input._sparse_size = size\n\n    return input\n\n\n@implements(aten.sub.Tensor)\ndef _sub(\n    input: EdgeIndex,\n    other: Union[int, Tensor, EdgeIndex],\n    *,\n    alpha: int = 1,\n) -> Union[EdgeIndex, Tensor]:\n\n    out = aten.sub.Tensor(\n        input._data,\n        other._data if isinstance(other, EdgeIndex) else other,\n        alpha=alpha,\n    )\n\n    if out.dtype not in INDEX_DTYPES:\n        return out\n    if out.dim() != 2 or out.size(0) != 2:\n        return out\n\n    out = EdgeIndex(out)\n\n    if isinstance(other, Tensor) and other.numel() <= 1:\n        other = int(other)\n\n    if isinstance(other, int):\n        size = maybe_sub(input._sparse_size, other, alpha)\n        assert len(size) == 2\n        out._sparse_size = size\n        out._sort_order = input._sort_order\n        out._is_undirected = input.is_undirected\n        out._T_perm = input._T_perm\n\n    elif isinstance(other, Tensor) and other.size() == (2, 1):\n        size = maybe_sub(input._sparse_size, other.view(-1).tolist(), alpha)\n        assert len(size) == 2\n        out._sparse_size = size\n        out._sort_order = input._sort_order\n        if torch.equal(other[0], other[1]):\n            out._is_undirected = input.is_undirected\n        out._T_perm = input._T_perm\n\n    return out\n\n\n@implements(aten.sub_.Tensor)\ndef sub_(\n    input: EdgeIndex,\n    other: Union[int, Tensor, EdgeIndex],\n    *,\n    alpha: int = 1,\n) -> EdgeIndex:\n\n    sparse_size = input._sparse_size\n    sort_order = input._sort_order\n    is_undirected = input._is_undirected\n    T_perm = input._T_perm\n    input._clear_metadata()\n\n    aten.sub_.Tensor(\n        input._data,\n        other._data if isinstance(other, EdgeIndex) else other,\n        alpha=alpha,\n    )\n\n    if isinstance(other, Tensor) and other.numel() <= 1:\n        other = int(other)\n\n    if isinstance(other, int):\n        size = maybe_sub(sparse_size, other, alpha)\n        assert len(size) == 2\n        input._sparse_size = size\n        input._sort_order = sort_order\n        input._is_undirected = is_undirected\n        input._T_perm = T_perm\n\n    elif isinstance(other, Tensor) and other.size() == (2, 1):\n        size = maybe_sub(sparse_size, other.view(-1).tolist(), alpha)\n        assert len(size) == 2\n        input._sparse_size = size\n        input._sort_order = sort_order\n        if torch.equal(other[0], other[1]):\n            input._is_undirected = is_undirected\n        input._T_perm = T_perm\n\n    return input\n\n\n# Sparse-Dense Matrix Multiplication ##########################################\n\n\ndef _torch_sparse_spmm(\n    input: EdgeIndex,\n    other: Tensor,\n    value: Optional[Tensor] = None,\n    reduce: ReduceType = 'sum',\n    transpose: bool = False,\n) -> Tensor:\n    # `torch-sparse` still provides a faster sparse-dense matrix multiplication\n    # code path on GPUs (after all these years...):\n    assert torch_geometric.typing.WITH_TORCH_SPARSE\n    reduce = PYG_REDUCE[reduce] if reduce in PYG_REDUCE else reduce\n\n    # Optional arguments for backpropagation:\n    colptr: Optional[Tensor] = None\n    perm: Optional[Tensor] = None\n\n    if not transpose:\n        assert input.is_sorted_by_row\n        (rowptr, col), _ = input.get_csr()\n        row = input._data[0]\n        if other.requires_grad and reduce in ['sum', 'mean']:\n            (colptr, _), perm = input.get_csc()\n    else:\n        assert input.is_sorted_by_col\n        (rowptr, col), _ = input.get_csc()\n        row = input._data[1]\n        if other.requires_grad and reduce in ['sum', 'mean']:\n            (colptr, _), perm = input.get_csr()\n\n    if reduce == 'sum':\n        return torch.ops.torch_sparse.spmm_sum(  #\n            row, rowptr, col, value, colptr, perm, other)\n\n    if reduce == 'mean':\n        rowcount = rowptr.diff() if other.requires_grad else None\n        return torch.ops.torch_sparse.spmm_mean(  #\n            row, rowptr, col, value, rowcount, colptr, perm, other)\n\n    if reduce == 'min':\n        return torch.ops.torch_sparse.spmm_min(rowptr, col, value, other)[0]\n\n    if reduce == 'max':\n        return torch.ops.torch_sparse.spmm_max(rowptr, col, value, other)[0]\n\n    raise NotImplementedError\n\n\nclass _TorchSPMM(torch.autograd.Function):\n    @staticmethod\n    def forward(\n        ctx: Any,\n        input: EdgeIndex,\n        other: Tensor,\n        value: Optional[Tensor] = None,\n        reduce: ReduceType = 'sum',\n        transpose: bool = False,\n    ) -> Tensor:\n\n        reduce = TORCH_REDUCE[reduce] if reduce in TORCH_REDUCE else reduce\n\n        value = value.detach() if value is not None else value\n        if other.requires_grad:\n            other = other.detach()\n            ctx.save_for_backward(input, value)\n            ctx.reduce = reduce\n            ctx.transpose = transpose\n\n        if not transpose:\n            assert input.is_sorted_by_row\n            adj = input.to_sparse_csr(value)\n        else:\n            assert input.is_sorted_by_col\n            adj = input.to_sparse_csc(value).t()\n\n        if torch_geometric.typing.WITH_PT20 and not other.is_cuda:\n            return torch.sparse.mm(adj, other, reduce)\n        else:  # pragma: no cover\n            assert reduce == 'sum'\n            return adj @ other\n\n    @staticmethod\n    def backward(\n        ctx: Any,\n        *grad_outputs: Any,\n    ) -> Tuple[None, Optional[Tensor], None, None, None]:\n\n        grad_out, = grad_outputs\n\n        other_grad: Optional[Tensor] = None\n        if ctx.needs_input_grad[1]:\n            input, value = ctx.saved_tensors\n            assert ctx.reduce == 'sum'\n\n            if not ctx.transpose:\n                if value is None and input.is_undirected:\n                    adj = input.to_sparse_csr(value)\n                else:\n                    (colptr, row), perm = input.get_csc()\n                    if value is not None and perm is not None:\n                        value = value[perm]\n                    else:\n                        value = input._get_value()\n                    adj = torch.sparse_csr_tensor(\n                        crow_indices=colptr,\n                        col_indices=row,\n                        values=value,\n                        size=input.get_sparse_size()[::-1],\n                        device=input.device,\n                    )\n            else:\n                if value is None and input.is_undirected:\n                    adj = input.to_sparse_csc(value).t()\n                else:\n                    (rowptr, col), perm = input.get_csr()\n                    if value is not None and perm is not None:\n                        value = value[perm]\n                    else:\n                        value = input._get_value()\n                    adj = torch.sparse_csr_tensor(\n                        crow_indices=rowptr,\n                        col_indices=col,\n                        values=value,\n                        size=input.get_sparse_size()[::-1],\n                        device=input.device,\n                    )\n\n            other_grad = adj @ grad_out\n\n        if ctx.needs_input_grad[2]:\n            raise NotImplementedError(\"Gradient computation for 'value' not \"\n                                      \"yet supported\")\n\n        return None, other_grad, None, None, None\n\n\ndef _scatter_spmm(\n    input: EdgeIndex,\n    other: Tensor,\n    value: Optional[Tensor] = None,\n    reduce: ReduceType = 'sum',\n    transpose: bool = False,\n) -> Tensor:\n    from torch_geometric.utils import scatter\n\n    if not transpose:\n        other_j = other[input._data[1]]\n        index = input._data[0]\n        dim_size = input.get_sparse_size(0)\n    else:\n        other_j = other[input._data[0]]\n        index = input._data[1]\n        dim_size = input.get_sparse_size(1)\n\n    other_j = other_j * value.view(-1, 1) if value is not None else other_j\n    return scatter(other_j, index, 0, dim_size=dim_size, reduce=reduce)\n\n\ndef _spmm(\n    input: EdgeIndex,\n    other: Tensor,\n    value: Optional[Tensor] = None,\n    reduce: ReduceType = 'sum',\n    transpose: bool = False,\n) -> Tensor:\n\n    if reduce not in get_args(ReduceType):\n        raise ValueError(f\"`reduce='{reduce}'` is not a valid reduction\")\n\n    if not transpose and not input.is_sorted_by_row:\n        cls_name = input.__class__.__name__\n        raise ValueError(f\"'matmul(..., transpose=False)' requires \"\n                         f\"'{cls_name}' to be sorted by rows\")\n\n    if transpose and not input.is_sorted_by_col:\n        cls_name = input.__class__.__name__\n        raise ValueError(f\"'matmul(..., transpose=True)' requires \"\n                         f\"'{cls_name}' to be sorted by columns\")\n\n    if (torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling()\n            and other.is_cuda):  # pragma: no cover\n        return _torch_sparse_spmm(input, other, value, reduce, transpose)\n\n    if value is not None and value.requires_grad:\n        if torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling():\n            return _torch_sparse_spmm(input, other, value, reduce, transpose)\n        return _scatter_spmm(input, other, value, reduce, transpose)\n\n    if torch_geometric.typing.WITH_PT20:\n        if reduce == 'sum' or reduce == 'add':\n            return _TorchSPMM.apply(input, other, value, 'sum', transpose)\n\n        if reduce == 'mean':\n            out = _TorchSPMM.apply(input, other, value, 'sum', transpose)\n            count = input.get_indptr().diff()\n            return out / count.clamp_(min=1).to(out.dtype).view(-1, 1)\n\n        if not other.is_cuda and not other.requires_grad:\n            return _TorchSPMM.apply(input, other, value, reduce, transpose)\n\n    if torch_geometric.typing.WITH_TORCH_SPARSE and not is_compiling():\n        return _torch_sparse_spmm(input, other, value, reduce, transpose)\n\n    return _scatter_spmm(input, other, value, reduce, transpose)\n\n\ndef matmul(\n    input: EdgeIndex,\n    other: Union[Tensor, EdgeIndex],\n    input_value: Optional[Tensor] = None,\n    other_value: Optional[Tensor] = None,\n    reduce: ReduceType = 'sum',\n    transpose: bool = False,\n) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]:\n\n    if not isinstance(other, EdgeIndex):\n        if other_value is not None:\n            raise ValueError(\"'other_value' not supported for sparse-dense \"\n                             \"matrix multiplication\")\n        return _spmm(input, other, input_value, reduce, transpose)\n\n    if reduce not in ['sum', 'add']:\n        raise NotImplementedError(f\"`reduce='{reduce}'` not yet supported for \"\n                                  f\"sparse-sparse matrix multiplication\")\n\n    transpose &= not input.is_undirected or input_value is not None\n\n    if torch_geometric.typing.NO_MKL:  # pragma: no cover\n        sparse_input = input.to_sparse_coo(input_value)\n    elif input.is_sorted_by_col:\n        sparse_input = input.to_sparse_csc(input_value)\n    else:\n        sparse_input = input.to_sparse_csr(input_value)\n\n    if transpose:\n        sparse_input = sparse_input.t()\n\n    if torch_geometric.typing.NO_MKL:  # pragma: no cover\n        other = other.to_sparse_coo(other_value)\n    elif other.is_sorted_by_col:\n        other = other.to_sparse_csc(other_value)\n    else:\n        other = other.to_sparse_csr(other_value)\n\n    out = torch.matmul(sparse_input, other)\n\n    rowptr: Optional[Tensor] = None\n    if out.layout == torch.sparse_csr:\n        rowptr = out.crow_indices().to(input.dtype)\n        col = out.col_indices().to(input.dtype)\n        edge_index = torch._convert_indices_from_csr_to_coo(\n            rowptr, col, out_int32=rowptr.dtype != torch.int64)\n\n    elif out.layout == torch.sparse_coo:  # pragma: no cover\n        out = out.coalesce()\n        edge_index = out.indices()\n\n    else:\n        raise NotImplementedError\n\n    edge_index = EdgeIndex(edge_index)\n    edge_index._sort_order = SortOrder.ROW\n    edge_index._sparse_size = (out.size(0), out.size(1))\n    edge_index._indptr = rowptr\n\n    return edge_index, out.values()\n\n\n@implements(aten.mm.default)\ndef _mm(\n    input: EdgeIndex,\n    other: Union[Tensor, EdgeIndex],\n) -> Union[Tensor, Tuple[EdgeIndex, Tensor]]:\n    return matmul(input, other)\n\n\n@implements(aten._sparse_addmm.default)\ndef _addmm(\n    input: Tensor,\n    mat1: EdgeIndex,\n    mat2: Tensor,\n    beta: float = 1.0,\n    alpha: float = 1.0,\n) -> Tensor:\n    assert input.abs().sum() == 0.0\n    out = matmul(mat1, mat2)\n    assert isinstance(out, Tensor)\n    return alpha * out if alpha != 1.0 else out\n\n\nif hasattr(aten, '_sparse_mm_reduce_impl'):\n\n    @implements(aten._sparse_mm_reduce_impl.default)\n    def _mm_reduce(\n        mat1: EdgeIndex,\n        mat2: Tensor,\n        reduce: ReduceType = 'sum',\n    ) -> Tuple[Tensor, Tensor]:\n        out = matmul(mat1, mat2, reduce=reduce)\n        assert isinstance(out, Tensor)\n        return out, out  # We return a dummy tensor for `argout` for now.\n"
  },
  {
    "path": "torch_geometric/experimental.py",
    "content": "import functools\nimport inspect\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport torch\n\n# TODO (matthias) This file currently requires manual imports to let\n# TorchScript work on decorated functions. Not totally sure why :(\nfrom torch_geometric.utils import *  # noqa\n\n__experimental_flag__: Dict[str, bool] = {\n    'disable_dynamic_shapes': False,\n}\n\nOptions = Optional[Union[str, List[str]]]\n\n\ndef get_options(options: Options) -> List[str]:\n    if options is None:\n        options = list(__experimental_flag__.keys())\n    if isinstance(options, str):\n        options = [options]\n    return options\n\n\ndef is_experimental_mode_enabled(options: Options = None) -> bool:\n    r\"\"\"Returns :obj:`True` if the experimental mode is enabled. See\n    :class:`torch_geometric.experimental_mode` for a list of (optional)\n    options.\n    \"\"\"\n    if torch.jit.is_scripting() or torch.jit.is_tracing():\n        return False\n    options = get_options(options)\n    return all([__experimental_flag__[option] for option in options])\n\n\ndef set_experimental_mode_enabled(mode: bool, options: Options = None) -> None:\n    for option in get_options(options):\n        __experimental_flag__[option] = mode\n\n\nclass experimental_mode:\n    r\"\"\"Context-manager that enables the experimental mode to test new but\n    potentially unstable features.\n\n    .. code-block:: python\n\n        with torch_geometric.experimental_mode():\n            out = model(data.x, data.edge_index)\n\n    Args:\n        options (str or list, optional): Currently there are no experimental\n            features.\n    \"\"\"\n    def __init__(self, options: Options = None) -> None:\n        self.options = get_options(options)\n        self.previous_state = {\n            option: __experimental_flag__[option]\n            for option in self.options\n        }\n\n    def __enter__(self) -> None:\n        set_experimental_mode_enabled(True, self.options)\n\n    def __exit__(self, *args: Any) -> None:\n        for option, value in self.previous_state.items():\n            __experimental_flag__[option] = value\n\n\nclass set_experimental_mode:\n    r\"\"\"Context-manager that sets the experimental mode on or off.\n\n    :class:`set_experimental_mode` will enable or disable the experimental mode\n    based on its argument :attr:`mode`.\n    It can be used as a context-manager or as a function.\n\n    See :class:`experimental_mode` above for more details.\n    \"\"\"\n    def __init__(self, mode: bool, options: Options = None) -> None:\n        self.options = get_options(options)\n        self.previous_state = {\n            option: __experimental_flag__[option]\n            for option in self.options\n        }\n        set_experimental_mode_enabled(mode, self.options)\n\n    def __enter__(self) -> None:\n        pass\n\n    def __exit__(self, *args: Any) -> None:\n        for option, value in self.previous_state.items():\n            __experimental_flag__[option] = value\n\n\ndef disable_dynamic_shapes(required_args: List[str]) -> Callable:\n    r\"\"\"A decorator that disables the usage of dynamic shapes for the given\n    arguments, i.e., it will raise an error in case :obj:`required_args` are\n    not passed and needs to be automatically inferred.\n    \"\"\"\n    def decorator(func: Callable) -> Callable:\n        spec = inspect.getfullargspec(func)\n\n        required_args_pos: Dict[str, int] = {}\n        for arg_name in required_args:\n            if arg_name not in spec.args:\n                raise ValueError(f\"The function '{func}' does not have a \"\n                                 f\"'{arg_name}' argument\")\n            required_args_pos[arg_name] = spec.args.index(arg_name)\n\n        num_args = len(spec.args)\n        num_default_args = 0 if spec.defaults is None else len(spec.defaults)\n        num_positional_args = num_args - num_default_args\n\n        @functools.wraps(func)\n        def wrapper(*args: Any, **kwargs: Any) -> Any:\n            if not is_experimental_mode_enabled('disable_dynamic_shapes'):\n                return func(*args, **kwargs)\n\n            for required_arg in required_args:\n                index = required_args_pos[required_arg]\n\n                value: Optional[Any] = None\n                if index < len(args):\n                    value = args[index]\n                elif required_arg in kwargs:\n                    value = kwargs[required_arg]\n                elif num_default_args > 0:\n                    assert spec.defaults is not None\n                    value = spec.defaults[index - num_positional_args]\n\n                if value is None:\n                    raise ValueError(f\"Dynamic shapes disabled. Argument \"\n                                     f\"'{required_arg}' needs to be set\")\n\n            return func(*args, **kwargs)\n\n        return wrapper\n\n    return decorator\n"
  },
  {
    "path": "torch_geometric/explain/__init__.py",
    "content": "from .config import ExplainerConfig, ModelConfig, ThresholdConfig\nfrom .explanation import Explanation, HeteroExplanation\nfrom .algorithm import *  # noqa\nfrom .explainer import Explainer\nfrom .metric import *  # noqa\n\n__all__ = [\n    'ExplainerConfig',\n    'ModelConfig',\n    'ThresholdConfig',\n    'Explanation',\n    'HeteroExplanation',\n    'Explainer',\n]\n"
  },
  {
    "path": "torch_geometric/explain/algorithm/__init__.py",
    "content": "from .base import ExplainerAlgorithm\nfrom .dummy_explainer import DummyExplainer\nfrom .gnn_explainer import GNNExplainer\nfrom .captum_explainer import CaptumExplainer\nfrom .pg_explainer import PGExplainer\nfrom .attention_explainer import AttentionExplainer\nfrom .graphmask_explainer import GraphMaskExplainer\n\n__all__ = classes = [\n    'ExplainerAlgorithm',\n    'DummyExplainer',\n    'GNNExplainer',\n    'CaptumExplainer',\n    'PGExplainer',\n    'AttentionExplainer',\n    'GraphMaskExplainer',\n]\n"
  },
  {
    "path": "torch_geometric/explain/algorithm/attention_explainer.py",
    "content": "import logging\nfrom typing import Dict, List, Optional, Union, overload\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explanation, HeteroExplanation\nfrom torch_geometric.explain.algorithm import ExplainerAlgorithm\nfrom torch_geometric.explain.config import ExplanationType, ModelTaskLevel\nfrom torch_geometric.nn.conv.message_passing import MessagePassing\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\nclass AttentionExplainer(ExplainerAlgorithm):\n    r\"\"\"An explainer that uses the attention coefficients produced by an\n    attention-based GNN (*e.g.*,\n    :class:`~torch_geometric.nn.conv.GATConv`,\n    :class:`~torch_geometric.nn.conv.GATv2Conv`, or\n    :class:`~torch_geometric.nn.conv.TransformerConv`) as edge explanation.\n    Attention scores across layers and heads will be aggregated according to\n    the :obj:`reduce` argument.\n\n    Args:\n        reduce (str, optional): The method to reduce the attention scores\n            across layers and heads. (default: :obj:`\"max\"`)\n    \"\"\"\n    def __init__(self, reduce: str = 'max'):\n        super().__init__()\n        self.reduce = reduce\n        self.is_hetero = False\n\n    @overload\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> Explanation:\n        ...\n\n    @overload\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Dict[NodeType, Tensor],\n        edge_index: Dict[EdgeType, Tensor],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> HeteroExplanation:\n        ...\n\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> Union[Explanation, HeteroExplanation]:\n        \"\"\"Generate explanations based on attention coefficients.\"\"\"\n        self.is_hetero = isinstance(x, dict)\n\n        # Collect attention coefficients\n        alphas_dict = self._collect_attention_coefficients(\n            model, x, edge_index, **kwargs)\n\n        # Process attention coefficients\n        if self.is_hetero:\n            return self._create_hetero_explanation(model, alphas_dict,\n                                                   edge_index, index, x)\n        else:\n            return self._create_homo_explanation(model, alphas_dict,\n                                                 edge_index, index, x)\n\n    @overload\n    def _collect_attention_coefficients(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        **kwargs,\n    ) -> List[Tensor]:\n        ...\n\n    @overload\n    def _collect_attention_coefficients(\n        self,\n        model: torch.nn.Module,\n        x: Dict[NodeType, Tensor],\n        edge_index: Dict[EdgeType, Tensor],\n        **kwargs,\n    ) -> Dict[EdgeType, List[Tensor]]:\n        ...\n\n    def _collect_attention_coefficients(\n        self,\n        model: torch.nn.Module,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        **kwargs,\n    ) -> Union[List[Tensor], Dict[EdgeType, List[Tensor]]]:\n        \"\"\"Collect attention coefficients from model layers.\"\"\"\n        if self.is_hetero:\n            # For heterogeneous graphs, store alphas by edge type\n            alphas_dict: Dict[EdgeType, List[Tensor]] = {}\n\n            # Get list of edge types\n            edge_types = list(edge_index.keys())\n\n            # Hook function to capture attention coefficients by edge type\n            def hook(module, msg_kwargs, out):\n                # Find edge type from the module's full name\n                module_name = getattr(module, '_name', None)\n                if module_name is None:\n                    return\n\n                edge_type = None\n                for edge_tuple in edge_types:\n                    src_type, edge_name, dst_type = edge_tuple\n                    # Check if all components appear in the module name in\n                    # order\n                    try:\n                        src_idx = module_name.index(src_type)\n                        edge_idx = module_name.index(edge_name, src_idx)\n                        dst_idx = module_name.index(dst_type, edge_idx)\n                        if src_idx < edge_idx < dst_idx:\n                            edge_type = edge_tuple\n                            break\n                    except ValueError:  # Component not found\n                        continue\n\n                if edge_type is None:\n                    return\n\n                if edge_type not in alphas_dict:\n                    alphas_dict[edge_type] = []\n\n                # Extract alpha from message kwargs or module\n                if 'alpha' in msg_kwargs[0]:\n                    alphas_dict[edge_type].append(\n                        msg_kwargs[0]['alpha'].detach())\n                elif getattr(module, '_alpha', None) is not None:\n                    alphas_dict[edge_type].append(module._alpha.detach())\n        else:\n            # For homogeneous graphs, store all alphas in a list\n            alphas: List[Tensor] = []\n\n            def hook(module, msg_kwargs, out):\n                if 'alpha' in msg_kwargs[0]:\n                    alphas.append(msg_kwargs[0]['alpha'].detach())\n                elif getattr(module, '_alpha', None) is not None:\n                    alphas.append(module._alpha.detach())\n\n        # Register hooks for all message passing modules\n        hook_handles = []\n        for name, module in model.named_modules():\n            if isinstance(module,\n                          MessagePassing) and module.explain is not False:\n                # Store name for hetero graph lookup in the hook\n                if self.is_hetero:\n                    module._name = name\n\n                hook_handles.append(module.register_message_forward_hook(hook))\n\n        # Forward pass to collect attention coefficients.\n        model(x, edge_index, **kwargs)\n\n        # Remove hooks\n        for handle in hook_handles:\n            handle.remove()\n\n        # Check if we collected any attention coefficients.\n        if self.is_hetero:\n            if not alphas_dict:\n                raise ValueError(\n                    \"Could not collect any attention coefficients. \"\n                    \"Please ensure that your model is using \"\n                    \"attention-based GNN layers.\")\n            return alphas_dict\n        else:\n            if not alphas:\n                raise ValueError(\n                    \"Could not collect any attention coefficients. \"\n                    \"Please ensure that your model is using \"\n                    \"attention-based GNN layers.\")\n            return alphas\n\n    def _process_attention_coefficients(\n        self,\n        alphas: List[Tensor],\n        edge_index_size: int,\n    ) -> Tensor:\n        \"\"\"Process collected attention coefficients into a single mask.\"\"\"\n        for i, alpha in enumerate(alphas):\n            # Ensure alpha doesn't exceed edge_index size\n            alpha = alpha[:edge_index_size]\n\n            # Reduce multi-head attention\n            if alpha.dim() == 2:\n                alpha = getattr(torch, self.reduce)(alpha, dim=-1)\n                if isinstance(alpha, tuple):  # Handle torch.max output\n                    alpha = alpha[0]\n            elif alpha.dim() > 2:\n                raise ValueError(f\"Cannot reduce attention coefficients of \"\n                                 f\"shape {list(alpha.size())}\")\n            alphas[i] = alpha\n\n        # Combine attention coefficients across layers\n        if len(alphas) > 1:\n            alpha = torch.stack(alphas, dim=-1)\n            alpha = getattr(torch, self.reduce)(alpha, dim=-1)\n            if isinstance(alpha, tuple):  # Handle torch.max output\n                alpha = alpha[0]\n        else:\n            alpha = alphas[0]\n\n        return alpha\n\n    def _create_homo_explanation(\n        self,\n        model: torch.nn.Module,\n        alphas: List[Tensor],\n        edge_index: Tensor,\n        index: Optional[Union[int, Tensor]],\n        x: Tensor,\n    ) -> Explanation:\n        \"\"\"Create explanation for homogeneous graph.\"\"\"\n        # Get hard edge mask for node-level tasks\n        hard_edge_mask = None\n        if self.model_config.task_level == ModelTaskLevel.node:\n            _, hard_edge_mask = self._get_hard_masks(model, index, edge_index,\n                                                     num_nodes=x.size(0))\n\n        # Process attention coefficients\n        alpha = self._process_attention_coefficients(alphas,\n                                                     edge_index.size(1))\n\n        # Post-process mask with hard edge mask if needed\n        alpha = self._post_process_mask(alpha, hard_edge_mask,\n                                        apply_sigmoid=False)\n\n        return Explanation(edge_mask=alpha)\n\n    def _create_hetero_explanation(\n        self,\n        model: torch.nn.Module,\n        alphas_dict: Dict[EdgeType, List[Tensor]],\n        edge_index: Dict[EdgeType, Tensor],\n        index: Optional[Union[int, Tensor]],\n        x: Dict[NodeType, Tensor],\n    ) -> HeteroExplanation:\n        \"\"\"Create explanation for heterogeneous graph.\"\"\"\n        edge_masks_dict = {}\n\n        # Process each edge type separately\n        for edge_type, alphas in alphas_dict.items():\n            if not alphas:\n                continue\n\n            # Get hard edge mask for node-level tasks\n            hard_edge_mask = None\n            if self.model_config.task_level == ModelTaskLevel.node:\n                src_type, _, dst_type = edge_type\n                _, hard_edge_mask = self._get_hard_masks(\n                    model, index, edge_index[edge_type],\n                    num_nodes=max(x[src_type].size(0), x[dst_type].size(0)))\n\n            # Process attention coefficients for this edge type\n            alpha = self._process_attention_coefficients(\n                alphas, edge_index[edge_type].size(1))\n\n            # Apply hard mask if available\n            edge_masks_dict[edge_type] = self._post_process_mask(\n                alpha, hard_edge_mask, apply_sigmoid=False)\n\n        # Create heterogeneous explanation\n        explanation = HeteroExplanation()\n        explanation.set_value_dict('edge_mask', edge_masks_dict)\n        return explanation\n\n    def supports(self) -> bool:\n        explanation_type = self.explainer_config.explanation_type\n        if explanation_type != ExplanationType.model:\n            logging.error(f\"'{self.__class__.__name__}' only supports \"\n                          f\"model explanations \"\n                          f\"got (`explanation_type={explanation_type.value}`)\")\n            return False\n\n        node_mask_type = self.explainer_config.node_mask_type\n        if node_mask_type is not None:\n            logging.error(f\"'{self.__class__.__name__}' does not support \"\n                          f\"explaining input node features \"\n                          f\"got (`node_mask_type={node_mask_type.value}`)\")\n            return False\n\n        return True\n"
  },
  {
    "path": "torch_geometric/explain/algorithm/base.py",
    "content": "from abc import abstractmethod\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explanation, HeteroExplanation\nfrom torch_geometric.explain.config import (\n    ExplainerConfig,\n    ModelConfig,\n    ModelReturnType,\n)\nfrom torch_geometric.nn import MessagePassing\nfrom torch_geometric.typing import EdgeType, NodeType\nfrom torch_geometric.utils import k_hop_subgraph\n\n\nclass ExplainerAlgorithm(torch.nn.Module):\n    r\"\"\"An abstract base class for implementing explainer algorithms.\"\"\"\n    @abstractmethod\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> Union[Explanation, HeteroExplanation]:\n        r\"\"\"Computes the explanation.\n\n        Args:\n            model (torch.nn.Module): The model to explain.\n            x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input\n                node features of a homogeneous or heterogeneous graph.\n            edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The\n                input edge indices of a homogeneous or heterogeneous graph.\n            target (torch.Tensor): The target of the model.\n            index (Union[int, Tensor], optional): The index of the model\n                output to explain. Can be a single index or a tensor of\n                indices. (default: :obj:`None`)\n            **kwargs (optional): Additional keyword arguments passed to\n                :obj:`model`.\n        \"\"\"\n\n    @abstractmethod\n    def supports(self) -> bool:\n        r\"\"\"Checks if the explainer supports the user-defined settings provided\n        in :obj:`self.explainer_config`, :obj:`self.model_config`.\n        \"\"\"\n\n    ###########################################################################\n\n    @property\n    def explainer_config(self) -> ExplainerConfig:\n        r\"\"\"Returns the connected explainer configuration.\"\"\"\n        if not hasattr(self, '_explainer_config'):\n            raise ValueError(\n                f\"The explanation algorithm '{self.__class__.__name__}' is \"\n                f\"not yet connected to any explainer configuration. Please \"\n                f\"call `{self.__class__.__name__}.connect(...)` before \"\n                f\"proceeding.\")\n        return self._explainer_config\n\n    @property\n    def model_config(self) -> ModelConfig:\n        r\"\"\"Returns the connected model configuration.\"\"\"\n        if not hasattr(self, '_model_config'):\n            raise ValueError(\n                f\"The explanation algorithm '{self.__class__.__name__}' is \"\n                f\"not yet connected to any model configuration. Please call \"\n                f\"`{self.__class__.__name__}.connect(...)` before \"\n                f\"proceeding.\")\n        return self._model_config\n\n    def connect(\n        self,\n        explainer_config: ExplainerConfig,\n        model_config: ModelConfig,\n    ):\n        r\"\"\"Connects an explainer and model configuration to the explainer\n        algorithm.\n        \"\"\"\n        self._explainer_config = ExplainerConfig.cast(explainer_config)\n        self._model_config = ModelConfig.cast(model_config)\n\n        if not self.supports():\n            raise ValueError(\n                f\"The explanation algorithm '{self.__class__.__name__}' does \"\n                f\"not support the given explanation settings.\")\n\n    # Helper functions ########################################################\n\n    @staticmethod\n    def _post_process_mask(\n        mask: Optional[Tensor],\n        hard_mask: Optional[Tensor] = None,\n        apply_sigmoid: bool = True,\n    ) -> Optional[Tensor]:\n        r\"\"\"\"Post processes any mask to not include any attributions of\n        elements not involved during message passing.\n        \"\"\"\n        if mask is None:\n            return mask\n\n        mask = mask.detach()\n\n        if apply_sigmoid:\n            mask = mask.sigmoid()\n\n        if hard_mask is not None and mask.size(0) == hard_mask.size(0):\n            mask[~hard_mask] = 0.\n\n        return mask\n\n    @staticmethod\n    def _get_hard_masks(\n        model: torch.nn.Module,\n        node_index: Optional[Union[int, Tensor]],\n        edge_index: Tensor,\n        num_nodes: int,\n    ) -> Tuple[Optional[Tensor], Optional[Tensor]]:\n        r\"\"\"Returns hard node and edge masks that only include the nodes and\n        edges visited during message passing.\n        \"\"\"\n        if node_index is None:\n            return None, None  # Consider all nodes and edges.\n\n        index, _, _, edge_mask = k_hop_subgraph(\n            node_index,\n            num_hops=ExplainerAlgorithm._num_hops(model),\n            edge_index=edge_index,\n            num_nodes=num_nodes,\n            flow=ExplainerAlgorithm._flow(model),\n        )\n\n        node_mask = edge_index.new_zeros(num_nodes, dtype=torch.bool)\n        node_mask[index] = True\n\n        return node_mask, edge_mask\n\n    @staticmethod\n    def _num_hops(model: torch.nn.Module) -> int:\n        r\"\"\"Returns the number of hops the :obj:`model` is aggregating\n        information from.\n        \"\"\"\n        num_hops = 0\n        for module in model.modules():\n            if isinstance(module, MessagePassing):\n                num_hops += 1\n        return num_hops\n\n    @staticmethod\n    def _flow(model: torch.nn.Module) -> str:\n        r\"\"\"Determines the message passing flow of the :obj:`model`.\"\"\"\n        for module in model.modules():\n            if isinstance(module, MessagePassing):\n                return module.flow\n        return 'source_to_target'\n\n    def _loss_binary_classification(self, y_hat: Tensor, y: Tensor) -> Tensor:\n        if self.model_config.return_type == ModelReturnType.raw:\n            loss_fn = F.binary_cross_entropy_with_logits\n        elif self.model_config.return_type == ModelReturnType.probs:\n            loss_fn = F.binary_cross_entropy\n        else:\n            raise AssertionError()\n\n        return loss_fn(y_hat.view_as(y), y.float())\n\n    def _loss_multiclass_classification(\n        self,\n        y_hat: Tensor,\n        y: Tensor,\n    ) -> Tensor:\n        if self.model_config.return_type == ModelReturnType.raw:\n            loss_fn = F.cross_entropy\n        elif self.model_config.return_type == ModelReturnType.probs:\n            loss_fn = F.nll_loss\n            y_hat = y_hat.log()\n        elif self.model_config.return_type == ModelReturnType.log_probs:\n            loss_fn = F.nll_loss\n        else:\n            raise AssertionError()\n\n        return loss_fn(y_hat, y)\n\n    def _loss_regression(self, y_hat: Tensor, y: Tensor) -> Tensor:\n        assert self.model_config.return_type == ModelReturnType.raw\n        return F.mse_loss(y_hat, y)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/explain/algorithm/captum.py",
    "content": "from enum import Enum\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain.algorithm.utils import (\n    clear_masks,\n    set_hetero_masks,\n    set_masks,\n)\nfrom torch_geometric.explain.config import (\n    ModelConfig,\n    ModelMode,\n    ModelReturnType,\n)\nfrom torch_geometric.typing import EdgeType, Metadata, NodeType\n\n\nclass MaskLevelType(Enum):\n    \"\"\"Enum class for the mask level type.\"\"\"\n    node = 'node'\n    edge = 'edge'\n    node_and_edge = 'node_and_edge'\n\n    @property\n    def with_edge(self) -> bool:\n        return self in [MaskLevelType.edge, MaskLevelType.node_and_edge]\n\n\nclass CaptumModel(torch.nn.Module):\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        mask_type: Union[str, MaskLevelType],\n        output_idx: Optional[Union[int, Tensor]] = None,\n        model_config: Optional[ModelConfig] = None,\n    ):\n        super().__init__()\n\n        self.mask_type = MaskLevelType(mask_type)\n        self.model = model\n        self.output_idx = output_idx\n        self.model_config = model_config\n\n    def forward(self, mask, *args):\n        \"\"\"\"\"\"  # noqa: D419\n        # The mask tensor, which comes from Captum's attribution methods,\n        # contains the number of samples in dimension 0. Since we are\n        # working with only one sample, we squeeze the tensors below.\n        assert mask.shape[0] == 1, \"Dimension 0 of input should be 1\"\n        if self.mask_type == MaskLevelType.edge:\n            assert len(args) >= 2, \"Expects at least x and edge_index as args.\"\n        if self.mask_type == MaskLevelType.node:\n            assert len(args) >= 1, \"Expects at least edge_index as args.\"\n        if self.mask_type == MaskLevelType.node_and_edge:\n            assert args[0].shape[0] == 1, \"Dimension 0 of input should be 1\"\n            assert len(args[1:]) >= 1, \"Expects at least edge_index as args.\"\n\n        # Set edge mask:\n        if self.mask_type == MaskLevelType.edge:\n            set_masks(self.model, mask.squeeze(0), args[1],\n                      apply_sigmoid=False)\n        elif self.mask_type == MaskLevelType.node_and_edge:\n            set_masks(self.model, args[0].squeeze(0), args[1],\n                      apply_sigmoid=False)\n            args = args[1:]\n\n        if self.mask_type == MaskLevelType.edge:\n            x = self.model(*args)\n\n        else:\n            x = self.model(mask.squeeze(0), *args)\n\n        return self.postprocess(x)\n\n    def postprocess(self, x: Tensor) -> Tensor:\n        if self.mask_type.with_edge:\n            clear_masks(self.model)\n\n        if self.output_idx is not None:  # Filter by output index:\n            x = x[self.output_idx]\n            if (isinstance(self.output_idx, int)\n                    or self.output_idx.dim() == 0):\n                x = x.unsqueeze(0)\n\n        # Convert binary classification to multi-class classification:\n        if (self.model_config is not None\n                and self.model_config.mode == ModelMode.binary_classification):\n            assert self.model_config.return_type == ModelReturnType.probs\n            x = x.view(-1, 1)\n            x = torch.cat([1 - x, x], dim=-1)\n\n        return x\n\n\n# TODO(jinu) Is there any point of inheriting from `CaptumModel`\nclass CaptumHeteroModel(CaptumModel):\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        mask_type: Union[str, MaskLevelType],\n        output_idx: Optional[Union[int, Tensor]],\n        metadata: Metadata,\n        model_config: Optional[ModelConfig] = None,\n    ):\n        super().__init__(model, mask_type, output_idx, model_config)\n        self.node_types = metadata[0]\n        self.edge_types = metadata[1]\n        self.num_node_types = len(self.node_types)\n        self.num_edge_types = len(self.edge_types)\n\n    def _captum_data_to_hetero_data(\n        self, *args\n    ) -> Tuple[Dict[NodeType, Tensor], Dict[EdgeType, Tensor], Optional[Dict[\n            EdgeType, Tensor]]]:\n        \"\"\"Converts tuple of tensors to `x_dict`, `edge_index_dict` and\n        `edge_mask_dict`.\n        \"\"\"\n        if self.mask_type == MaskLevelType.node:\n            node_tensors = args[:self.num_node_types]\n            node_tensors = [mask.squeeze(0) for mask in node_tensors]\n            x_dict = dict(zip(self.node_types, node_tensors))\n            edge_index_dict = args[self.num_node_types]\n        elif self.mask_type == MaskLevelType.edge:\n            edge_mask_tensors = args[:self.num_edge_types]\n            x_dict = args[self.num_edge_types]\n            edge_index_dict = args[self.num_edge_types + 1]\n        else:\n            node_tensors = args[:self.num_node_types]\n            node_tensors = [mask.squeeze(0) for mask in node_tensors]\n            x_dict = dict(zip(self.node_types, node_tensors))\n            edge_mask_tensors = args[self.num_node_types:self.num_node_types +\n                                     self.num_edge_types]\n            edge_index_dict = args[self.num_node_types + self.num_edge_types]\n\n        if self.mask_type.with_edge:\n            edge_mask_tensors = [mask.squeeze(0) for mask in edge_mask_tensors]\n            edge_mask_dict = dict(zip(self.edge_types, edge_mask_tensors))\n        else:\n            edge_mask_dict = None\n        return x_dict, edge_index_dict, edge_mask_dict\n\n    def forward(self, *args):\n        # Validate args:\n        if self.mask_type == MaskLevelType.node:\n            assert len(args) >= self.num_node_types + 1\n            len_remaining_args = len(args) - (self.num_node_types + 1)\n        elif self.mask_type == MaskLevelType.edge:\n            assert len(args) >= self.num_edge_types + 2\n            len_remaining_args = len(args) - (self.num_edge_types + 2)\n        else:\n            assert len(args) >= self.num_node_types + self.num_edge_types + 1\n            len_remaining_args = len(args) - (self.num_node_types +\n                                              self.num_edge_types + 1)\n\n        # Get main args:\n        (x_dict, edge_index_dict,\n         edge_mask_dict) = self._captum_data_to_hetero_data(*args)\n\n        if self.mask_type.with_edge:\n            set_hetero_masks(self.model, edge_mask_dict, edge_index_dict)\n\n        if len_remaining_args > 0:\n            # If there are args other than `x_dict` and `edge_index_dict`\n            x = self.model(x_dict, edge_index_dict,\n                           *args[-len_remaining_args:])\n        else:\n            x = self.model(x_dict, edge_index_dict)\n\n        return self.postprocess(x)\n\n\ndef _to_edge_mask(edge_index: Tensor) -> Tensor:\n    num_edges = edge_index.shape[1]\n    return torch.ones(num_edges, requires_grad=True, device=edge_index.device)\n\n\ndef to_captum_input(\n    x: Union[Tensor, Dict[NodeType, Tensor]],\n    edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n    mask_type: Union[str, MaskLevelType],\n    *args,\n) -> Tuple[Tuple[Tensor, ...], Tuple[Tensor, ...]]:\n    r\"\"\"Given :obj:`x`, :obj:`edge_index` and :obj:`mask_type`, converts it\n    to a format to use in `Captum <https://captum.ai/>`_ attribution\n    methods. Returns :obj:`inputs` and :obj:`additional_forward_args`\n    required for :captum:`Captum's` :obj:`attribute` functions.\n    See :meth:`~torch_geometric.nn.models.to_captum_model` for example usage.\n\n    Args:\n        x (torch.Tensor or Dict[NodeType, torch.Tensor]): The node features.\n            For heterogeneous graphs this is a dictionary holding node features\n            for each node type.\n        edge_index(torch.Tensor or Dict[EdgeType, torch.Tensor]): The edge\n            indices. For heterogeneous graphs this is a dictionary holding the\n            :obj:`edge index` for each edge type.\n        mask_type (str): Denotes the type of mask to be created with\n            a Captum explainer. Valid inputs are :obj:`\"edge\"`, :obj:`\"node\"`,\n            and :obj:`\"node_and_edge\"`.\n        *args: Additional forward arguments of the model being explained\n            which will be added to :obj:`additional_forward_args`.\n    \"\"\"\n    mask_type = MaskLevelType(mask_type)\n\n    additional_forward_args = []\n    if isinstance(x, Tensor) and isinstance(edge_index, Tensor):\n        if mask_type == MaskLevelType.node:\n            inputs = [x.unsqueeze(0)]\n        elif mask_type == MaskLevelType.edge:\n            inputs = [_to_edge_mask(edge_index).unsqueeze(0)]\n            additional_forward_args.append(x)\n        else:\n            inputs = [x.unsqueeze(0), _to_edge_mask(edge_index).unsqueeze(0)]\n        additional_forward_args.append(edge_index)\n\n    elif isinstance(x, Dict) and isinstance(edge_index, Dict):\n        node_types = x.keys()\n        edge_types = edge_index.keys()\n        inputs = []\n        if mask_type == MaskLevelType.node:\n            for key in node_types:\n                inputs.append(x[key].unsqueeze(0))\n        elif mask_type == MaskLevelType.edge:\n            for key in edge_types:\n                inputs.append(_to_edge_mask(edge_index[key]).unsqueeze(0))\n            additional_forward_args.append(x)\n        else:\n            for key in node_types:\n                inputs.append(x[key].unsqueeze(0))\n            for key in edge_types:\n                inputs.append(_to_edge_mask(edge_index[key]).unsqueeze(0))\n        additional_forward_args.append(edge_index)\n\n    else:\n        raise ValueError(\n            \"'x' and 'edge_index' need to be either\"\n            f\"'Dict' or 'Tensor' got({type(x)}, {type(edge_index)})\")\n\n    additional_forward_args.extend(args)\n\n    return tuple(inputs), tuple(additional_forward_args)\n\n\ndef captum_output_to_dicts(\n    captum_attrs: Tuple[Tensor, ...],\n    mask_type: Union[str, MaskLevelType],\n    metadata: Metadata,\n) -> Tuple[Optional[Dict[NodeType, Tensor]], Optional[Dict[EdgeType, Tensor]]]:\n    r\"\"\"Convert the output of `Captum <https://captum.ai/>`_ attribution\n    methods which is a tuple of attributions to two dictionaries with node and\n    edge attribution tensors. This function is used while explaining\n    :class:`~torch_geometric.data.HeteroData` objects.\n    See :meth:`~torch_geometric.nn.models.to_captum_model` for example usage.\n\n    Args:\n        captum_attrs (tuple[torch.Tensor]): The output of attribution methods.\n        mask_type (str): Denotes the type of mask to be created with\n            a Captum explainer. Valid inputs are :obj:`\"edge\"`, :obj:`\"node\"`,\n            and :obj:`\"node_and_edge\"`:\n\n            1. :obj:`\"edge\"`: :obj:`captum_attrs` contains only edge\n               attributions. The returned tuple has no node attributions, and\n               an edge attribution dictionary edge types as keys and edge mask\n               tensors of shape :obj:`[num_edges]` as values.\n\n            2. :obj:`\"node\"`: :obj:`captum_attrs` contains only node\n               attributions. The returned tuple has a node attribution\n               dictionary with node types as keys and node mask tensors of\n               shape :obj:`[num_nodes, num_features]` as values, and no edge\n               attributions.\n\n            3. :obj:`\"node_and_edge\"`: :obj:`captum_attrs` contains node and\n                edge attributions.\n\n        metadata (Metadata): The metadata of the heterogeneous graph.\n    \"\"\"\n    mask_type = MaskLevelType(mask_type)\n    node_types = metadata[0]\n    edge_types = metadata[1]\n    x_attr_dict, edge_attr_dict = None, None\n    captum_attrs = [captum_attr.squeeze(0) for captum_attr in captum_attrs]\n    if mask_type == MaskLevelType.node:\n        assert len(node_types) == len(captum_attrs)\n        x_attr_dict = dict(zip(node_types, captum_attrs))\n    elif mask_type == MaskLevelType.edge:\n        assert len(edge_types) == len(captum_attrs)\n        edge_attr_dict = dict(zip(edge_types, captum_attrs))\n    elif mask_type == MaskLevelType.node_and_edge:\n        assert len(edge_types) + len(node_types) == len(captum_attrs)\n        x_attr_dict = dict(zip(node_types, captum_attrs[:len(node_types)]))\n        edge_attr_dict = dict(zip(edge_types, captum_attrs[len(node_types):]))\n    return x_attr_dict, edge_attr_dict\n\n\ndef convert_captum_output(\n    captum_attrs: Tuple[Tensor, ...],\n    mask_type: Union[str, MaskLevelType],\n    metadata: Optional[Metadata] = None,\n):\n    r\"\"\"Convert the output of `Captum.ai <https://captum.ai/>`_ attribution\n    methods which is a tuple of attributions to either\n    :obj:`(node_mask, edge_mask)` or :obj:`(node_mask_dict, edge_mask_dict)`.\n    \"\"\"\n    mask_type = MaskLevelType(mask_type)\n    if metadata is not None:\n        return captum_output_to_dicts(captum_attrs, mask_type, metadata)\n\n    node_mask = edge_mask = None\n    if mask_type == MaskLevelType.edge:\n        edge_mask = captum_attrs[0].squeeze(0)\n    elif mask_type == MaskLevelType.node:\n        node_mask = captum_attrs[0].squeeze(0)\n    else:\n        node_mask = captum_attrs[0].squeeze(0)\n        edge_mask = captum_attrs[1].squeeze(0)\n\n    return node_mask, edge_mask\n"
  },
  {
    "path": "torch_geometric/explain/algorithm/captum_explainer.py",
    "content": "import inspect\nimport logging\nimport warnings\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explanation, HeteroExplanation\nfrom torch_geometric.explain.algorithm import ExplainerAlgorithm\nfrom torch_geometric.explain.algorithm.captum import (\n    CaptumHeteroModel,\n    CaptumModel,\n    MaskLevelType,\n    convert_captum_output,\n    to_captum_input,\n)\nfrom torch_geometric.explain.config import MaskType, ModelMode, ModelReturnType\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\nclass CaptumExplainer(ExplainerAlgorithm):\n    \"\"\"A `Captum <https://captum.ai>`__-based explainer for identifying compact\n    subgraph structures and node features that play a crucial role in the\n    predictions made by a GNN.\n\n    This explainer algorithm uses :captum:`null` `Captum <https://captum.ai/>`_\n    to compute attributions.\n\n    Currently, the following attribution methods are supported:\n\n    * :class:`captum.attr.IntegratedGradients`\n    * :class:`captum.attr.Saliency`\n    * :class:`captum.attr.InputXGradient`\n    * :class:`captum.attr.Deconvolution`\n    * :class:`captum.attr.ShapleyValueSampling`\n    * :class:`captum.attr.GuidedBackprop`\n\n    Args:\n        attribution_method (Attribution or str): The Captum attribution method\n            to use. Can be a string or a :class:`captum.attr` method.\n        **kwargs: Additional arguments for the Captum attribution method.\n    \"\"\"\n    SUPPORTED_METHODS = [  # TODO: Add support for more methods.\n        'IntegratedGradients',\n        'Saliency',\n        'InputXGradient',\n        'Deconvolution',\n        'ShapleyValueSampling',\n        'GuidedBackprop',\n    ]\n\n    def __init__(\n        self,\n        attribution_method: Union[str, Any],\n        **kwargs,\n    ):\n        super().__init__()\n\n        import captum.attr\n\n        if isinstance(attribution_method, str):\n            self.attribution_method_class = getattr(\n                captum.attr,\n                attribution_method,\n            )\n        else:\n            self.attribution_method_class = attribution_method\n\n        if not self._is_supported_attribution_method():\n            raise ValueError(f\"{self.__class__.__name__} does not support \"\n                             f\"attribution method \"\n                             f\"{self.attribution_method_class.__name__}\")\n\n        if kwargs.get('internal_batch_size', 1) != 1:\n            warnings.warn(\"Overriding 'internal_batch_size' to 1\",\n                          stacklevel=2)\n\n        if 'internal_batch_size' in self._get_attribute_parameters():\n            kwargs['internal_batch_size'] = 1\n\n        self.kwargs = kwargs\n\n    def _get_mask_type(self) -> MaskLevelType:\n        r\"\"\"Based on the explainer config, return the mask type.\"\"\"\n        node_mask_type = self.explainer_config.node_mask_type\n        edge_mask_type = self.explainer_config.edge_mask_type\n        if node_mask_type is not None and edge_mask_type is not None:\n            mask_type = MaskLevelType.node_and_edge\n        elif node_mask_type is not None:\n            mask_type = MaskLevelType.node\n        elif edge_mask_type is not None:\n            mask_type = MaskLevelType.edge\n        else:\n            raise ValueError(\"Neither node mask type nor \"\n                             \"edge mask type is specified.\")\n        return mask_type\n\n    def _get_attribute_parameters(self) -> Dict[str, Any]:\n        r\"\"\"Returns the attribute arguments.\"\"\"\n        signature = inspect.signature(self.attribution_method_class.attribute)\n        return signature.parameters\n\n    def _needs_baseline(self) -> bool:\n        r\"\"\"Checks if the method needs a baseline.\"\"\"\n        parameters = self._get_attribute_parameters()\n        if 'baselines' in parameters:\n            param = parameters['baselines']\n            if param.default is inspect.Parameter.empty:\n                return True\n        return False\n\n    def _is_supported_attribution_method(self) -> bool:\n        r\"\"\"Returns :obj:`True` if `self.attribution_method` is supported.\"\"\"\n        # This is redundant for now since all supported methods need a baseline\n        if self._needs_baseline():\n            return False\n        elif self.attribution_method_class.__name__ in self.SUPPORTED_METHODS:\n            return True\n        return False\n\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> Union[Explanation, HeteroExplanation]:\n\n        mask_type = self._get_mask_type()\n\n        inputs, add_forward_args = to_captum_input(\n            x,\n            edge_index,\n            mask_type,\n            *kwargs.values(),\n        )\n\n        if isinstance(x, dict):  # Heterogeneous GNN:\n            metadata = (list(x.keys()), list(edge_index.keys()))\n            captum_model = CaptumHeteroModel(\n                model,\n                mask_type,\n                index,\n                metadata,\n                self.model_config,\n            )\n        else:  # Homogeneous GNN:\n            metadata = None\n            captum_model = CaptumModel(\n                model,\n                mask_type,\n                index,\n                self.model_config,\n            )\n\n        self.attribution_method_instance = self.attribution_method_class(\n            captum_model)\n\n        # In Captum, the target is the class index for which the attribution is\n        # computed. Within CaptumModel, we transform the binary classification\n        # into a multi-class classification task.\n        if self.model_config.mode == ModelMode.regression:\n            target = None\n        elif index is not None:\n            target = target[index]\n\n        attributions = self.attribution_method_instance.attribute(\n            inputs=inputs,\n            target=target,\n            additional_forward_args=add_forward_args,\n            **self.kwargs,\n        )\n\n        node_mask, edge_mask = convert_captum_output(\n            attributions,\n            mask_type,\n            metadata,\n        )\n\n        if not isinstance(x, dict):\n            return Explanation(node_mask=node_mask, edge_mask=edge_mask)\n\n        explanation = HeteroExplanation()\n        explanation.set_value_dict('node_mask', node_mask)\n        explanation.set_value_dict('edge_mask', edge_mask)\n        return explanation\n\n    def supports(self) -> bool:\n        node_mask_type = self.explainer_config.node_mask_type\n        if node_mask_type not in [None, MaskType.attributes]:\n            logging.error(f\"'{self.__class__.__name__}' expects \"\n                          f\"'node_mask_type' to be 'None' or 'attributes' \"\n                          f\"(got '{node_mask_type.value}')\")\n            return False\n\n        return_type = self.model_config.return_type\n        if (self.model_config.mode == ModelMode.binary_classification\n                and return_type != ModelReturnType.probs):\n            logging.error(f\"'{self.__class__.__name__}' expects \"\n                          f\"'return_type' to be 'probs' for binary \"\n                          f\"classification tasks (got '{return_type.value}')\")\n            return False\n\n        # TODO (ramona) Confirm that output type is valid.\n        return True\n"
  },
  {
    "path": "torch_geometric/explain/algorithm/dummy_explainer.py",
    "content": "from collections import defaultdict\nfrom typing import Dict, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explanation, HeteroExplanation\nfrom torch_geometric.explain.algorithm import ExplainerAlgorithm\nfrom torch_geometric.explain.config import MaskType\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\nclass DummyExplainer(ExplainerAlgorithm):\n    r\"\"\"A dummy explainer that returns random explanations (useful for testing\n    purposes).\n    \"\"\"\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        edge_attr: Optional[Union[Tensor, Dict[EdgeType, Tensor]]] = None,\n        **kwargs,\n    ) -> Union[Explanation, HeteroExplanation]:\n        assert isinstance(x, (Tensor, dict))\n\n        node_mask_type = self.explainer_config.node_mask_type\n        edge_mask_type = self.explainer_config.edge_mask_type\n\n        if isinstance(x, Tensor):  # Homogeneous graph.\n            assert isinstance(edge_index, Tensor)\n\n            node_mask = None\n            if node_mask_type == MaskType.object:\n                node_mask = torch.rand(x.size(0), 1, device=x.device)\n            elif node_mask_type == MaskType.common_attributes:\n                node_mask = torch.rand(1, x.size(1), device=x.device)\n            elif node_mask_type == MaskType.attributes:\n                node_mask = torch.rand_like(x)\n\n            edge_mask = None\n            if edge_mask_type == MaskType.object:\n                edge_mask = torch.rand(edge_index.size(1), device=x.device)\n\n            return Explanation(node_mask=node_mask, edge_mask=edge_mask)\n        else:  # isinstance(x, dict):  # Heterogeneous graph.\n            assert isinstance(edge_index, dict)\n\n            node_dict = defaultdict(dict)\n            for k, v in x.items():\n                node_mask = None\n                if node_mask_type == MaskType.object:\n                    node_mask = torch.rand(v.size(0), 1, device=v.device)\n                elif node_mask_type == MaskType.common_attributes:\n                    node_mask = torch.rand(1, v.size(1), device=v.device)\n                elif node_mask_type == MaskType.attributes:\n                    node_mask = torch.rand_like(v)\n                if node_mask is not None:\n                    node_dict[k]['node_mask'] = node_mask\n\n            edge_dict = defaultdict(dict)\n            for k, v in edge_index.items():\n                edge_mask = None\n                if edge_mask_type == MaskType.object:\n                    edge_mask = torch.rand(v.size(1), device=v.device)\n                if edge_mask is not None:\n                    edge_dict[k]['edge_mask'] = edge_mask\n\n            return HeteroExplanation({**node_dict, **edge_dict})\n\n    def supports(self) -> bool:\n        return True\n"
  },
  {
    "path": "torch_geometric/explain/algorithm/gnn_explainer.py",
    "content": "from math import sqrt\nfrom typing import Dict, Optional, Tuple, Union, overload\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn.parameter import Parameter\n\nfrom torch_geometric.explain import (\n    ExplainerConfig,\n    Explanation,\n    HeteroExplanation,\n    ModelConfig,\n)\nfrom torch_geometric.explain.algorithm import ExplainerAlgorithm\nfrom torch_geometric.explain.algorithm.utils import (\n    clear_masks,\n    set_hetero_masks,\n    set_masks,\n)\nfrom torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\nclass GNNExplainer(ExplainerAlgorithm):\n    r\"\"\"The GNN-Explainer model from the `\"GNNExplainer: Generating\n    Explanations for Graph Neural Networks\"\n    <https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph\n    structures and node features that play a crucial role in the predictions\n    made by a GNN.\n\n    .. note::\n\n        For an example of using :class:`GNNExplainer`, see\n        `examples/explain/gnn_explainer.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/explain/gnn_explainer.py>`_,\n        `examples/explain/gnn_explainer_ba_shapes.py <https://github.com/\n        pyg-team/pytorch_geometric/blob/master/examples/\n        explain/gnn_explainer_ba_shapes.py>`_, and `examples/explain/\n        gnn_explainer_link_pred.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/explain/gnn_explainer_link_pred.py>`_.\n\n    .. note::\n\n        The :obj:`edge_size` coefficient is multiplied by the number of nodes\n        in the explanation at every iteration, and the resulting value is added\n        to the loss as a regularization term, with the goal of producing\n        compact explanations.\n        A higher value will push the algorithm towards explanations with less\n        elements.\n        Consider adjusting the :obj:`edge_size` coefficient according to the\n        average node degree in the dataset, especially if this value is bigger\n        than in the datasets used in the original paper.\n\n    Args:\n        epochs (int, optional): The number of epochs to train.\n            (default: :obj:`100`)\n        lr (float, optional): The learning rate to apply.\n            (default: :obj:`0.01`)\n        **kwargs (optional): Additional hyper-parameters to override default\n            settings in\n            :attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.\n    \"\"\"\n\n    default_coeffs = {\n        'edge_size': 0.005,\n        'edge_reduction': 'sum',\n        'node_feat_size': 1.0,\n        'node_feat_reduction': 'mean',\n        'edge_ent': 1.0,\n        'node_feat_ent': 0.1,\n        'EPS': 1e-15,\n    }\n\n    def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs):\n        super().__init__()\n        self.epochs = epochs\n        self.lr = lr\n        self.coeffs = dict(self.default_coeffs)\n        self.coeffs.update(kwargs)\n\n        self.node_mask = self.hard_node_mask = None\n        self.edge_mask = self.hard_edge_mask = None\n        self.is_hetero = False\n\n    @overload\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> Explanation:\n        ...\n\n    @overload\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Dict[NodeType, Tensor],\n        edge_index: Dict[EdgeType, Tensor],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> HeteroExplanation:\n        ...\n\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> Union[Explanation, HeteroExplanation]:\n        self.is_hetero = isinstance(x, dict)\n        self._train(model, x, edge_index, target=target, index=index, **kwargs)\n        explanation = self._create_explanation()\n        self._clean_model(model)\n        return explanation\n\n    def _create_explanation(self) -> Union[Explanation, HeteroExplanation]:\n        \"\"\"Create an explanation object from the current masks.\"\"\"\n        if self.is_hetero:\n            # For heterogeneous graphs, process each type separately\n            node_mask_dict = {}\n            edge_mask_dict = {}\n\n            for node_type, mask in self.node_mask.items():\n                if mask is not None:\n                    node_mask_dict[node_type] = self._post_process_mask(\n                        mask,\n                        self.hard_node_mask[node_type],\n                        apply_sigmoid=True,\n                    )\n\n            for edge_type, mask in self.edge_mask.items():\n                if mask is not None:\n                    edge_mask_dict[edge_type] = self._post_process_mask(\n                        mask,\n                        self.hard_edge_mask[edge_type],\n                        apply_sigmoid=True,\n                    )\n\n            # Create heterogeneous explanation\n            explanation = HeteroExplanation()\n            explanation.set_value_dict('node_mask', node_mask_dict)\n            explanation.set_value_dict('edge_mask', edge_mask_dict)\n\n        else:\n            # For homogeneous graphs, process single masks\n            node_mask = self._post_process_mask(\n                self.node_mask,\n                self.hard_node_mask,\n                apply_sigmoid=True,\n            )\n            edge_mask = self._post_process_mask(\n                self.edge_mask,\n                self.hard_edge_mask,\n                apply_sigmoid=True,\n            )\n\n            # Create homogeneous explanation\n            explanation = Explanation(node_mask=node_mask, edge_mask=edge_mask)\n\n        return explanation\n\n    def supports(self) -> bool:\n        return True\n\n    @overload\n    def _train(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> None:\n        ...\n\n    @overload\n    def _train(\n        self,\n        model: torch.nn.Module,\n        x: Dict[NodeType, Tensor],\n        edge_index: Dict[EdgeType, Tensor],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> None:\n        ...\n\n    def _train(\n        self,\n        model: torch.nn.Module,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> None:\n        # Initialize masks based on input type\n        self._initialize_masks(x, edge_index)\n\n        # Collect parameters for optimization\n        parameters = self._collect_parameters(model, edge_index)\n\n        # Create optimizer\n        optimizer = torch.optim.Adam(parameters, lr=self.lr)\n\n        # Training loop\n        for i in range(self.epochs):\n            optimizer.zero_grad()\n\n            # Forward pass with masked inputs\n            y_hat = self._forward_with_masks(model, x, edge_index, **kwargs)\n            y = target\n\n            # Handle index if provided\n            if index is not None:\n                y_hat, y = y_hat[index], y[index]\n\n            # Calculate loss\n            loss = self._loss(y_hat, y)\n\n            # Backward pass\n            loss.backward()\n            optimizer.step()\n\n            # In the first iteration, collect gradients to identify important\n            # nodes/edges\n            if i == 0:\n                self._collect_gradients()\n\n    def _collect_parameters(self, model, edge_index):\n        \"\"\"Collect parameters for optimization.\"\"\"\n        parameters = []\n\n        if self.is_hetero:\n            # For heterogeneous graphs, collect parameters from all types\n            for mask in self.node_mask.values():\n                if mask is not None:\n                    parameters.append(mask)\n            if any(v is not None for v in self.edge_mask.values()):\n                set_hetero_masks(model, self.edge_mask, edge_index)\n            for mask in self.edge_mask.values():\n                if mask is not None:\n                    parameters.append(mask)\n        else:\n            # For homogeneous graphs, collect single parameters\n            if self.node_mask is not None:\n                parameters.append(self.node_mask)\n            if self.edge_mask is not None:\n                set_masks(model, self.edge_mask, edge_index,\n                          apply_sigmoid=True)\n                parameters.append(self.edge_mask)\n\n        return parameters\n\n    @overload\n    def _forward_with_masks(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        **kwargs,\n    ) -> Tensor:\n        ...\n\n    @overload\n    def _forward_with_masks(\n        self,\n        model: torch.nn.Module,\n        x: Dict[NodeType, Tensor],\n        edge_index: Dict[EdgeType, Tensor],\n        **kwargs,\n    ) -> Tensor:\n        ...\n\n    def _forward_with_masks(\n        self,\n        model: torch.nn.Module,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        **kwargs,\n    ) -> Tensor:\n        \"\"\"Forward pass with masked inputs.\"\"\"\n        if self.is_hetero:\n            # Apply masks to heterogeneous inputs\n            h_dict = {}\n            for node_type, features in x.items():\n                if node_type in self.node_mask and self.node_mask[\n                        node_type] is not None:\n                    h_dict[node_type] = features * self.node_mask[\n                        node_type].sigmoid()\n                else:\n                    h_dict[node_type] = features\n\n            # Forward pass with masked features\n            return model(h_dict, edge_index, **kwargs)\n        else:\n            # Apply mask to homogeneous input\n            h = x if self.node_mask is None else x * self.node_mask.sigmoid()\n\n            # Forward pass with masked features\n            return model(h, edge_index, **kwargs)\n\n    def _initialize_masks(\n        self,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n    ) -> None:\n        node_mask_type = self.explainer_config.node_mask_type\n        edge_mask_type = self.explainer_config.edge_mask_type\n\n        if self.is_hetero:\n            # Initialize dictionaries for heterogeneous masks\n            self.node_mask = {}\n            self.hard_node_mask = {}\n            self.edge_mask = {}\n            self.hard_edge_mask = {}\n\n            # Initialize node masks for each node type\n            for node_type, features in x.items():\n                device = features.device\n                N, F = features.size()\n                self._initialize_node_mask(node_mask_type, node_type, N, F,\n                                           device)\n\n            # Initialize edge masks for each edge type\n            for edge_type, indices in edge_index.items():\n                device = indices.device\n                E = indices.size(1)\n                N = max(indices.max().item() + 1,\n                        max(feat.size(0) for feat in x.values()))\n                self._initialize_edge_mask(edge_mask_type, edge_type, E, N,\n                                           device)\n        else:\n            # Initialize masks for homogeneous graph\n            device = x.device\n            (N, F), E = x.size(), edge_index.size(1)\n\n            # Initialize homogeneous node and edge masks\n            self._initialize_homogeneous_masks(node_mask_type, edge_mask_type,\n                                               N, F, E, device)\n\n    def _initialize_node_mask(\n        self,\n        node_mask_type,\n        node_type,\n        N,\n        F,\n        device,\n    ) -> None:\n        \"\"\"Initialize node mask for a specific node type.\"\"\"\n        std = 0.1\n        if node_mask_type is None:\n            self.node_mask[node_type] = None\n            self.hard_node_mask[node_type] = None\n        elif node_mask_type == MaskType.object:\n            self.node_mask[node_type] = Parameter(\n                torch.randn(N, 1, device=device) * std)\n            self.hard_node_mask[node_type] = None\n        elif node_mask_type == MaskType.attributes:\n            self.node_mask[node_type] = Parameter(\n                torch.randn(N, F, device=device) * std)\n            self.hard_node_mask[node_type] = None\n        elif node_mask_type == MaskType.common_attributes:\n            self.node_mask[node_type] = Parameter(\n                torch.randn(1, F, device=device) * std)\n            self.hard_node_mask[node_type] = None\n        else:\n            raise ValueError(f\"Invalid node mask type: {node_mask_type}\")\n\n    def _initialize_edge_mask(self, edge_mask_type, edge_type, E, N, device):\n        \"\"\"Initialize edge mask for a specific edge type.\"\"\"\n        if edge_mask_type is None:\n            self.edge_mask[edge_type] = None\n            self.hard_edge_mask[edge_type] = None\n        elif edge_mask_type == MaskType.object:\n            std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))\n            self.edge_mask[edge_type] = Parameter(\n                torch.randn(E, device=device) * std)\n            self.hard_edge_mask[edge_type] = None\n        else:\n            raise ValueError(f\"Invalid edge mask type: {edge_mask_type}\")\n\n    def _initialize_homogeneous_masks(self, node_mask_type, edge_mask_type, N,\n                                      F, E, device):\n        \"\"\"Initialize masks for homogeneous graph.\"\"\"\n        # Initialize node mask\n        std = 0.1\n        if node_mask_type is None:\n            self.node_mask = None\n        elif node_mask_type == MaskType.object:\n            self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)\n        elif node_mask_type == MaskType.attributes:\n            self.node_mask = Parameter(torch.randn(N, F, device=device) * std)\n        elif node_mask_type == MaskType.common_attributes:\n            self.node_mask = Parameter(torch.randn(1, F, device=device) * std)\n        else:\n            raise ValueError(f\"Invalid node mask type: {node_mask_type}\")\n\n        # Initialize edge mask\n        if edge_mask_type is None:\n            self.edge_mask = None\n        elif edge_mask_type == MaskType.object:\n            std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))\n            self.edge_mask = Parameter(torch.randn(E, device=device) * std)\n        else:\n            raise ValueError(f\"Invalid edge mask type: {edge_mask_type}\")\n\n    def _collect_gradients(self) -> None:\n        if self.is_hetero:\n            self._collect_hetero_gradients()\n        else:\n            self._collect_homo_gradients()\n\n    def _collect_hetero_gradients(self):\n        \"\"\"Collect gradients for heterogeneous graph.\"\"\"\n        for node_type, mask in self.node_mask.items():\n            if mask is not None:\n                if mask.grad is None:\n                    raise ValueError(\n                        f\"Could not compute gradients for node masks of type \"\n                        f\"'{node_type}'. Please make sure that node masks are \"\n                        f\"used inside the model or disable it via \"\n                        f\"`node_mask_type=None`.\")\n\n                self.hard_node_mask[node_type] = mask.grad != 0.0\n\n        for edge_type, mask in self.edge_mask.items():\n            if mask is not None:\n                if mask.grad is None:\n                    raise ValueError(\n                        f\"Could not compute gradients for edge masks of type \"\n                        f\"'{edge_type}'. Please make sure that edge masks are \"\n                        f\"used inside the model or disable it via \"\n                        f\"`edge_mask_type=None`.\")\n                self.hard_edge_mask[edge_type] = mask.grad != 0.0\n\n    def _collect_homo_gradients(self):\n        \"\"\"Collect gradients for homogeneous graph.\"\"\"\n        if self.node_mask is not None:\n            if self.node_mask.grad is None:\n                raise ValueError(\"Could not compute gradients for node \"\n                                 \"features. Please make sure that node \"\n                                 \"features are used inside the model or \"\n                                 \"disable it via `node_mask_type=None`.\")\n            self.hard_node_mask = self.node_mask.grad != 0.0\n\n        if self.edge_mask is not None:\n            if self.edge_mask.grad is None:\n                raise ValueError(\"Could not compute gradients for edges. \"\n                                 \"Please make sure that edges are used \"\n                                 \"via message passing inside the model or \"\n                                 \"disable it via `edge_mask_type=None`.\")\n            self.hard_edge_mask = self.edge_mask.grad != 0.0\n\n    def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:\n        # Calculate base loss based on model configuration\n        loss = self._calculate_base_loss(y_hat, y)\n\n        # Apply regularization based on graph type\n        if self.is_hetero:\n            # Apply regularization for heterogeneous graph\n            loss = self._apply_hetero_regularization(loss)\n        else:\n            # Apply regularization for homogeneous graph\n            loss = self._apply_homo_regularization(loss)\n\n        return loss\n\n    def _calculate_base_loss(self, y_hat, y):\n        \"\"\"Calculate base loss based on model configuration.\"\"\"\n        if self.model_config.mode == ModelMode.binary_classification:\n            return self._loss_binary_classification(y_hat, y)\n        elif self.model_config.mode == ModelMode.multiclass_classification:\n            return self._loss_multiclass_classification(y_hat, y)\n        elif self.model_config.mode == ModelMode.regression:\n            return self._loss_regression(y_hat, y)\n        else:\n            raise ValueError(f\"Invalid model mode: {self.model_config.mode}\")\n\n    def _apply_hetero_regularization(self, loss):\n        \"\"\"Apply regularization for heterogeneous graph.\"\"\"\n        # Apply regularization for each edge type\n        for edge_type, mask in self.edge_mask.items():\n            if (mask is not None\n                    and self.hard_edge_mask[edge_type] is not None):\n                loss = self._add_mask_regularization(\n                    loss, mask, self.hard_edge_mask[edge_type],\n                    self.coeffs['edge_size'], self.coeffs['edge_reduction'],\n                    self.coeffs['edge_ent'])\n\n        # Apply regularization for each node type\n        for node_type, mask in self.node_mask.items():\n            if (mask is not None\n                    and self.hard_node_mask[node_type] is not None):\n                loss = self._add_mask_regularization(\n                    loss, mask, self.hard_node_mask[node_type],\n                    self.coeffs['node_feat_size'],\n                    self.coeffs['node_feat_reduction'],\n                    self.coeffs['node_feat_ent'])\n\n        return loss\n\n    def _apply_homo_regularization(self, loss):\n        \"\"\"Apply regularization for homogeneous graph.\"\"\"\n        # Apply regularization for edge mask\n        if self.hard_edge_mask is not None:\n            assert self.edge_mask is not None\n            loss = self._add_mask_regularization(loss, self.edge_mask,\n                                                 self.hard_edge_mask,\n                                                 self.coeffs['edge_size'],\n                                                 self.coeffs['edge_reduction'],\n                                                 self.coeffs['edge_ent'])\n\n        # Apply regularization for node mask\n        if self.hard_node_mask is not None:\n            assert self.node_mask is not None\n            loss = self._add_mask_regularization(\n                loss, self.node_mask, self.hard_node_mask,\n                self.coeffs['node_feat_size'],\n                self.coeffs['node_feat_reduction'],\n                self.coeffs['node_feat_ent'])\n\n        return loss\n\n    def _add_mask_regularization(self, loss, mask, hard_mask, size_coeff,\n                                 reduction_name, ent_coeff):\n        \"\"\"Add size and entropy regularization for a mask.\"\"\"\n        m = mask[hard_mask].sigmoid()\n        reduce_fn = getattr(torch, reduction_name)\n        # Add size regularization\n        loss = loss + size_coeff * reduce_fn(m)\n        # Add entropy regularization\n        ent = -m * torch.log(m + self.coeffs['EPS']) - (\n            1 - m) * torch.log(1 - m + self.coeffs['EPS'])\n        loss = loss + ent_coeff * ent.mean()\n\n        return loss\n\n    def _clean_model(self, model):\n        clear_masks(model)\n        self.node_mask = self.hard_node_mask = None\n        self.edge_mask = self.hard_edge_mask = None\n\n\nclass GNNExplainer_:\n    r\"\"\"Deprecated version for :class:`GNNExplainer`.\"\"\"\n\n    coeffs = GNNExplainer.default_coeffs\n\n    conversion_node_mask_type = {\n        'feature': 'common_attributes',\n        'individual_feature': 'attributes',\n        'scalar': 'object',\n    }\n\n    conversion_return_type = {\n        'log_prob': 'log_probs',\n        'prob': 'probs',\n        'raw': 'raw',\n        'regression': 'raw',\n    }\n\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        epochs: int = 100,\n        lr: float = 0.01,\n        return_type: str = 'log_prob',\n        feat_mask_type: str = 'feature',\n        allow_edge_mask: bool = True,\n        **kwargs,\n    ):\n        assert feat_mask_type in ['feature', 'individual_feature', 'scalar']\n\n        explainer_config = ExplainerConfig(\n            explanation_type='model',\n            node_mask_type=self.conversion_node_mask_type[feat_mask_type],\n            edge_mask_type=MaskType.object if allow_edge_mask else None,\n        )\n        model_config = ModelConfig(\n            mode='regression'\n            if return_type == 'regression' else 'multiclass_classification',\n            task_level=ModelTaskLevel.node,\n            return_type=self.conversion_return_type[return_type],\n        )\n\n        self.model = model\n        self._explainer = GNNExplainer(epochs=epochs, lr=lr, **kwargs)\n        self._explainer.connect(explainer_config, model_config)\n\n    @torch.no_grad()\n    def get_initial_prediction(self, *args, **kwargs) -> Tensor:\n\n        training = self.model.training\n        self.model.eval()\n\n        out = self.model(*args, **kwargs)\n        if (self._explainer.model_config.mode ==\n                ModelMode.multiclass_classification):\n            out = out.argmax(dim=-1)\n\n        self.model.train(training)\n\n        return out\n\n    def explain_graph(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        **kwargs,\n    ) -> Tuple[Tensor, Tensor]:\n        self._explainer.model_config.task_level = ModelTaskLevel.graph\n\n        explanation = self._explainer(\n            self.model,\n            x,\n            edge_index,\n            target=self.get_initial_prediction(x, edge_index, **kwargs),\n            **kwargs,\n        )\n        return self._convert_output(explanation, edge_index)\n\n    def explain_node(\n        self,\n        node_idx: int,\n        x: Tensor,\n        edge_index: Tensor,\n        **kwargs,\n    ) -> Tuple[Tensor, Tensor]:\n        self._explainer.model_config.task_level = ModelTaskLevel.node\n        explanation = self._explainer(\n            self.model,\n            x,\n            edge_index,\n            target=self.get_initial_prediction(x, edge_index, **kwargs),\n            index=node_idx,\n            **kwargs,\n        )\n        return self._convert_output(explanation, edge_index, index=node_idx,\n                                    x=x)\n\n    def _convert_output(self, explanation, edge_index, index=None, x=None):\n        node_mask = explanation.get('node_mask')\n        edge_mask = explanation.get('edge_mask')\n\n        if node_mask is not None:\n            node_mask_type = self._explainer.explainer_config.node_mask_type\n            if node_mask_type in {MaskType.object, MaskType.common_attributes}:\n                node_mask = node_mask.view(-1)\n\n        if edge_mask is None:\n            if index is not None:\n                _, edge_mask = self._explainer._get_hard_masks(\n                    self.model, index, edge_index, num_nodes=x.size(0))\n                edge_mask = edge_mask.to(x.dtype)\n            else:\n                edge_mask = torch.ones(edge_index.size(1),\n                                       device=edge_index.device)\n\n        return node_mask, edge_mask\n"
  },
  {
    "path": "torch_geometric/explain/algorithm/graphmask_explainer.py",
    "content": "import math\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import LayerNorm, Linear, Parameter, ReLU\nfrom tqdm import tqdm\n\nfrom torch_geometric.explain import Explanation\nfrom torch_geometric.explain.algorithm import ExplainerAlgorithm\nfrom torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel\nfrom torch_geometric.nn import MessagePassing\n\n\ndef explain_message(self, out: Tensor, x_i: Tensor, x_j: Tensor) -> Tensor:\n    basis_messages = F.layer_norm(out, (out.size(-1), )).relu()\n\n    if getattr(self, 'message_scale', None) is not None:\n        basis_messages = basis_messages * self.message_scale.unsqueeze(-1)\n\n        if self.message_replacement is not None:\n            if basis_messages.shape == self.message_replacement.shape:\n                basis_messages = (basis_messages +\n                                  (1 - self.message_scale).unsqueeze(-1) *\n                                  self.message_replacement)\n            else:\n                basis_messages = (basis_messages +\n                                  ((1 - self.message_scale).unsqueeze(-1) *\n                                   self.message_replacement.unsqueeze(0)))\n\n    self.latest_messages = basis_messages\n    self.latest_source_embeddings = x_j\n    self.latest_target_embeddings = x_i\n\n    return basis_messages\n\n\nclass GraphMaskExplainer(ExplainerAlgorithm):\n    r\"\"\"The GraphMask-Explainer model from the `\"Interpreting Graph Neural\n    Networks for NLP With Differentiable Edge Masking\"\n    <https://arxiv.org/abs/2010.00577>`_ paper for identifying layer-wise\n    compact subgraph structures and node features that play a crucial role in\n    the predictions made by a GNN.\n\n    .. note::\n        For an example of using :class:`GraphMaskExplainer`,\n        see `examples/explain/graphmask_explainer.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        /explain/graphmask_explainer.py>`_.\n\n        A working real-time example of :class:`GraphMaskExplainer` in the form\n        of a deployed app can be accessed `here\n        <https://graph-explainability.streamlit.app/>`_.\n\n    Args:\n        num_layers (int): The number of layers to use.\n        epochs (int, optional): The number of epochs to train.\n            (default: :obj:`100`)\n        lr (float, optional): The learning rate to apply.\n            (default: :obj:`0.01`)\n        penalty_scaling (int, optional): Scaling value of penalty term. Value\n            must lie between 0 and 10. (default: :obj:`5`)\n        lambda_optimizer_lr (float, optional): The learning rate to optimize\n            the Lagrange multiplier. (default: :obj:`1e-2`)\n        init_lambda (float, optional): The Lagrange multiplier. Value must lie\n            between :obj:`0` and `1`. (default: :obj:`0.55`)\n        allowance (float, optional): A float value between :obj:`0` and\n            :obj:`1` denotes tolerance level. (default: :obj:`0.03`)\n        log (bool, optional): If set to :obj:`False`, will not log any\n            learning progress. (default: :obj:`True`)\n        **kwargs (optional): Additional hyper-parameters to override default\n            settings in\n            :attr:`~torch_geometric.nn.models.GraphMaskExplainer.coeffs`.\n    \"\"\"\n    coeffs = {\n        'node_feat_size': 1.0,\n        'node_feat_reduction': 'mean',\n        'node_feat_ent': 0.1,\n        'EPS': 1e-15,\n    }\n\n    def __init__(\n        self,\n        num_layers: int,\n        epochs: int = 100,\n        lr: float = 0.01,\n        penalty_scaling: int = 5,\n        lambda_optimizer_lr: int = 1e-2,\n        init_lambda: int = 0.55,\n        allowance: int = 0.03,\n        allow_multiple_explanations: bool = False,\n        log: bool = True,\n        **kwargs,\n    ):\n        super().__init__()\n        assert 0 <= penalty_scaling <= 10\n        assert 0 <= init_lambda <= 1\n        assert 0 <= allowance <= 1\n\n        self.num_layers = num_layers\n        self.init_lambda = init_lambda\n        self.lambda_optimizer_lr = lambda_optimizer_lr\n        self.penalty_scaling = penalty_scaling\n        self.allowance = allowance\n        self.allow_multiple_explanations = allow_multiple_explanations\n        self.epochs = epochs\n        self.lr = lr\n        self.log = log\n        self.coeffs.update(kwargs)\n\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> Explanation:\n\n        hard_node_mask = None\n\n        if self.model_config.task_level == ModelTaskLevel.node:\n            hard_node_mask, hard_edge_mask = self._get_hard_masks(\n                model, index, edge_index, num_nodes=x.size(0))\n        self._train_explainer(model, x, edge_index, target=target, index=index,\n                              **kwargs)\n        node_mask = self._post_process_mask(self.node_feat_mask,\n                                            hard_node_mask, apply_sigmoid=True)\n        edge_mask = self._explain(model, index=index)\n        edge_mask = edge_mask[:edge_index.size(1)]\n\n        return Explanation(node_mask=node_mask, edge_mask=edge_mask)\n\n    def supports(self) -> bool:\n        return True\n\n    def _hard_concrete(\n        self,\n        input_element: Tensor,\n        summarize_penalty: bool = True,\n        beta: float = 1 / 3,\n        gamma: float = -0.2,\n        zeta: float = 1.2,\n        loc_bias: int = 2,\n        min_val: int = 0,\n        max_val: int = 1,\n        training: bool = True,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Helps to set the edge mask while sampling its values from the\n        hard-concrete distribution.\n        \"\"\"\n        input_element = input_element + loc_bias\n\n        if training:\n            u = torch.empty_like(input_element).uniform_(1e-6, 1.0 - 1e-6)\n\n            s = torch.sigmoid(\n                (torch.log(u) - torch.log(1 - u) + input_element) / beta)\n\n            penalty = torch.sigmoid(input_element -\n                                    beta * math.log(-gamma / zeta))\n        else:\n            s = torch.sigmoid(input_element)\n            penalty = torch.zeros_like(input_element)\n\n        if summarize_penalty:\n            penalty = penalty.mean()\n\n        s = s * (zeta - gamma) + gamma\n\n        clipped_s = s.clamp(min_val, max_val)\n\n        clip_value = (torch.min(clipped_s) + torch.max(clipped_s)) / 2\n        hard_concrete = (clipped_s > clip_value).float()\n        clipped_s = clipped_s + (hard_concrete - clipped_s).detach()\n\n        return clipped_s, penalty\n\n    def _set_masks(\n        self,\n        i_dim: List[int],\n        j_dim: List[int],\n        h_dim: List[int],\n        x: Tensor,\n    ):\n        r\"\"\"Sets the node masks and edge masks.\"\"\"\n        (num_nodes, num_feat), std, device = x.size(), 0.1, x.device\n        self.feat_mask_type = self.explainer_config.node_mask_type\n\n        if self.feat_mask_type == MaskType.attributes:\n            self.node_feat_mask = torch.nn.Parameter(\n                torch.randn(num_nodes, num_feat, device=device) * std)\n        elif self.feat_mask_type == MaskType.object:\n            self.node_feat_mask = torch.nn.Parameter(\n                torch.randn(num_nodes, 1, device=device) * std)\n        else:\n            self.node_feat_mask = torch.nn.Parameter(\n                torch.randn(1, num_feat, device=device) * std)\n\n        baselines, self.gates, full_biases = [], torch.nn.ModuleList(), []\n\n        for v_dim, m_dim, o_dim in zip(i_dim, j_dim, h_dim):\n            self.transform, self.layer_norm = [], []\n            input_dims = [v_dim, m_dim, v_dim]\n            for _, input_dim in enumerate(input_dims):\n                self.transform.append(\n                    Linear(input_dim, o_dim, bias=False).to(device))\n                self.layer_norm.append(LayerNorm(o_dim).to(device))\n\n            self.transforms = torch.nn.ModuleList(self.transform)\n            self.layer_norms = torch.nn.ModuleList(self.layer_norm)\n\n            self.full_bias = Parameter(\n                torch.tensor(o_dim, dtype=torch.float, device=device))\n            full_biases.append(self.full_bias)\n\n            self.reset_parameters(input_dims, o_dim)\n\n            self.non_linear = ReLU()\n            self.output_layer = Linear(o_dim, 1).to(device)\n\n            gate = [\n                self.transforms, self.layer_norms, self.non_linear,\n                self.output_layer\n            ]\n            self.gates.extend(gate)\n\n            baseline = torch.tensor(m_dim, dtype=torch.float, device=device)\n            stdv = 1. / math.sqrt(m_dim)\n            baseline.uniform_(-stdv, stdv)\n            baseline = torch.nn.Parameter(baseline)\n            baselines.append(baseline)\n\n        full_biases = torch.nn.ParameterList(full_biases)\n        self.full_biases = full_biases\n\n        baselines = torch.nn.ParameterList(baselines)\n        self.baselines = baselines\n\n        for parameter in self.parameters():\n            parameter.requires_grad = False\n\n    def _enable_layer(self, layer: int):\n        r\"\"\"Enables the input layer's edge mask.\"\"\"\n        for d in range(layer * 4, (layer * 4) + 4):\n            for parameter in self.gates[d].parameters():\n                parameter.requires_grad = True\n        self.full_biases[layer].requires_grad = True\n        self.baselines[layer].requires_grad = True\n\n    def reset_parameters(self, input_dims: List[int], h_dim: List[int]):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        fan_in = sum(input_dims)\n\n        std = math.sqrt(2.0 / float(fan_in + h_dim))\n        a = math.sqrt(3.0) * std\n\n        for transform in self.transforms:\n            torch.nn.init._no_grad_uniform_(transform.weight, -a, a)\n\n        torch.nn.init.zeros_(self.full_bias)\n\n        for layer_norm in self.layer_norms:\n            layer_norm.reset_parameters()\n\n    def _loss(self, y_hat: Tensor, y: Tensor, penalty: float) -> Tensor:\n        if self.model_config.mode == ModelMode.binary_classification:\n            loss = self._loss_binary_classification(y_hat, y)\n        elif self.model_config.mode == ModelMode.multiclass_classification:\n            loss = self._loss_multiclass_classification(y_hat, y)\n        elif self.model_config.mode == ModelMode.regression:\n            loss = self._loss_regression(y_hat, y)\n        else:\n            raise AssertionError()\n\n        g = torch.relu(loss - self.allowance).mean()\n        f = penalty * self.penalty_scaling\n\n        loss = f + F.softplus(self.lambda_op) * g\n\n        m = self.node_feat_mask.sigmoid()\n        node_feat_reduce = getattr(torch, self.coeffs['node_feat_reduction'])\n        loss = loss + self.coeffs['node_feat_size'] * node_feat_reduce(m)\n        ent = -m * torch.log(m + self.coeffs['EPS']) - (\n            1 - m) * torch.log(1 - m + self.coeffs['EPS'])\n        loss = loss + self.coeffs['node_feat_ent'] * ent.mean()\n\n        return loss\n\n    def _freeze_model(self, module: torch.nn.Module):\n        r\"\"\"Freezes the parameters of the original GNN model by disabling\n        their gradients.\n        \"\"\"\n        for param in module.parameters():\n            param.requires_grad = False\n\n    def _set_flags(self, model: torch.nn.Module):\n        r\"\"\"Initializes the underlying explainer model's parameters for each\n        layer of the original GNN model.\n        \"\"\"\n        for module in model.modules():\n            if isinstance(module, MessagePassing):\n                module.explain_message = explain_message.__get__(\n                    module, MessagePassing)\n                module.explain = True\n\n    def _inject_messages(\n        self,\n        model: torch.nn.Module,\n        message_scale: List[Tensor],\n        message_replacement: torch.nn.ParameterList,\n        set: bool = False,\n    ):\n        r\"\"\"Injects the computed messages into each layer of the original GNN\n        model.\n        \"\"\"\n        i = 0\n        for module in model.modules():\n            if isinstance(module, MessagePassing):\n                if not set:\n                    module.message_scale = message_scale[i]\n                    module.message_replacement = message_replacement[i]\n                    i = i + 1\n                else:\n                    module.message_scale = None\n                    module.message_replacement = None\n\n    def _train_explainer(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ):\n        r\"\"\"Trains the underlying explainer model.\n\n        Args:\n            model (torch.nn.Module): The model to explain.\n            x (torch.Tensor): The input node features.\n            edge_index (torch.Tensor): The input edge indices.\n            target (torch.Tensor): The target of the model.\n            index (int or torch.Tensor, optional): The index of the model\n                output to explain. Needs to be a single index.\n                (default: :obj:`None`)\n            **kwargs (optional): Additional keyword arguments passed to\n                :obj:`model`.\n        \"\"\"\n        if (not isinstance(index, Tensor) and not isinstance(index, int)\n                and index is not None):\n            raise ValueError(\"'index' parameter can only be a 'Tensor', \"\n                             \"'integer' or set to 'None' instead.\")\n\n        self._freeze_model(model)\n        self._set_flags(model)\n\n        input_dims, output_dims = [], []\n        for module in model.modules():\n            if isinstance(module, MessagePassing):\n                input_dims.append(module.in_channels)\n                output_dims.append(module.out_channels)\n\n        self._set_masks(input_dims, output_dims, output_dims, x)\n\n        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)\n\n        for layer in reversed(list(range(self.num_layers))):\n            if self.log:\n                pbar = tqdm(total=self.epochs)\n                if self.model_config.task_level == ModelTaskLevel.node:\n                    pbar.set_description(\n                        f'Train explainer for node(s) {index} with layer '\n                        f'{layer}')\n                elif self.model_config.task_level == ModelTaskLevel.edge:\n                    pbar.set_description(\n                        f\"Train explainer for edge-level task with layer \"\n                        f\"{layer}\")\n                else:\n                    pbar.set_description(\n                        f'Train explainer for graph {index} with layer '\n                        f'{layer}')\n            self._enable_layer(layer)\n            for _ in range(self.epochs):\n                with torch.no_grad():\n                    model(x, edge_index, **kwargs)\n                gates, total_penalty = [], 0\n                latest_source_embeddings, latest_messages = [], []\n                latest_target_embeddings = []\n                for module in model.modules():\n                    if isinstance(module, MessagePassing):\n                        latest_source_embeddings.append(\n                            module.latest_source_embeddings)\n                        latest_messages.append(module.latest_messages)\n                        latest_target_embeddings.append(\n                            module.latest_target_embeddings)\n                gate_input = [\n                    latest_source_embeddings, latest_messages,\n                    latest_target_embeddings\n                ]\n                for i in range(self.num_layers):\n                    output = self.full_biases[i]\n                    for j in range(len(gate_input)):\n                        try:\n                            partial = self.gates[i * 4][j](gate_input[j][i])\n                        except Exception:\n                            try:\n                                self._set_masks(output_dims, output_dims,\n                                                output_dims, x)\n                                partial = self.gates[i * 4][j](\n                                    gate_input[j][i])\n                            except Exception:\n                                self._set_masks(input_dims, input_dims,\n                                                output_dims, x)\n                                partial = self.gates[i * 4][j](\n                                    gate_input[j][i])\n                        result = self.gates[(i * 4) + 1][j](partial)\n                        output = output + result\n                    relu_output = self.gates[(i * 4) + 2](output /\n                                                          len(gate_input))\n                    sampling_weights = self.gates[(i * 4) +\n                                                  3](relu_output).squeeze(\n                                                      dim=-1)\n                    sampling_weights, penalty = self._hard_concrete(\n                        sampling_weights)\n                    gates.append(sampling_weights)\n                    total_penalty += penalty\n\n                self._inject_messages(model, gates, self.baselines)\n\n                self.lambda_op = torch.tensor(self.init_lambda,\n                                              requires_grad=True)\n                optimizer_lambda = torch.optim.RMSprop(\n                    [self.lambda_op], lr=self.lambda_optimizer_lr,\n                    centered=True)\n\n                optimizer.zero_grad()\n                optimizer_lambda.zero_grad()\n\n                h = x * self.node_feat_mask.sigmoid()\n                y_hat, y = model(x=h, edge_index=edge_index, **kwargs), target\n\n                if (self.model_config.task_level == ModelTaskLevel.node or\n                        self.model_config.task_level == ModelTaskLevel.edge):\n                    if index is not None:\n                        y_hat, y = y_hat[index], y[index]\n\n                self._inject_messages(model, gates, self.baselines, True)\n\n                loss = self._loss(y_hat, y, total_penalty)\n\n                loss.backward()\n                optimizer.step()\n                self.lambda_op.grad *= -1\n                optimizer_lambda.step()\n\n                if self.lambda_op.item() < -2:\n                    self.lambda_op.data = torch.full_like(\n                        self.lambda_op.data, -2)\n                elif self.lambda_op.item() > 30:\n                    self.lambda_op.data = torch.full_like(\n                        self.lambda_op.data, 30)\n\n                if self.log:\n                    pbar.update(1)\n\n            if self.log:\n                pbar.close()\n\n    def _explain(\n        self,\n        model: torch.nn.Module,\n        *,\n        index: Optional[Union[int, Tensor]] = None,\n    ) -> Tensor:\n        r\"\"\"Generates explanations for the original GNN model.\n\n        Args:\n            model (torch.nn.Module): The model to explain.\n            index (int or torch.Tensor, optional): The index of the model\n                output to explain. Needs to be a single index.\n                (default: :obj:`None`).\n        \"\"\"\n        if (not isinstance(index, Tensor) and not isinstance(index, int)\n                and index is not None):\n            raise ValueError(\"'index' parameter can only be a 'Tensor', \"\n                             \"'integer' or set to 'None' instead.\")\n\n        self._freeze_model(model)\n        self._set_flags(model)\n\n        with torch.no_grad():\n            latest_source_embeddings, latest_messages = [], []\n            latest_target_embeddings = []\n            for module in model.modules():\n                if isinstance(module, MessagePassing):\n                    latest_source_embeddings.append(\n                        module.latest_source_embeddings)\n                    latest_messages.append(module.latest_messages)\n                    latest_target_embeddings.append(\n                        module.latest_target_embeddings)\n            gate_input = [\n                latest_source_embeddings, latest_messages,\n                latest_target_embeddings\n            ]\n            if self.log:\n                pbar = tqdm(total=self.num_layers)\n            for i in range(self.num_layers):\n                if self.log:\n                    pbar.set_description(\"Explain\")\n                output = self.full_biases[i]\n                for j in range(len(gate_input)):\n                    partial = self.gates[i * 4][j](gate_input[j][i])\n                    result = self.gates[(i * 4) + 1][j](partial)\n                    output = output + result\n                relu_output = self.gates[(i * 4) + 2](output / len(gate_input))\n                sampling_weights = self.gates[(i * 4) +\n                                              3](relu_output).squeeze(dim=-1)\n                sampling_weights, _ = self._hard_concrete(\n                    sampling_weights, training=False)\n                if i == 0:\n                    edge_weight = sampling_weights\n                else:\n                    edge_weight = torch.cat((edge_weight, sampling_weights), 0)\n                if self.log:\n                    pbar.update(1)\n        if self.log:\n            pbar.close()\n\n        edge_mask = edge_weight.view(-1,\n                                     edge_weight.size(0) // self.num_layers)\n        edge_mask = torch.mean(edge_mask, 0)\n\n        return edge_mask\n"
  },
  {
    "path": "torch_geometric/explain/algorithm/pg_explainer.py",
    "content": "import logging\nfrom typing import Dict, Optional, Tuple, Union, overload\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import ReLU, Sequential\n\nfrom torch_geometric.explain import Explanation, HeteroExplanation\nfrom torch_geometric.explain.algorithm import ExplainerAlgorithm\nfrom torch_geometric.explain.algorithm.utils import (\n    clear_masks,\n    set_hetero_masks,\n    set_masks,\n)\nfrom torch_geometric.explain.config import (\n    ExplanationType,\n    ModelMode,\n    ModelTaskLevel,\n)\nfrom torch_geometric.nn import HANConv, HeteroConv, HGTConv, Linear\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.typing import EdgeType, NodeType\nfrom torch_geometric.utils import get_embeddings, get_embeddings_hetero\n\n\nclass PGExplainer(ExplainerAlgorithm):\n    r\"\"\"The PGExplainer model from the `\"Parameterized Explainer for Graph\n    Neural Network\" <https://arxiv.org/abs/2011.04573>`_ paper.\n\n    Internally, it utilizes a neural network to identify subgraph structures\n    that play a crucial role in the predictions made by a GNN.\n    Importantly, the :class:`PGExplainer` needs to be trained via\n    :meth:`~PGExplainer.train` before being able to generate explanations:\n\n    .. code-block:: python\n\n        explainer = Explainer(\n            model=model,\n            algorithm=PGExplainer(epochs=30, lr=0.003),\n            explanation_type='phenomenon',\n            edge_mask_type='object',\n            model_config=ModelConfig(...),\n        )\n\n        # Train against a variety of node-level or graph-level predictions:\n        for epoch in range(30):\n            for index in [...]:  # Indices to train against.\n                loss = explainer.algorithm.train(epoch, model, x, edge_index,\n                                                 target=target, index=index)\n\n        # Get the final explanations:\n        explanation = explainer(x, edge_index, target=target, index=0)\n\n    Args:\n        epochs (int): The number of epochs to train.\n        lr (float, optional): The learning rate to apply.\n            (default: :obj:`0.003`).\n        **kwargs (optional): Additional hyper-parameters to override default\n            settings in\n            :attr:`~torch_geometric.explain.algorithm.PGExplainer.coeffs`.\n    \"\"\"\n\n    coeffs = {\n        'edge_size': 0.05,\n        'edge_ent': 1.0,\n        'temp': [5.0, 2.0],\n        'bias': 0.01,\n    }\n\n    # NOTE: Add more in the future as needed.\n    SUPPORTED_HETERO_MODELS = [\n        HGTConv,\n        HANConv,\n        HeteroConv,\n    ]\n\n    def __init__(self, epochs: int, lr: float = 0.003, **kwargs):\n        super().__init__()\n        self.epochs = epochs\n        self.lr = lr\n        self.coeffs.update(kwargs)\n\n        self.mlp = Sequential(\n            Linear(-1, 64),\n            ReLU(),\n            Linear(64, 1),\n        )\n        self.optimizer = torch.optim.Adam(self.mlp.parameters(), lr=lr)\n        self._curr_epoch = -1\n        self.is_hetero = False\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        reset(self.mlp)\n\n    @overload\n    def train(\n        self,\n        epoch: int,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> float:\n        ...\n\n    @overload\n    def train(\n        self,\n        epoch: int,\n        model: torch.nn.Module,\n        x: Dict[NodeType, Tensor],\n        edge_index: Dict[EdgeType, Tensor],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> float:\n        ...\n\n    def train(\n        self,\n        epoch: int,\n        model: torch.nn.Module,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> float:\n        r\"\"\"Trains the underlying explainer model.\n        Needs to be called before being able to make predictions.\n\n        Args:\n            epoch (int): The current epoch of the training phase.\n            model (torch.nn.Module): The model to explain.\n            x (torch.Tensor or Dict[str, torch.Tensor]): The input node\n                features. Can be either homogeneous or heterogeneous.\n            edge_index (torch.Tensor or Dict[Tuple[str, str, str]): The input\n                edge indices. Can be either homogeneous or heterogeneous.\n            target (torch.Tensor): The target of the model.\n            index (int or torch.Tensor, optional): The index of the model\n                output to explain. Needs to be a single index.\n                (default: :obj:`None`)\n            **kwargs (optional): Additional keyword arguments passed to\n                :obj:`model`.\n        \"\"\"\n        self.is_hetero = isinstance(x, dict)\n        if self.is_hetero:\n            assert isinstance(edge_index, dict)\n\n        if self.model_config.task_level == ModelTaskLevel.node:\n            if index is None:\n                raise ValueError(f\"The 'index' argument needs to be provided \"\n                                 f\"in '{self.__class__.__name__}' for \"\n                                 f\"node-level explanations\")\n            if isinstance(index, Tensor) and index.numel() > 1:\n                raise ValueError(f\"Only scalars are supported for the 'index' \"\n                                 f\"argument in '{self.__class__.__name__}'\")\n\n        # Get embeddings based on whether the graph is homogeneous or\n        # heterogeneous\n        node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs)\n\n        # Train the model\n        self.optimizer.zero_grad()\n        temperature = self._get_temperature(epoch)\n\n        # Process embeddings and generate edge masks\n        edge_mask = self._generate_edge_masks(node_embeddings, edge_index,\n                                              index, temperature)\n\n        # Apply masks to the model\n        if self.is_hetero:\n            set_hetero_masks(model, edge_mask, edge_index, apply_sigmoid=True)\n\n            # For node-level tasks, we can compute hard masks\n            if self.model_config.task_level == ModelTaskLevel.node:\n                # Process each edge type separately\n                for edge_type, mask in edge_mask.items():\n                    # Get the edge indices for this edge type\n                    edges = edge_index[edge_type]\n                    src_type, _, dst_type = edge_type\n\n                    # Get hard masks for this specific edge type\n                    _, hard_mask = self._get_hard_masks(\n                        model, index, edges,\n                        num_nodes=max(x[src_type].size(0),\n                                      x[dst_type].size(0)))\n\n                    edge_mask[edge_type] = mask[hard_mask]\n        else:\n            # Apply masks for homogeneous graphs\n            set_masks(model, edge_mask, edge_index, apply_sigmoid=True)\n\n            # For node-level tasks, we may need to apply hard masks\n            hard_edge_mask = None\n            if self.model_config.task_level == ModelTaskLevel.node:\n                _, hard_edge_mask = self._get_hard_masks(\n                    model, index, edge_index, num_nodes=x.size(0))\n                edge_mask = edge_mask[hard_edge_mask]\n\n        # Forward pass with masks applied\n        y_hat, y = model(x, edge_index, **kwargs), target\n\n        if index is not None:\n            y_hat, y = y_hat[index], y[index]\n\n        # Calculate loss\n        loss = self._loss(y_hat, y, edge_mask)\n\n        # Backward pass and optimization\n        loss.backward()\n        self.optimizer.step()\n\n        # Clean up\n        clear_masks(model)\n        self._curr_epoch = epoch\n\n        return float(loss)\n\n    @overload\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Tensor,\n        edge_index: Tensor,\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> Explanation:\n        ...\n\n    @overload\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Dict[NodeType, Tensor],\n        edge_index: Dict[EdgeType, Tensor],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> HeteroExplanation:\n        ...\n\n    def forward(\n        self,\n        model: torch.nn.Module,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        *,\n        target: Tensor,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> Union[Explanation, HeteroExplanation]:\n        self.is_hetero = isinstance(x, dict)\n\n        if self._curr_epoch < self.epochs - 1:  # Safety check:\n            raise ValueError(f\"'{self.__class__.__name__}' is not yet fully \"\n                             f\"trained (got {self._curr_epoch + 1} epochs \"\n                             f\"from {self.epochs} epochs). Please first train \"\n                             f\"the underlying explainer model by running \"\n                             f\"`explainer.algorithm.train(...)`.\")\n\n        if self.model_config.task_level == ModelTaskLevel.node:\n            if index is None:\n                raise ValueError(f\"The 'index' argument needs to be provided \"\n                                 f\"in '{self.__class__.__name__}' for \"\n                                 f\"node-level explanations\")\n            if isinstance(index, Tensor) and index.numel() > 1:\n                raise ValueError(f\"Only scalars are supported for the 'index' \"\n                                 f\"argument in '{self.__class__.__name__}'\")\n\n        # Get embeddings\n        node_embeddings = self._get_embeddings(model, x, edge_index, **kwargs)\n\n        # Generate explanations\n        if self.is_hetero:\n            # Generate edge masks for each edge type\n            edge_masks = {}\n\n            # Generate masks for each edge type\n            for edge_type, edge_idx in edge_index.items():\n                src_node_type, _, dst_node_type = edge_type\n\n                assert src_node_type in node_embeddings\n                assert dst_node_type in node_embeddings\n\n                inputs = self._get_inputs_hetero(node_embeddings, edge_type,\n                                                 edge_idx, index)\n                logits = self.mlp(inputs).view(-1)\n\n                # For node-level explanations, get hard masks for this\n                # specific edge type\n                hard_edge_mask = None\n                if self.model_config.task_level == ModelTaskLevel.node:\n                    _, hard_edge_mask = self._get_hard_masks(\n                        model, index, edge_idx,\n                        num_nodes=max(x[src_node_type].size(0),\n                                      x[dst_node_type].size(0)))\n\n                # Apply hard mask if available and it has any True values\n                edge_masks[edge_type] = self._post_process_mask(\n                    logits, hard_edge_mask, apply_sigmoid=True)\n\n            explanation = HeteroExplanation()\n            explanation.set_value_dict('edge_mask', edge_masks)\n            return explanation\n        else:\n            hard_edge_mask = None\n            if self.model_config.task_level == ModelTaskLevel.node:\n                # We need to compute hard masks to properly clean up edges\n                _, hard_edge_mask = self._get_hard_masks(\n                    model, index, edge_index, num_nodes=x.size(0))\n\n            inputs = self._get_inputs(node_embeddings, edge_index, index)\n            logits = self.mlp(inputs).view(-1)\n\n            edge_mask = self._post_process_mask(logits, hard_edge_mask,\n                                                apply_sigmoid=True)\n\n            return Explanation(edge_mask=edge_mask)\n\n    def supports(self) -> bool:\n        explanation_type = self.explainer_config.explanation_type\n        if explanation_type != ExplanationType.phenomenon:\n            logging.error(f\"'{self.__class__.__name__}' only supports \"\n                          f\"phenomenon explanations \"\n                          f\"got (`explanation_type={explanation_type.value}`)\")\n            return False\n\n        task_level = self.model_config.task_level\n        if task_level not in {ModelTaskLevel.node, ModelTaskLevel.graph}:\n            logging.error(f\"'{self.__class__.__name__}' only supports \"\n                          f\"node-level or graph-level explanations \"\n                          f\"got (`task_level={task_level.value}`)\")\n            return False\n\n        node_mask_type = self.explainer_config.node_mask_type\n        if node_mask_type is not None:\n            logging.error(f\"'{self.__class__.__name__}' does not support \"\n                          f\"explaining input node features \"\n                          f\"got (`node_mask_type={node_mask_type.value}`)\")\n            return False\n\n        return True\n\n    ###########################################################################\n\n    def _get_embeddings(self, model: torch.nn.Module, x: Union[Tensor,\n                                                               Dict[NodeType,\n                                                                    Tensor]],\n                        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n                        **kwargs) -> Union[Tensor, Dict[NodeType, Tensor]]:\n        \"\"\"Get embeddings from the model based on input type.\"\"\"\n        if self.is_hetero:\n            # For heterogeneous graphs, get embeddings for each node type\n            embeddings_dict = get_embeddings_hetero(\n                model,\n                self.SUPPORTED_HETERO_MODELS,\n                x,\n                edge_index,\n                **kwargs,\n            )\n\n            # Use the last layer's embeddings for each node type\n            last_embedding_dict = {\n                node_type: embs[-1] if embs and len(embs) > 0 else None\n                for node_type, embs in embeddings_dict.items()\n            }\n\n            # Skip if no embeddings were captured\n            if not any(emb is not None\n                       for emb in last_embedding_dict.values()):\n                raise ValueError(\n                    \"No embeddings were captured from the model. \"\n                    \"Please check if the model architecture is supported.\")\n\n            return last_embedding_dict\n        else:\n            # For homogeneous graphs, get embeddings directly\n            return get_embeddings(model, x, edge_index, **kwargs)[-1]\n\n    def _generate_edge_masks(\n            self, emb: Union[Tensor, Dict[NodeType, Tensor]],\n            edge_index: Union[Tensor,\n                              Dict[EdgeType,\n                                   Tensor]], index: Optional[Union[int,\n                                                                   Tensor]],\n            temperature: float) -> Union[Tensor, Dict[EdgeType, Tensor]]:\n        \"\"\"Generate edge masks based on embeddings.\"\"\"\n        if self.is_hetero:\n            # For heterogeneous graphs, generate masks for each edge type\n            edge_masks = {}\n\n            for edge_type, edge_idx in edge_index.items():\n                src, _, dst = edge_type\n\n                assert src in emb and dst in emb\n                # Generate inputs for this edge type\n                inputs = self._get_inputs_hetero(emb, edge_type, edge_idx,\n                                                 index)\n                logits = self.mlp(inputs).view(-1)\n                edge_masks[edge_type] = self._concrete_sample(\n                    logits, temperature)\n\n            # Ensure we have at least one valid edge mask\n            if not edge_masks:\n                raise ValueError(\n                    \"Could not generate edge masks for any edge type. \"\n                    \"Please ensure the model architecture is supported.\")\n\n            return edge_masks\n        else:\n            # For homogeneous graphs, generate a single mask\n            inputs = self._get_inputs(emb, edge_index, index)\n            logits = self.mlp(inputs).view(-1)\n            return self._concrete_sample(logits, temperature)\n\n    def _get_inputs(self, embedding: Tensor, edge_index: Tensor,\n                    index: Optional[int] = None) -> Tensor:\n        zs = [embedding[edge_index[0]], embedding[edge_index[1]]]\n        if self.model_config.task_level == ModelTaskLevel.node:\n            assert index is not None\n            zs.append(embedding[index].view(1, -1).repeat(zs[0].size(0), 1))\n        return torch.cat(zs, dim=-1)\n\n    def _get_inputs_hetero(self, embedding_dict: Dict[NodeType, Tensor],\n                           edge_type: Tuple[str, str, str], edge_index: Tensor,\n                           index: Optional[int] = None) -> Tensor:\n        src, _, dst = edge_type\n\n        # Get embeddings for source and destination nodes\n        src_emb = embedding_dict[src]\n        dst_emb = embedding_dict[dst]\n\n        # Source and destination node embeddings\n        zs = [src_emb[edge_index[0]], dst_emb[edge_index[1]]]\n\n        # For node-level explanations, add the target node embedding\n        if self.model_config.task_level == ModelTaskLevel.node:\n            assert index is not None\n            # Assuming index refers to a node of type 'src'\n            target_emb = src_emb[index].view(1, -1).repeat(zs[0].size(0), 1)\n            zs.append(target_emb)\n\n        return torch.cat(zs, dim=-1)\n\n    def _get_temperature(self, epoch: int) -> float:\n        temp = self.coeffs['temp']\n        return temp[0] * pow(temp[1] / temp[0], epoch / self.epochs)\n\n    def _concrete_sample(self, logits: Tensor,\n                         temperature: float = 1.0) -> Tensor:\n        bias = self.coeffs['bias']\n        eps = (1 - 2 * bias) * torch.rand_like(logits) + bias\n        return (eps.log() - (1 - eps).log() + logits) / temperature\n\n    def _loss(self, y_hat: Tensor, y: Tensor,\n              edge_mask: Union[Tensor, Dict[EdgeType, Tensor]]) -> Tensor:\n        # Calculate base loss based on model configuration\n        loss = self._calculate_base_loss(y_hat, y)\n\n        # Apply regularization based on graph type\n        if self.is_hetero:\n            loss = self._apply_hetero_regularization(loss, edge_mask)\n        else:\n            loss = self._apply_homo_regularization(loss, edge_mask)\n\n        return loss\n\n    def _calculate_base_loss(self, y_hat: Tensor, y: Tensor) -> Tensor:\n        \"\"\"Calculate base loss based on model configuration.\"\"\"\n        if self.model_config.mode == ModelMode.binary_classification:\n            return self._loss_binary_classification(y_hat, y)\n        elif self.model_config.mode == ModelMode.multiclass_classification:\n            return self._loss_multiclass_classification(y_hat, y)\n        elif self.model_config.mode == ModelMode.regression:\n            return self._loss_regression(y_hat, y)\n        else:\n            raise ValueError(\n                f\"Unsupported model mode: {self.model_config.mode}\")\n\n    def _apply_hetero_regularization(\n            self, loss: Tensor, edge_mask: Dict[EdgeType, Tensor]) -> Tensor:\n        \"\"\"Apply regularization for heterogeneous graph.\"\"\"\n        for _, mask in edge_mask.items():\n            loss = self._add_mask_regularization(loss, mask)\n\n        return loss\n\n    def _apply_homo_regularization(self, loss: Tensor,\n                                   edge_mask: Tensor) -> Tensor:\n        \"\"\"Apply regularization for homogeneous graph.\"\"\"\n        return self._add_mask_regularization(loss, edge_mask)\n\n    def _add_mask_regularization(self, loss: Tensor, mask: Tensor) -> Tensor:\n        \"\"\"Add size and entropy regularization for a mask.\"\"\"\n        # Apply sigmoid for mask values\n        mask = mask.sigmoid()\n\n        # Size regularization\n        size_loss = mask.sum() * self.coeffs['edge_size']\n\n        # Entropy regularization\n        masked = 0.99 * mask + 0.005\n        mask_ent = -masked * masked.log() - (1 - masked) * (1 - masked).log()\n        mask_ent_loss = mask_ent.mean() * self.coeffs['edge_ent']\n\n        return loss + size_loss + mask_ent_loss\n"
  },
  {
    "path": "torch_geometric/explain/algorithm/utils.py",
    "content": "from typing import Dict, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn import MessagePassing\nfrom torch_geometric.typing import EdgeType\n\n\ndef set_masks(\n    model: torch.nn.Module,\n    mask: Union[Tensor, Parameter],\n    edge_index: Tensor,\n    apply_sigmoid: bool = True,\n):\n    r\"\"\"Apply mask to every graph layer in the :obj:`model`.\"\"\"\n    loop_mask = edge_index[0] != edge_index[1]\n\n    # Loop over layers and set masks on MessagePassing layers:\n    for module in model.modules():\n        if isinstance(module, MessagePassing):\n            # Skip layers that have been explicitly set to `False`:\n            if module.explain is False:\n                continue\n\n            # Convert mask to a param if it was previously registered as one.\n            # This is a workaround for the fact that PyTorch does not allow\n            # assignments of pure tensors to parameter attributes:\n            if (not isinstance(mask, Parameter)\n                    and '_edge_mask' in module._parameters):\n                mask = Parameter(mask)\n\n            module.explain = True\n            module._edge_mask = mask\n            module._loop_mask = loop_mask\n            module._apply_sigmoid = apply_sigmoid\n\n\ndef set_hetero_masks(\n    model: torch.nn.Module,\n    mask_dict: Dict[EdgeType, Union[Tensor, Parameter]],\n    edge_index_dict: Dict[EdgeType, Tensor],\n    apply_sigmoid: bool = True,\n):\n    r\"\"\"Apply masks to every heterogeneous graph layer in the :obj:`model`\n    according to edge types.\n    \"\"\"\n    for module in model.modules():\n        if isinstance(module, torch.nn.ModuleDict):\n            for edge_type in mask_dict.keys():\n                if edge_type in module:\n                    edge_level_module = module[edge_type]\n                elif '__'.join(edge_type) in module:\n                    edge_level_module = module['__'.join(edge_type)]\n                else:\n                    continue\n\n                set_masks(\n                    edge_level_module,\n                    mask_dict[edge_type],\n                    edge_index_dict[edge_type],\n                    apply_sigmoid=apply_sigmoid,\n                )\n\n\ndef clear_masks(model: torch.nn.Module):\n    r\"\"\"Clear all masks from the model.\"\"\"\n    for module in model.modules():\n        if isinstance(module, MessagePassing):\n            if module.explain is True:\n                module.explain = None\n            module._edge_mask = None\n            module._loop_mask = None\n            module._apply_sigmoid = True\n    return module\n"
  },
  {
    "path": "torch_geometric/explain/config.py",
    "content": "from dataclasses import dataclass\nfrom enum import Enum\nfrom typing import Optional, Union\n\nfrom torch_geometric.utils.mixin import CastMixin\n\n\nclass ExplanationType(Enum):\n    \"\"\"Enum class for the explanation type.\"\"\"\n    model = 'model'\n    phenomenon = 'phenomenon'\n\n\nclass MaskType(Enum):\n    \"\"\"Enum class for the mask type.\"\"\"\n    object = 'object'\n    common_attributes = 'common_attributes'\n    attributes = 'attributes'\n\n\nclass ModelMode(Enum):\n    \"\"\"Enum class for the model return type.\"\"\"\n    binary_classification = 'binary_classification'\n    multiclass_classification = 'multiclass_classification'\n    regression = 'regression'\n\n\nclass ModelTaskLevel(Enum):\n    \"\"\"Enum class for the model task level.\"\"\"\n    node = 'node'\n    edge = 'edge'\n    graph = 'graph'\n\n\nclass ModelReturnType(Enum):\n    \"\"\"Enum class for the model return type.\"\"\"\n    raw = 'raw'\n    probs = 'probs'\n    log_probs = 'log_probs'\n\n\nclass ThresholdType(Enum):\n    \"\"\"Enum class for the threshold type.\"\"\"\n    hard = 'hard'\n    topk = 'topk'\n    topk_hard = 'topk_hard'\n    # connected = 'connected'  # TODO\n\n\n@dataclass\nclass ExplainerConfig(CastMixin):\n    r\"\"\"Configuration class to store and validate high level explanation\n    parameters.\n\n    Args:\n        explanation_type (ExplanationType or str): The type of explanation to\n            compute. The possible values are:\n\n                - :obj:`\"model\"`: Explains the model prediction.\n\n                - :obj:`\"phenomenon\"`: Explains the phenomenon that the model\n                  is trying to predict.\n\n            In practice, this means that the explanation algorithm will either\n            compute their losses with respect to the model output\n            (:obj:`\"model\"`) or the target output (:obj:`\"phenomenon\"`).\n\n        node_mask_type (MaskType or str, optional): The type of mask to apply\n            on nodes. The possible values are (default: :obj:`None`):\n\n                - :obj:`None`: Will not apply any mask on nodes.\n\n                - :obj:`\"object\"`: Will mask each node.\n\n                - :obj:`\"common_attributes\"`: Will mask each feature.\n\n                - :obj:`\"attributes\"`: Will mask each feature across all nodes.\n\n        edge_mask_type (MaskType or str, optional): The type of mask to apply\n            on edges. Has the sample possible values as :obj:`node_mask_type`.\n            (default: :obj:`None`)\n    \"\"\"\n    explanation_type: ExplanationType\n    node_mask_type: Optional[MaskType]\n    edge_mask_type: Optional[MaskType]\n\n    def __init__(\n        self,\n        explanation_type: Union[ExplanationType, str],\n        node_mask_type: Optional[Union[MaskType, str]] = None,\n        edge_mask_type: Optional[Union[MaskType, str]] = None,\n    ):\n        if node_mask_type is not None:\n            node_mask_type = MaskType(node_mask_type)\n        if edge_mask_type is not None:\n            edge_mask_type = MaskType(edge_mask_type)\n\n        if edge_mask_type is not None and edge_mask_type != MaskType.object:\n            raise ValueError(f\"'edge_mask_type' needs be None or of type \"\n                             f\"'object' (got '{edge_mask_type.value}')\")\n\n        if node_mask_type is None and edge_mask_type is None:\n            raise ValueError(\"Either 'node_mask_type' or 'edge_mask_type' \"\n                             \"must be provided\")\n\n        self.explanation_type = ExplanationType(explanation_type)\n        self.node_mask_type = node_mask_type\n        self.edge_mask_type = edge_mask_type\n\n\n@dataclass\nclass ModelConfig(CastMixin):\n    r\"\"\"Configuration class to store model parameters.\n\n    Args:\n        mode (ModelMode or str): The mode of the model. The possible values\n            are:\n\n                - :obj:`\"binary_classification\"`: A binary classification\n                  model.\n\n                - :obj:`\"multiclass_classification\"`: A multiclass\n                  classification model.\n\n                - :obj:`\"regression\"`: A regression model.\n\n        task_level (ModelTaskLevel or str): The task-level of the model.\n            The possible values are:\n\n                - :obj:`\"node\"`: A node-level prediction model.\n\n                - :obj:`\"edge\"`: An edge-level prediction model.\n\n                - :obj:`\"graph\"`: A graph-level prediction model.\n\n        return_type (ModelReturnType or str, optional): The return type of the\n            model. The possible values are (default: :obj:`None`):\n\n                - :obj:`\"raw\"`: The model returns raw values.\n\n                - :obj:`\"probs\"`: The model returns probabilities.\n\n                - :obj:`\"log_probs\"`: The model returns log-probabilities.\n    \"\"\"\n    mode: ModelMode\n    task_level: ModelTaskLevel\n    return_type: ModelReturnType\n\n    def __init__(\n        self,\n        mode: Union[ModelMode, str],\n        task_level: Union[ModelTaskLevel, str],\n        return_type: Optional[Union[ModelReturnType, str]] = None,\n    ):\n        self.mode = ModelMode(mode)\n        self.task_level = ModelTaskLevel(task_level)\n\n        if return_type is None and self.mode == ModelMode.regression:\n            return_type = ModelReturnType.raw\n\n        self.return_type = ModelReturnType(return_type)\n\n        if (self.mode == ModelMode.regression\n                and self.return_type != ModelReturnType.raw):\n            raise ValueError(f\"A model for regression needs to return raw \"\n                             f\"outputs (got {self.return_type.value})\")\n\n        if (self.mode == ModelMode.binary_classification and self.return_type\n                not in [ModelReturnType.raw, ModelReturnType.probs]):\n            raise ValueError(\n                f\"A model for binary classification needs to return raw \"\n                f\"outputs or probabilities (got {self.return_type.value})\")\n\n\n@dataclass\nclass ThresholdConfig(CastMixin):\n    r\"\"\"Configuration class to store and validate threshold parameters.\n\n    Args:\n        threshold_type (ThresholdType or str): The type of threshold to apply.\n            The possible values are:\n\n                - :obj:`None`: No threshold is applied.\n\n                - :obj:`\"hard\"`: A hard threshold is applied to each mask.\n                  The elements of the mask with a value below the :obj:`value`\n                  are set to :obj:`0`, the others are set to :obj:`1`.\n\n                - :obj:`\"topk\"`: A soft threshold is applied to each mask.\n                  The top obj:`value` elements of each mask are kept, the\n                  others are set to :obj:`0`.\n\n                - :obj:`\"topk_hard\"`: Same as :obj:`\"topk\"` but values are set\n                  to :obj:`1` for all elements which are kept.\n\n        value (int or float, optional): The value to use when thresholding.\n            (default: :obj:`None`)\n    \"\"\"\n    type: ThresholdType\n    value: Union[float, int]\n\n    def __init__(\n        self,\n        threshold_type: Union[ThresholdType, str],\n        value: Union[float, int],\n    ):\n        self.type = ThresholdType(threshold_type)\n        self.value = value\n\n        if not isinstance(self.value, (int, float)):\n            raise ValueError(f\"Threshold value must be a float or int \"\n                             f\"(got {type(self.value)}).\")\n\n        if (self.type == ThresholdType.hard\n                and (self.value < 0 or self.value > 1)):\n            raise ValueError(f\"Threshold value must be between 0 and 1 \"\n                             f\"(got {self.value})\")\n\n        if self.type in [ThresholdType.topk, ThresholdType.topk_hard]:\n            if not isinstance(self.value, int):\n                raise ValueError(f\"Threshold value needs to be an integer \"\n                                 f\"(got {type(self.value)}).\")\n            if self.value <= 0:\n                raise ValueError(f\"Threshold value needs to be positive \"\n                                 f\"(got {self.value}).\")\n"
  },
  {
    "path": "torch_geometric/explain/explainer.py",
    "content": "import warnings\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain import (\n    ExplainerAlgorithm,\n    Explanation,\n    HeteroExplanation,\n)\nfrom torch_geometric.explain.algorithm.utils import (\n    clear_masks,\n    set_hetero_masks,\n    set_masks,\n)\nfrom torch_geometric.explain.config import (\n    ExplainerConfig,\n    ExplanationType,\n    MaskType,\n    ModelConfig,\n    ModelMode,\n    ModelReturnType,\n    ThresholdConfig,\n)\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\nclass Explainer:\n    r\"\"\"An explainer class for instance-level explanations of Graph Neural\n    Networks.\n\n    Args:\n        model (torch.nn.Module): The model to explain.\n        algorithm (ExplainerAlgorithm): The explanation algorithm.\n        explanation_type (ExplanationType or str): The type of explanation to\n            compute. The possible values are:\n\n                - :obj:`\"model\"`: Explains the model prediction.\n\n                - :obj:`\"phenomenon\"`: Explains the phenomenon that the model\n                  is trying to predict.\n\n            In practice, this means that the explanation algorithm will either\n            compute their losses with respect to the model output\n            (:obj:`\"model\"`) or the target output (:obj:`\"phenomenon\"`).\n        model_config (ModelConfig): The model configuration.\n            See :class:`~torch_geometric.explain.config.ModelConfig` for\n            available options. (default: :obj:`None`)\n        node_mask_type (MaskType or str, optional): The type of mask to apply\n            on nodes. The possible values are (default: :obj:`None`):\n\n                - :obj:`None`: Will not apply any mask on nodes.\n\n                - :obj:`\"object\"`: Will mask each node.\n\n                - :obj:`\"common_attributes\"`: Will mask each feature.\n\n                - :obj:`\"attributes\"`: Will mask each feature across all nodes.\n\n        edge_mask_type (MaskType or str, optional): The type of mask to apply\n            on edges. Has the sample possible values as :obj:`node_mask_type`.\n            (default: :obj:`None`)\n        threshold_config (ThresholdConfig, optional): The threshold\n            configuration.\n            See :class:`~torch_geometric.explain.config.ThresholdConfig` for\n            available options. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        algorithm: ExplainerAlgorithm,\n        explanation_type: Union[ExplanationType, str],\n        model_config: Union[ModelConfig, Dict[str, Any]],\n        node_mask_type: Optional[Union[MaskType, str]] = None,\n        edge_mask_type: Optional[Union[MaskType, str]] = None,\n        threshold_config: Optional[ThresholdConfig] = None,\n    ):\n        explainer_config = ExplainerConfig(\n            explanation_type=explanation_type,\n            node_mask_type=node_mask_type,\n            edge_mask_type=edge_mask_type,\n        )\n\n        self.model = model\n        self.algorithm = algorithm\n\n        self.explanation_type = explainer_config.explanation_type\n        self.model_config = ModelConfig.cast(model_config)\n        self.node_mask_type = explainer_config.node_mask_type\n        self.edge_mask_type = explainer_config.edge_mask_type\n        self.threshold_config = ThresholdConfig.cast(threshold_config)\n\n        self.algorithm.connect(explainer_config, self.model_config)\n\n    @torch.no_grad()\n    def get_prediction(self, *args, **kwargs) -> Tensor:\n        r\"\"\"Returns the prediction of the model on the input graph.\n\n        If the model mode is :obj:`\"regression\"`, the prediction is returned as\n        a scalar value.\n        If the model mode is :obj:`\"multiclass_classification\"` or\n        :obj:`\"binary_classification\"`, the prediction is returned as the\n        predicted class label.\n\n        Args:\n            *args: Arguments passed to the model.\n            **kwargs (optional): Additional keyword arguments passed to the\n                model.\n        \"\"\"\n        training = self.model.training\n        self.model.eval()\n\n        with torch.no_grad():\n            out = self.model(*args, **kwargs)\n\n        self.model.train(training)\n\n        return out\n\n    def get_masked_prediction(\n        self,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        node_mask: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None,\n        edge_mask: Optional[Union[Tensor, Dict[EdgeType, Tensor]]] = None,\n        **kwargs,\n    ) -> Tensor:\n        r\"\"\"Returns the prediction of the model on the input graph with node\n        and edge masks applied.\n        \"\"\"\n        if isinstance(x, Tensor) and node_mask is not None:\n            x = node_mask * x\n        elif isinstance(x, dict) and node_mask is not None:\n            x = {key: value * node_mask[key] for key, value in x.items()}\n\n        if isinstance(edge_mask, Tensor):\n            set_masks(self.model, edge_mask, edge_index, apply_sigmoid=False)\n        elif isinstance(edge_mask, dict):\n            set_hetero_masks(self.model, edge_mask, edge_index,\n                             apply_sigmoid=False)\n\n        out = self.get_prediction(x, edge_index, **kwargs)\n        clear_masks(self.model)\n        return out\n\n    def __call__(\n        self,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        *,\n        target: Optional[Tensor] = None,\n        index: Optional[Union[int, Tensor]] = None,\n        **kwargs,\n    ) -> Union[Explanation, HeteroExplanation]:\n        r\"\"\"Computes the explanation of the GNN for the given inputs and\n        target.\n\n        .. note::\n\n            If you get an error message like \"Trying to backward through the\n            graph a second time\", make sure that the target you provided\n            was computed with :meth:`torch.no_grad`.\n\n        Args:\n            x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input\n                node features of a homogeneous or heterogeneous graph.\n            edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The\n                input edge indices of a homogeneous or heterogeneous graph.\n            target (torch.Tensor): The target of the model.\n                If the explanation type is :obj:`\"phenomenon\"`, the target has\n                to be provided.\n                If the explanation type is :obj:`\"model\"`, the target should be\n                set to :obj:`None` and will get automatically inferred. For\n                classification tasks, the target needs to contain the class\n                labels. (default: :obj:`None`)\n            index (Union[int, Tensor], optional): The indices in the\n                first-dimension of the model output to explain.\n                Can be a single index or a tensor of indices.\n                If set to :obj:`None`, all model outputs will be explained.\n                (default: :obj:`None`)\n            **kwargs: additional arguments to pass to the GNN.\n        \"\"\"\n        # Choose the `target` depending on the explanation type:\n        prediction: Optional[Tensor] = None\n        if self.explanation_type == ExplanationType.phenomenon:\n            if target is None:\n                raise ValueError(\n                    f\"The 'target' has to be provided for the explanation \"\n                    f\"type '{self.explanation_type.value}'\")\n        elif self.explanation_type == ExplanationType.model:\n            if target is not None:\n                warnings.warn(\n                    f\"The 'target' should not be provided for the explanation \"\n                    f\"type '{self.explanation_type.value}'\", stacklevel=2)\n            prediction = self.get_prediction(x, edge_index, **kwargs)\n            target = self.get_target(prediction)\n\n        if isinstance(index, int):\n            index = torch.tensor([index])\n\n        training = self.model.training\n        self.model.eval()\n\n        explanation = self.algorithm(\n            self.model,\n            x,\n            edge_index,\n            target=target,\n            index=index,\n            **kwargs,\n        )\n\n        self.model.train(training)\n\n        # Add explainer objectives to the `Explanation` object:\n        explanation._model_config = self.model_config\n        explanation.prediction = prediction\n        explanation.target = target\n        explanation.index = index\n\n        # Add model inputs to the `Explanation` object:\n        if isinstance(explanation, Explanation):\n            explanation._model_args = list(kwargs.keys())\n            explanation.x = x\n            explanation.edge_index = edge_index\n\n            for key, arg in kwargs.items():  # Add remaining `kwargs`:\n                explanation[key] = arg\n\n        elif isinstance(explanation, HeteroExplanation):\n            # TODO Add `explanation._model_args`\n\n            assert isinstance(x, dict)\n            explanation.set_value_dict('x', x)\n\n            assert isinstance(edge_index, dict)\n            explanation.set_value_dict('edge_index', edge_index)\n\n            for key, arg in kwargs.items():  # Add remaining `kwargs`:\n                if isinstance(arg, dict):\n                    # Keyword arguments are likely named `{attr_name}_dict`\n                    # while we only want to assign the `{attr_name}` to the\n                    # `HeteroExplanation` object:\n                    key = key[:-5] if key.endswith('_dict') else key\n                    explanation.set_value_dict(key, arg)\n                else:\n                    explanation[key] = arg\n\n        explanation.validate_masks()\n        return explanation.threshold(self.threshold_config)\n\n    def get_target(self, prediction: Tensor) -> Tensor:\n        r\"\"\"Returns the target of the model from a given prediction.\n\n        If the model mode is of type :obj:`\"regression\"`, the prediction is\n        returned as it is.\n        If the model mode is of type :obj:`\"multiclass_classification\"` or\n        :obj:`\"binary_classification\"`, the prediction is returned as the\n        predicted class label.\n        \"\"\"\n        if self.model_config.mode == ModelMode.binary_classification:\n            # TODO: Allow customization of the thresholds used below.\n            if self.model_config.return_type == ModelReturnType.raw:\n                return (prediction > 0).long().view(-1)\n            if self.model_config.return_type == ModelReturnType.probs:\n                return (prediction > 0.5).long().view(-1)\n            raise AssertionError()\n\n        if self.model_config.mode == ModelMode.multiclass_classification:\n            return prediction.argmax(dim=-1)\n\n        return prediction\n"
  },
  {
    "path": "torch_geometric/explain/explanation.py",
    "content": "import copy\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data.data import Data, warn_or_raise\nfrom torch_geometric.data.hetero_data import HeteroData\nfrom torch_geometric.explain.config import ThresholdConfig, ThresholdType\nfrom torch_geometric.typing import EdgeType, NodeType\nfrom torch_geometric.visualization import (\n    visualize_graph,\n    visualize_hetero_graph,\n)\n\n\nclass ExplanationMixin:\n    @property\n    def available_explanations(self) -> List[str]:\n        \"\"\"Returns the available explanation masks.\"\"\"\n        return [key for key in self.keys() if key.endswith('_mask')]\n\n    def validate_masks(self, raise_on_error: bool = True) -> bool:\n        r\"\"\"Validates the correctness of the :class:`Explanation` masks.\"\"\"\n        status = True\n\n        for store in self.node_stores:\n            if 'node_mask' not in store:\n                continue\n\n            if store.node_mask.dim() != 2:\n                status = False\n                warn_or_raise(\n                    f\"Expected a 'node_mask' with two dimensions (got \"\n                    f\"{store.node_mask.dim()} dimensions)\", raise_on_error)\n\n            if store.node_mask.size(0) not in {1, store.num_nodes}:\n                status = False\n                warn_or_raise(\n                    f\"Expected a 'node_mask' with {store.num_nodes} nodes \"\n                    f\"(got {store.node_mask.size(0)} nodes)\", raise_on_error)\n\n            if 'x' in store:\n                num_features = store.x.size(-1)\n            else:\n                num_features = store.node_mask.size(-1)\n\n            if store.node_mask.size(1) not in {1, num_features}:\n                status = False\n                warn_or_raise(\n                    f\"Expected a 'node_mask' with {num_features} features (\"\n                    f\"got {store.node_mask.size(1)} features)\", raise_on_error)\n\n        for store in self.edge_stores:\n            if 'edge_mask' not in store:\n                continue\n\n            if store.edge_mask.dim() != 1:\n                status = False\n                warn_or_raise(\n                    f\"Expected an 'edge_mask' with one dimension (got \"\n                    f\"{store.edge_mask.dim()} dimensions)\", raise_on_error)\n\n            if store.edge_mask.size(0) != store.num_edges:\n                status = False\n                warn_or_raise(\n                    f\"Expected an 'edge_mask' with {store.num_edges} edges \"\n                    f\"(got {store.edge_mask.size(0)} edges)\", raise_on_error)\n\n        return status\n\n    def _threshold_mask(\n        self,\n        mask: Optional[Tensor],\n        threshold_config: ThresholdConfig,\n    ) -> Optional[Tensor]:\n\n        if mask is None:\n            return None\n\n        if threshold_config.type == ThresholdType.hard:\n            return (mask > threshold_config.value).float()\n\n        if threshold_config.type in [\n                ThresholdType.topk,\n                ThresholdType.topk_hard,\n        ]:\n            if threshold_config.value >= mask.numel():\n                if threshold_config.type == ThresholdType.topk:\n                    return mask\n                else:\n                    return torch.ones_like(mask)\n\n            value, index = torch.topk(\n                mask.flatten(),\n                k=threshold_config.value,\n            )\n\n            out = torch.zeros_like(mask.flatten())\n            if threshold_config.type == ThresholdType.topk:\n                out[index] = value\n            else:\n                out[index] = 1.0\n            return out.view(mask.size())\n\n        raise AssertionError()\n\n    def threshold(\n        self,\n        *args,\n        **kwargs,\n    ) -> Union['Explanation', 'HeteroExplanation']:\n        \"\"\"Thresholds the explanation masks according to the thresholding\n        method.\n\n        Args:\n            *args: Arguments passed to :class:`ThresholdConfig`.\n            **kwargs: Keyword arguments passed to :class:`ThresholdConfig`.\n        \"\"\"\n        threshold_config = ThresholdConfig.cast(*args, **kwargs)\n\n        if threshold_config is None:\n            return self\n\n        # Avoid modification of the original explanation:\n        out = copy.copy(self)\n\n        for store in out.node_stores:\n            store.node_mask = self._threshold_mask(store.get('node_mask'),\n                                                   threshold_config)\n\n        for store in out.edge_stores:\n            store.edge_mask = self._threshold_mask(store.get('edge_mask'),\n                                                   threshold_config)\n\n        return out\n\n\nclass Explanation(Data, ExplanationMixin):\n    r\"\"\"Holds all the obtained explanations of a homogeneous graph.\n\n    The explanation object is a :obj:`~torch_geometric.data.Data` object and\n    can hold node attributions and edge attributions.\n    It can also hold the original graph if needed.\n\n    Args:\n        node_mask (Tensor, optional): Node-level mask with shape\n            :obj:`[num_nodes, 1]`, :obj:`[1, num_features]` or\n            :obj:`[num_nodes, num_features]`. (default: :obj:`None`)\n        edge_mask (Tensor, optional): Edge-level mask with shape\n            :obj:`[num_edges]`. (default: :obj:`None`)\n        **kwargs (optional): Additional attributes.\n    \"\"\"\n    def validate(self, raise_on_error: bool = True) -> bool:\n        r\"\"\"Validates the correctness of the :class:`Explanation` object.\"\"\"\n        status = super().validate(raise_on_error)\n        status &= self.validate_masks(raise_on_error)\n        return status\n\n    def get_explanation_subgraph(self) -> 'Explanation':\n        r\"\"\"Returns the induced subgraph, in which all nodes and edges with\n        zero attribution are masked out.\n        \"\"\"\n        node_mask = self.get('node_mask')\n        if node_mask is not None:\n            node_mask = node_mask.sum(dim=-1) > 0\n        edge_mask = self.get('edge_mask')\n        if edge_mask is not None:\n            edge_mask = edge_mask > 0\n        return self._apply_masks(node_mask, edge_mask)\n\n    def get_complement_subgraph(self) -> 'Explanation':\n        r\"\"\"Returns the induced subgraph, in which all nodes and edges with any\n        attribution are masked out.\n        \"\"\"\n        node_mask = self.get('node_mask')\n        if node_mask is not None:\n            node_mask = node_mask.sum(dim=-1) == 0\n        edge_mask = self.get('edge_mask')\n        if edge_mask is not None:\n            edge_mask = edge_mask == 0\n        return self._apply_masks(node_mask, edge_mask)\n\n    def _apply_masks(\n        self,\n        node_mask: Optional[Tensor] = None,\n        edge_mask: Optional[Tensor] = None,\n    ) -> 'Explanation':\n        out = copy.copy(self)\n\n        if edge_mask is not None:\n            for key, value in self.items():\n                if key == 'edge_index':\n                    out.edge_index = value[:, edge_mask]\n                elif self.is_edge_attr(key):\n                    out[key] = value[edge_mask]\n\n        if node_mask is not None:\n            out = out.subgraph(node_mask)\n\n        return out\n\n    def visualize_feature_importance(\n        self,\n        path: Optional[str] = None,\n        feat_labels: Optional[List[str]] = None,\n        top_k: Optional[int] = None,\n    ):\n        r\"\"\"Creates a bar plot of the node feature importances by summing up\n        the node mask across all nodes.\n\n        Args:\n            path (str, optional): The path to where the plot is saved.\n                If set to :obj:`None`, will visualize the plot on-the-fly.\n                (default: :obj:`None`)\n            feat_labels (List[str], optional): The labels of features.\n                (default :obj:`None`)\n            top_k (int, optional): Top k features to plot. If :obj:`None`\n                plots all features. (default: :obj:`None`)\n        \"\"\"\n        node_mask = self.get('node_mask')\n        if node_mask is None:\n            raise ValueError(f\"The attribute 'node_mask' is not available \"\n                             f\"in '{self.__class__.__name__}' \"\n                             f\"(got {self.available_explanations})\")\n        if node_mask.dim() != 2 or node_mask.size(1) <= 1:\n            raise ValueError(f\"Cannot compute feature importance for \"\n                             f\"object-level 'node_mask' \"\n                             f\"(got shape {node_mask.size()})\")\n\n        if feat_labels is None:\n            feat_labels = range(node_mask.size(1))\n\n        score = node_mask.sum(dim=0)\n\n        return _visualize_score(score, feat_labels, path, top_k)\n\n    def visualize_graph(\n        self,\n        path: Optional[str] = None,\n        backend: Optional[str] = None,\n        node_labels: Optional[List[str]] = None,\n    ) -> None:\n        r\"\"\"Visualizes the explanation graph with edge opacity corresponding to\n        edge importance.\n\n        Args:\n            path (str, optional): The path to where the plot is saved.\n                If set to :obj:`None`, will visualize the plot on-the-fly.\n                (default: :obj:`None`)\n            backend (str, optional): The graph drawing backend to use for\n                visualization (:obj:`\"graphviz\"`, :obj:`\"networkx\"`).\n                If set to :obj:`None`, will use the most appropriate\n                visualization backend based on available system packages.\n                (default: :obj:`None`)\n            node_labels (list[str], optional): The labels/IDs of nodes.\n                (default: :obj:`None`)\n        \"\"\"\n        edge_mask = self.get('edge_mask')\n        if edge_mask is None:\n            raise ValueError(f\"The attribute 'edge_mask' is not available \"\n                             f\"in '{self.__class__.__name__}' \"\n                             f\"(got {self.available_explanations})\")\n        visualize_graph(self.edge_index, edge_mask, path, backend, node_labels)\n\n\nclass HeteroExplanation(HeteroData, ExplanationMixin):\n    r\"\"\"Holds all the obtained explanations of a heterogeneous graph.\n\n    The explanation object is a :obj:`~torch_geometric.data.HeteroData` object\n    and can hold node attributions and edge attributions.\n    It can also hold the original graph if needed.\n    \"\"\"\n    def validate(self, raise_on_error: bool = True) -> bool:\n        r\"\"\"Validates the correctness of the :class:`Explanation` object.\"\"\"\n        status = super().validate(raise_on_error)\n        status &= self.validate_masks(raise_on_error)\n        return status\n\n    def get_explanation_subgraph(self) -> 'HeteroExplanation':\n        r\"\"\"Returns the induced subgraph, in which all nodes and edges with\n        zero attribution are masked out.\n        \"\"\"\n        return self._apply_masks(\n            node_mask_dict={\n                key: mask.sum(dim=-1) > 0\n                for key, mask in self.collect('node_mask', True).items()\n            },\n            edge_mask_dict={\n                key: mask > 0\n                for key, mask in self.collect('edge_mask', True).items()\n            },\n        )\n\n    def get_complement_subgraph(self) -> 'HeteroExplanation':\n        r\"\"\"Returns the induced subgraph, in which all nodes and edges with any\n        attribution are masked out.\n        \"\"\"\n        return self._apply_masks(\n            node_mask_dict={\n                key: mask.sum(dim=-1) == 0\n                for key, mask in self.collect('node_mask', True).items()\n            },\n            edge_mask_dict={\n                key: mask == 0\n                for key, mask in self.collect('edge_mask', True).items()\n            },\n        )\n\n    def _apply_masks(\n        self,\n        node_mask_dict: Dict[NodeType, Tensor],\n        edge_mask_dict: Dict[EdgeType, Tensor],\n    ) -> 'HeteroExplanation':\n        out = copy.copy(self)\n\n        for edge_type, edge_mask in edge_mask_dict.items():\n            for key, value in self[edge_type].items():\n                if key == 'edge_index':\n                    out[edge_type].edge_index = value[:, edge_mask]\n                elif self[edge_type].is_edge_attr(key):\n                    out[edge_type][key] = value[edge_mask]\n\n        return out.subgraph(node_mask_dict)\n\n    def visualize_feature_importance(\n        self,\n        path: Optional[str] = None,\n        feat_labels: Optional[Dict[NodeType, List[str]]] = None,\n        top_k: Optional[int] = None,\n    ):\n        r\"\"\"Creates a bar plot of the node feature importances by summing up\n        node masks across all nodes for each node type.\n\n        Args:\n            path (str, optional): The path to where the plot is saved.\n                If set to :obj:`None`, will visualize the plot on-the-fly.\n                (default: :obj:`None`)\n            feat_labels (Dict[NodeType, List[str]], optional): The labels of\n                features for each node type. (default :obj:`None`)\n            top_k (int, optional): Top k features to plot. If :obj:`None`\n                plots all features. (default: :obj:`None`)\n        \"\"\"\n        node_mask_dict = self.node_mask_dict\n        for node_mask in node_mask_dict.values():\n            if node_mask.dim() != 2:\n                raise ValueError(f\"Cannot compute feature importance for \"\n                                 f\"object-level 'node_mask' \"\n                                 f\"(got shape {node_mask.size()})\")\n\n        if feat_labels is None:\n            feat_labels = {}\n            for node_type, node_mask in node_mask_dict.items():\n                feat_labels[node_type] = range(node_mask.size(1))\n\n        score = torch.cat(\n            [node_mask.sum(dim=0) for node_mask in node_mask_dict.values()],\n            dim=0)\n\n        all_feat_labels = []\n        for node_type in node_mask_dict.keys():\n            all_feat_labels += [\n                f'{node_type}#{label}' for label in feat_labels[node_type]\n            ]\n\n        return _visualize_score(score, all_feat_labels, path, top_k)\n\n    def visualize_graph(\n            self,\n            path: Optional[str] = None,\n            node_labels: Optional[Dict[NodeType, List[str]]] = None,\n            node_size_range: Tuple[float, float] = (50, 500),\n            node_opacity_range: Tuple[float, float] = (0.2, 1.0),\n            edge_width_range: Tuple[float, float] = (0.1, 2.0),\n            edge_opacity_range: Tuple[float, float] = (0.2, 1.0),\n    ) -> None:\n        r\"\"\"Visualizes the explanation subgraph using networkx, with edge\n        opacity corresponding to edge importance and node colors\n        corresponding to node types.\n\n        Args:\n            path (str, optional): The path to where the plot is saved.\n                If set to :obj:`None`, will visualize the plot on-the-fly.\n                (default: :obj:`None`)\n            node_labels (Dict[NodeType, List[str]], optional): The display\n                names of nodes for each node type that will be shown in the\n                visualization. (default: :obj:`None`)\n            node_size_range (Tuple[float, float], optional): The minimum and\n                maximum node size in the visualization.\n                (default: :obj:`(50, 500)`)\n            node_opacity_range (Tuple[float, float], optional): The minimum and\n                maximum node opacity in the visualization.\n                (default: :obj:`(0.2, 1.0)`)\n            edge_width_range (Tuple[float, float], optional): The minimum and\n                maximum edge width in the visualization.\n                (default: :obj:`(0.1, 2.0)`)\n            edge_opacity_range (Tuple[float, float], optional): The minimum and\n                maximum edge opacity in the visualization.\n                (default: :obj:`(0.2, 1.0)`)\n        \"\"\"\n        # Validate node labels if provided\n        if node_labels is not None:\n            for node_type, labels in node_labels.items():\n                if node_type not in self.node_types:\n                    raise ValueError(\n                        f\"Node type '{node_type}' in node_labels \"\n                        f\"does not exist in the explanation graph\")\n                if len(labels) != self[node_type].num_nodes:\n                    raise ValueError(f\"Number of labels for node type \"\n                                     f\"'{node_type}' (got {len(labels)}) does \"\n                                     f\"not match the number of nodes \"\n                                     f\"(got {self[node_type].num_nodes})\")\n        # Get the explanation subgraph\n        subgraph = self.get_explanation_subgraph()\n\n        # Prepare edge indices and weights for each edge type\n        edge_index_dict = {}\n        edge_weight_dict = {}\n        for edge_type in subgraph.edge_types:\n            if edge_type[0] == 'x' or edge_type[-1] == 'x':  # Skip edges\n                continue\n            edge_index_dict[edge_type] = subgraph[edge_type].edge_index\n            edge_weight_dict[edge_type] = subgraph[edge_type].get(\n                'edge_mask',\n                torch.ones(subgraph[edge_type].edge_index.size(1)))\n\n        # Prepare node weights for each node type\n        node_weight_dict = {}\n        for node_type in subgraph.node_types:\n            if node_type == 'x':  # Skip the global store\n                continue\n            node_weight_dict[node_type] = subgraph[node_type] \\\n                .get('node_mask',\n                     torch.ones(subgraph[node_type].num_nodes)).squeeze(-1)\n\n        # Call the visualization function\n        visualize_hetero_graph(\n            edge_index_dict=edge_index_dict,\n            edge_weight_dict=edge_weight_dict,\n            path=path,\n            node_labels_dict=node_labels,\n            node_weight_dict=node_weight_dict,\n            node_size_range=node_size_range,\n            node_opacity_range=node_opacity_range,\n            edge_width_range=edge_width_range,\n            edge_opacity_range=edge_opacity_range,\n        )\n\n\ndef _visualize_score(\n    score: torch.Tensor,\n    labels: List[str],\n    path: Optional[str] = None,\n    top_k: Optional[int] = None,\n):\n    import matplotlib.pyplot as plt\n    import pandas as pd\n\n    if len(labels) != score.numel():\n        raise ValueError(f\"The number of labels (got {len(labels)}) must \"\n                         f\"match the number of scores (got {score.numel()})\")\n\n    score = score.cpu().numpy()\n\n    df = pd.DataFrame({'score': score}, index=labels)\n    df = df.sort_values('score', ascending=False)\n    df = df.round(decimals=3)\n\n    if top_k is not None:\n        df = df.head(top_k)\n        title = f\"Feature importance for top {len(df)} features\"\n    else:\n        title = f\"Feature importance for {len(df)} features\"\n\n    ax = df.plot(\n        kind='barh',\n        figsize=(10, 7),\n        title=title,\n        ylabel='Feature label',\n        xlim=[0, float(df['score'].max()) + 0.3],\n        legend=False,\n    )\n    plt.gca().invert_yaxis()\n    ax.bar_label(container=ax.containers[0], label_type='edge')\n\n    if path is not None:\n        plt.savefig(path)\n    else:\n        plt.show()\n\n    plt.close()\n"
  },
  {
    "path": "torch_geometric/explain/metric/__init__.py",
    "content": "from .basic import groundtruth_metrics\nfrom .fidelity import fidelity, characterization_score, fidelity_curve_auc\nfrom .faithfulness import unfaithfulness\n\n__all__ = classes = [\n    'groundtruth_metrics',\n    'fidelity',\n    'characterization_score',\n    'fidelity_curve_auc',\n    'unfaithfulness',\n]\n"
  },
  {
    "path": "torch_geometric/explain/metric/basic.py",
    "content": "from typing import List, Optional, Tuple, Union\n\nfrom torch import Tensor\n\nMETRICS = ['accuracy', 'recall', 'precision', 'f1_score', 'auroc']\n\n\ndef groundtruth_metrics(\n    pred_mask: Tensor,\n    target_mask: Tensor,\n    metrics: Optional[Union[str, List[str]]] = None,\n    threshold: float = 0.5,\n) -> Union[float, Tuple[float, ...]]:\n    r\"\"\"Compares and evaluates an explanation mask with the ground-truth\n    explanation mask.\n\n    Args:\n        pred_mask (torch.Tensor): The prediction mask to evaluate.\n        target_mask (torch.Tensor): The ground-truth target mask.\n        metrics (str or List[str], optional): The metrics to return\n            (:obj:`\"accuracy\"`, :obj:`\"recall\"`, :obj:`\"precision\"`,\n            :obj:`\"f1_score\"`, :obj:`\"auroc\"`). (default: :obj:`[\"accuracy\",\n            \"recall\", \"precision\", \"f1_score\", \"auroc\"]`)\n        threshold (float, optional): The threshold value to perform hard\n            thresholding of :obj:`mask` and :obj:`groundtruth`.\n            (default: :obj:`0.5`)\n    \"\"\"\n    import torchmetrics\n\n    if metrics is None:\n        metrics = METRICS\n\n    if isinstance(metrics, str):\n        metrics = [metrics]\n\n    if not isinstance(metrics, (tuple, list)):\n        raise ValueError(f\"Expected metrics to be a string or a list of \"\n                         f\"strings (got {type(metrics)})\")\n\n    pred_mask = pred_mask.view(-1)\n    target_mask = (target_mask >= threshold).view(-1)\n\n    outs = []\n    for metric in metrics:\n        if metric not in METRICS:\n            raise ValueError(f\"Encountered invalid metric {metric}\")\n\n        fn = getattr(torchmetrics.functional, metric)\n        if metric in {'auroc'}:\n            out = fn(pred_mask, target_mask, 'binary')\n        else:\n            out = fn(pred_mask, target_mask, 'binary', threshold)\n\n        outs.append(float(out))\n\n    return tuple(outs) if len(outs) > 1 else outs[0]\n"
  },
  {
    "path": "torch_geometric/explain/metric/faithfulness.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.explain import Explainer, Explanation\nfrom torch_geometric.explain.config import MaskType, ModelMode, ModelReturnType\n\n\ndef unfaithfulness(\n    explainer: Explainer,\n    explanation: Explanation,\n    top_k: Optional[int] = None,\n) -> float:\n    r\"\"\"Evaluates how faithful an :class:`~torch_geometric.explain.Explanation`\n    is to an underlying GNN predictor, as described in the\n    `\"Evaluating Explainability for Graph Neural Networks\"\n    <https://arxiv.org/abs/2208.09339>`_ paper.\n\n    In particular, the graph explanation unfaithfulness metric is defined as\n\n    .. math::\n        \\textrm{GEF}(y, \\hat{y}) = 1 - \\exp(- \\textrm{KL}(y || \\hat{y}))\n\n    where :math:`y` refers to the prediction probability vector obtained from\n    the original graph, and :math:`\\hat{y}` refers to the prediction\n    probability vector obtained from the masked subgraph.\n    Finally, the Kullback-Leibler (KL) divergence score quantifies the distance\n    between the two probability distributions.\n\n    Args:\n        explainer (Explainer): The explainer to evaluate.\n        explanation (Explanation): The explanation to evaluate.\n        top_k (int, optional): If set, will only keep the original values of\n            the top-:math:`k` node features identified by an explanation.\n            If set to :obj:`None`, will use :obj:`explanation.node_mask` as it\n            is for masking node features. (default: :obj:`None`)\n    \"\"\"\n    if explainer.model_config.mode == ModelMode.regression:\n        raise ValueError(\"Fidelity not defined for 'regression' models\")\n\n    if top_k is not None and explainer.node_mask_type == MaskType.object:\n        raise ValueError(\"Cannot apply top-k feature selection based on a \"\n                         \"node mask of type 'object'\")\n\n    node_mask = explanation.get('node_mask')\n    edge_mask = explanation.get('edge_mask')\n    x, edge_index = explanation.x, explanation.edge_index\n    kwargs = {key: explanation[key] for key in explanation._model_args}\n\n    y = explanation.get('prediction')\n    if y is None:  # == ExplanationType.phenomenon\n        y = explainer.get_prediction(x, edge_index, **kwargs)\n\n    if node_mask is not None and top_k is not None:\n        feat_importance = node_mask.sum(dim=0)\n        _, top_k_index = feat_importance.topk(top_k)\n        node_mask = torch.zeros_like(node_mask)\n        node_mask[:, top_k_index] = 1.0\n\n    y_hat = explainer.get_masked_prediction(x, edge_index, node_mask,\n                                            edge_mask, **kwargs)\n\n    if explanation.get('index') is not None:\n        y, y_hat = y[explanation.index], y_hat[explanation.index]\n\n    if explainer.model_config.return_type == ModelReturnType.raw:\n        y, y_hat = y.softmax(dim=-1), y_hat.softmax(dim=-1)\n    elif explainer.model_config.return_type == ModelReturnType.log_probs:\n        y, y_hat = y.exp(), y_hat.exp()\n\n    kl_div = F.kl_div(y.log(), y_hat, reduction='batchmean')\n    return 1 - float(torch.exp(-kl_div))\n"
  },
  {
    "path": "torch_geometric/explain/metric/fidelity.py",
    "content": "from typing import Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.explain import Explainer, Explanation\nfrom torch_geometric.explain.config import ExplanationType, ModelMode\n\n\ndef fidelity(\n    explainer: Explainer,\n    explanation: Explanation,\n) -> Tuple[float, float]:\n    r\"\"\"Evaluates the fidelity of an\n    :class:`~torch_geometric.explain.Explainer` given an\n    :class:`~torch_geometric.explain.Explanation`, as described in the\n    `\"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for\n    Graph Neural Networks\" <https://arxiv.org/abs/2206.09677>`_ paper.\n\n    Fidelity evaluates the contribution of the produced explanatory subgraph\n    to the initial prediction, either by giving only the subgraph to the model\n    (fidelity-) or by removing it from the entire graph (fidelity+).\n    The fidelity scores capture how good an explainable model reproduces the\n    natural phenomenon or the GNN model logic.\n\n    For **phenomenon** explanations, the fidelity scores are given by:\n\n    .. math::\n        \\textrm{fid}_{+} &= \\frac{1}{N} \\sum_{i = 1}^N\n        \\| \\mathbb{1}(\\hat{y}_i = y_i) -\n        \\mathbb{1}( \\hat{y}_i^{G_{C \\setminus S}} = y_i) \\|\n\n        \\textrm{fid}_{-} &= \\frac{1}{N} \\sum_{i = 1}^N\n        \\| \\mathbb{1}(\\hat{y}_i = y_i) -\n        \\mathbb{1}( \\hat{y}_i^{G_S} = y_i) \\|\n\n    For **model** explanations, the fidelity scores are given by:\n\n    .. math::\n        \\textrm{fid}_{+} &= 1 - \\frac{1}{N} \\sum_{i = 1}^N\n        \\mathbb{1}( \\hat{y}_i^{G_{C \\setminus S}} = \\hat{y}_i)\n\n        \\textrm{fid}_{-} &= 1 - \\frac{1}{N} \\sum_{i = 1}^N\n        \\mathbb{1}( \\hat{y}_i^{G_S} = \\hat{y}_i)\n\n    Args:\n        explainer (Explainer): The explainer to evaluate.\n        explanation (Explanation): The explanation to evaluate.\n    \"\"\"\n    if explainer.model_config.mode == ModelMode.regression:\n        raise ValueError(\"Fidelity not defined for 'regression' models\")\n\n    node_mask = explanation.get('node_mask')\n    edge_mask = explanation.get('edge_mask')\n    kwargs = {key: explanation[key] for key in explanation._model_args}\n\n    y = explanation.target\n    if explainer.explanation_type == ExplanationType.phenomenon:\n        y_hat = explainer.get_prediction(\n            explanation.x,\n            explanation.edge_index,\n            **kwargs,\n        )\n        y_hat = explainer.get_target(y_hat)\n\n    explain_y_hat = explainer.get_masked_prediction(\n        explanation.x,\n        explanation.edge_index,\n        node_mask,\n        edge_mask,\n        **kwargs,\n    )\n    explain_y_hat = explainer.get_target(explain_y_hat)\n\n    complement_y_hat = explainer.get_masked_prediction(\n        explanation.x,\n        explanation.edge_index,\n        1. - node_mask if node_mask is not None else None,\n        1. - edge_mask if edge_mask is not None else None,\n        **kwargs,\n    )\n    complement_y_hat = explainer.get_target(complement_y_hat)\n\n    if explanation.get('index') is not None:\n        y = y[explanation.index]\n        if explainer.explanation_type == ExplanationType.phenomenon:\n            y_hat = y_hat[explanation.index]\n        explain_y_hat = explain_y_hat[explanation.index]\n        complement_y_hat = complement_y_hat[explanation.index]\n\n    if explainer.explanation_type == ExplanationType.model:\n        pos_fidelity = 1. - (complement_y_hat == y).float().mean()\n        neg_fidelity = 1. - (explain_y_hat == y).float().mean()\n    else:\n        pos_fidelity = ((y_hat == y).float() -\n                        (complement_y_hat == y).float()).abs().mean()\n        neg_fidelity = ((y_hat == y).float() -\n                        (explain_y_hat == y).float()).abs().mean()\n\n    return float(pos_fidelity), float(neg_fidelity)\n\n\ndef characterization_score(\n    pos_fidelity: Tensor,\n    neg_fidelity: Tensor,\n    pos_weight: float = 0.5,\n    neg_weight: float = 0.5,\n) -> Tensor:\n    r\"\"\"Returns the componentwise characterization score as described in the\n    `\"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for\n    Graph Neural Networks\" <https://arxiv.org/abs/2206.09677>`_ paper.\n\n    ..  math::\n       \\textrm{charact} = \\frac{w_{+} + w_{-}}{\\frac{w_{+}}{\\textrm{fid}_{+}} +\n        \\frac{w_{-}}{1 - \\textrm{fid}_{-}}}\n\n    Args:\n        pos_fidelity (torch.Tensor): The positive fidelity\n            :math:`\\textrm{fid}_{+}`.\n        neg_fidelity (torch.Tensor): The negative fidelity\n            :math:`\\textrm{fid}_{-}`.\n        pos_weight (float, optional): The weight :math:`w_{+}` for\n            :math:`\\textrm{fid}_{+}`. (default: :obj:`0.5`)\n        neg_weight (float, optional): The weight :math:`w_{-}` for\n            :math:`\\textrm{fid}_{-}`. (default: :obj:`0.5`)\n    \"\"\"\n    if (pos_weight + neg_weight) != 1.0:\n        raise ValueError(f\"The weights need to sum up to 1 \"\n                         f\"(got {pos_weight} and {neg_weight})\")\n\n    denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity))\n    return 1. / denom\n\n\ndef fidelity_curve_auc(\n    pos_fidelity: Tensor,\n    neg_fidelity: Tensor,\n    x: Tensor,\n) -> Tensor:\n    r\"\"\"Returns the AUC for the fidelity curve as described in the\n    `\"GraphFramEx: Towards Systematic Evaluation of Explainability Methods for\n    Graph Neural Networks\" <https://arxiv.org/abs/2206.09677>`_ paper.\n\n    More precisely, returns the AUC of\n\n    .. math::\n        f(x) = \\frac{\\textrm{fid}_{+}}{1 - \\textrm{fid}_{-}}\n\n    Args:\n        pos_fidelity (torch.Tensor): The positive fidelity\n            :math:`\\textrm{fid}_{+}`.\n        neg_fidelity (torch.Tensor): The negative fidelity\n            :math:`\\textrm{fid}_{-}`.\n        x (torch.Tensor): Tensor containing the points on the :math:`x`-axis.\n            Needs to be sorted in ascending order.\n    \"\"\"\n    if torch.any(neg_fidelity == 1):\n        raise ValueError(\"There exists negative fidelity values containing 1, \"\n                         \"leading to a division by zero\")\n\n    y = pos_fidelity / (1. - neg_fidelity)\n    return auc(x, y)\n\n\ndef auc(x: Tensor, y: Tensor) -> Tensor:\n    if torch.any(x.diff() < 0):\n        raise ValueError(\"'x' must be given in ascending order\")\n    return torch.trapezoid(y, x)\n"
  },
  {
    "path": "torch_geometric/graphgym/__init__.py",
    "content": "from .contrib import *  # noqa\nfrom .models import *  # noqa\nfrom .utils import *  # noqa\nfrom .checkpoint import load_ckpt, save_ckpt, remove_ckpt, clean_ckpt\nfrom .cmd_args import parse_args\nfrom .config import (cfg, set_cfg, load_cfg, dump_cfg, set_run_dir,\n                     set_out_dir, get_fname)\nfrom .init import init_weights\nfrom .loader import create_loader\nfrom .logger import set_printing, create_logger\nfrom .loss import compute_loss\nfrom .model_builder import create_model\nfrom .optim import create_optimizer, create_scheduler\nfrom .train import train\nfrom .register import (register_base, register_act, register_node_encoder,\n                       register_edge_encoder, register_stage, register_head,\n                       register_layer, register_pooling, register_network,\n                       register_config, register_dataset, register_loader,\n                       register_optimizer, register_scheduler, register_loss,\n                       register_train, register_metric)\n\n__all__ = classes = [\n    'load_ckpt',\n    'save_ckpt',\n    'remove_ckpt',\n    'clean_ckpt',\n    'parse_args',\n    'cfg',\n    'set_cfg',\n    'load_cfg',\n    'dump_cfg',\n    'set_run_dir',\n    'set_out_dir',\n    'get_fname',\n    'init_weights',\n    'create_loader',\n    'set_printing',\n    'create_logger',\n    'compute_loss',\n    'create_model',\n    'create_optimizer',\n    'create_scheduler',\n    'train',\n    'register_base',\n    'register_act',\n    'register_node_encoder',\n    'register_edge_encoder',\n    'register_stage',\n    'register_head',\n    'register_layer',\n    'register_pooling',\n    'register_network',\n    'register_config',\n    'register_dataset',\n    'register_loader',\n    'register_optimizer',\n    'register_scheduler',\n    'register_loss',\n    'register_train',\n    'register_metric',\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/benchmark.py",
    "content": "# Do not change; required for benchmarking\n\nimport torch_geometric_benchmark.torchprof_local as torchprof  # noqa\nfrom pytorch_memlab import LineProfiler  # noqa\nfrom torch_geometric_benchmark.utils import count_parameters  # noqa\nfrom torch_geometric_benchmark.utils import get_gpu_memory_nvdia  # noqa\nfrom torch_geometric_benchmark.utils import get_memory_status  # noqa\nfrom torch_geometric_benchmark.utils import get_model_size  # noqa\n\nglobal_line_profiler = LineProfiler()\nglobal_line_profiler.enable()\n"
  },
  {
    "path": "torch_geometric/graphgym/checkpoint.py",
    "content": "import glob\nimport os\nimport os.path as osp\nfrom typing import Any, Dict, List, Optional, Union\n\nimport torch\n\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.io import fs\n\nMODEL_STATE = 'model_state'\nOPTIMIZER_STATE = 'optimizer_state'\nSCHEDULER_STATE = 'scheduler_state'\n\n\ndef load_ckpt(\n    model: torch.nn.Module,\n    optimizer: Optional[torch.optim.Optimizer] = None,\n    scheduler: Optional[Any] = None,\n    epoch: int = -1,\n) -> int:\n    r\"\"\"Loads the model checkpoint at a given epoch.\"\"\"\n    epoch = get_ckpt_epoch(epoch)\n    path = get_ckpt_path(epoch)\n\n    if not osp.exists(path):\n        return 0\n\n    ckpt = fs.torch_load(path)\n    model.load_state_dict(ckpt[MODEL_STATE])\n    if optimizer is not None and OPTIMIZER_STATE in ckpt:\n        optimizer.load_state_dict(ckpt[OPTIMIZER_STATE])\n    if scheduler is not None and SCHEDULER_STATE in ckpt:\n        scheduler.load_state_dict(ckpt[SCHEDULER_STATE])\n\n    return epoch + 1\n\n\ndef save_ckpt(\n    model: torch.nn.Module,\n    optimizer: Optional[torch.optim.Optimizer] = None,\n    scheduler: Optional[Any] = None,\n    epoch: int = 0,\n):\n    r\"\"\"Saves the model checkpoint at a given epoch.\"\"\"\n    ckpt: Dict[str, Any] = {}\n    ckpt[MODEL_STATE] = model.state_dict()\n    if optimizer is not None:\n        ckpt[OPTIMIZER_STATE] = optimizer.state_dict()\n    if scheduler is not None:\n        ckpt[SCHEDULER_STATE] = scheduler.state_dict()\n\n    os.makedirs(get_ckpt_dir(), exist_ok=True)\n    torch.save(ckpt, get_ckpt_path(get_ckpt_epoch(epoch)))\n\n\ndef remove_ckpt(epoch: int = -1):\n    r\"\"\"Removes the model checkpoint at a given epoch.\"\"\"\n    os.remove(get_ckpt_path(get_ckpt_epoch(epoch)))\n\n\ndef clean_ckpt():\n    r\"\"\"Removes all but the last model checkpoint.\"\"\"\n    for epoch in get_ckpt_epochs()[:-1]:\n        os.remove(get_ckpt_path(epoch))\n\n\n###############################################################################\n\n\ndef get_ckpt_dir() -> str:\n    return osp.join(cfg.run_dir, 'ckpt')\n\n\ndef get_ckpt_path(epoch: Union[int, str]) -> str:\n    return osp.join(get_ckpt_dir(), f'{epoch}.ckpt')\n\n\ndef get_ckpt_epochs() -> List[int]:\n    paths = glob.glob(get_ckpt_path('*'))\n    return sorted([int(osp.basename(path).split('.')[0]) for path in paths])\n\n\ndef get_ckpt_epoch(epoch: int) -> int:\n    if epoch < 0:\n        epochs = get_ckpt_epochs()\n        epoch = epochs[epoch] if len(epochs) > 0 else 0\n    return epoch\n"
  },
  {
    "path": "torch_geometric/graphgym/cmd_args.py",
    "content": "import argparse\n\n\ndef parse_args() -> argparse.Namespace:\n    r\"\"\"Parses the command line arguments.\"\"\"\n    parser = argparse.ArgumentParser(description='GraphGym')\n\n    parser.add_argument('--cfg', dest='cfg_file', type=str, required=True,\n                        help='The configuration file path.')\n    parser.add_argument('--repeat', type=int, default=1,\n                        help='The number of repeated jobs.')\n    parser.add_argument('--mark_done', action='store_true',\n                        help='Mark yaml as done after a job has finished.')\n    parser.add_argument('opts', default=None, nargs=argparse.REMAINDER,\n                        help='See graphgym/config.py for remaining options.')\n\n    return parser.parse_args()\n"
  },
  {
    "path": "torch_geometric/graphgym/config.py",
    "content": "import functools\nimport inspect\nimport logging\nimport os\nimport os.path as osp\nimport warnings\nfrom collections.abc import Iterable\nfrom dataclasses import asdict\nfrom typing import Any\n\nimport torch_geometric.graphgym.register as register\nfrom torch_geometric.io import fs\n\ntry:  # Define global config object\n    from yacs.config import CfgNode as CN\n    cfg = CN()\nexcept ImportError:\n    cfg = None\n    warnings.warn(\n        \"Could not define global config object. Please install \"\n        \"'yacs' via 'pip install yacs' in order to use GraphGym\", stacklevel=2)\n\n\ndef set_cfg(cfg):\n    r\"\"\"This function sets the default config value.\n\n    1) Note that for an experiment, only part of the arguments will be used\n       The remaining unused arguments won't affect anything.\n       So feel free to register any argument in graphgym.contrib.config\n    2) We support *at most* two levels of configs, *e.g.*,\n       :obj:`cfg.dataset.name`.\n\n    :return: Configuration use by the experiment.\n    \"\"\"\n    if cfg is None:\n        return cfg\n\n    # ----------------------------------------------------------------------- #\n    # Basic options\n    # ----------------------------------------------------------------------- #\n\n    # Set print destination: stdout / file / both\n    cfg.print = 'both'\n\n    # Select device: 'cpu', 'cuda', 'auto'\n    cfg.accelerator = 'auto'\n\n    # number of devices: eg. for 2 GPU set cfg.devices=2\n    cfg.devices = 1\n\n    # Output directory\n    cfg.out_dir = 'results'\n\n    # Config name (in out_dir)\n    cfg.cfg_dest = 'config.yaml'\n\n    # Names of registered custom metric funcs to be used (use defaults if none)\n    cfg.custom_metrics = []\n\n    # Random seed\n    cfg.seed = 0\n\n    # Print rounding\n    cfg.round = 4\n\n    # Tensorboard support for each run\n    cfg.tensorboard_each_run = False\n\n    # Tensorboard support for aggregated results\n    cfg.tensorboard_agg = True\n\n    # Additional num of worker for data loading\n    cfg.num_workers = 0\n\n    # Max threads used by PyTorch\n    cfg.num_threads = 6\n\n    # The metric for selecting the best epoch for each run\n    cfg.metric_best = 'auto'\n\n    # argmax or argmin in aggregating results\n    cfg.metric_agg = 'argmax'\n\n    # If visualize embedding.\n    cfg.view_emb = False\n\n    # If get GPU usage\n    cfg.gpu_mem = False\n\n    # If do benchmark analysis\n    cfg.benchmark = False\n\n    # ----------------------------------------------------------------------- #\n    # Globally shared variables:\n    # These variables will be set dynamically based on the input dataset\n    # Do not directly set them here or in .yaml files\n    # ----------------------------------------------------------------------- #\n\n    cfg.share = CN()\n\n    # Size of input dimension\n    cfg.share.dim_in = 1\n\n    # Size of out dimension, i.e., number of labels to be predicted\n    cfg.share.dim_out = 1\n\n    # Number of dataset splits: train/val/test\n    cfg.share.num_splits = 1\n\n    # ----------------------------------------------------------------------- #\n    # Dataset options\n    # ----------------------------------------------------------------------- #\n    cfg.dataset = CN()\n\n    # Name of the dataset\n    cfg.dataset.name = 'Cora'\n\n    # if PyG: look for it in Pytorch Geometric dataset\n    # if NetworkX/nx: load data in NetworkX format\n    cfg.dataset.format = 'PyG'\n\n    # Dir to load the dataset. If the dataset is downloaded, this is the\n    # cache dir\n    cfg.dataset.dir = './datasets'\n\n    # Task: node, edge, graph, link_pred\n    cfg.dataset.task = 'node'\n\n    # Type of task: classification, regression, classification_binary\n    # classification_multi\n    cfg.dataset.task_type = 'classification'\n\n    # Transductive / Inductive\n    # Graph classification is always inductive\n    cfg.dataset.transductive = True\n\n    # Split ratio of dataset. Len=2: Train, Val. Len=3: Train, Val, Test\n    cfg.dataset.split = [0.8, 0.1, 0.1]\n\n    # Whether to shuffle the graphs for splitting\n    cfg.dataset.shuffle_split = True\n\n    # Whether random split or use custom split: random / custom\n    cfg.dataset.split_mode = 'random'\n\n    # Whether to use an encoder for general attribute features\n    cfg.dataset.encoder = True\n\n    # Name of general encoder\n    cfg.dataset.encoder_name = 'db'\n\n    # If add batchnorm after general encoder\n    cfg.dataset.encoder_bn = True\n\n    # Whether to use an encoder for the node features\n    cfg.dataset.node_encoder = False\n\n    # Name of node encoder\n    cfg.dataset.node_encoder_name = 'Atom'\n\n    # If add batchnorm after node encoder\n    cfg.dataset.node_encoder_bn = True\n\n    # Whether to use an encoder for the edge features\n    cfg.dataset.edge_encoder = False\n\n    # Name of edge encoder\n    cfg.dataset.edge_encoder_name = 'Bond'\n\n    # If add batchnorm after edge encoder\n    cfg.dataset.edge_encoder_bn = True\n\n    # Dimension of the encoded features.\n    # For now the node and edge encoding dimensions\n    # are the same.\n    cfg.dataset.encoder_dim = 128\n\n    # Dimension for edge feature. Updated by the real dim of the dataset\n    cfg.dataset.edge_dim = 128\n\n    # ============== Link/edge tasks only\n\n    # all or disjoint\n    cfg.dataset.edge_train_mode = 'all'\n\n    # Used in disjoint edge_train_mode. The proportion of edges used for\n    # message-passing\n    cfg.dataset.edge_message_ratio = 0.8\n\n    # The ratio of negative samples to positive samples\n    cfg.dataset.edge_negative_sampling_ratio = 1.0\n\n    # Whether resample disjoint when dataset.edge_train_mode is 'disjoint'\n    cfg.dataset.resample_disjoint = False\n\n    # Whether resample negative edges at training time (link prediction only)\n    cfg.dataset.resample_negative = False\n\n    # What transformation function is applied to the dataset\n    cfg.dataset.transform = 'none'\n\n    # Whether cache the splitted dataset\n    # NOTE: it should be cautiouslly used, as cached dataset may not have\n    # exactly the same setting as the config file\n    cfg.dataset.cache_save = False\n    cfg.dataset.cache_load = False\n\n    # Whether remove the original node features in the dataset\n    cfg.dataset.remove_feature = False\n\n    # Simplify TU dataset for synthetic tasks\n    cfg.dataset.tu_simple = True\n\n    # Convert to undirected graph (save 2*E edges)\n    cfg.dataset.to_undirected = False\n\n    # dataset location: local, snowflake\n    cfg.dataset.location = 'local'\n\n    # Define label: Table name\n    cfg.dataset.label_table = 'none'\n\n    # Define label: Column name\n    cfg.dataset.label_column = 'none'\n\n    # ----------------------------------------------------------------------- #\n    # Training options\n    # ----------------------------------------------------------------------- #\n    cfg.train = CN()\n\n    # Total graph mini-batch size\n    cfg.train.batch_size = 16\n\n    # Sampling strategy for a train loader\n    cfg.train.sampler = 'full_batch'\n\n    # Minibatch node\n    cfg.train.sample_node = False\n\n    # Num of sampled node per graph\n    cfg.train.node_per_graph = 32\n\n    # Radius: same, extend. same: same as cfg.gnn.layers_mp, extend: layers+1\n    cfg.train.radius = 'extend'\n\n    # Evaluate model on test data every eval period epochs\n    cfg.train.eval_period = 10\n\n    # Option to skip training epoch evaluation\n    cfg.train.skip_train_eval = False\n\n    # Save model checkpoint every checkpoint period epochs\n    cfg.train.ckpt_period = 100\n\n    # Enabling checkpoint, set False to disable and save I/O\n    cfg.train.enable_ckpt = True\n\n    # Resume training from the latest checkpoint in the output directory\n    cfg.train.auto_resume = False\n\n    # The epoch to resume. -1 means resume the latest epoch.\n    cfg.train.epoch_resume = -1\n\n    # Clean checkpoint: only keep the last ckpt\n    cfg.train.ckpt_clean = True\n\n    # Number of iterations per epoch (for sampling based loaders only)\n    cfg.train.iter_per_epoch = 32\n\n    # GraphSAINTRandomWalkSampler: random walk length\n    cfg.train.walk_length = 4\n\n    # NeighborSampler: number of sampled nodes per layer\n    cfg.train.neighbor_sizes = [20, 15, 10, 5]\n\n    # ----------------------------------------------------------------------- #\n    # Validation options\n    # ----------------------------------------------------------------------- #\n    cfg.val = CN()\n\n    # Minibatch node\n    cfg.val.sample_node = False\n\n    # Sampling strategy for a val/test loader\n    cfg.val.sampler = 'full_batch'\n\n    # Num of sampled node per graph\n    cfg.val.node_per_graph = 32\n\n    # Radius: same, extend. same: same as cfg.gnn.layers_mp, extend: layers+1\n    cfg.val.radius = 'extend'\n\n    # ----------------------------------------------------------------------- #\n    # Model options\n    # ----------------------------------------------------------------------- #\n    cfg.model = CN()\n\n    # Model type to use\n    cfg.model.type = 'gnn'\n\n    # Auto match computational budget, match upper bound / lower bound\n    cfg.model.match_upper = True\n\n    # Loss function: cross_entropy, mse\n    cfg.model.loss_fun = 'cross_entropy'\n\n    # size average for loss function. 'mean' or 'sum'\n    cfg.model.size_average = 'mean'\n\n    # Threshold for binary classification\n    cfg.model.thresh = 0.5\n\n    # ============== Link/edge tasks only\n    # Edge decoding methods.\n    #   - dot: compute dot(u, v) to predict link (binary)\n    #   - cosine_similarity: use cosine similarity (u, v) to predict link (\n    #   binary)\n    #   - concat: use u||v followed by an nn.Linear to obtain edge embedding\n    #   (multi-class)\n    cfg.model.edge_decoding = 'dot'\n    # ===================================\n\n    # ================== Graph tasks only\n    # Pooling methods.\n    #   - add: global add pool\n    #   - mean: global mean pool\n    #   - max: global max pool\n    cfg.model.graph_pooling = 'add'\n    # ===================================\n\n    # ----------------------------------------------------------------------- #\n    # GNN options\n    # ----------------------------------------------------------------------- #\n    cfg.gnn = CN()\n\n    # Prediction head. Use cfg.dataset.task by default\n    cfg.gnn.head = 'default'\n\n    # Number of layers before message passing\n    cfg.gnn.layers_pre_mp = 0\n\n    # Number of layers for message passing\n    cfg.gnn.layers_mp = 2\n\n    # Number of layers after message passing\n    cfg.gnn.layers_post_mp = 0\n\n    # Hidden layer dim. Automatically set if train.auto_match = True\n    cfg.gnn.dim_inner = 16\n\n    # Type of graph conv: generalconv, gcnconv, sageconv, gatconv, ...\n    cfg.gnn.layer_type = 'generalconv'\n\n    # Stage type: 'stack', 'skipsum', 'skipconcat'\n    cfg.gnn.stage_type = 'stack'\n\n    # How many layers to skip each time\n    cfg.gnn.skip_every = 1\n\n    # Whether use batch norm\n    cfg.gnn.batchnorm = True\n\n    # Activation\n    cfg.gnn.act = 'relu'\n\n    # Dropout\n    cfg.gnn.dropout = 0.0\n\n    # Aggregation type: add, mean, max\n    # Note: only for certain layers that explicitly set aggregation type\n    # e.g., when cfg.gnn.layer_type = 'generalconv'\n    cfg.gnn.agg = 'add'\n\n    # Normalize adj\n    cfg.gnn.normalize_adj = False\n\n    # Message direction: single, both\n    cfg.gnn.msg_direction = 'single'\n\n    # Whether add message from node itself: none, add, cat\n    cfg.gnn.self_msg = 'concat'\n\n    # Number of attention heads\n    cfg.gnn.att_heads = 1\n\n    # After concat attention heads, add a linear layer\n    cfg.gnn.att_final_linear = False\n\n    # After concat attention heads, add a linear layer\n    cfg.gnn.att_final_linear_bn = False\n\n    # Normalize after message passing\n    cfg.gnn.l2norm = True\n\n    # randomly use fewer edges for message passing\n    cfg.gnn.keep_edge = 0.5\n\n    # clear cached feature_new\n    cfg.gnn.clear_feature = True\n\n    # ----------------------------------------------------------------------- #\n    # Optimizer options\n    # ----------------------------------------------------------------------- #\n    cfg.optim = CN()\n\n    # optimizer: sgd, adam\n    cfg.optim.optimizer = 'adam'\n\n    # Base learning rate\n    cfg.optim.base_lr = 0.01\n\n    # L2 regularization\n    cfg.optim.weight_decay = 5e-4\n\n    # SGD momentum\n    cfg.optim.momentum = 0.9\n\n    # scheduler: none, steps, cos\n    cfg.optim.scheduler = 'cos'\n\n    # Steps for 'steps' policy (in epochs)\n    cfg.optim.steps = [30, 60, 90]\n\n    # Learning rate multiplier for 'steps' policy\n    cfg.optim.lr_decay = 0.1\n\n    # Maximal number of epochs\n    cfg.optim.max_epoch = 200\n\n    # ----------------------------------------------------------------------- #\n    # Batch norm options\n    # ----------------------------------------------------------------------- #\n    cfg.bn = CN()\n\n    # BN epsilon\n    cfg.bn.eps = 1e-5\n\n    # BN momentum (BN momentum in PyTorch = 1 - BN momentum in Caffe2)\n    cfg.bn.mom = 0.1\n\n    # ----------------------------------------------------------------------- #\n    # Memory options\n    # ----------------------------------------------------------------------- #\n    cfg.mem = CN()\n\n    # Perform ReLU inplace\n    cfg.mem.inplace = False\n\n    # Set user customized cfgs\n    for func in register.config_dict.values():\n        func(cfg)\n\n\ndef assert_cfg(cfg):\n    r\"\"\"Checks config values, do necessary post processing to the configs.\"\"\"\n    if cfg.dataset.task not in ['node', 'edge', 'graph', 'link_pred']:\n        raise ValueError(f\"Task '{cfg.dataset.task}' not supported. Must be \"\n                         f\"one of node, edge, graph, link_pred\")\n    if 'classification' in cfg.dataset.task_type and cfg.model.loss_fun == \\\n            'mse':\n        cfg.model.loss_fun = 'cross_entropy'\n        logging.warning(\n            'model.loss_fun changed to cross_entropy for classification.')\n    if cfg.dataset.task_type == 'regression' and cfg.model.loss_fun == \\\n            'cross_entropy':\n        cfg.model.loss_fun = 'mse'\n        logging.warning('model.loss_fun changed to mse for regression.')\n    if cfg.dataset.task == 'graph' and cfg.dataset.transductive:\n        cfg.dataset.transductive = False\n        logging.warning('dataset.transductive changed '\n                        'to False for graph task.')\n    if cfg.gnn.layers_post_mp < 1:\n        cfg.gnn.layers_post_mp = 1\n        logging.warning('Layers after message passing should be >=1')\n    if cfg.gnn.head == 'default':\n        cfg.gnn.head = cfg.dataset.task\n    cfg.run_dir = cfg.out_dir\n\n\ndef dump_cfg(cfg):\n    r\"\"\"Dumps the config to the output directory specified in\n    :obj:`cfg.out_dir`.\n\n    Args:\n        cfg (CfgNode): Configuration node\n    \"\"\"\n    os.makedirs(cfg.out_dir, exist_ok=True)\n    cfg_file = osp.join(cfg.out_dir, cfg.cfg_dest)\n    with open(cfg_file, 'w') as f:\n        cfg.dump(stream=f)\n\n\ndef load_cfg(cfg, args):\n    r\"\"\"Load configurations from file system and command line.\n\n    Args:\n        cfg (CfgNode): Configuration node\n        args (ArgumentParser): Command argument parser\n    \"\"\"\n    cfg.merge_from_file(args.cfg_file)\n    cfg.merge_from_list(args.opts)\n    assert_cfg(cfg)\n\n\ndef makedirs_rm_exist(dir):\n    if osp.isdir(dir):\n        fs.rm(dir)\n    os.makedirs(dir, exist_ok=True)\n\n\ndef get_fname(fname):\n    r\"\"\"Extract filename from file name path.\n\n    Args:\n        fname (str): Filename for the yaml format configuration file\n    \"\"\"\n    fname = osp.basename(fname)\n    if fname.endswith('.yaml'):\n        fname = fname[:-5]\n    elif fname.endswith('.yml'):\n        fname = fname[:-4]\n    return fname\n\n\ndef set_out_dir(out_dir, fname):\n    r\"\"\"Create the directory for full experiment run.\n\n    Args:\n        out_dir (str): Directory for output, specified in :obj:`cfg.out_dir`\n        fname (str): Filename for the yaml format configuration file\n    \"\"\"\n    fname = get_fname(fname)\n    cfg.out_dir = osp.join(out_dir, fname)\n    # Make output directory\n    if cfg.train.auto_resume:\n        os.makedirs(cfg.out_dir, exist_ok=True)\n    else:\n        makedirs_rm_exist(cfg.out_dir)\n\n\ndef set_run_dir(out_dir):\n    r\"\"\"Create the directory for each random seed experiment run.\n\n    Args:\n        out_dir (str): Directory for output, specified in :obj:`cfg.out_dir`\n    \"\"\"\n    cfg.run_dir = osp.join(out_dir, str(cfg.seed))\n    # Make output directory\n    if cfg.train.auto_resume:\n        os.makedirs(cfg.run_dir, exist_ok=True)\n    else:\n        makedirs_rm_exist(cfg.run_dir)\n\n\nset_cfg(cfg)\n\n\ndef from_config(func):\n    if inspect.isclass(func):\n        params = list(inspect.signature(func.__init__).parameters.values())[1:]\n    else:\n        params = list(inspect.signature(func).parameters.values())\n\n    arg_names = [p.name for p in params]\n    has_defaults = [p.default != inspect.Parameter.empty for p in params]\n\n    @functools.wraps(func)\n    def wrapper(*args, cfg: Any = None, **kwargs):\n        if cfg is not None:\n            cfg = dict(cfg) if isinstance(cfg, Iterable) else asdict(cfg)\n\n            iterator = zip(arg_names[len(args):], has_defaults[len(args):])\n            for arg_name, has_default in iterator:\n                if arg_name in kwargs:\n                    continue\n                elif arg_name in cfg:\n                    kwargs[arg_name] = cfg[arg_name]\n                elif not has_default:\n                    raise ValueError(f\"'cfg.{arg_name}' undefined\")\n        return func(*args, **kwargs)\n\n    return wrapper\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/__init__.py",
    "content": "from .act import *  # noqa\nfrom .config import *  # noqa\nfrom .encoder import *  # noqa\nfrom .head import *  # noqa\nfrom .layer import *  # noqa\nfrom .loader import *  # noqa\nfrom .loss import *  # noqa\nfrom .network import *  # noqa\nfrom .optimizer import *  # noqa\nfrom .pooling import *  # noqa\nfrom .stage import *  # noqa\nfrom .train import *  # noqa\nfrom .transform import *  # noqa\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/act/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/config/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/encoder/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/head/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/layer/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/layer/generalconv.py",
    "content": "import torch\nfrom torch.nn import Parameter\n\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.inits import glorot, zeros\nfrom torch_geometric.utils import add_remaining_self_loops, scatter\n\n\nclass GeneralConvLayer(MessagePassing):\n    r\"\"\"A general GNN layer.\"\"\"\n    def __init__(self, in_channels, out_channels, improved=False, cached=False,\n                 bias=True, **kwargs):\n        super().__init__(aggr=cfg.gnn.agg, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.improved = improved\n        self.cached = cached\n        self.normalize = cfg.gnn.normalize_adj\n\n        self.weight = Parameter(torch.empty(in_channels, out_channels))\n        if cfg.gnn.self_msg == 'concat':\n            self.weight_self = Parameter(torch.empty(in_channels,\n                                                     out_channels))\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot(self.weight)\n        if cfg.gnn.self_msg == 'concat':\n            glorot(self.weight_self)\n        zeros(self.bias)\n        self.cached_result = None\n        self.cached_num_edges = None\n\n    @staticmethod\n    def norm(edge_index, num_nodes, edge_weight=None, improved=False,\n             dtype=None):\n        if edge_weight is None:\n            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,\n                                     device=edge_index.device)\n\n        fill_value = 1 if not improved else 2\n        edge_index, edge_weight = add_remaining_self_loops(\n            edge_index, edge_weight, fill_value, num_nodes)\n\n        row, col = edge_index\n        deg = scatter(edge_weight, row, 0, num_nodes, reduce='sum')\n        deg_inv_sqrt = deg.pow(-0.5)\n        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0\n\n        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]\n\n    def forward(self, x, edge_index, edge_weight=None, edge_feature=None):\n        if cfg.gnn.self_msg == 'concat':\n            x_self = torch.matmul(x, self.weight_self)\n        x = torch.matmul(x, self.weight)\n\n        if self.cached and self.cached_result is not None:\n            if edge_index.size(1) != self.cached_num_edges:\n                raise RuntimeError(\n                    'Cached {} number of edges, but found {}. Please '\n                    'disable the caching behavior of this layer by removing '\n                    'the `cached=True` argument in its constructor.'.format(\n                        self.cached_num_edges, edge_index.size(1)))\n\n        if not self.cached or self.cached_result is None:\n            self.cached_num_edges = edge_index.size(1)\n            if self.normalize:\n                edge_index, norm = self.norm(edge_index, x.size(self.node_dim),\n                                             edge_weight, self.improved,\n                                             x.dtype)\n            else:\n                norm = edge_weight\n            self.cached_result = edge_index, norm\n\n        edge_index, norm = self.cached_result\n        x_msg = self.propagate(edge_index, x=x, norm=norm,\n                               edge_feature=edge_feature)\n        if cfg.gnn.self_msg == 'none':\n            return x_msg\n        elif cfg.gnn.self_msg == 'add':\n            return x_msg + x\n        elif cfg.gnn.self_msg == 'concat':\n            return x_msg + x_self\n        else:\n            raise ValueError('self_msg {} not defined'.format(\n                cfg.gnn.self_msg))\n\n    def message(self, x_j, norm, edge_feature):\n        if edge_feature is None:\n            return norm.view(-1, 1) * x_j if norm is not None else x_j\n        else:\n            return norm.view(-1, 1) * (\n                x_j + edge_feature) if norm is not None else (x_j +\n                                                              edge_feature)\n\n    def update(self, aggr_out):\n        if self.bias is not None:\n            aggr_out = aggr_out + self.bias\n        return aggr_out\n\n    def __repr__(self):\n        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,\n                                   self.out_channels)\n\n\nclass GeneralEdgeConvLayer(MessagePassing):\n    r\"\"\"General GNN layer, with edge features.\"\"\"\n    def __init__(self, in_channels, out_channels, edge_dim, improved=False,\n                 cached=False, bias=True, **kwargs):\n        super().__init__(aggr=cfg.gnn.agg, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.improved = improved\n        self.cached = cached\n        self.normalize = cfg.gnn.normalize_adj\n        self.msg_direction = cfg.gnn.msg_direction\n\n        if self.msg_direction == 'single':\n            self.linear_msg = torch.nn.Linear(\n                in_channels + edge_dim,\n                out_channels,\n                bias=False,\n            )\n        else:\n            self.linear_msg = torch.nn.Linear(\n                in_channels * 2 + edge_dim,\n                out_channels,\n                bias=False,\n            )\n        if cfg.gnn.self_msg == 'concat':\n            self.linear_self = torch.nn.Linear(\n                in_channels,\n                out_channels,\n                bias=False,\n            )\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        zeros(self.bias)\n        self.cached_result = None\n        self.cached_num_edges = None\n\n    @staticmethod\n    def norm(edge_index, num_nodes, edge_weight=None, improved=False,\n             dtype=None):\n        if edge_weight is None:\n            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,\n                                     device=edge_index.device)\n\n        fill_value = 1 if not improved else 2\n        edge_index, edge_weight = add_remaining_self_loops(\n            edge_index, edge_weight, fill_value, num_nodes)\n\n        row, col = edge_index\n        deg = scatter(edge_weight, row, 0, num_nodes, reduce='sum')\n        deg_inv_sqrt = deg.pow(-0.5)\n        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0\n\n        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]\n\n    def forward(self, x, edge_index, edge_weight=None, edge_feature=None):\n        if self.cached and self.cached_result is not None:\n            if edge_index.size(1) != self.cached_num_edges:\n                raise RuntimeError(\n                    'Cached {} number of edges, but found {}. Please '\n                    'disable the caching behavior of this layer by removing '\n                    'the `cached=True` argument in its constructor.'.format(\n                        self.cached_num_edges, edge_index.size(1)))\n\n        if not self.cached or self.cached_result is None:\n            self.cached_num_edges = edge_index.size(1)\n            if self.normalize:\n                edge_index, norm = self.norm(edge_index, x.size(self.node_dim),\n                                             edge_weight, self.improved,\n                                             x.dtype)\n            else:\n                norm = edge_weight\n            self.cached_result = edge_index, norm\n\n        edge_index, norm = self.cached_result\n\n        x_msg = self.propagate(edge_index, x=x, norm=norm,\n                               edge_feature=edge_feature)\n\n        if cfg.gnn.self_msg == 'concat':\n            x_self = self.linear_self(x)\n            return x_self + x_msg\n        elif cfg.gnn.self_msg == 'add':\n            return x + x_msg\n        else:\n            return x_msg\n\n    def message(self, x_i, x_j, norm, edge_feature):\n        if self.msg_direction == 'both':\n            x_j = torch.cat((x_i, x_j, edge_feature), dim=-1)\n        else:\n            x_j = torch.cat((x_j, edge_feature), dim=-1)\n        x_j = self.linear_msg(x_j)\n        return norm.view(-1, 1) * x_j if norm is not None else x_j\n\n    def update(self, aggr_out):\n        if self.bias is not None:\n            aggr_out = aggr_out + self.bias\n        return aggr_out\n\n    def __repr__(self):\n        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,\n                                   self.out_channels)\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/loader/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/loss/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/network/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/optimizer/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/pooling/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/stage/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/train/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/contrib/transform/__init__.py",
    "content": "from os.path import dirname, basename, isfile, join\nimport glob\n\nmodules = glob.glob(join(dirname(__file__), \"*.py\"))\n__all__ = [\n    basename(f)[:-3] for f in modules\n    if isfile(f) and not f.endswith('__init__.py')\n]\n"
  },
  {
    "path": "torch_geometric/graphgym/imports.py",
    "content": "import warnings\n\nimport torch\n\ntry:\n    import lightning.pytorch as pl\n    _pl_is_available = True\nexcept ImportError:\n    try:\n        import pytorch_lightning as pl\n        _pl_is_available = True\n    except ImportError:\n        _pl_is_available = False\n\nif _pl_is_available:\n    LightningModule = pl.LightningModule\n    Callback = pl.Callback\nelse:\n    pl = object\n    LightningModule = torch.nn.Module\n    Callback = object\n\n    warnings.warn(\n        \"To use GraphGym, install 'pytorch_lightning' or 'lightning' via \"\n        \"'pip install pytorch_lightning' or 'pip install lightning'\",\n        stacklevel=2)\n"
  },
  {
    "path": "torch_geometric/graphgym/init.py",
    "content": "import torch\n\n\ndef init_weights(m):\n    r\"\"\"Performs weight initialization.\n\n    Args:\n        m (nn.Module): PyTorch module\n\n    \"\"\"\n    if (isinstance(m, torch.nn.BatchNorm2d)\n            or isinstance(m, torch.nn.BatchNorm1d)):\n        m.weight.data.fill_(1.0)\n        m.bias.data.zero_()\n    elif isinstance(m, torch.nn.Linear):\n        m.weight.data = torch.nn.init.xavier_uniform_(\n            m.weight.data, gain=torch.nn.init.calculate_gain('relu'))\n        if m.bias is not None:\n            m.bias.data.zero_()\n"
  },
  {
    "path": "torch_geometric/graphgym/loader.py",
    "content": "import os.path as osp\nfrom typing import Callable\n\nimport torch\n\nimport torch_geometric.graphgym.register as register\nimport torch_geometric.transforms as T\nfrom torch_geometric.datasets import (\n    PPI,\n    Amazon,\n    Coauthor,\n    KarateClub,\n    MNISTSuperpixels,\n    Planetoid,\n    QM7b,\n    TUDataset,\n)\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.models.transform import (\n    create_link_label,\n    neg_sampling_transform,\n)\nfrom torch_geometric.loader import (\n    ClusterLoader,\n    DataLoader,\n    GraphSAINTEdgeSampler,\n    GraphSAINTNodeSampler,\n    GraphSAINTRandomWalkSampler,\n    NeighborSampler,\n    RandomNodeLoader,\n)\nfrom torch_geometric.utils import (\n    index_to_mask,\n    negative_sampling,\n    to_undirected,\n)\n\nindex2mask = index_to_mask  # TODO Backward compatibility\n\n\ndef planetoid_dataset(name: str) -> Callable:\n    return lambda root: Planetoid(root, name)\n\n\nregister.register_dataset('Cora', planetoid_dataset('Cora'))\nregister.register_dataset('CiteSeer', planetoid_dataset('CiteSeer'))\nregister.register_dataset('PubMed', planetoid_dataset('PubMed'))\nregister.register_dataset('PPI', PPI)\n\n\ndef load_pyg(name, dataset_dir):\n    \"\"\"Load PyG dataset objects. (More PyG datasets will be supported).\n\n    Args:\n        name (str): dataset name\n        dataset_dir (str): data directory\n\n    Returns: PyG dataset object\n\n    \"\"\"\n    dataset_dir = osp.join(dataset_dir, name)\n    if name in ['Cora', 'CiteSeer', 'PubMed']:\n        dataset = Planetoid(dataset_dir, name)\n    elif name[:3] == 'TU_':\n        # TU_IMDB doesn't have node features\n        if name[3:] == 'IMDB':\n            name = 'IMDB-MULTI'\n            dataset = TUDataset(dataset_dir, name, transform=T.Constant())\n        else:\n            dataset = TUDataset(dataset_dir, name[3:])\n    elif name == 'Karate':\n        dataset = KarateClub()\n    elif 'Coauthor' in name:\n        if 'CS' in name:\n            dataset = Coauthor(dataset_dir, name='CS')\n        else:\n            dataset = Coauthor(dataset_dir, name='Physics')\n    elif 'Amazon' in name:\n        if 'Computers' in name:\n            dataset = Amazon(dataset_dir, name='Computers')\n        else:\n            dataset = Amazon(dataset_dir, name='Photo')\n    elif name == 'MNIST':\n        dataset = MNISTSuperpixels(dataset_dir)\n    elif name == 'PPI':\n        dataset = PPI(dataset_dir)\n    elif name == 'QM7b':\n        dataset = QM7b(dataset_dir)\n    else:\n        raise ValueError(f\"'{name}' not support\")\n\n    return dataset\n\n\ndef set_dataset_attr(dataset, name, value, size):\n    dataset._data_list = None\n    dataset.data[name] = value\n    if dataset.slices is not None:\n        dataset.slices[name] = torch.tensor([0, size], dtype=torch.long)\n\n\ndef load_ogb(name, dataset_dir):\n    r\"\"\"Load OGB dataset objects.\n\n    Args:\n        name (str): dataset name\n        dataset_dir (str): data directory\n\n    Returns: PyG dataset object\n\n    \"\"\"\n    from ogb.graphproppred import PygGraphPropPredDataset\n    from ogb.linkproppred import PygLinkPropPredDataset\n    from ogb.nodeproppred import PygNodePropPredDataset\n\n    if name[:4] == 'ogbn':\n        dataset = PygNodePropPredDataset(name=name, root=dataset_dir)\n        splits = dataset.get_idx_split()\n        split_names = ['train_mask', 'val_mask', 'test_mask']\n        for i, key in enumerate(splits.keys()):\n            mask = index_to_mask(splits[key], size=dataset._data.y.shape[0])\n            set_dataset_attr(dataset, split_names[i], mask, len(mask))\n        edge_index = to_undirected(dataset._data.edge_index)\n        set_dataset_attr(dataset, 'edge_index', edge_index,\n                         edge_index.shape[1])\n\n    elif name[:4] == 'ogbg':\n        dataset = PygGraphPropPredDataset(name=name, root=dataset_dir)\n        splits = dataset.get_idx_split()\n        split_names = [\n            'train_graph_index', 'val_graph_index', 'test_graph_index'\n        ]\n        for i, key in enumerate(splits.keys()):\n            id = splits[key]\n            set_dataset_attr(dataset, split_names[i], id, len(id))\n\n    elif name[:4] == \"ogbl\":\n        dataset = PygLinkPropPredDataset(name=name, root=dataset_dir)\n        splits = dataset.get_edge_split()\n        id = splits['train']['edge'].T\n        if cfg.dataset.resample_negative:\n            set_dataset_attr(dataset, 'train_pos_edge_index', id, id.shape[1])\n            dataset.transform = neg_sampling_transform\n        else:\n            id_neg = negative_sampling(edge_index=id,\n                                       num_nodes=dataset._data.num_nodes,\n                                       num_neg_samples=id.shape[1])\n            id_all = torch.cat([id, id_neg], dim=-1)\n            label = create_link_label(id, id_neg)\n            set_dataset_attr(dataset, 'train_edge_index', id_all,\n                             id_all.shape[1])\n            set_dataset_attr(dataset, 'train_edge_label', label, len(label))\n\n        id, id_neg = splits['valid']['edge'].T, splits['valid']['edge_neg'].T\n        id_all = torch.cat([id, id_neg], dim=-1)\n        label = create_link_label(id, id_neg)\n        set_dataset_attr(dataset, 'val_edge_index', id_all, id_all.shape[1])\n        set_dataset_attr(dataset, 'val_edge_label', label, len(label))\n\n        id, id_neg = splits['test']['edge'].T, splits['test']['edge_neg'].T\n        id_all = torch.cat([id, id_neg], dim=-1)\n        label = create_link_label(id, id_neg)\n        set_dataset_attr(dataset, 'test_edge_index', id_all, id_all.shape[1])\n        set_dataset_attr(dataset, 'test_edge_label', label, len(label))\n\n    else:\n        raise ValueError('OGB dataset: {} non-exist')\n    return dataset\n\n\ndef load_dataset():\n    r\"\"\"Load dataset objects.\n\n    Returns: PyG dataset object\n\n    \"\"\"\n    format = cfg.dataset.format\n    name = cfg.dataset.name\n    dataset_dir = cfg.dataset.dir\n    # Try to load customized data format\n    for func in register.loader_dict.values():\n        dataset = func(format, name, dataset_dir)\n        if dataset is not None:\n            return dataset\n    # Load from Pytorch Geometric dataset\n    if format == 'PyG':\n        dataset = load_pyg(name, dataset_dir)\n    # Load from OGB formatted data\n    elif format == 'OGB':\n        dataset = load_ogb(name.replace('_', '-'), dataset_dir)\n    else:\n        raise ValueError(f\"Unknown data format '{format}'\")\n    return dataset\n\n\ndef set_dataset_info(dataset):\n    r\"\"\"Set global dataset information.\n\n    Args:\n        dataset: PyG dataset object\n\n    \"\"\"\n    # get dim_in and dim_out\n    try:\n        cfg.share.dim_in = dataset._data.x.shape[1]\n    except Exception:\n        cfg.share.dim_in = 1\n    try:\n        if cfg.dataset.task_type == 'classification':\n            cfg.share.dim_out = torch.unique(dataset._data.y).shape[0]\n        else:\n            cfg.share.dim_out = dataset._data.y.shape[1]\n    except Exception:\n        cfg.share.dim_out = 1\n\n    # count number of dataset splits\n    cfg.share.num_splits = 1\n    for key in dataset._data.keys():\n        if 'val' in key:\n            cfg.share.num_splits += 1\n            break\n    for key in dataset._data.keys():\n        if 'test' in key:\n            cfg.share.num_splits += 1\n            break\n\n\ndef create_dataset():\n    r\"\"\"Create dataset object.\n\n    Returns: PyG dataset object\n\n    \"\"\"\n    dataset = load_dataset()\n    set_dataset_info(dataset)\n\n    return dataset\n\n\ndef get_loader(dataset, sampler, batch_size, shuffle=True):\n    pw = cfg.num_workers > 0\n    if sampler == \"full_batch\" or len(dataset) > 1:\n        loader_train = DataLoader(dataset, batch_size=batch_size,\n                                  shuffle=shuffle, num_workers=cfg.num_workers,\n                                  pin_memory=True, persistent_workers=pw)\n    elif sampler == \"neighbor\":\n        loader_train = NeighborSampler(\n            dataset[0], sizes=cfg.train.neighbor_sizes[:cfg.gnn.layers_mp],\n            batch_size=batch_size, shuffle=shuffle,\n            num_workers=cfg.num_workers, pin_memory=True)\n    elif sampler == \"random_node\":\n        loader_train = RandomNodeLoader(dataset[0],\n                                        num_parts=cfg.train.train_parts,\n                                        shuffle=shuffle,\n                                        num_workers=cfg.num_workers,\n                                        pin_memory=True, persistent_workers=pw)\n    elif sampler == \"saint_rw\":\n        loader_train = \\\n            GraphSAINTRandomWalkSampler(dataset[0],\n                                        batch_size=batch_size,\n                                        walk_length=cfg.train.walk_length,\n                                        num_steps=cfg.train.iter_per_epoch,\n                                        sample_coverage=0,\n                                        shuffle=shuffle,\n                                        num_workers=cfg.num_workers,\n                                        pin_memory=True,\n                                        persistent_workers=pw)\n    elif sampler == \"saint_node\":\n        loader_train = \\\n            GraphSAINTNodeSampler(dataset[0], batch_size=batch_size,\n                                  num_steps=cfg.train.iter_per_epoch,\n                                  sample_coverage=0, shuffle=shuffle,\n                                  num_workers=cfg.num_workers,\n                                  pin_memory=True,\n                                  persistent_workers=pw)\n    elif sampler == \"saint_edge\":\n        loader_train = \\\n            GraphSAINTEdgeSampler(dataset[0], batch_size=batch_size,\n                                  num_steps=cfg.train.iter_per_epoch,\n                                  sample_coverage=0, shuffle=shuffle,\n                                  num_workers=cfg.num_workers,\n                                  pin_memory=True,\n                                  persistent_workers=pw)\n    elif sampler == \"cluster\":\n        loader_train = ClusterLoader(\n            dataset[0],\n            num_parts=cfg.train.train_parts,\n            save_dir=osp.join(\n                cfg.dataset.dir,\n                cfg.dataset.name.replace(\"-\", \"_\"),\n            ),\n            batch_size=batch_size,\n            shuffle=shuffle,\n            num_workers=cfg.num_workers,\n            pin_memory=True,\n            persistent_workers=pw,\n        )\n\n    else:\n        raise NotImplementedError(f\"'{sampler}' is not implemented\")\n\n    return loader_train\n\n\ndef create_loader():\n    \"\"\"Create data loader object.\n\n    Returns: List of PyTorch data loaders\n\n    \"\"\"\n    dataset = create_dataset()\n    # train loader\n    if cfg.dataset.task == 'graph':\n        id = dataset.data['train_graph_index']\n        loaders = [\n            get_loader(dataset[id], cfg.train.sampler, cfg.train.batch_size,\n                       shuffle=True)\n        ]\n        delattr(dataset.data, 'train_graph_index')\n    else:\n        loaders = [\n            get_loader(dataset, cfg.train.sampler, cfg.train.batch_size,\n                       shuffle=True)\n        ]\n\n    # val and test loaders\n    for i in range(cfg.share.num_splits - 1):\n        if cfg.dataset.task == 'graph':\n            split_names = ['val_graph_index', 'test_graph_index']\n            id = dataset.data[split_names[i]]\n            loaders.append(\n                get_loader(dataset[id], cfg.val.sampler, cfg.train.batch_size,\n                           shuffle=False))\n            delattr(dataset.data, split_names[i])\n        else:\n            loaders.append(\n                get_loader(dataset, cfg.val.sampler, cfg.train.batch_size,\n                           shuffle=False))\n\n    return loaders\n"
  },
  {
    "path": "torch_geometric/graphgym/logger.py",
    "content": "import logging\nimport math\nimport os\nimport sys\nimport time\nfrom typing import Any, Dict, Optional\n\nimport torch\n\nfrom torch_geometric.graphgym import register\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.imports import Callback, pl\nfrom torch_geometric.graphgym.utils.device import get_current_gpu_usage\nfrom torch_geometric.graphgym.utils.io import dict_to_json, dict_to_tb\n\n\ndef set_printing():\n    \"\"\"Set up printing options.\"\"\"\n    logging.root.handlers = []\n    logging_cfg = {'level': logging.INFO, 'format': '%(message)s'}\n    os.makedirs(cfg.run_dir, exist_ok=True)\n    h_file = logging.FileHandler(f'{cfg.run_dir}/logging.log')\n    h_stdout = logging.StreamHandler(sys.stdout)\n    if cfg.print == 'file':\n        logging_cfg['handlers'] = [h_file]\n    elif cfg.print == 'stdout':\n        logging_cfg['handlers'] = [h_stdout]\n    elif cfg.print == 'both':\n        logging_cfg['handlers'] = [h_file, h_stdout]\n    else:\n        raise ValueError('Print option not supported')\n    logging.basicConfig(**logging_cfg)\n\n\nclass Logger:\n    def __init__(self, name='train', task_type=None):\n        self.name = name\n        self.task_type = task_type\n\n        self._epoch_total = cfg.optim.max_epoch\n        self._time_total = 0  # won't be reset\n\n        self.out_dir = f'{cfg.run_dir}/{name}'\n        os.makedirs(self.out_dir, exist_ok=True)\n        if cfg.tensorboard_each_run:\n            from tensorboardX import SummaryWriter\n            self.tb_writer = SummaryWriter(self.out_dir)\n\n        self.reset()\n\n    def __getitem__(self, key):\n        return getattr(self, key, None)\n\n    def __setitem__(self, key, value):\n        setattr(self, key, value)\n\n    def reset(self):\n        self._iter = 0\n        self._size_current = 0\n        self._loss = 0\n        self._lr = 0\n        self._params = 0\n        self._time_used = 0\n        self._true = []\n        self._pred = []\n        self._custom_stats = {}\n\n    # basic properties\n    def basic(self):\n        stats = {\n            'loss': round(self._loss / self._size_current, cfg.round),\n            'lr': round(self._lr, cfg.round),\n            'params': self._params,\n            'time_iter': round(self.time_iter(), cfg.round),\n        }\n        gpu_memory = get_current_gpu_usage()\n        if gpu_memory > 0:\n            stats['gpu_memory'] = gpu_memory\n        return stats\n\n    # customized input properties\n    def custom(self):\n        if len(self._custom_stats) == 0:\n            return {}\n        out = {}\n        for key, val in self._custom_stats.items():\n            out[key] = val / self._size_current\n        return out\n\n    def _get_pred_int(self, pred_score):\n        if len(pred_score.shape) == 1 or pred_score.shape[1] == 1:\n            return (pred_score > cfg.model.thresh).long()\n        else:\n            return pred_score.max(dim=1)[1]\n\n    # task properties\n    def classification_binary(self):\n        from sklearn.metrics import (\n            accuracy_score,\n            f1_score,\n            precision_score,\n            recall_score,\n            roc_auc_score,\n        )\n\n        true, pred_score = torch.cat(self._true), torch.cat(self._pred)\n        pred_int = self._get_pred_int(pred_score)\n        try:\n            r_a_score = roc_auc_score(true, pred_score)\n        except ValueError:\n            r_a_score = 0.0\n        return {\n            'accuracy': round(accuracy_score(true, pred_int), cfg.round),\n            'precision': round(precision_score(true, pred_int), cfg.round),\n            'recall': round(recall_score(true, pred_int), cfg.round),\n            'f1': round(f1_score(true, pred_int), cfg.round),\n            'auc': round(r_a_score, cfg.round),\n        }\n\n    def classification_multi(self):\n        from sklearn.metrics import accuracy_score\n\n        true, pred_score = torch.cat(self._true), torch.cat(self._pred)\n        pred_int = self._get_pred_int(pred_score)\n        return {'accuracy': round(accuracy_score(true, pred_int), cfg.round)}\n\n    def regression(self):\n        from sklearn.metrics import mean_absolute_error, mean_squared_error\n\n        true, pred = torch.cat(self._true), torch.cat(self._pred)\n        return {\n            'mae':\n            float(round(mean_absolute_error(true, pred), cfg.round)),\n            'mse':\n            float(round(mean_squared_error(true, pred), cfg.round)),\n            'rmse':\n            float(round(math.sqrt(mean_squared_error(true, pred)), cfg.round))\n        }\n\n    def time_iter(self):\n        return self._time_used / self._iter\n\n    def eta(self, epoch_current):\n        epoch_current += 1  # since counter starts from 0\n        time_per_epoch = self._time_total / epoch_current\n        return time_per_epoch * (self._epoch_total - epoch_current)\n\n    def update_stats(self, true, pred, loss, lr, time_used, params, **kwargs):\n        assert true.shape[0] == pred.shape[0]\n        self._iter += 1\n        self._true.append(true)\n        self._pred.append(pred)\n        batch_size = true.shape[0]\n        self._size_current += batch_size\n        self._loss += loss * batch_size\n        self._lr = lr\n        self._params = params\n        self._time_used += time_used\n        self._time_total += time_used\n        for key, val in kwargs.items():\n            if key not in self._custom_stats:\n                self._custom_stats[key] = val * batch_size\n            else:\n                self._custom_stats[key] += val * batch_size\n\n    def write_iter(self):\n        raise NotImplementedError\n\n    def write_epoch(self, cur_epoch):\n        basic_stats = self.basic()\n\n        # Try to load customized metrics\n        task_stats = {}\n        for custom_metric in cfg.custom_metrics:\n            func = register.metric_dict.get(custom_metric)\n            if not func:\n                raise ValueError(\n                    f'Unknown custom metric function name: {custom_metric}')\n            custom_metric_score = func(self._true, self._pred, self.task_type)\n            task_stats[custom_metric] = custom_metric_score\n\n        if not task_stats:  # use default metrics if no matching custom metric\n            if self.task_type == 'regression':\n                task_stats = self.regression()\n            elif self.task_type == 'classification_binary':\n                task_stats = self.classification_binary()\n            elif self.task_type == 'classification_multi':\n                task_stats = self.classification_multi()\n            else:\n                raise ValueError('Task has to be regression or classification')\n\n        epoch_stats = {'epoch': cur_epoch}\n        eta_stats = {'eta': round(self.eta(cur_epoch), cfg.round)}\n        custom_stats = self.custom()\n\n        if self.name == 'train':\n            stats = {\n                **epoch_stats,\n                **eta_stats,\n                **basic_stats,\n                **task_stats,\n                **custom_stats\n            }\n        else:\n            stats = {\n                **epoch_stats,\n                **basic_stats,\n                **task_stats,\n                **custom_stats\n            }\n\n        # print\n        logging.info(f'{self.name}: {stats}')\n        # json\n        dict_to_json(stats, f'{self.out_dir}/stats.json')\n        # tensorboard\n        if cfg.tensorboard_each_run:\n            dict_to_tb(stats, self.tb_writer, cur_epoch)\n        self.reset()\n\n    def close(self):\n        if cfg.tensorboard_each_run:\n            self.tb_writer.close()\n\n\ndef infer_task():\n    num_label = cfg.share.dim_out\n    if cfg.dataset.task_type == 'classification':\n        if num_label <= 2:\n            task_type = 'classification_binary'\n        else:\n            task_type = 'classification_multi'\n    else:\n        task_type = cfg.dataset.task_type\n    return task_type\n\n\ndef create_logger():\n    r\"\"\"Create logger for the experiment.\"\"\"\n    loggers = []\n    names = ['train', 'val', 'test']\n    for i, _ in enumerate(range(cfg.share.num_splits)):\n        loggers.append(Logger(name=names[i], task_type=infer_task()))\n    return loggers\n\n\nclass LoggerCallback(Callback):\n    def __init__(self):\n        self._logger = create_logger()\n        self._train_epoch_start_time = None\n        self._val_epoch_start_time = None\n        self._test_epoch_start_time = None\n\n    @property\n    def train_logger(self) -> Any:\n        return self._logger[0]\n\n    @property\n    def val_logger(self) -> Any:\n        return self._logger[1]\n\n    @property\n    def test_logger(self) -> Any:\n        return self._logger[2]\n\n    def close(self):\n        for logger in self._logger:\n            logger.close()\n\n    def _get_stats(\n        self,\n        epoch_start_time: int,\n        outputs: Dict[str, Any],\n        trainer: 'pl.Trainer',\n    ) -> Dict:\n        return dict(\n            true=outputs['true'].detach().cpu(),\n            pred=outputs['pred_score'].detach().cpu(),\n            loss=float(outputs['loss']),\n            lr=trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0],\n            time_used=time.time() - epoch_start_time,\n            params=cfg.params,\n        )\n\n    def on_train_epoch_start(\n        self,\n        trainer: 'pl.Trainer',\n        pl_module: 'pl.LightningModule',\n    ):\n        self._train_epoch_start_time = time.time()\n\n    def on_validation_epoch_start(\n        self,\n        trainer: 'pl.Trainer',\n        pl_module: 'pl.LightningModule',\n    ):\n        self._val_epoch_start_time = time.time()\n\n    def on_test_epoch_start(\n        self,\n        trainer: 'pl.Trainer',\n        pl_module: 'pl.LightningModule',\n    ):\n        self._test_epoch_start_time = time.time()\n\n    def on_train_batch_end(\n        self,\n        trainer: 'pl.Trainer',\n        pl_module: 'pl.LightningModule',\n        outputs: Dict[str, Any],\n        batch: Any,\n        batch_idx: int,\n        unused: int = 0,\n    ):\n        stats = self._get_stats(self._train_epoch_start_time, outputs, trainer)\n        self.train_logger.update_stats(**stats)\n\n    def on_validation_batch_end(\n        self,\n        trainer: 'pl.Trainer',\n        pl_module: 'pl.LightningModule',\n        outputs: Optional[Dict[str, Any]],\n        batch: Any,\n        batch_idx: int,\n        dataloader_idx: int = 0,\n    ):\n        stats = self._get_stats(self._val_epoch_start_time, outputs, trainer)\n        self.val_logger.update_stats(**stats)\n\n    def on_test_batch_end(\n        self,\n        trainer: 'pl.Trainer',\n        pl_module: 'pl.LightningModule',\n        outputs: Optional[Dict[str, Any]],\n        batch: Any,\n        batch_idx: int,\n        dataloader_idx: int = 0,\n    ):\n        stats = self._get_stats(self._test_epoch_start_time, outputs, trainer)\n        self.test_logger.update_stats(**stats)\n\n    def on_train_epoch_end(\n        self,\n        trainer: 'pl.Trainer',\n        pl_module: 'pl.LightningModule',\n    ):\n        self.train_logger.write_epoch(trainer.current_epoch)\n\n    def on_validation_epoch_end(\n        self,\n        trainer: 'pl.Trainer',\n        pl_module: 'pl.LightningModule',\n    ):\n        self.val_logger.write_epoch(trainer.current_epoch)\n\n    def on_test_epoch_end(\n        self,\n        trainer: 'pl.Trainer',\n        pl_module: 'pl.LightningModule',\n    ):\n        self.test_logger.write_epoch(trainer.current_epoch)\n\n    def on_fit_end(self, trainer, pl_module):\n        self.close()\n"
  },
  {
    "path": "torch_geometric/graphgym/loss.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nimport torch_geometric.graphgym.register as register\nfrom torch_geometric.graphgym.config import cfg\n\n\ndef compute_loss(pred, true):\n    \"\"\"Compute loss and prediction score.\n\n    Args:\n        pred (torch.tensor): Unnormalized prediction\n        true (torch.tensor): Ground truth labels\n\n    Returns: Loss, normalized prediction score\n\n    \"\"\"\n    bce_loss = torch.nn.BCEWithLogitsLoss(reduction=cfg.model.size_average)\n    mse_loss = torch.nn.MSELoss(reduction=cfg.model.size_average)\n\n    # default manipulation for pred and true\n    # can be skipped if special loss computation is needed\n    pred = pred.squeeze(-1) if pred.ndim > 1 else pred\n    true = true.squeeze(-1) if true.ndim > 1 else true\n\n    # Try to load customized loss\n    for func in register.loss_dict.values():\n        value = func(pred, true)\n        if value is not None:\n            return value\n\n    if cfg.model.loss_fun == 'cross_entropy':\n        # multiclass\n        if pred.ndim > 1 and true.ndim == 1:\n            pred = F.log_softmax(pred, dim=-1)\n            return F.nll_loss(pred, true), pred\n        # binary or multilabel\n        else:\n            true = true.float()\n            return bce_loss(pred, true), torch.sigmoid(pred)\n    elif cfg.model.loss_fun == 'mse':\n        true = true.float()\n        return mse_loss(pred, true), pred\n    else:\n        raise ValueError(f\"Loss function '{cfg.model.loss_fun}' not supported\")\n"
  },
  {
    "path": "torch_geometric/graphgym/model_builder.py",
    "content": "import time\nfrom typing import Any, Dict, Tuple\n\nimport torch\n\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.imports import LightningModule\nfrom torch_geometric.graphgym.loss import compute_loss\nfrom torch_geometric.graphgym.models.gnn import GNN\nfrom torch_geometric.graphgym.optim import create_optimizer, create_scheduler\nfrom torch_geometric.graphgym.register import network_dict, register_network\n\nregister_network('gnn', GNN)\n\n\nclass GraphGymModule(LightningModule):\n    def __init__(self, dim_in, dim_out, cfg):\n        super().__init__()\n        self.cfg = cfg\n        self.model = network_dict[cfg.model.type](dim_in=dim_in,\n                                                  dim_out=dim_out)\n\n    def forward(self, *args, **kwargs):\n        return self.model(*args, **kwargs)\n\n    def configure_optimizers(self) -> Tuple[Any, Any]:\n        optimizer = create_optimizer(self.model.parameters(), self.cfg.optim)\n        scheduler = create_scheduler(optimizer, self.cfg.optim)\n        return [optimizer], [scheduler]\n\n    def _shared_step(self, batch, split: str) -> Dict:\n        batch.split = split\n        pred, true = self(batch)\n        loss, pred_score = compute_loss(pred, true)\n        step_end_time = time.time()\n        return dict(loss=loss, true=true, pred_score=pred_score.detach(),\n                    step_end_time=step_end_time)\n\n    def training_step(self, batch, *args, **kwargs):\n        return self._shared_step(batch, split=\"train\")\n\n    def validation_step(self, batch, *args, **kwargs):\n        return self._shared_step(batch, split=\"val\")\n\n    def test_step(self, batch, *args, **kwargs):\n        return self._shared_step(batch, split=\"test\")\n\n    @property\n    def encoder(self) -> torch.nn.Module:\n        return self.model.encoder\n\n    @property\n    def mp(self) -> torch.nn.Module:\n        return self.model.mp\n\n    @property\n    def post_mp(self) -> torch.nn.Module:\n        return self.model.post_mp\n\n    @property\n    def pre_mp(self) -> torch.nn.Module:\n        return self.model.pre_mp\n\n    def lr_scheduler_step(self, *args, **kwargs):\n        # Needed for PyTorch 2.0 since the base class of LR schedulers changed.\n        # TODO Remove once we only want to support PyTorch Lightning >= 2.0.\n        return super().lr_scheduler_step(*args, **kwargs)\n\n\ndef create_model(to_device=True, dim_in=None, dim_out=None) -> GraphGymModule:\n    r\"\"\"Create model for graph machine learning.\n\n    Args:\n        to_device (bool, optional): Whether to transfer the model to the\n            specified device. (default: :obj:`True`)\n        dim_in (int, optional): Input dimension to the model\n        dim_out (int, optional): Output dimension to the model\n    \"\"\"\n    dim_in = cfg.share.dim_in if dim_in is None else dim_in\n    dim_out = cfg.share.dim_out if dim_out is None else dim_out\n    # binary classification, output dim = 1\n    if 'classification' == cfg.dataset.task_type and dim_out == 2:\n        dim_out = 1\n\n    model = GraphGymModule(dim_in, dim_out, cfg)\n    if to_device:\n        model.to(torch.device(cfg.accelerator))\n    return model\n"
  },
  {
    "path": "torch_geometric/graphgym/models/__init__.py",
    "content": "from .encoder import (IntegerFeatureEncoder, AtomEncoder, BondEncoder)\nfrom .gnn import (GNNLayer, GNNPreMP, GNNStackStage, FeatureEncoder, GNN)\nfrom .head import (GNNNodeHead, GNNEdgeHead, GNNGraphHead)\nfrom .layer import (GeneralLayer, GeneralMultiLayer, Linear, BatchNorm1dNode,\n                    BatchNorm1dEdge, MLP, GCNConv, SAGEConv, GATConv, GINConv,\n                    SplineConv, GeneralConv, GeneralEdgeConv,\n                    GeneralSampleEdgeConv)\nfrom .pooling import (global_add_pool, global_mean_pool, global_max_pool)\n\n__all__ = [\n    'IntegerFeatureEncoder',\n    'AtomEncoder',\n    'BondEncoder',\n    'GNNLayer',\n    'GNNPreMP',\n    'GNNStackStage',\n    'FeatureEncoder',\n    'GNN',\n    'GNNNodeHead',\n    'GNNEdgeHead',\n    'GNNGraphHead',\n    'GeneralLayer',\n    'GeneralMultiLayer',\n    'Linear',\n    'BatchNorm1dNode',\n    'BatchNorm1dEdge',\n    'MLP',\n    'GCNConv',\n    'SAGEConv',\n    'GATConv',\n    'GINConv',\n    'SplineConv',\n    'GeneralConv',\n    'GeneralEdgeConv',\n    'GeneralSampleEdgeConv',\n    'global_add_pool',\n    'global_mean_pool',\n    'global_max_pool',\n]\n\nclasses = __all__\n"
  },
  {
    "path": "torch_geometric/graphgym/models/act.py",
    "content": "import torch\n\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.register import register_act\n\n\ndef relu():\n    return torch.nn.ReLU(inplace=cfg.mem.inplace)\n\n\ndef selu():\n    return torch.nn.SELU(inplace=cfg.mem.inplace)\n\n\ndef prelu():\n    return torch.nn.PReLU()\n\n\ndef elu():\n    return torch.nn.ELU(inplace=cfg.mem.inplace)\n\n\ndef lrelu_01():\n    return torch.nn.LeakyReLU(0.1, inplace=cfg.mem.inplace)\n\n\ndef lrelu_025():\n    return torch.nn.LeakyReLU(0.25, inplace=cfg.mem.inplace)\n\n\ndef lrelu_05():\n    return torch.nn.LeakyReLU(0.5, inplace=cfg.mem.inplace)\n\n\nif cfg is not None:\n    register_act('relu', relu)\n    register_act('selu', selu)\n    register_act('prelu', prelu)\n    register_act('elu', elu)\n    register_act('lrelu_01', lrelu_01)\n    register_act('lrelu_025', lrelu_025)\n    register_act('lrelu_05', lrelu_05)\n"
  },
  {
    "path": "torch_geometric/graphgym/models/encoder.py",
    "content": "import torch\n\nfrom torch_geometric.graphgym.register import (\n    register_edge_encoder,\n    register_node_encoder,\n)\n\n\n@register_node_encoder('Integer')\nclass IntegerFeatureEncoder(torch.nn.Module):\n    r\"\"\"Provides an encoder for integer node features.\n\n    Args:\n        emb_dim (int): The output embedding dimension.\n        num_classes (int): The number of classes/integers.\n\n    Example:\n        >>> encoder = IntegerFeatureEncoder(emb_dim=16, num_classes=10)\n        >>> batch = torch.randint(0, 10, (10, 2))\n        >>> encoder(batch).size()\n        torch.Size([10, 16])\n    \"\"\"\n    def __init__(self, emb_dim: int, num_classes: int):\n        super().__init__()\n\n        self.encoder = torch.nn.Embedding(num_classes, emb_dim)\n        torch.nn.init.xavier_uniform_(self.encoder.weight.data)\n\n    def forward(self, batch):\n        # Encode just the first dimension if more exist\n        batch.x = self.encoder(batch.x[:, 0])\n\n        return batch\n\n\n@register_node_encoder('Atom')\nclass AtomEncoder(torch.nn.Module):\n    r\"\"\"The atom encoder used in OGB molecule dataset.\n\n    Args:\n        emb_dim (int): The output embedding dimension.\n\n    Example:\n        >>> encoder = AtomEncoder(emb_dim=16)\n        >>> batch = torch.randint(0, 10, (10, 3))\n        >>> encoder(batch).size()\n        torch.Size([10, 16])\n    \"\"\"\n    def __init__(self, emb_dim, *args, **kwargs):\n        super().__init__()\n\n        from ogb.utils.features import get_atom_feature_dims\n\n        self.atom_embedding_list = torch.nn.ModuleList()\n\n        for dim in get_atom_feature_dims():\n            emb = torch.nn.Embedding(dim, emb_dim)\n            torch.nn.init.xavier_uniform_(emb.weight.data)\n            self.atom_embedding_list.append(emb)\n\n    def forward(self, batch):\n        encoded_features = 0\n        for i in range(batch.x.shape[1]):\n            encoded_features += self.atom_embedding_list[i](batch.x[:, i])\n\n        batch.x = encoded_features\n        return batch\n\n\n@register_edge_encoder('Bond')\nclass BondEncoder(torch.nn.Module):\n    r\"\"\"The bond encoder used in OGB molecule dataset.\n\n    Args:\n        emb_dim (int): The output embedding dimension.\n\n    Example:\n        >>> encoder = BondEncoder(emb_dim=16)\n        >>> batch = torch.randint(0, 10, (10, 3))\n        >>> encoder(batch).size()\n        torch.Size([10, 16])\n    \"\"\"\n    def __init__(self, emb_dim: int):\n        super().__init__()\n\n        from ogb.utils.features import get_bond_feature_dims\n\n        self.bond_embedding_list = torch.nn.ModuleList()\n\n        for dim in get_bond_feature_dims():\n            emb = torch.nn.Embedding(dim, emb_dim)\n            torch.nn.init.xavier_uniform_(emb.weight.data)\n            self.bond_embedding_list.append(emb)\n\n    def forward(self, batch):\n        bond_embedding = 0\n        for i in range(batch.edge_attr.shape[1]):\n            edge_attr = batch.edge_attr\n            bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])\n\n        batch.edge_attr = bond_embedding\n        return batch\n"
  },
  {
    "path": "torch_geometric/graphgym/models/gnn.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nimport torch_geometric.graphgym.register as register\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.init import init_weights\nfrom torch_geometric.graphgym.models.layer import (\n    BatchNorm1dNode,\n    GeneralLayer,\n    GeneralMultiLayer,\n    new_layer_config,\n)\nfrom torch_geometric.graphgym.register import register_stage\n\n\ndef GNNLayer(dim_in: int, dim_out: int, has_act: bool = True) -> GeneralLayer:\n    r\"\"\"Creates a GNN layer, given the specified input and output dimensions\n    and the underlying configuration in :obj:`cfg`.\n\n    Args:\n        dim_in (int): The input dimension\n        dim_out (int): The output dimension.\n        has_act (bool, optional): Whether to apply an activation function\n            after the layer. (default: :obj:`True`)\n    \"\"\"\n    return GeneralLayer(\n        cfg.gnn.layer_type,\n        layer_config=new_layer_config(\n            dim_in,\n            dim_out,\n            1,\n            has_act=has_act,\n            has_bias=False,\n            cfg=cfg,\n        ),\n    )\n\n\ndef GNNPreMP(dim_in: int, dim_out: int, num_layers: int) -> GeneralMultiLayer:\n    r\"\"\"Creates a NN layer used before message passing, given the specified\n    input and output dimensions and the underlying configuration in :obj:`cfg`.\n\n    Args:\n        dim_in (int): The input dimension\n        dim_out (int): The output dimension.\n        num_layers (int): The number of layers.\n    \"\"\"\n    return GeneralMultiLayer(\n        'linear',\n        layer_config=new_layer_config(\n            dim_in,\n            dim_out,\n            num_layers,\n            has_act=False,\n            has_bias=False,\n            cfg=cfg,\n        ),\n    )\n\n\n@register_stage('stack')\n@register_stage('skipsum')\n@register_stage('skipconcat')\nclass GNNStackStage(torch.nn.Module):\n    r\"\"\"Stacks a number of GNN layers.\n\n    Args:\n        dim_in (int): The input dimension\n        dim_out (int): The output dimension.\n        num_layers (int): The number of layers.\n    \"\"\"\n    def __init__(self, dim_in, dim_out, num_layers):\n        super().__init__()\n        self.num_layers = num_layers\n        for i in range(num_layers):\n            if cfg.gnn.stage_type == 'skipconcat':\n                d_in = dim_in if i == 0 else dim_in + i * dim_out\n            else:\n                d_in = dim_in if i == 0 else dim_out\n            layer = GNNLayer(d_in, dim_out)\n            self.add_module(f'layer{i}', layer)\n\n    def forward(self, batch):\n        for i, layer in enumerate(self.children()):\n            x = batch.x\n            batch = layer(batch)\n            if cfg.gnn.stage_type == 'skipsum':\n                batch.x = x + batch.x\n            elif (cfg.gnn.stage_type == 'skipconcat'\n                  and i < self.num_layers - 1):\n                batch.x = torch.cat([x, batch.x], dim=1)\n        if cfg.gnn.l2norm:\n            batch.x = F.normalize(batch.x, p=2, dim=-1)\n        return batch\n\n\nclass FeatureEncoder(torch.nn.Module):\n    r\"\"\"Encodes node and edge features, given the specified input dimension and\n    the underlying configuration in :obj:`cfg`.\n\n    Args:\n        dim_in (int): The input feature dimension.\n    \"\"\"\n    def __init__(self, dim_in: int):\n        super().__init__()\n        self.dim_in = dim_in\n        if cfg.dataset.node_encoder:\n            # Encode integer node features via `torch.nn.Embedding`:\n            NodeEncoder = register.node_encoder_dict[\n                cfg.dataset.node_encoder_name]\n            self.node_encoder = NodeEncoder(cfg.gnn.dim_inner)\n            if cfg.dataset.node_encoder_bn:\n                self.node_encoder_bn = BatchNorm1dNode(\n                    new_layer_config(\n                        cfg.gnn.dim_inner,\n                        -1,\n                        -1,\n                        has_act=False,\n                        has_bias=False,\n                        cfg=cfg,\n                    ))\n            # Update `dim_in` to reflect the new dimension fo the node features\n            self.dim_in = cfg.gnn.dim_inner\n        if cfg.dataset.edge_encoder:\n            # Encode integer edge features via `torch.nn.Embedding`:\n            EdgeEncoder = register.edge_encoder_dict[\n                cfg.dataset.edge_encoder_name]\n            self.edge_encoder = EdgeEncoder(cfg.gnn.dim_inner)\n            if cfg.dataset.edge_encoder_bn:\n                self.edge_encoder_bn = BatchNorm1dNode(\n                    new_layer_config(\n                        cfg.gnn.dim_inner,\n                        -1,\n                        -1,\n                        has_act=False,\n                        has_bias=False,\n                        cfg=cfg,\n                    ))\n\n    def forward(self, batch):\n        for module in self.children():\n            batch = module(batch)\n        return batch\n\n\nclass GNN(torch.nn.Module):\n    r\"\"\"A general Graph Neural Network (GNN) model.\n\n    The GNN model consists of three main components:\n\n    1. An encoder to transform input features into a fixed-size embedding\n       space.\n    2. A processing or message passing stage for information exchange between\n       nodes.\n    3. A head to produce the final output features/predictions.\n\n    The configuration of each component is determined by the underlying\n    configuration in :obj:`cfg`.\n\n    Args:\n        dim_in (int): The input feature dimension.\n        dim_out (int): The output feature dimension.\n        **kwargs (optional): Additional keyword arguments.\n    \"\"\"\n    def __init__(self, dim_in: int, dim_out: int, **kwargs):\n        super().__init__()\n        GNNStage = register.stage_dict[cfg.gnn.stage_type]\n        GNNHead = register.head_dict[cfg.gnn.head]\n\n        self.encoder = FeatureEncoder(dim_in)\n        dim_in = self.encoder.dim_in\n\n        if cfg.gnn.layers_pre_mp > 0:\n            self.pre_mp = GNNPreMP(dim_in, cfg.gnn.dim_inner,\n                                   cfg.gnn.layers_pre_mp)\n            dim_in = cfg.gnn.dim_inner\n        if cfg.gnn.layers_mp > 0:\n            self.mp = GNNStage(dim_in=dim_in, dim_out=cfg.gnn.dim_inner,\n                               num_layers=cfg.gnn.layers_mp)\n        self.post_mp = GNNHead(dim_in=cfg.gnn.dim_inner, dim_out=dim_out)\n\n        self.apply(init_weights)\n\n    def forward(self, batch):\n        for module in self.children():\n            batch = module(batch)\n        return batch\n"
  },
  {
    "path": "torch_geometric/graphgym/models/head.py",
    "content": "import torch\n\nimport torch_geometric.graphgym.register as register\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.models.layer import MLP, new_layer_config\nfrom torch_geometric.graphgym.register import register_head\n\n\n@register_head('node')\nclass GNNNodeHead(torch.nn.Module):\n    r\"\"\"A GNN prediction head for node-level prediction tasks.\n\n    Args:\n        dim_in (int): The input feature dimension.\n        dim_out (int): The output feature dimension.\n    \"\"\"\n    def __init__(self, dim_in: int, dim_out: int):\n        super().__init__()\n        self.layer_post_mp = MLP(\n            new_layer_config(\n                dim_in,\n                dim_out,\n                cfg.gnn.layers_post_mp,\n                has_act=False,\n                has_bias=True,\n                cfg=cfg,\n            ))\n\n    def _apply_index(self, batch):\n        x = batch.x\n        y = batch.y if 'y' in batch else None\n\n        if 'split' not in batch:\n            return x, y\n\n        mask = batch[f'{batch.split}_mask']\n        return x[mask], y[mask] if y is not None else None\n\n    def forward(self, batch):\n        batch = self.layer_post_mp(batch)\n        pred, label = self._apply_index(batch)\n        return pred, label\n\n\n@register_head('edge')\n@register_head('link_pred')\nclass GNNEdgeHead(torch.nn.Module):\n    r\"\"\"A GNN prediction head for edge-level/link-level prediction tasks.\n\n    Args:\n        dim_in (int): The input feature dimension.\n        dim_out (int): The output feature dimension.\n    \"\"\"\n    def __init__(self, dim_in: int, dim_out: int):\n        super().__init__()\n        # Module to decode edges from node embeddings:\n        if cfg.model.edge_decoding == 'concat':\n            self.layer_post_mp = MLP(\n                new_layer_config(\n                    dim_in * 2,\n                    dim_out,\n                    cfg.gnn.layers_post_mp,\n                    has_act=False,\n                    has_bias=True,\n                    cfg=cfg,\n                ))\n            self.decode_module = lambda v1, v2: \\\n                self.layer_post_mp(torch.cat((v1, v2), dim=-1))\n        else:\n            if dim_out > 1:\n                raise ValueError(f\"Binary edge decoding \"\n                                 f\"'{cfg.model.edge_decoding}' is used for \"\n                                 f\"multi-class classification\")\n            self.layer_post_mp = MLP(\n                new_layer_config(\n                    dim_in,\n                    dim_in,\n                    cfg.gnn.layers_post_mp,\n                    has_act=False,\n                    has_bias=True,\n                    cfg=cfg,\n                ))\n            if cfg.model.edge_decoding == 'dot':\n                self.decode_module = lambda v1, v2: torch.sum(v1 * v2, dim=-1)\n            elif cfg.model.edge_decoding == 'cosine_similarity':\n                self.decode_module = torch.nn.CosineSimilarity(dim=-1)\n            else:\n                raise ValueError(f\"Unknown edge decoding \"\n                                 f\"'{cfg.model.edge_decoding}'\")\n\n    def _apply_index(self, batch):\n        index = f'{batch.split}_edge_index'\n        label = f'{batch.split}_edge_label'\n        return batch.x[batch[index]], batch[label]\n\n    def forward(self, batch):\n        if cfg.model.edge_decoding != 'concat':\n            batch = self.layer_post_mp(batch)\n        pred, label = self._apply_index(batch)\n        nodes_first = pred[0]\n        nodes_second = pred[1]\n        pred = self.decode_module(nodes_first, nodes_second)\n        return pred, label\n\n\n@register_head('graph')\nclass GNNGraphHead(torch.nn.Module):\n    r\"\"\"A GNN prediction head for graph-level prediction tasks.\n    A post message passing layer (as specified by :obj:`cfg.gnn.post_mp`) is\n    used to transform the pooled graph-level embeddings using an MLP.\n\n    Args:\n        dim_in (int): The input feature dimension.\n        dim_out (int): The output feature dimension.\n    \"\"\"\n    def __init__(self, dim_in: int, dim_out: int):\n        super().__init__()\n        self.layer_post_mp = MLP(\n            new_layer_config(dim_in, dim_out, cfg.gnn.layers_post_mp,\n                             has_act=False, has_bias=True, cfg=cfg))\n        self.pooling_fun = register.pooling_dict[cfg.model.graph_pooling]\n\n    def _apply_index(self, batch):\n        return batch.graph_feature, batch.y\n\n    def forward(self, batch):\n        graph_emb = self.pooling_fun(batch.x, batch.batch)\n        graph_emb = self.layer_post_mp(graph_emb)\n        batch.graph_feature = graph_emb\n        pred, label = self._apply_index(batch)\n        return pred, label\n"
  },
  {
    "path": "torch_geometric/graphgym/models/layer.py",
    "content": "import copy\nfrom dataclasses import dataclass, replace\n\nimport torch\nimport torch.nn.functional as F\n\nimport torch_geometric as pyg\nimport torch_geometric.graphgym.models.act\nimport torch_geometric.graphgym.register as register\nfrom torch_geometric.graphgym.contrib.layer.generalconv import (\n    GeneralConvLayer,\n    GeneralEdgeConvLayer,\n)\nfrom torch_geometric.graphgym.register import register_layer\nfrom torch_geometric.nn import Linear as Linear_pyg\n\n\n@dataclass\nclass LayerConfig:\n    # batchnorm parameters.\n    has_batchnorm: bool = False\n    bn_eps: float = 1e-5\n    bn_mom: float = 0.1\n\n    # mem parameters.\n    mem_inplace: bool = False\n\n    # gnn parameters.\n    dim_in: int = -1\n    dim_out: int = -1\n    edge_dim: int = -1\n    dim_inner: int = None\n    num_layers: int = 2\n    has_bias: bool = True\n    # regularizer parameters.\n    has_l2norm: bool = True\n    dropout: float = 0.0\n    # activation parameters.\n    has_act: bool = True\n    final_act: bool = True\n    act: str = 'relu'\n\n    # other parameters.\n    keep_edge: float = 0.5\n\n\ndef new_layer_config(\n    dim_in: int,\n    dim_out: int,\n    num_layers: int,\n    has_act: bool,\n    has_bias: bool,\n    cfg,\n) -> LayerConfig:\n    r\"\"\"Create a layer configuration for a GNN layer.\n\n    Args:\n        dim_in (int): The input feature dimension.\n        dim_out (int): The output feature dimension.\n        num_layers (int): The number of hidden layers\n        has_act (bool): Whether to apply an activation function after the\n            layer.\n        has_bias (bool): Whether to apply a bias term in the layer.\n        cfg (ConfigNode): The underlying configuration.\n    \"\"\"\n    return LayerConfig(\n        has_batchnorm=cfg.gnn.batchnorm,\n        bn_eps=cfg.bn.eps,\n        bn_mom=cfg.bn.mom,\n        mem_inplace=cfg.mem.inplace,\n        dim_in=dim_in,\n        dim_out=dim_out,\n        edge_dim=cfg.dataset.edge_dim,\n        has_l2norm=cfg.gnn.l2norm,\n        dropout=cfg.gnn.dropout,\n        has_act=has_act,\n        final_act=True,\n        act=cfg.gnn.act,\n        has_bias=has_bias,\n        keep_edge=cfg.gnn.keep_edge,\n        dim_inner=cfg.gnn.dim_inner,\n        num_layers=num_layers,\n    )\n\n\nclass GeneralLayer(torch.nn.Module):\n    r\"\"\"A general wrapper for layers.\n\n    Args:\n        name (str): The registered name of the layer.\n        layer_config (LayerConfig): The configuration of the layer.\n        **kwargs (optional): Additional keyword arguments.\n    \"\"\"\n    def __init__(self, name, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        self.has_l2norm = layer_config.has_l2norm\n        has_bn = layer_config.has_batchnorm\n        layer_config.has_bias = not has_bn\n        self.layer = register.layer_dict[name](layer_config, **kwargs)\n        layer_wrapper = []\n        if has_bn:\n            layer_wrapper.append(\n                torch.nn.BatchNorm1d(\n                    layer_config.dim_out,\n                    eps=layer_config.bn_eps,\n                    momentum=layer_config.bn_mom,\n                ))\n        if layer_config.dropout > 0:\n            layer_wrapper.append(\n                torch.nn.Dropout(\n                    p=layer_config.dropout,\n                    inplace=layer_config.mem_inplace,\n                ))\n        if layer_config.has_act:\n            layer_wrapper.append(register.act_dict[layer_config.act]())\n        self.post_layer = torch.nn.Sequential(*layer_wrapper)\n\n    def forward(self, batch):\n        batch = self.layer(batch)\n        if isinstance(batch, torch.Tensor):\n            batch = self.post_layer(batch)\n            if self.has_l2norm:\n                batch = F.normalize(batch, p=2, dim=1)\n        else:\n            batch.x = self.post_layer(batch.x)\n            if self.has_l2norm:\n                batch.x = F.normalize(batch.x, p=2, dim=1)\n        return batch\n\n\nclass GeneralMultiLayer(torch.nn.Module):\n    r\"\"\"A general wrapper class for a stacking multiple NN layers.\n\n    Args:\n        name (str): The registered name of the layer.\n        layer_config (LayerConfig): The configuration of the layer.\n        **kwargs (optional): Additional keyword arguments.\n    \"\"\"\n    def __init__(self, name, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        if layer_config.dim_inner:\n            dim_inner = layer_config.dim_out\n        else:\n            dim_inner = layer_config.dim_inner\n\n        for i in range(layer_config.num_layers):\n            d_in = layer_config.dim_in if i == 0 else dim_inner\n            d_out = layer_config.dim_out \\\n                if i == layer_config.num_layers - 1 else dim_inner\n            has_act = layer_config.final_act \\\n                if i == layer_config.num_layers - 1 else True\n            inter_layer_config = copy.deepcopy(layer_config)\n            inter_layer_config.dim_in = d_in\n            inter_layer_config.dim_out = d_out\n            inter_layer_config.has_act = has_act\n            layer = GeneralLayer(name, inter_layer_config, **kwargs)\n            self.add_module(f'Layer_{i}', layer)\n\n    def forward(self, batch):\n        for layer in self.children():\n            batch = layer(batch)\n        return batch\n\n\n# ---------- Core basic layers. Input: batch; Output: batch ----------------- #\n\n\n@register_layer('linear')\nclass Linear(torch.nn.Module):\n    r\"\"\"A basic Linear layer.\n\n    Args:\n        layer_config (LayerConfig): The configuration of the layer.\n        **kwargs (optional): Additional keyword arguments.\n    \"\"\"\n    def __init__(self, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        self.model = Linear_pyg(\n            layer_config.dim_in,\n            layer_config.dim_out,\n            bias=layer_config.has_bias,\n        )\n\n    def forward(self, batch):\n        if isinstance(batch, torch.Tensor):\n            batch = self.model(batch)\n        else:\n            batch.x = self.model(batch.x)\n        return batch\n\n\nclass BatchNorm1dNode(torch.nn.Module):\n    r\"\"\"A batch normalization layer for node-level features.\n\n    Args:\n        layer_config (LayerConfig): The configuration of the layer.\n    \"\"\"\n    def __init__(self, layer_config: LayerConfig):\n        super().__init__()\n        self.bn = torch.nn.BatchNorm1d(\n            layer_config.dim_in,\n            eps=layer_config.bn_eps,\n            momentum=layer_config.bn_mom,\n        )\n\n    def forward(self, batch):\n        batch.x = self.bn(batch.x)\n        return batch\n\n\nclass BatchNorm1dEdge(torch.nn.Module):\n    r\"\"\"A batch normalization layer for edge-level features.\n\n    Args:\n        layer_config (LayerConfig): The configuration of the layer.\n    \"\"\"\n    def __init__(self, layer_config: LayerConfig):\n        super().__init__()\n        self.bn = torch.nn.BatchNorm1d(\n            layer_config.dim_in,\n            eps=layer_config.bn_eps,\n            momentum=layer_config.bn_mom,\n        )\n\n    def forward(self, batch):\n        batch.edge_attr = self.bn(batch.edge_attr)\n        return batch\n\n\n@register_layer('mlp')\nclass MLP(torch.nn.Module):\n    \"\"\"A basic MLP model.\n\n    Args:\n        layer_config (LayerConfig): The configuration of the layer.\n        **kwargs (optional): Additional keyword arguments.\n    \"\"\"\n    def __init__(self, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        if layer_config.dim_inner is None:\n            dim_inner = layer_config.dim_in\n        else:\n            dim_inner = layer_config.dim_inner\n\n        layer_config.has_bias = True\n        layers = []\n        if layer_config.num_layers > 1:\n            sub_layer_config = LayerConfig(\n                num_layers=layer_config.num_layers - 1,\n                dim_in=layer_config.dim_in, dim_out=dim_inner,\n                dim_inner=dim_inner, final_act=True)\n            layers.append(GeneralMultiLayer('linear', sub_layer_config))\n            layer_config = replace(layer_config, dim_in=dim_inner)\n            layers.append(Linear(layer_config))\n        else:\n            layers.append(Linear(layer_config))\n        self.model = torch.nn.Sequential(*layers)\n\n    def forward(self, batch):\n        if isinstance(batch, torch.Tensor):\n            batch = self.model(batch)\n        else:\n            batch.x = self.model(batch.x)\n        return batch\n\n\n@register_layer('gcnconv')\nclass GCNConv(torch.nn.Module):\n    r\"\"\"A Graph Convolutional Network (GCN) layer.\"\"\"\n    def __init__(self, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        self.model = pyg.nn.GCNConv(\n            layer_config.dim_in,\n            layer_config.dim_out,\n            bias=layer_config.has_bias,\n        )\n\n    def forward(self, batch):\n        batch.x = self.model(batch.x, batch.edge_index)\n        return batch\n\n\n@register_layer('sageconv')\nclass SAGEConv(torch.nn.Module):\n    r\"\"\"A GraphSAGE layer.\"\"\"\n    def __init__(self, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        self.model = pyg.nn.SAGEConv(\n            layer_config.dim_in,\n            layer_config.dim_out,\n            bias=layer_config.has_bias,\n        )\n\n    def forward(self, batch):\n        batch.x = self.model(batch.x, batch.edge_index)\n        return batch\n\n\n@register_layer('gatconv')\nclass GATConv(torch.nn.Module):\n    r\"\"\"A Graph Attention Network (GAT) layer.\"\"\"\n    def __init__(self, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        self.model = pyg.nn.GATConv(\n            layer_config.dim_in,\n            layer_config.dim_out,\n            bias=layer_config.has_bias,\n        )\n\n    def forward(self, batch):\n        batch.x = self.model(batch.x, batch.edge_index)\n        return batch\n\n\n@register_layer('ginconv')\nclass GINConv(torch.nn.Module):\n    r\"\"\"A Graph Isomorphism Network (GIN) layer.\"\"\"\n    def __init__(self, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        gin_nn = torch.nn.Sequential(\n            Linear_pyg(layer_config.dim_in, layer_config.dim_out),\n            torch.nn.ReLU(),\n            Linear_pyg(layer_config.dim_out, layer_config.dim_out),\n        )\n        self.model = pyg.nn.GINConv(gin_nn)\n\n    def forward(self, batch):\n        batch.x = self.model(batch.x, batch.edge_index)\n        return batch\n\n\n@register_layer('splineconv')\nclass SplineConv(torch.nn.Module):\n    r\"\"\"A SplineCNN layer.\"\"\"\n    def __init__(self, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        self.model = pyg.nn.SplineConv(\n            layer_config.dim_in,\n            layer_config.dim_out,\n            dim=1,\n            kernel_size=2,\n            bias=layer_config.has_bias,\n        )\n\n    def forward(self, batch):\n        batch.x = self.model(batch.x, batch.edge_index, batch.edge_attr)\n        return batch\n\n\n@register_layer('generalconv')\nclass GeneralConv(torch.nn.Module):\n    r\"\"\"A general GNN layer.\"\"\"\n    def __init__(self, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        self.model = GeneralConvLayer(\n            layer_config.dim_in,\n            layer_config.dim_out,\n            bias=layer_config.has_bias,\n        )\n\n    def forward(self, batch):\n        batch.x = self.model(batch.x, batch.edge_index)\n        return batch\n\n\n@register_layer('generaledgeconv')\nclass GeneralEdgeConv(torch.nn.Module):\n    r\"\"\"A general GNN layer with edge feature support.\"\"\"\n    def __init__(self, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        self.model = GeneralEdgeConvLayer(\n            layer_config.dim_in,\n            layer_config.dim_out,\n            layer_config.edge_dim,\n            bias=layer_config.has_bias,\n        )\n\n    def forward(self, batch):\n        batch.x = self.model(batch.x, batch.edge_index,\n                             edge_feature=batch.edge_attr)\n        return batch\n\n\n@register_layer('generalsampleedgeconv')\nclass GeneralSampleEdgeConv(torch.nn.Module):\n    r\"\"\"A general GNN layer that supports edge features and edge sampling.\"\"\"\n    def __init__(self, layer_config: LayerConfig, **kwargs):\n        super().__init__()\n        self.model = GeneralEdgeConvLayer(\n            layer_config.dim_in,\n            layer_config.dim_out,\n            layer_config.edge_dim,\n            bias=layer_config.has_bias,\n        )\n        self.keep_edge = layer_config.keep_edge\n\n    def forward(self, batch):\n        edge_mask = torch.rand(batch.edge_index.shape[1]) < self.keep_edge\n        edge_index = batch.edge_index[:, edge_mask]\n        edge_feature = batch.edge_attr[edge_mask, :]\n        batch.x = self.model(batch.x, edge_index, edge_feature=edge_feature)\n        return batch\n"
  },
  {
    "path": "torch_geometric/graphgym/models/pooling.py",
    "content": "from torch_geometric.graphgym.register import register_pooling\nfrom torch_geometric.nn import (\n    global_add_pool,\n    global_max_pool,\n    global_mean_pool,\n)\n\nregister_pooling('add', global_add_pool)\nregister_pooling('mean', global_mean_pool)\nregister_pooling('max', global_max_pool)\n"
  },
  {
    "path": "torch_geometric/graphgym/models/transform.py",
    "content": "import torch\n\nfrom torch_geometric.utils import negative_sampling\n\n\ndef create_link_label(pos_edge_index, neg_edge_index):\n    \"\"\"Create labels for link prediction, based on positive and negative edges.\n\n    Args:\n        pos_edge_index (torch.tensor): Positive edge index [2, num_edges]\n        neg_edge_index (torch.tensor): Negative edge index [2, num_edges]\n\n    Returns: Link label tensor, [num_positive_edges + num_negative_edges]\n\n    \"\"\"\n    num_links = pos_edge_index.size(1) + neg_edge_index.size(1)\n    link_labels = torch.zeros(num_links, dtype=torch.float,\n                              device=pos_edge_index.device)\n    link_labels[:pos_edge_index.size(1)] = 1.\n    return link_labels\n\n\ndef neg_sampling_transform(data):\n    \"\"\"Do negative sampling for link prediction tasks.\n\n    Args:\n        data (torch_geometric.data): Input data object\n\n    Returns: Transformed data object with negative edges + link pred labels\n\n    \"\"\"\n    train_neg_edge_index = negative_sampling(\n        edge_index=data.train_pos_edge_index, num_nodes=data.num_nodes,\n        num_neg_samples=data.train_pos_edge_index.size(1))\n    data.train_edge_index = torch.cat(\n        [data.train_pos_edge_index, train_neg_edge_index], dim=-1)\n    data.train_edge_label = create_link_label(data.train_pos_edge_index,\n                                              train_neg_edge_index)\n\n    return data\n"
  },
  {
    "path": "torch_geometric/graphgym/optim.py",
    "content": "from dataclasses import dataclass, field\nfrom typing import Any, Iterator, List, Optional\n\nfrom torch.nn import Parameter\nfrom torch.optim import SGD, Adam, Optimizer\nfrom torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR, StepLR\n\nimport torch_geometric.graphgym.register as register\nfrom torch_geometric.graphgym.config import from_config\n\n\n@dataclass\nclass OptimizerConfig:\n    optimizer: str = 'adam'  # ['sgd', 'adam']\n    base_lr: float = 0.01\n    weight_decay: float = 5e-4\n    momentum: float = 0.9  # 'sgd' policy\n\n\n@register.register_optimizer('adam')\ndef adam_optimizer(params: Iterator[Parameter], base_lr: float,\n                   weight_decay: float) -> Adam:\n    return Adam(params, lr=base_lr, weight_decay=weight_decay)\n\n\n@register.register_optimizer('sgd')\ndef sgd_optimizer(params: Iterator[Parameter], base_lr: float, momentum: float,\n                  weight_decay: float) -> SGD:\n    return SGD(params, lr=base_lr, momentum=momentum,\n               weight_decay=weight_decay)\n\n\ndef create_optimizer(params: Iterator[Parameter], cfg: Any) -> Any:\n    r\"\"\"Creates a config-driven optimizer.\"\"\"\n    params = filter(lambda p: p.requires_grad, params)\n    func = register.optimizer_dict.get(cfg.optimizer, None)\n    if func is not None:\n        return from_config(func)(params, cfg=cfg)\n    raise ValueError(f\"Optimizer '{cfg.optimizer}' not supported\")\n\n\n@dataclass\nclass SchedulerConfig:\n    scheduler: Optional[str] = 'cos'  # [None, 'steps', 'cos']\n    steps: List[int] = field(default_factory=[30, 60, 90])  # 'steps' policy\n    lr_decay: float = 0.1  # 'steps' policy\n    max_epoch: int = 200\n\n\n@register.register_scheduler(None)\n@register.register_scheduler('none')\ndef none_scheduler(optimizer: Optimizer, max_epoch: int) -> StepLR:\n    return StepLR(optimizer, step_size=max_epoch + 1)\n\n\n@register.register_scheduler('step')\ndef step_scheduler(optimizer: Optimizer, steps: List[int],\n                   lr_decay: float) -> MultiStepLR:\n    return MultiStepLR(optimizer, milestones=steps, gamma=lr_decay)\n\n\n@register.register_scheduler('cos')\ndef cos_scheduler(optimizer: Optimizer, max_epoch: int) -> CosineAnnealingLR:\n    return CosineAnnealingLR(optimizer, T_max=max_epoch)\n\n\ndef create_scheduler(optimizer: Optimizer, cfg: Any) -> Any:\n    r\"\"\"Creates a config-driven learning rate scheduler.\"\"\"\n    func = register.scheduler_dict.get(cfg.scheduler, None)\n    if func is not None:\n        return from_config(func)(optimizer, cfg=cfg)\n    raise ValueError(f\"Scheduler '{cfg.scheduler}' not supported\")\n"
  },
  {
    "path": "torch_geometric/graphgym/register.py",
    "content": "from typing import Any, Callable, Dict, Union\n\nact_dict: Dict[str, Any] = {}\nnode_encoder_dict: Dict[str, Any] = {}\nedge_encoder_dict: Dict[str, Any] = {}\nstage_dict: Dict[str, Any] = {}\nhead_dict: Dict[str, Any] = {}\nlayer_dict: Dict[str, Any] = {}\npooling_dict: Dict[str, Any] = {}\nnetwork_dict: Dict[str, Any] = {}\nconfig_dict: Dict[str, Any] = {}\ndataset_dict: Dict[str, Any] = {}\nloader_dict: Dict[str, Any] = {}\noptimizer_dict: Dict[str, Any] = {}\nscheduler_dict: Dict[str, Any] = {}\nloss_dict: Dict[str, Any] = {}\ntrain_dict: Dict[str, Any] = {}\nmetric_dict: Dict[str, Any] = {}\n\n\ndef register_base(mapping: Dict[str, Any], key: str,\n                  module: Any = None) -> Union[None, Callable]:\n    r\"\"\"Base function for registering a module in GraphGym.\n\n    Args:\n        mapping (dict): :python:`Python` dictionary to register the module.\n            hosting all the registered modules\n        key (str): The name of the module.\n        module (any, optional): The module. If set to :obj:`None`, will return\n            a decorator to register a module.\n    \"\"\"\n    if module is not None:\n        if key in mapping:\n            raise KeyError(f\"Module with '{key}' already defined\")\n        mapping[key] = module\n        return\n\n    # Other-wise, use it as a decorator:\n    def bounded_register(module):\n        register_base(mapping, key, module)\n        return module\n\n    return bounded_register\n\n\ndef register_act(key: str, module: Any = None):\n    r\"\"\"Registers an activation function in GraphGym.\"\"\"\n    return register_base(act_dict, key, module)\n\n\ndef register_node_encoder(key: str, module: Any = None):\n    r\"\"\"Registers a node feature encoder in GraphGym.\"\"\"\n    return register_base(node_encoder_dict, key, module)\n\n\ndef register_edge_encoder(key: str, module: Any = None):\n    r\"\"\"Registers an edge feature encoder in GraphGym.\"\"\"\n    return register_base(edge_encoder_dict, key, module)\n\n\ndef register_stage(key: str, module: Any = None):\n    r\"\"\"Registers a customized GNN stage in GraphGym.\"\"\"\n    return register_base(stage_dict, key, module)\n\n\ndef register_head(key: str, module: Any = None):\n    r\"\"\"Registers a GNN prediction head in GraphGym.\"\"\"\n    return register_base(head_dict, key, module)\n\n\ndef register_layer(key: str, module: Any = None):\n    r\"\"\"Registers a GNN layer in GraphGym.\"\"\"\n    return register_base(layer_dict, key, module)\n\n\ndef register_pooling(key: str, module: Any = None):\n    r\"\"\"Registers a GNN global pooling/readout layer in GraphGym.\"\"\"\n    return register_base(pooling_dict, key, module)\n\n\ndef register_network(key: str, module: Any = None):\n    r\"\"\"Registers a GNN model in GraphGym.\"\"\"\n    return register_base(network_dict, key, module)\n\n\ndef register_config(key: str, module: Any = None):\n    r\"\"\"Registers a configuration group in GraphGym.\"\"\"\n    return register_base(config_dict, key, module)\n\n\ndef register_dataset(key: str, module: Any = None):\n    r\"\"\"Registers a dataset in GraphGym.\"\"\"\n    return register_base(dataset_dict, key, module)\n\n\ndef register_loader(key: str, module: Any = None):\n    r\"\"\"Registers a data loader in GraphGym.\"\"\"\n    return register_base(loader_dict, key, module)\n\n\ndef register_optimizer(key: str, module: Any = None):\n    r\"\"\"Registers an optimizer in GraphGym.\"\"\"\n    return register_base(optimizer_dict, key, module)\n\n\ndef register_scheduler(key: str, module: Any = None):\n    r\"\"\"Registers a learning rate scheduler in GraphGym.\"\"\"\n    return register_base(scheduler_dict, key, module)\n\n\ndef register_loss(key: str, module: Any = None):\n    r\"\"\"Registers a loss function in GraphGym.\"\"\"\n    return register_base(loss_dict, key, module)\n\n\ndef register_train(key: str, module: Any = None):\n    r\"\"\"Registers a training function in GraphGym.\"\"\"\n    return register_base(train_dict, key, module)\n\n\ndef register_metric(key: str, module: Any = None):\n    r\"\"\"Register a metric function in GraphGym.\"\"\"\n    return register_base(metric_dict, key, module)\n"
  },
  {
    "path": "torch_geometric/graphgym/train.py",
    "content": "import warnings\nfrom typing import Any, Dict, Optional\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom torch_geometric.data.lightning.datamodule import LightningDataModule\nfrom torch_geometric.graphgym import create_loader\nfrom torch_geometric.graphgym.checkpoint import get_ckpt_dir\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.imports import pl\nfrom torch_geometric.graphgym.logger import LoggerCallback\nfrom torch_geometric.graphgym.model_builder import GraphGymModule\n\n\nclass GraphGymDataModule(LightningDataModule):\n    r\"\"\"A :class:`pytorch_lightning.LightningDataModule` for handling data\n    loading routines in GraphGym.\n\n    This class provides data loaders for training, validation, and testing, and\n    can be accessed through the :meth:`train_dataloader`,\n    :meth:`val_dataloader`, and :meth:`test_dataloader` methods, respectively.\n    \"\"\"\n    def __init__(self):\n        self.loaders = create_loader()\n        super().__init__(has_val=True, has_test=True)\n\n    def train_dataloader(self) -> DataLoader:\n        return self.loaders[0]\n\n    def val_dataloader(self) -> DataLoader:\n        # better way would be to test after fit.\n        # First call trainer.fit(...) then trainer.test(...)\n        return self.loaders[1]\n\n    def test_dataloader(self) -> DataLoader:\n        return self.loaders[2]\n\n\ndef train(\n    model: GraphGymModule,\n    datamodule: GraphGymDataModule,\n    logger: bool = True,\n    trainer_config: Optional[Dict[str, Any]] = None,\n):\n    r\"\"\"Trains a GraphGym model using PyTorch Lightning.\n\n    Args:\n        model (GraphGymModule): The GraphGym model.\n        datamodule (GraphGymDataModule): The GraphGym data module.\n        logger (bool, optional): Whether to enable logging during training.\n            (default: :obj:`True`)\n        trainer_config (dict, optional): Additional trainer configuration.\n    \"\"\"\n    warnings.filterwarnings('ignore', '.*use `CSVLogger` as the default.*')\n\n    callbacks = []\n    if logger:\n        callbacks.append(LoggerCallback())\n    if cfg.train.enable_ckpt:\n        ckpt_cbk = pl.callbacks.ModelCheckpoint(dirpath=get_ckpt_dir())\n        callbacks.append(ckpt_cbk)\n\n    trainer_config = trainer_config or {}\n    trainer = pl.Trainer(\n        **trainer_config,\n        enable_checkpointing=cfg.train.enable_ckpt,\n        callbacks=callbacks,\n        default_root_dir=cfg.out_dir,\n        max_epochs=cfg.optim.max_epoch,\n        accelerator=cfg.accelerator,\n        devices='auto' if not torch.cuda.is_available() else cfg.devices,\n    )\n\n    trainer.fit(model, datamodule=datamodule)\n    trainer.test(model, datamodule=datamodule)\n"
  },
  {
    "path": "torch_geometric/graphgym/utils/LICENSE",
    "content": ""
  },
  {
    "path": "torch_geometric/graphgym/utils/__init__.py",
    "content": "from .agg_runs import agg_runs, agg_batch\nfrom .comp_budget import params_count, match_baseline_cfg\nfrom .device import get_current_gpu_usage, auto_select_device\nfrom .epoch import is_eval_epoch, is_ckpt_epoch\nfrom .io import dict_to_json, dict_list_to_json, dict_to_tb, makedirs_rm_exist\nfrom .tools import dummy_context\n\n__all__ = [\n    'agg_runs',\n    'agg_batch',\n    'params_count',\n    'match_baseline_cfg',\n    'get_current_gpu_usage',\n    'auto_select_device',\n    'is_eval_epoch',\n    'is_ckpt_epoch',\n    'dict_to_json',\n    'dict_list_to_json',\n    'dict_to_tb',\n    'makedirs_rm_exist',\n    'dummy_context',\n]\n\nclasses = __all__\n"
  },
  {
    "path": "torch_geometric/graphgym/utils/agg_runs.py",
    "content": "import logging\nimport os\nimport os.path as osp\n\nimport numpy as np\n\nfrom torch_geometric.graphgym.config import cfg\nfrom torch_geometric.graphgym.utils.io import (\n    dict_list_to_json,\n    dict_list_to_tb,\n    dict_to_json,\n    json_to_dict_list,\n    makedirs_rm_exist,\n    string_to_python,\n)\n\ntry:\n    from tensorboardX import SummaryWriter\nexcept ImportError:\n    SummaryWriter = None\n\n\ndef is_seed(s):\n    try:\n        int(s)\n        return True\n    except Exception:\n        return False\n\n\ndef is_split(s):\n    if s in ['train', 'val']:\n        return True\n    else:\n        return False\n\n\ndef join_list(l1, l2):\n    assert len(l1) == len(l2), \\\n        'Results with different seeds must have the save format'\n    for i in range(len(l1)):\n        l1[i] += l2[i]\n    return l1\n\n\ndef agg_dict_list(dict_list):\n    \"\"\"Aggregate a list of dictionaries: mean + std\n    Args:\n        dict_list: list of dictionaries.\n\n    \"\"\"\n    dict_agg = {'epoch': dict_list[0]['epoch']}\n    for key in dict_list[0]:\n        if key != 'epoch':\n            value = np.array([dict[key] for dict in dict_list])\n            dict_agg[key] = np.mean(value).round(cfg.round)\n            dict_agg[f'{key}_std'] = np.std(value).round(cfg.round)\n    return dict_agg\n\n\ndef name_to_dict(run):\n    run = run.split('-', 1)[-1]\n    cols = run.split('=')\n    keys, vals = [], []\n    keys.append(cols[0])\n    for col in cols[1:-1]:\n        try:\n            val, key = col.rsplit('-', 1)\n        except Exception:\n            print(col)\n        keys.append(key)\n        vals.append(string_to_python(val))\n    vals.append(cols[-1])\n    return dict(zip(keys, vals))\n\n\ndef rm_keys(dict, keys):\n    for key in keys:\n        dict.pop(key, None)\n\n\ndef agg_runs(dir, metric_best='auto'):\n    r\"\"\"Aggregate over different random seeds of a single experiment.\n\n    Args:\n        dir (str): Directory of the results, containing 1 experiment\n        metric_best (str, optional): The metric for selecting the best\n        validation performance. Options: auto, accuracy, auc.\n\n    \"\"\"\n    results = {'train': None, 'val': None}\n    results_best = {'train': None, 'val': None}\n    for seed in os.listdir(dir):\n        if is_seed(seed):\n            dir_seed = osp.join(dir, seed)\n\n            split = 'val'\n            if split in os.listdir(dir_seed):\n                dir_split = osp.join(dir_seed, split)\n                fname_stats = osp.join(dir_split, 'stats.json')\n                stats_list = json_to_dict_list(fname_stats)\n                if metric_best == 'auto':\n                    metric = 'auc' if 'auc' in stats_list[0] else 'accuracy'\n                else:\n                    metric = metric_best\n                performance_np = np.array(  # noqa\n                    [stats[metric] for stats in stats_list])\n                best_epoch = \\\n                    stats_list[\n                        eval(f\"performance_np.{cfg.metric_agg}()\")][\n                        'epoch']\n                print(best_epoch)\n\n            for split in os.listdir(dir_seed):\n                if is_split(split):\n                    dir_split = osp.join(dir_seed, split)\n                    fname_stats = osp.join(dir_split, 'stats.json')\n                    stats_list = json_to_dict_list(fname_stats)\n                    stats_best = [\n                        stats for stats in stats_list\n                        if stats['epoch'] == best_epoch\n                    ][0]\n                    print(stats_best)\n                    stats_list = [[stats] for stats in stats_list]\n                    if results[split] is None:\n                        results[split] = stats_list\n                    else:\n                        results[split] = join_list(results[split], stats_list)\n                    if results_best[split] is None:\n                        results_best[split] = [stats_best]\n                    else:\n                        results_best[split] += [stats_best]\n    results = {k: v for k, v in results.items() if v is not None}\n    results_best = {k: v for k, v in results_best.items() if v is not None}\n    for key in results:\n        for i in range(len(results[key])):\n            results[key][i] = agg_dict_list(results[key][i])\n    for key in results_best:\n        results_best[key] = agg_dict_list(results_best[key])\n    # save aggregated results\n    for key, value in results.items():\n        dir_out = osp.join(dir, 'agg', key)\n        makedirs_rm_exist(dir_out)\n        fname = osp.join(dir_out, 'stats.json')\n        dict_list_to_json(value, fname)\n\n        if cfg.tensorboard_agg:\n            if SummaryWriter is None:\n                raise ImportError(\n                    'Tensorboard support requires `tensorboardX`.')\n            writer = SummaryWriter(dir_out)\n            dict_list_to_tb(value, writer)\n            writer.close()\n    for key, value in results_best.items():\n        dir_out = osp.join(dir, 'agg', key)\n        fname = osp.join(dir_out, 'best.json')\n        dict_to_json(value, fname)\n    logging.info('Results aggregated across runs saved in {}'.format(\n        osp.join(dir, 'agg')))\n\n\ndef agg_batch(dir, metric_best='auto'):\n    r\"\"\"Aggregate across results from multiple experiments via grid search.\n\n    Args:\n        dir (str): Directory of the results, containing multiple experiments\n        metric_best (str, optional): The metric for selecting the best\n        validation performance. Options: auto, accuracy, auc.\n\n    \"\"\"\n    import pandas as pd\n    results = {'train': [], 'val': [], 'test': []}\n    for run in os.listdir(dir):\n        if run != 'agg':\n            dict_name = name_to_dict(run)\n            dir_run = osp.join(dir, run, 'agg')\n            if osp.isdir(dir_run):\n                for split in os.listdir(dir_run):\n                    dir_split = osp.join(dir_run, split)\n                    fname_stats = osp.join(dir_split, 'best.json')\n                    dict_stats = json_to_dict_list(fname_stats)[\n                        -1]  # get best val epoch\n                    rm_keys(dict_stats,\n                            ['lr', 'lr_std', 'eta', 'eta_std', 'params_std'])\n                    results[split].append({**dict_name, **dict_stats})\n    dir_out = osp.join(dir, 'agg')\n    makedirs_rm_exist(dir_out)\n    for key in results:\n        if len(results[key]) > 0:\n            results[key] = pd.DataFrame(results[key])\n            results[key] = results[key].sort_values(\n                list(dict_name.keys()), ascending=[True] * len(dict_name))\n            fname = osp.join(dir_out, f'{key}_best.csv')\n            results[key].to_csv(fname, index=False)\n\n    results = {'train': [], 'val': [], 'test': []}\n    for run in os.listdir(dir):\n        if run != 'agg':\n            dict_name = name_to_dict(run)\n            dir_run = osp.join(dir, run, 'agg')\n            if osp.isdir(dir_run):\n                for split in os.listdir(dir_run):\n                    dir_split = osp.join(dir_run, split)\n                    fname_stats = osp.join(dir_split, 'stats.json')\n                    dict_stats = json_to_dict_list(fname_stats)[\n                        -1]  # get last epoch\n                    rm_keys(dict_stats,\n                            ['lr', 'lr_std', 'eta', 'eta_std', 'params_std'])\n                    results[split].append({**dict_name, **dict_stats})\n    dir_out = osp.join(dir, 'agg')\n    for key in results:\n        if len(results[key]) > 0:\n            results[key] = pd.DataFrame(results[key])\n            results[key] = results[key].sort_values(\n                list(dict_name.keys()), ascending=[True] * len(dict_name))\n            fname = osp.join(dir_out, f'{key}.csv')\n            results[key].to_csv(fname, index=False)\n\n    results = {'train': [], 'val': [], 'test': []}\n    for run in os.listdir(dir):\n        if run != 'agg':\n            dict_name = name_to_dict(run)\n            dir_run = osp.join(dir, run, 'agg')\n            if osp.isdir(dir_run):\n                for split in os.listdir(dir_run):\n                    dir_split = osp.join(dir_run, split)\n                    fname_stats = osp.join(dir_split, 'stats.json')\n                    dict_stats = json_to_dict_list(\n                        fname_stats)  # get best epoch\n                    if metric_best == 'auto':\n                        metric = 'auc' if 'auc' in dict_stats[0] \\\n                            else 'accuracy'\n                    else:\n                        metric = metric_best\n                    performance_np = np.array(  # noqa\n                        [stats[metric] for stats in dict_stats])\n                    dict_stats = dict_stats[eval(\"performance_np.{}()\".format(\n                        cfg.metric_agg))]\n                    rm_keys(dict_stats,\n                            ['lr', 'lr_std', 'eta', 'eta_std', 'params_std'])\n                    results[split].append({**dict_name, **dict_stats})\n    dir_out = osp.join(dir, 'agg')\n    for key in results:\n        if len(results[key]) > 0:\n            results[key] = pd.DataFrame(results[key])\n            results[key] = results[key].sort_values(\n                list(dict_name.keys()), ascending=[True] * len(dict_name))\n            fname = osp.join(dir_out, f'{key}_bestepoch.csv')\n            results[key].to_csv(fname, index=False)\n\n    print(f'Results aggregated across models saved in {dir_out}')\n"
  },
  {
    "path": "torch_geometric/graphgym/utils/comp_budget.py",
    "content": "import math\n\nfrom torch_geometric.graphgym.config import cfg, set_cfg\nfrom torch_geometric.graphgym.model_builder import create_model\n\n\ndef params_count(model):\n    \"\"\"Computes the number of parameters.\n\n    Args:\n        model (nn.Module): PyTorch model\n    \"\"\"\n    return sum([p.numel() for p in model.parameters()])\n\n\ndef get_stats():\n    model = create_model(to_device=False, dim_in=1, dim_out=1)\n    return params_count(model)\n\n\ndef match_computation(stats_baseline, key=None, mode='sqrt'):\n    \"\"\"Match computation budget by modifying :obj:`cfg.gnn.dim_inner`.\"\"\"\n    key = key or ['gnn', 'dim_inner']\n    stats = get_stats()\n    if stats != stats_baseline:\n        # Phase 1: fast approximation\n        while True:\n            if mode == 'sqrt':\n                scale = math.sqrt(stats_baseline / stats)\n            elif mode == 'linear':\n                scale = stats_baseline / stats\n            step = int(round(cfg[key[0]][key[1]] * scale)) \\\n                - cfg[key[0]][key[1]]\n            cfg[key[0]][key[1]] += step\n            stats = get_stats()\n            if abs(step) <= 1:\n                break\n        # Phase 2: fine tune\n        flag_init = 1 if stats < stats_baseline else -1\n        step = 1\n        while True:\n            cfg[key[0]][key[1]] += flag_init * step\n            stats = get_stats()\n            flag = 1 if stats < stats_baseline else -1\n            if stats == stats_baseline:\n                return stats\n            if flag != flag_init:\n                if not cfg.model.match_upper:  # stats is SMALLER\n                    if flag < 0:\n                        cfg[key[0]][key[1]] -= flag_init * step\n                    return get_stats()\n                else:\n                    if flag > 0:\n                        cfg[key[0]][key[1]] -= flag_init * step\n                    return get_stats()\n    return stats\n\n\ndef dict_to_stats(cfg_dict):\n    from yacs.config import CfgNode as CN\n    set_cfg(cfg)\n    cfg_new = CN(cfg_dict)\n    cfg.merge_from_other_cfg(cfg_new)\n    stats = get_stats()\n    set_cfg(cfg)\n    return stats\n\n\ndef match_baseline_cfg(cfg_dict, cfg_dict_baseline, verbose=True):\n    \"\"\"Match the computational budget of a given baseline model. The current\n    configuration dictionary will be modified and returned.\n\n    Args:\n        cfg_dict (dict): Current experiment's configuration\n        cfg_dict_baseline (dict): Baseline configuration\n        verbose (str, optional): If printing matched parameter conunts\n    \"\"\"\n    from yacs.config import CfgNode as CN\n    stats_baseline = dict_to_stats(cfg_dict_baseline)\n    set_cfg(cfg)\n    cfg_new = CN(cfg_dict)\n    cfg.merge_from_other_cfg(cfg_new)\n    stats = match_computation(stats_baseline, key=['gnn', 'dim_inner'])\n    if 'gnn' in cfg_dict:\n        cfg_dict['gnn']['dim_inner'] = cfg.gnn.dim_inner\n    else:\n        cfg_dict['gnn'] = {'dim_inner', cfg.gnn.dim_inner}\n    set_cfg(cfg)\n    if verbose:\n        print(f\"Computational budget has matched - Baseline params: \"\n              f\"{stats_baseline}, Current params: {stats}\")\n    return cfg_dict\n"
  },
  {
    "path": "torch_geometric/graphgym/utils/device.py",
    "content": "import os\nimport subprocess\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.graphgym.config import cfg\n\n\ndef get_gpu_memory_map():\n    \"\"\"Get the current GPU usage.\"\"\"\n    result = subprocess.check_output([\n        'nvidia-smi', '--query-gpu=memory.used',\n        '--format=csv,nounits,noheader'\n    ], encoding='utf-8')\n    gpu_memory = np.array([int(x) for x in result.strip().split('\\n')])\n    return gpu_memory\n\n\ndef get_current_gpu_usage():\n    \"\"\"Get the current GPU memory usage.\"\"\"\n    if cfg.gpu_mem and cfg.device != 'cpu' and torch.cuda.is_available():\n        result = subprocess.check_output([\n            'nvidia-smi', '--query-compute-apps=pid,used_memory',\n            '--format=csv,nounits,noheader'\n        ], encoding='utf-8')\n        current_pid = os.getpid()\n        used_memory = 0\n        for line in result.strip().split('\\n'):\n            line = line.split(', ')\n            if current_pid == int(line[0]):\n                used_memory += int(line[1])\n        return used_memory\n    else:\n        return -1\n\n\ndef auto_select_device():\n    r\"\"\"Auto select device for the current experiment.\"\"\"\n    if cfg.accelerator == 'auto':\n        if torch.cuda.is_available():\n            cfg.accelerator = 'cuda'\n            cfg.devices = 1\n        else:\n            cfg.accelerator = 'cpu'\n            cfg.devices = None\n"
  },
  {
    "path": "torch_geometric/graphgym/utils/epoch.py",
    "content": "from torch_geometric.graphgym.config import cfg\n\n\ndef is_train_eval_epoch(cur_epoch):\n    \"\"\"Determines if the model should be evaluated at the training epoch.\"\"\"\n    return is_eval_epoch(cur_epoch) or not cfg.train.skip_train_eval\n\n\ndef is_eval_epoch(cur_epoch):\n    \"\"\"Determines if the model should be evaluated at the current epoch.\"\"\"\n    return ((cur_epoch + 1) % cfg.train.eval_period == 0 or cur_epoch == 0\n            or (cur_epoch + 1) == cfg.optim.max_epoch)\n\n\ndef is_ckpt_epoch(cur_epoch):\n    \"\"\"Determines if the model should be evaluated at the current epoch.\"\"\"\n    return ((cur_epoch + 1) % cfg.train.ckpt_period == 0\n            or (cur_epoch + 1) == cfg.optim.max_epoch)\n"
  },
  {
    "path": "torch_geometric/graphgym/utils/io.py",
    "content": "import ast\nimport json\nimport os\nimport os.path as osp\n\nfrom torch_geometric.io import fs\n\n\ndef string_to_python(string):\n    try:\n        return ast.literal_eval(string)\n    except Exception:\n        return string\n\n\ndef dict_to_json(dict, fname):\n    \"\"\"Dump a :python:`Python` dictionary to a JSON file.\n\n    Args:\n        dict (dict): The :python:`Python` dictionary.\n        fname (str): The output file name.\n    \"\"\"\n    with open(fname, 'a') as f:\n        json.dump(dict, f)\n        f.write('\\n')\n\n\ndef dict_list_to_json(dict_list, fname):\n    \"\"\"Dump a list of :python:`Python` dictionaries to a JSON file.\n\n    Args:\n        dict_list (list of dict): List of :python:`Python` dictionaries.\n        fname (str): the output file name.\n    \"\"\"\n    with open(fname, 'a') as f:\n        for dict in dict_list:\n            json.dump(dict, f)\n            f.write('\\n')\n\n\ndef json_to_dict_list(fname):\n    dict_list = []\n    epoch_set = set()\n    with open(fname) as f:\n        lines = f.readlines()\n        for line in lines:\n            line = line.rstrip()\n            dict = json.loads(line)\n            if dict['epoch'] not in epoch_set:\n                dict_list.append(dict)\n            epoch_set.add(dict['epoch'])\n    return dict_list\n\n\ndef dict_to_tb(dict, writer, epoch):\n    \"\"\"Add a dictionary of statistics to a Tensorboard writer.\n\n    Args:\n        dict (dict): Statistics of experiments, the keys are attribute names,\n        the values are the attribute values\n        writer: Tensorboard writer object\n        epoch (int): The current epoch\n    \"\"\"\n    for key in dict:\n        writer.add_scalar(key, dict[key], epoch)\n\n\ndef dict_list_to_tb(dict_list, writer):\n    for dict in dict_list:\n        assert 'epoch' in dict, 'Key epoch must exist in stats dict'\n        dict_to_tb(dict, writer, dict['epoch'])\n\n\ndef makedirs_rm_exist(dir):\n    \"\"\"Make a directory, remove any existing data.\n\n    Args:\n        dir (str): The directory to be created.\n    \"\"\"\n    if osp.isdir(dir):\n        fs.rm(dir)\n    os.makedirs(dir, exist_ok=True)\n"
  },
  {
    "path": "torch_geometric/graphgym/utils/plot.py",
    "content": "import os.path as osp\n\n\ndef view_emb(emb, dir):\n    \"\"\"Visualize a embedding matrix.\n\n    Args:\n        emb (torch.tensor): Embedding matrix with shape (N, D). D is the\n        feature dimension.\n        dir (str): Output directory for the embedding figure.\n    \"\"\"\n    import matplotlib.pyplot as plt\n    import seaborn as sns\n    from sklearn.decomposition import PCA\n\n    sns.set_context('poster')\n\n    if emb.shape[1] > 2:\n        pca = PCA(n_components=2)\n        emb = pca.fit_transform(emb)\n    plt.figure(figsize=(10, 10))\n    plt.scatter(emb[:, 0], emb[:, 1])\n    plt.savefig(osp.join(dir, 'emb_pca.png'), dpi=100)\n"
  },
  {
    "path": "torch_geometric/graphgym/utils/tools.py",
    "content": "class dummy_context():\n    \"\"\"Default context manager that does nothing.\"\"\"\n    def __enter__(self):\n        return None\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        return False\n"
  },
  {
    "path": "torch_geometric/hash_tensor.py",
    "content": "import functools\nimport warnings\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    List,\n    Optional,\n    Tuple,\n    Type,\n    Union,\n)\n\nimport numpy as np\nimport torch\nimport torch.utils._pytree as pytree\nimport xxhash\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.typing import CPUHashMap, CUDAHashMap\n\naten = torch.ops.aten\n\nHANDLED_FUNCTIONS: Dict[Callable, Callable] = {}\n\n\ndef implements(torch_function: Callable) -> Callable:\n    r\"\"\"Registers a :pytorch:`PyTorch` function override.\"\"\"\n    @functools.wraps(torch_function)\n    def decorator(my_function: Callable) -> Callable:\n        HANDLED_FUNCTIONS[torch_function] = my_function\n        return my_function\n\n    return decorator\n\n\ndef as_key_tensor(\n    key: Any,\n    *,\n    device: Optional[torch.device] = None,\n) -> Tensor:\n    try:\n        key = torch.as_tensor(key, device=device)\n    except Exception:\n        device = device or torch.get_default_device()\n        key = torch.tensor(\n            [xxhash.xxh64(x).intdigest() & 0x7FFFFFFFFFFFFFFF for x in key],\n            dtype=torch.int64, device=device)\n\n    if key.element_size() == 1:\n        key = key.view(torch.uint8)\n    elif key.element_size() == 2:\n        key = key.view(torch.int16)\n    elif key.element_size() == 4:\n        key = key.view(torch.int32)\n    elif key.element_size() == 8:\n        key = key.view(torch.int64)\n    else:\n        raise ValueError(f\"Received invalid dtype '{key.dtype}' with \"\n                         f\"{key.element_size()} bytes\")\n\n    return key\n\n\ndef get_hash_map(key: Tensor) -> Union[CPUHashMap, CUDAHashMap]:\n    if torch_geometric.typing.WITH_CUDA_HASH_MAP and key.is_cuda:\n        return CUDAHashMap(key, 0.5)\n\n    if key.is_cuda:\n        warnings.warn(\n            \"Fallback to CPU-based mapping algorithm which may \"\n            \"cause slowdowns and device synchronization. Please \"\n            \"install 'pyg-lib' for an accelerated 'HashTensor' \"\n            \"implementation.\", stacklevel=2)\n\n    if torch_geometric.typing.WITH_CPU_HASH_MAP:\n        return CPUHashMap(key.cpu(), -1)\n\n    import pandas as pd\n\n    return pd.CategoricalDtype(\n        categories=key.cpu().numpy(),\n        ordered=True,\n    )\n\n\nclass HashTensor(Tensor):\n    r\"\"\"A :pytorch:`null` :class:`torch.Tensor` that can be referenced by\n    arbitrary keys rather than indices in the first dimension.\n\n    :class:`HashTensor` sub-classes a general :pytorch:`null`\n    :class:`torch.Tensor`, and extends it by CPU- and GPU-accelerated mapping\n    routines. This allow for fast and efficient access to non-contiguous\n    indices/keys while the underlying data is stored in a compact format.\n\n    This representation is ideal for scenarios where one needs a fast mapping\n    routine without relying on CPU-based external packages, and can be used,\n    *e.g.*, to perform mapping of global indices to local indices during\n    subgraph creation, or in data-processing pipelines to map non-contiguous\n    input data into a contiguous space, such as\n\n    * mapping of hashed node IDs to range :obj:`[0, num_nodes - 1]`\n    * mapping of raw input data, *e.g.*, categorical data to range\n      :obj:`[0, num_categories - 1]`\n\n    Specifically, :class:`HashTensor` supports *any* keys of *any* type,\n    *e.g.*, strings, timestamps, etc.\n\n    .. code-block:: python\n\n        from torch_geometric import HashTensor\n\n        key = torch.tensor([1000, 100, 10000])\n        value = torch.randn(3, 4)\n\n        tensor = HashTensor(key, value)\n        assert tensor.size() == (3, 4)\n\n        # Filtering:\n        query = torch.tensor([10000, 1000])\n        out = tensor[query]\n        assert out.equal(value[[2, 0]])\n\n        # Accessing non-existing keys:\n        out = tensor[[10000, 0]]\n        out.isnan()\n        >>> tensor([[False, False, False, False],\n        ...         [True, True, True, True])\n\n        # If `value` is not given, indexing returns the position of `query` in\n        # `key`, and `-1` otherwise:\n        key = ['Animation', 'Comedy', 'Fantasy']\n        tensor = HashTensor(key)\n\n        out = tensor[['Comedy', 'Romance']]\n        >>> tensor([1, -1])\n\n    Args:\n        key: The keys in the first dimension.\n        value: The values to hold.\n        dtype: The desired data type of the values of the returned tensor.\n        device: The device of the returned tensor.\n    \"\"\"\n    _map: Union[Tensor, CPUHashMap, CUDAHashMap]\n    _value: Optional[Tensor]\n    _min_key: Tensor\n    _max_key: Tensor\n\n    @staticmethod\n    def __new__(\n        cls: Type,\n        key: Any,\n        value: Optional[Any] = None,\n        *,\n        dtype: Optional[torch.dtype] = None,\n        device: Optional[torch.device] = None,\n    ) -> 'HashTensor':\n\n        if value is not None:\n            value = torch.as_tensor(value, dtype=dtype, device=device)\n            device = value.device\n\n        key = as_key_tensor(key, device=device)\n\n        if key.dim() != 1:\n            raise ValueError(f\"'key' data in '{cls.__name__}' needs to be \"\n                             f\"one-dimensional (got {key.dim()} dimensions)\")\n\n        if not key.is_contiguous():\n            raise ValueError(f\"'key' data in '{cls.__name__}' needs to be \"\n                             f\"contiguous\")\n\n        if value is not None:\n            if key.device != value.device:\n                raise ValueError(f\"'key' and 'value' data in '{cls.__name__}' \"\n                                 f\"are expected to be on the same device (got \"\n                                 f\"'{key.device}' and '{value.device}')\")\n\n            if key.numel() != value.size(0):\n                raise ValueError(f\"'key' and 'value' data in '{cls.__name__}' \"\n                                 f\"are expected to have the same size in the \"\n                                 f\"first dimension (got {key.size(0)} and \"\n                                 f\"{value.size(0)})\")\n\n        min_key = key.min() if key.numel() > 0 else key.new_zeros(())\n        max_key = key.max() if key.numel() > 0 else key.new_zeros(())\n\n        _range = max_key - min_key\n        # TODO Expose fixed threshold as argument.\n        if (key.dtype in {torch.uint8, torch.int16} or _range <= 1_000_000\n                or _range <= 2 * key.numel()):\n            _map = torch.full(\n                size=(_range + 3, ),\n                fill_value=-1,\n                dtype=torch.int64,\n                device=key.device,\n            )\n            _map[key.long() - (min_key.long() - 1)] = torch.arange(\n                key.numel(),\n                dtype=_map.dtype,\n                device=_map.device,\n            )\n        else:\n            _map = get_hash_map(key)\n\n        return cls._from_data(\n            _map,\n            value,\n            min_key,\n            max_key,\n            num_keys=key.numel(),\n            dtype=dtype,\n        )\n\n    # Private Methods #########################################################\n\n    @classmethod\n    def _from_data(\n        cls,\n        _map: Union[Tensor, CPUHashMap, CUDAHashMap],\n        value: Optional[Tensor],\n        min_key: Tensor,\n        max_key: Tensor,\n        *,\n        num_keys: int,\n        dtype: Optional[torch.dtype],\n    ) -> 'HashTensor':\n\n        if value is not None:\n            dtype = value.dtype\n            size = value.size()\n            stride = value.stride()\n            layout = value.layout\n            requires_grad = value.requires_grad\n        else:\n            dtype = dtype or torch.int64\n            size = torch.Size([num_keys])\n            stride = (1, )\n            layout = torch.strided\n            requires_grad = False\n\n        out = Tensor._make_wrapper_subclass(\n            cls,\n            size=size,\n            strides=stride,\n            dtype=dtype,\n            device=min_key.device,\n            layout=layout,\n            requires_grad=requires_grad,\n        )\n        assert isinstance(out, HashTensor)\n\n        out._map = _map\n        out._value = value\n        out._min_key = min_key\n        out._max_key = max_key\n\n        return out\n\n    @property\n    def _key(self) -> Tensor:\n        if isinstance(self._map, Tensor):\n            mask = self._map >= 0\n            key = mask.nonzero().view(-1) - 1\n            key = key[self._map[mask]]\n        elif (torch_geometric.typing.WITH_CUDA_HASH_MAP\n              or torch_geometric.typing.WITH_CPU_HASH_MAP):\n            key = self._map.keys().to(self.device)\n        else:\n            key = torch.from_numpy(self._map.categories.to_numpy())\n\n        return key.to(self.device)\n\n    def _shallow_copy(self) -> 'HashTensor':\n        return self._from_data(\n            self._map,\n            self._value,\n            self._min_key,\n            self._max_key,\n            num_keys=self.size(0),\n            dtype=self.dtype,\n        )\n\n    def _get(self, query: Tensor) -> Tensor:\n        if isinstance(self._map, Tensor):\n            index = query.long() - (self._min_key.long() - 1)\n            index = self._map[index.clamp_(min=0, max=self._map.numel() - 1)]\n        elif torch_geometric.typing.WITH_CUDA_HASH_MAP and query.is_cuda:\n            index = self._map.get(query)\n        elif torch_geometric.typing.WITH_CPU_HASH_MAP:\n            index = self._map.get(query.cpu())\n        else:\n            import pandas as pd\n\n            ser = pd.Series(query.cpu().numpy(), dtype=self._map)\n            index = torch.from_numpy(ser.cat.codes.to_numpy().copy()).long()\n\n        index = index.to(self.device)\n\n        if self._value is None:\n            return index.to(self.dtype)\n\n        out = self._value[index]\n\n        mask = index != -1\n        mask = mask.view([-1] + [1] * (out.dim() - 1))\n        fill_value = float('NaN') if out.is_floating_point() else -1\n        if torch_geometric.typing.WITH_PT20:\n            other: Union[int, float, Tensor] = fill_value\n        else:\n            other = torch.full_like(out, fill_value)\n\n        return out.where(mask, other)\n\n    # Methods #################################################################\n\n    def as_tensor(self) -> Tensor:\n        r\"\"\"Zero-copies the :class:`HashTensor` representation back to a\n        :class:`torch.Tensor` representation.\n        \"\"\"\n        if self._value is not None:\n            return self._value\n        return torch.arange(self.size(0), dtype=self.dtype, device=self.device)\n\n    # PyTorch/Python builtins #################################################\n\n    # Prevent auto-wrapping outputs back into the proper subclass type:\n    __torch_function__ = torch._C._disabled_torch_function_impl  # type: ignore\n\n    @classmethod\n    def __torch_dispatch__(  # type: ignore\n        cls: Type,\n        func: Callable[..., Any],\n        types: Iterable[Type[Any]],\n        args: Iterable[Tuple[Any, ...]] = (),\n        kwargs: Optional[Dict[Any, Any]] = None,\n    ) -> Any:\n        # Hold a number of `HANDLED_FUNCTIONS` that implement specific\n        # functions for valid `HashTensor` routines.\n        if func in HANDLED_FUNCTIONS:\n            return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))\n\n        # For all other PyTorch functions, we treat them as vanilla tensors.\n        args = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(), args)\n        if kwargs is not None:\n            kwargs = pytree.tree_map_only(HashTensor, lambda x: x.as_tensor(),\n                                          kwargs)\n        return func(*args, **(kwargs or {}))\n\n    def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:\n        attrs = ['_map', '_min_key', '_max_key']\n        if self._value is not None:\n            attrs.append('_value')\n\n        ctx = (self.size(0), self.dtype)\n\n        return attrs, ctx\n\n    @staticmethod\n    def __tensor_unflatten__(\n        inner_tensors: Dict[str, Any],\n        ctx: Tuple[Any, ...],\n        outer_size: Tuple[int, ...],\n        outer_stride: Tuple[int, ...],\n    ) -> 'HashTensor':\n        return HashTensor._from_data(\n            inner_tensors['_map'],\n            inner_tensors.get('_value', None),\n            inner_tensors['_min_key'],\n            inner_tensors['_min_key'],\n            num_keys=ctx[0],\n            dtype=ctx[1],\n        )\n\n    def __repr__(self) -> str:  # type: ignore\n        indent = len(f'{self.__class__.__name__}(')\n        tensor_str = torch._tensor_str._tensor_str(self.as_tensor(), indent)\n        return torch._tensor_str._str_intern(self, tensor_contents=tensor_str)\n\n    def tolist(self) -> List[Any]:\n        \"\"\"\"\"\"  # noqa: D419\n        return self.as_tensor().tolist()\n\n    def numpy(self, *, force: bool = False) -> np.ndarray:\n        \"\"\"\"\"\"  # noqa: D419\n        return self.as_tensor().numpy(force=force)\n\n    def index_select(  # type: ignore\n        self,\n        dim: int,\n        index: Any,\n    ) -> Union['HashTensor', Tensor]:\n        \"\"\"\"\"\"  # noqa: D419\n        return torch.index_select(self, dim, index)\n\n    def select(  # type: ignore\n        self,\n        dim: int,\n        index: Any,\n    ) -> Union['HashTensor', Tensor]:\n        \"\"\"\"\"\"  # noqa: D419\n        return torch.select(self, dim, index)\n\n    def share_memory_(self) -> 'HashTensor':\n        \"\"\"\"\"\"  # noqa: D419\n        if isinstance(self._map, Tensor):\n            self._map.share_memory_()\n        if self._value is not None:\n            self._value.share_memory_()\n        self._min_key.share_memory_()\n        self._max_key.share_memory_()\n        return self\n\n    def is_shared(self) -> bool:\n        \"\"\"\"\"\"  # noqa: D419\n        return self._min_key.is_shared()\n\n    def detach_(self) -> 'HashTensor':\n        \"\"\"\"\"\"  # noqa: D419\n        if self._value is not None:\n            self._value.detach_()\n        return super().detach_()  # type: ignore\n\n    def __getitem__(self, indices: Any) -> Union['HashTensor', Tensor]:\n        if not isinstance(indices, tuple):\n            indices = (indices, )\n        assert len(indices) > 0\n\n        # We convert any index tensor in the first dimension into a tensor.\n        # This means that downstream handling (i.e. in `aten.index.Tensor`)\n        # needs to take this pre-conversion into account. However, detecting\n        # whether the first dimension is indexed can be tricky at times:\n        # * We need to take into account `Ellipsis`\n        # * We need to take any unsqueezing into account\n        if indices[0] is Ellipsis and len(indices) > 1:\n            nonempty_indices = [i for i in indices[1:] if i is not None]\n            if len(nonempty_indices) == self.dim():\n                indices = indices[1:]\n\n        if isinstance(indices[0], (int, bool)):\n            index: Union[int, Tensor] = int(as_key_tensor([indices[0]]))\n            indices = (index, ) + indices[1:]\n        elif isinstance(indices[0], (Tensor, list, np.ndarray)):\n            index = as_key_tensor(indices[0], device=self.device)\n            indices = (index, ) + indices[1:]\n\n        indices = indices[0] if len(indices) == 1 else indices\n\n        return super().__getitem__(indices)\n\n\n@implements(aten.alias.default)\ndef _alias(tensor: HashTensor) -> HashTensor:\n    return tensor._shallow_copy()\n\n\n@implements(aten.clone.default)\ndef _clone(\n    tensor: HashTensor,\n    *,\n    memory_format: torch.memory_format = torch.preserve_format,\n) -> HashTensor:\n\n    value = tensor._value\n    if value is not None:\n        value = aten.clone.default(value, memory_format=memory_format)\n\n    return tensor._from_data(\n        tensor._map,  # NOTE No reason to do clone since it is read-only.\n        value,\n        tensor._min_key,  # NOTE No reason to do clone since it is read-only.\n        tensor._max_key,  # NOTE No reason to do clone since it is read-only.\n        num_keys=tensor.size(0),\n        dtype=tensor.dtype,\n    )\n\n\n@implements(aten.detach.default)\ndef _detach(tensor: HashTensor) -> HashTensor:\n    value = tensor._value\n    if value is not None:\n        value = aten.detach.default(value)\n\n    return tensor._from_data(\n        tensor._map,\n        value,\n        tensor._min_key,\n        tensor._max_key,\n        num_keys=tensor.size(0),\n        dtype=tensor.dtype,\n    )\n\n\n@implements(aten._to_copy.default)\ndef _to_copy(\n    tensor: HashTensor,\n    *,\n    dtype: Optional[torch.dtype] = None,\n    layout: Optional[torch.layout] = None,\n    device: Optional[torch.device] = None,\n    pin_memory: bool = False,\n    non_blocking: bool = False,\n    memory_format: Optional[torch.memory_format] = None,\n) -> HashTensor:\n\n    value = tensor._value\n    if value is not None:\n        value = aten._to_copy.default(\n            value,\n            dtype=dtype,\n            layout=layout,\n            device=device,\n            pin_memory=pin_memory,\n            non_blocking=non_blocking,\n            memory_format=memory_format,\n        )\n\n    min_key = aten._to_copy.default(tensor._min_key, device=device)\n    max_key = aten._to_copy.default(tensor._max_key, device=device)\n\n    _map = tensor._map\n    if isinstance(_map, Tensor):\n        _map = aten._to_copy.default(_map, device=device)\n    # Only convert `_map` in case `CUDAHashMap` exists - otherwise we use\n    # CPU-based mapping anyway and there is no need for a copy.\n    elif (torch_geometric.typing.WITH_CUDA_HASH_MAP and tensor.is_cuda\n          and tensor.device != min_key.device):\n        key = _map.keys()\n        key = aten._to_copy.default(key, device=device)\n        _map = get_hash_map(key)\n\n    return tensor._from_data(\n        _map,\n        value,\n        min_key,\n        max_key,\n        num_keys=tensor.size(0),\n        dtype=dtype or tensor.dtype,\n    )\n\n\n@implements(aten._pin_memory.default)\ndef _pin_memory(tensor: HashTensor) -> HashTensor:\n    _map = tensor._map\n    if isinstance(_map, Tensor):\n        _map = aten._pin_memory.default(_map)\n\n    value = tensor._value\n    if value is not None:\n        value = aten._pin_memory.default(value)\n\n    return tensor._from_data(\n        _map,\n        value,\n        aten._pin_memory.default(tensor._min_key),\n        aten._pin_memory.default(tensor._max_key),\n        num_keys=tensor.size(0),\n        dtype=tensor.dtype,\n    )\n\n\n@implements(aten.unsqueeze.default)\ndef _unsqueeze(tensor: HashTensor, dim: int) -> HashTensor:\n    if dim == 0 or dim == -(tensor.dim() + 1):\n        raise IndexError(f\"Cannot unsqueeze '{tensor.__class__.__name__}' in \"\n                         f\"the first dimension. Please call `as_tensor()` \"\n                         f\"beforehand\")\n\n    return tensor._from_data(\n        tensor._map,\n        aten.unsqueeze.default(tensor.as_tensor(), dim),\n        tensor._min_key,\n        tensor._max_key,\n        num_keys=tensor.size(0),\n        dtype=tensor.dtype,\n    )\n\n\n@implements(aten.squeeze.default)\ndef _squeeze_default(tensor: HashTensor) -> HashTensor:\n    if tensor._value is None:\n        return tensor._shallow_copy()\n\n    value = tensor.as_tensor()\n    for d in range(tensor.dim() - 1, 0, -1):\n        value = value.squeeze(d)\n\n    return tensor._from_data(\n        tensor._map,\n        value,\n        tensor._min_key,\n        tensor._max_key,\n        num_keys=tensor.size(0),\n        dtype=tensor.dtype,\n    )\n\n\n@implements(aten.squeeze.dim)\n@implements(getattr(aten.squeeze, 'dims', aten.squeeze.dim))\ndef _squeeze_dim(\n    tensor: HashTensor,\n    dim: Union[int, List[int]],\n) -> HashTensor:\n    if isinstance(dim, int):\n        dim = [dim]\n\n    for d in dim:\n        if d < -tensor.dim() or d >= tensor.dim():\n            raise IndexError(f\"Dimension out of range (expected to be in \"\n                             f\"range of [{-tensor.dim()}, {tensor.dim()-1}], \"\n                             f\"but got {d})\")\n\n    if tensor._value is None:\n        return tensor._shallow_copy()\n\n    value = tensor.as_tensor()\n    for d in dim[::-1]:\n        if d != 0 and d != -tensor.dim():\n            value = value.squeeze(d)\n\n    return tensor._from_data(\n        tensor._map,\n        value,\n        tensor._min_key,\n        tensor._max_key,\n        num_keys=tensor.size(0),\n        dtype=tensor.dtype,\n    )\n\n\n@implements(aten.slice.Tensor)\ndef _slice(\n    tensor: HashTensor,\n    dim: int,\n    start: Optional[int] = None,\n    end: Optional[int] = None,\n    step: int = 1,\n) -> HashTensor:\n\n    if dim == 0 or dim == -tensor.dim():\n        copy = start is None or start == 0 or start <= -tensor.size(0)\n        copy &= end is None or end > tensor.size(0)\n        copy &= step == 1\n        if copy:\n            return tensor._shallow_copy()\n\n        key = aten.slice.Tensor(tensor._key, 0, start, end, step)\n        value = aten.slice.Tensor(tensor.as_tensor(), 0, start, end, step)\n        return tensor.__class__(key, value)\n\n    return tensor._from_data(\n        tensor._map,\n        aten.slice.Tensor(tensor.as_tensor(), dim, start, end, step),\n        tensor._min_key,\n        tensor._max_key,\n        num_keys=tensor.size(0),\n        dtype=tensor.dtype,\n    )\n\n\n# Since PyTorch does only allow PyTorch tensors as indices in `index_select`,\n# we need to create a wrapper function and monkey patch `index_select` :(\n_old_index_select = torch.index_select\n\n\ndef _new_index_select(\n    input: Tensor,\n    dim: Union[int, str],\n    index: Tensor,\n    out: Optional[Tensor] = None,\n) -> Tensor:\n\n    if isinstance(dim, int) and (dim < -input.dim() or dim >= input.dim()):\n        raise IndexError(f\"Dimension out of range (expected to be in range of \"\n                         f\"[{-input.dim()}, {input.dim()-1}], but got {dim})\")\n\n    # We convert any index tensor in the first dimension into a tensor. This\n    # means that downstream handling (i.e. in `aten.index_select.default`)\n    # needs to take this pre-conversion into account.\n    if (not torch.jit.is_scripting() and isinstance(input, HashTensor)\n            and isinstance(dim, int) and (dim == 0 or dim == -input.dim())):\n        index = as_key_tensor(index, device=input.device)\n\n    if isinstance(dim, int):  # Type narrowing...\n        if out is None:\n            return _old_index_select(input, dim, index)\n        else:\n            return _old_index_select(input, dim, index, out=out)\n    else:\n        if out is None:\n            return _old_index_select(input, dim, index)\n        else:\n            return _old_index_select(input, dim, index, out=out)\n\n\ntorch.index_select = _new_index_select  # type: ignore\n\n\n@implements(aten.index_select.default)\ndef _index_select(\n    tensor: HashTensor,\n    dim: int,\n    index: Tensor,\n) -> Union[HashTensor, Tensor]:\n\n    if dim == 0 or dim == -tensor.dim():\n        return tensor._get(index)\n\n    return tensor._from_data(\n        tensor._map,\n        aten.index_select.default(tensor.as_tensor(), dim, index),\n        tensor._min_key,\n        tensor._max_key,\n        num_keys=tensor.size(0),\n        dtype=tensor.dtype,\n    )\n\n\n# Since PyTorch does only allow PyTorch tensors as indices in `select`, we need\n# to create a wrapper function and monkey patch `select` :(\n_old_select = torch.select\n\n\ndef _new_select(\n    input: Tensor,\n    dim: Union[int, str],\n    index: int,\n) -> Tensor:\n\n    if isinstance(dim, int) and (dim < -input.dim() or dim >= input.dim()):\n        raise IndexError(f\"Dimension out of range (expected to be in range of \"\n                         f\"[{-input.dim()}, {input.dim()-1}], but got {dim})\")\n\n    # We convert any index in the first dimension into an integer. This means\n    # that downstream handling (i.e. in `aten.select.int`) needs to take this\n    # pre-conversion into account.\n    if (not torch.jit.is_scripting() and isinstance(input, HashTensor)\n            and isinstance(dim, int) and (dim == 0 or dim == -input.dim())):\n        index = int(as_key_tensor([index]))\n\n    if isinstance(dim, int):  # Type narrowing...\n        return _old_select(input, dim, index)\n    else:\n        return _old_select(input, dim, index)\n\n\ntorch.select = _new_select  # type: ignore\n\n\n@implements(aten.select.int)\ndef _select(\n    tensor: HashTensor,\n    dim: int,\n    index: int,\n) -> Union[HashTensor, Tensor]:\n\n    if dim == 0 or dim == -tensor.dim():\n        key = torch.tensor(\n            [index],\n            dtype=tensor._min_key.dtype,\n            device=tensor.device,\n        )\n        return tensor._get(key).squeeze(0)\n\n    return tensor._from_data(\n        tensor._map,\n        aten.select.int(tensor.as_tensor(), dim, index),\n        tensor._min_key,\n        tensor._max_key,\n        num_keys=tensor.size(0),\n        dtype=tensor.dtype,\n    )\n\n\n@implements(aten.index.Tensor)\ndef _index(\n    tensor: HashTensor,\n    indices: List[Optional[Tensor]],\n) -> Union[HashTensor, Tensor]:\n\n    assert len(indices) > 0\n\n    if indices[0] is not None:\n        out = tensor._get(indices[0])\n        if len(indices) > 1:\n            out = aten.index.Tensor(out, [None] + indices[1:])\n        return out\n\n    return tensor._from_data(\n        tensor._map,\n        aten.index.Tensor(tensor.as_tensor(), indices),\n        tensor._min_key,\n        tensor._max_key,\n        num_keys=tensor.size(0),\n        dtype=tensor.dtype,\n    )\n"
  },
  {
    "path": "torch_geometric/home.py",
    "content": "import os\nimport os.path as osp\nfrom typing import Optional\n\nENV_PYG_HOME = 'PYG_HOME'\nDEFAULT_CACHE_DIR = osp.join('~', '.cache', 'pyg')\n\n_home_dir: Optional[str] = None\n\n\ndef get_home_dir() -> str:\n    r\"\"\"Get the cache directory used for storing all :pyg:`PyG`-related data.\n\n    If :meth:`set_home_dir` is not called, the path is given by the environment\n    variable :obj:`$PYG_HOME` which defaults to :obj:`\"~/.cache/pyg\"`.\n    \"\"\"\n    if _home_dir is not None:\n        return _home_dir\n\n    return osp.expanduser(os.getenv(ENV_PYG_HOME, DEFAULT_CACHE_DIR))\n\n\ndef set_home_dir(path: str) -> None:\n    r\"\"\"Set the cache directory used for storing all :pyg:`PyG`-related data.\n\n    Args:\n        path (str): The path to a local folder.\n    \"\"\"\n    global _home_dir\n    _home_dir = path\n"
  },
  {
    "path": "torch_geometric/index.py",
    "content": "import functools\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    List,\n    NamedTuple,\n    Optional,\n    Tuple,\n    Type,\n    Union,\n)\n\nimport numpy as np\nimport torch\nimport torch.utils._pytree as pytree\nfrom torch import Tensor\n\nfrom torch_geometric.typing import INDEX_DTYPES\n\naten = torch.ops.aten\n\nHANDLED_FUNCTIONS: Dict[Callable, Callable] = {}\n\n\ndef ptr2index(ptr: Tensor, output_size: Optional[int] = None) -> Tensor:\n    index = torch.arange(ptr.numel() - 1, dtype=ptr.dtype, device=ptr.device)\n    return index.repeat_interleave(ptr.diff(), output_size=output_size)\n\n\ndef index2ptr(index: Tensor, size: Optional[int] = None) -> Tensor:\n    if size is None:\n        size = int(index.max()) + 1 if index.numel() > 0 else 0\n\n    return torch._convert_indices_from_coo_to_csr(\n        index, size, out_int32=index.dtype != torch.int64)\n\n\nclass CatMetadata(NamedTuple):\n    nnz: List[int]\n    dim_size: List[Optional[int]]\n    is_sorted: List[bool]\n\n\ndef implements(torch_function: Callable) -> Callable:\n    r\"\"\"Registers a :pytorch:`PyTorch` function override.\"\"\"\n    @functools.wraps(torch_function)\n    def decorator(my_function: Callable) -> Callable:\n        HANDLED_FUNCTIONS[torch_function] = my_function\n        return my_function\n\n    return decorator\n\n\ndef assert_valid_dtype(tensor: Tensor) -> None:\n    if tensor.dtype not in INDEX_DTYPES:\n        raise ValueError(f\"'Index' holds an unsupported data type \"\n                         f\"(got '{tensor.dtype}', but expected one of \"\n                         f\"{INDEX_DTYPES})\")\n\n\ndef assert_one_dimensional(tensor: Tensor) -> None:\n    if tensor.dim() != 1:\n        raise ValueError(f\"'Index' needs to be one-dimensional \"\n                         f\"(got {tensor.dim()} dimensions)\")\n\n\ndef assert_contiguous(tensor: Tensor) -> None:\n    if not tensor.is_contiguous():\n        raise ValueError(\"'Index' needs to be contiguous. Please call \"\n                         \"`index.contiguous()` before proceeding.\")\n\n\ndef assert_sorted(func: Callable) -> Callable:\n    @functools.wraps(func)\n    def wrapper(self: 'Index', *args: Any, **kwargs: Any) -> Any:\n        if not self.is_sorted:\n            cls_name = self.__class__.__name__\n            raise ValueError(\n                f\"Cannot call '{func.__name__}' since '{cls_name}' is not \"\n                f\"sorted. Please call `{cls_name}.sort()` first.\")\n        return func(self, *args, **kwargs)\n\n    return wrapper\n\n\nclass Index(Tensor):\n    r\"\"\"A one-dimensional :obj:`index` tensor with additional (meta)data\n    attached.\n\n    :class:`Index` is a :pytorch:`null` :class:`torch.Tensor` that holds\n    indices of shape :obj:`[num_indices]`.\n\n    While :class:`Index` sub-classes a general :pytorch:`null`\n    :class:`torch.Tensor`, it can hold additional (meta)data, *i.e.*:\n\n    * :obj:`dim_size`: The size of the underlying sparse vector size, *i.e.*,\n      the size of a dimension that can be indexed via :obj:`index`.\n      By default, it is inferred as :obj:`dim_size=index.max() + 1`.\n    * :obj:`is_sorted`: Whether indices are sorted in ascending order.\n\n    Additionally, :class:`Index` caches data via :obj:`indptr` for fast CSR\n    conversion in case its representation is sorted.\n    Caches are filled based on demand (*e.g.*, when calling\n    :meth:`Index.get_indptr`), or when explicitly requested via\n    :meth:`Index.fill_cache_`, and are maintained and adjusted over its\n    lifespan.\n\n    This representation ensures optimal computation in GNN message passing\n    schemes, while preserving the ease-of-use of regular COO-based :pyg:`PyG`\n    workflows.\n\n    .. code-block:: python\n\n        from torch_geometric import Index\n\n        index = Index([0, 1, 1, 2], dim_size=3, is_sorted=True)\n        >>> Index([0, 1, 1, 2], dim_size=3, is_sorted=True)\n        assert index.dim_size == 3\n        assert index.is_sorted\n\n        # Flipping order:\n        index.flip(0)\n        >>> Index([[2, 1, 1, 0], dim_size=3)\n        assert not index.is_sorted\n\n        # Filtering:\n        mask = torch.tensor([True, True, True, False])\n        index[:, mask]\n        >>> Index([[0, 1, 1], dim_size=3, is_sorted=True)\n        assert index.is_sorted\n    \"\"\"\n    # See \"https://pytorch.org/docs/stable/notes/extending.html\"\n    # for a basic tutorial on how to subclass `torch.Tensor`.\n\n    # The underlying tensor representation:\n    _data: Tensor\n\n    # The size of the underlying sparse vector, e.g. `_data.max() + 1` :\n    _dim_size: Optional[int] = None\n\n    # Whether the `index` representation is sorted:\n    _is_sorted: bool = False\n\n    # A cache for its compressed representation:\n    _indptr: Optional[Tensor] = None\n\n    # Whenever we perform a concatenation of indices, we cache the original\n    # metadata to be able to reconstruct individual indices:\n    _cat_metadata: Optional[CatMetadata] = None\n\n    @staticmethod\n    def __new__(\n        cls: Type,\n        data: Any,\n        *args: Any,\n        dim_size: Optional[int] = None,\n        is_sorted: bool = False,\n        **kwargs: Any,\n    ) -> 'Index':\n        if not isinstance(data, Tensor):\n            data = torch.tensor(data, *args, **kwargs)\n        elif len(args) > 0:\n            raise TypeError(\n                f\"new() received an invalid combination of arguments - got \"\n                f\"(Tensor, {', '.join(str(type(arg)) for arg in args)})\")\n        elif len(kwargs) > 0:\n            raise TypeError(f\"new() received invalid keyword arguments - got \"\n                            f\"{set(kwargs.keys())})\")\n\n        assert isinstance(data, Tensor)\n\n        indptr: Optional[Tensor] = None\n\n        if isinstance(data, cls):  # If passed `Index`, inherit metadata:\n            indptr = data._indptr\n            dim_size = dim_size or data.dim_size\n            is_sorted = is_sorted or data.is_sorted\n\n        assert_valid_dtype(data)\n        assert_one_dimensional(data)\n        assert_contiguous(data)\n\n        out = Tensor._make_wrapper_subclass(\n            cls,\n            size=data.size(),\n            strides=data.stride(),\n            dtype=data.dtype,\n            device=data.device,\n            layout=data.layout,\n            requires_grad=False,\n        )\n        assert isinstance(out, Index)\n\n        # Attach metadata:\n        out._data = data\n        out._dim_size = dim_size\n        out._is_sorted = is_sorted\n        out._indptr = indptr\n\n        if isinstance(data, cls):\n            out._data = data._data\n\n            # Reset metadata if cache is invalidated:\n            if dim_size is not None and dim_size != data.dim_size:\n                out._indptr = None\n\n        return out\n\n    # Validation ##############################################################\n\n    def validate(self) -> 'Index':\n        r\"\"\"Validates the :class:`Index` representation.\n\n        In particular, it ensures that\n\n        * it only holds valid indices.\n        * the sort order is correctly set.\n        \"\"\"\n        assert_valid_dtype(self._data)\n        assert_one_dimensional(self._data)\n        assert_contiguous(self._data)\n\n        if self.numel() > 0 and self._data.min() < 0:\n            raise ValueError(f\"'{self.__class__.__name__}' contains negative \"\n                             f\"indices (got {int(self.min())})\")\n\n        if (self.numel() > 0 and self.dim_size is not None\n                and self._data.max() >= self.dim_size):\n            raise ValueError(f\"'{self.__class__.__name__}' contains larger \"\n                             f\"indices than its registered size \"\n                             f\"(got {int(self._data.max())}, but expected \"\n                             f\"values smaller than {self.dim_size})\")\n\n        if self.is_sorted and (self._data.diff() < 0).any():\n            raise ValueError(f\"'{self.__class__.__name__}' is not sorted\")\n\n        return self\n\n    # Properties ##############################################################\n\n    @property\n    def dim_size(self) -> Optional[int]:\n        r\"\"\"The size of the underlying sparse vector.\"\"\"\n        return self._dim_size\n\n    @property\n    def is_sorted(self) -> bool:\n        r\"\"\"Returns whether indices are sorted in ascending order.\"\"\"\n        return self._is_sorted\n\n    @property\n    def dtype(self) -> torch.dtype:  # type: ignore\n        # TODO Remove once PyTorch does not override `dtype` in `DataLoader`.\n        return self._data.dtype\n\n    # Cache Interface #########################################################\n\n    def get_dim_size(self) -> int:\n        r\"\"\"The size of the underlying sparse vector.\n        Automatically computed and cached when not explicitly set.\n        \"\"\"\n        if self._dim_size is None:\n            dim_size = int(self._data.max()) + 1 if self.numel() > 0 else 0\n            self._dim_size = dim_size\n\n        assert isinstance(self._dim_size, int)\n        return self._dim_size\n\n    def dim_resize_(self, dim_size: Optional[int]) -> 'Index':\n        r\"\"\"Assigns or re-assigns the size of the underlying sparse vector.\"\"\"\n        if self.is_sorted and self._indptr is not None:\n            if dim_size is None:\n                self._indptr = None\n\n            elif self._indptr.numel() - 1 >= dim_size:\n                self._indptr = self._indptr[:dim_size + 1]\n\n            else:\n                fill_value = self._indptr.new_full(\n                    (dim_size - self._indptr.numel() + 1, ),\n                    fill_value=self._indptr[-1],  # type: ignore\n                )\n                self._indptr = torch.cat([self._indptr, fill_value], dim=0)\n\n        self._dim_size = dim_size\n\n        return self\n\n    @assert_sorted\n    def get_indptr(self) -> Tensor:\n        r\"\"\"Returns the compressed index representation in case :class:`Index`\n        is sorted.\n        \"\"\"\n        if self._indptr is None:\n            self._indptr = index2ptr(self._data, self.get_dim_size())\n\n        assert isinstance(self._indptr, Tensor)\n        return self._indptr\n\n    def fill_cache_(self) -> 'Index':\n        r\"\"\"Fills the cache with (meta)data information.\"\"\"\n        self.get_dim_size()\n\n        if self.is_sorted:\n            self.get_indptr()\n\n        return self\n\n    # Methods #################################################################\n\n    def share_memory_(self) -> 'Index':\n        \"\"\"\"\"\"  # noqa: D419\n        self._data.share_memory_()\n        if self._indptr is not None:\n            self._indptr.share_memory_()\n        return self\n\n    def is_shared(self) -> bool:\n        \"\"\"\"\"\"  # noqa: D419\n        return self._data.is_shared()\n\n    def as_tensor(self) -> Tensor:\n        r\"\"\"Zero-copies the :class:`Index` representation back to a\n        :class:`torch.Tensor` representation.\n        \"\"\"\n        return self._data\n\n    # PyTorch/Python builtins #################################################\n\n    def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:\n        attrs = ['_data']\n        if self._indptr is not None:\n            attrs.append('_indptr')\n\n        ctx = (\n            self._dim_size,\n            self._is_sorted,\n            self._cat_metadata,\n        )\n\n        return attrs, ctx\n\n    @staticmethod\n    def __tensor_unflatten__(\n        inner_tensors: Dict[str, Any],\n        ctx: Tuple[Any, ...],\n        outer_size: Tuple[int, ...],\n        outer_stride: Tuple[int, ...],\n    ) -> 'Index':\n        index = Index(\n            inner_tensors['_data'],\n            dim_size=ctx[0],\n            is_sorted=ctx[1],\n        )\n\n        index._indptr = inner_tensors.get('_indptr', None)\n        index._cat_metadata = ctx[2]\n\n        return index\n\n    # Prevent auto-wrapping outputs back into the proper subclass type:\n    __torch_function__ = torch._C._disabled_torch_function_impl  # type: ignore\n\n    @classmethod\n    def __torch_dispatch__(  # type: ignore\n        cls: Type,\n        func: Callable[..., Any],\n        types: Iterable[Type[Any]],\n        args: Iterable[Tuple[Any, ...]] = (),\n        kwargs: Optional[Dict[Any, Any]] = None,\n    ) -> Any:\n        # `Index` should be treated as a regular PyTorch tensor for all\n        # standard PyTorch functionalities. However,\n        # * some of its metadata can be transferred to new functions, e.g.,\n        #   `torch.narrow()` can inherit the `is_sorted` property.\n        # * not all operations lead to valid `Index` tensors again, e.g.,\n        #   `torch.sum()` does not yield a `Index` as its output, or\n        #   `torch.stack() violates the [*] shape assumption.\n\n        # To account for this, we hold a number of `HANDLED_FUNCTIONS` that\n        # implement specific functions for valid `Index` routines.\n        if func in HANDLED_FUNCTIONS:\n            return HANDLED_FUNCTIONS[func](*args, **(kwargs or {}))\n\n        # For all other PyTorch functions, we treat them as vanilla tensors.\n        args = pytree.tree_map_only(Index, lambda x: x._data, args)\n        if kwargs is not None:\n            kwargs = pytree.tree_map_only(Index, lambda x: x._data, kwargs)\n        return func(*args, **(kwargs or {}))\n\n    def __repr__(self) -> str:  # type: ignore\n        prefix = f'{self.__class__.__name__}('\n        indent = len(prefix)\n        tensor_str = torch._tensor_str._tensor_str(self._data, indent)\n\n        suffixes = []\n        if self.dim_size is not None:\n            suffixes.append(f'dim_size={self.dim_size}')\n        if (self.device.type != torch._C._get_default_device()\n                or (self.device.type == 'cuda'\n                    and torch.cuda.current_device() != self.device.index)\n                or (self.device.type == 'mps')):\n            suffixes.append(f\"device='{self.device}'\")\n        if self.dtype != torch.int64:\n            suffixes.append(f'dtype={self.dtype}')\n        if self.is_sorted:\n            suffixes.append('is_sorted=True')\n\n        return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,\n                                               indent, force_newline=False)\n\n    def tolist(self) -> List[Any]:\n        \"\"\"\"\"\"  # noqa: D419\n        return self._data.tolist()\n\n    def numpy(self, *, force: bool = False) -> np.ndarray:\n        \"\"\"\"\"\"  # noqa: D419\n        return self._data.numpy(force=force)\n\n    # Helpers #################################################################\n\n    def _shallow_copy(self) -> 'Index':\n        out = Index(self._data)\n        out._dim_size = self._dim_size\n        out._is_sorted = self._is_sorted\n        out._indptr = self._indptr\n        out._cat_metadata = self._cat_metadata\n        return out\n\n    def _clear_metadata(self) -> 'Index':\n        self._dim_size = None\n        self._is_sorted = False\n        self._indptr = None\n        self._cat_metadata = None\n        return self\n\n\ndef apply_(\n    tensor: Index,\n    fn: Callable,\n    *args: Any,\n    **kwargs: Any,\n) -> Union[Index, Tensor]:\n\n    data = fn(tensor._data, *args, **kwargs)\n\n    if data.dtype not in INDEX_DTYPES:\n        return data\n\n    if tensor._data.data_ptr() != data.data_ptr():\n        out = Index(data)\n    else:  # In-place:\n        tensor._data = data\n        out = tensor\n\n    # Copy metadata:\n    out._dim_size = tensor._dim_size\n    out._is_sorted = tensor._is_sorted\n    out._cat_metadata = tensor._cat_metadata\n\n    # Convert cache:\n    if tensor._indptr is not None:\n        out._indptr = fn(tensor._indptr, *args, **kwargs)\n\n    return out\n\n\n@implements(aten.clone.default)\ndef _clone(\n    tensor: Index,\n    *,\n    memory_format: torch.memory_format = torch.preserve_format,\n) -> Index:\n    out = apply_(tensor, aten.clone.default, memory_format=memory_format)\n    assert isinstance(out, Index)\n    return out\n\n\n@implements(aten._to_copy.default)\ndef _to_copy(\n    tensor: Index,\n    *,\n    dtype: Optional[torch.dtype] = None,\n    layout: Optional[torch.layout] = None,\n    device: Optional[torch.device] = None,\n    pin_memory: bool = False,\n    non_blocking: bool = False,\n    memory_format: Optional[torch.memory_format] = None,\n) -> Union[Index, Tensor]:\n    return apply_(\n        tensor,\n        aten._to_copy.default,\n        dtype=dtype,\n        layout=layout,\n        device=device,\n        pin_memory=pin_memory,\n        non_blocking=non_blocking,\n        memory_format=memory_format,\n    )\n\n\n@implements(aten.alias.default)\ndef _alias(tensor: Index) -> Index:\n    return tensor._shallow_copy()\n\n\n@implements(aten._pin_memory.default)\ndef _pin_memory(tensor: Index) -> Index:\n    out = apply_(tensor, aten._pin_memory.default)\n    assert isinstance(out, Index)\n    return out\n\n\n@implements(aten.sort.default)\ndef _sort(\n    tensor: Index,\n    dim: int = -1,\n    descending: bool = False,\n) -> Tuple[Index, Tensor]:\n\n    if tensor.is_sorted and not descending:\n        return tensor, torch.arange(tensor._data.numel(),\n                                    device=tensor._data.device)\n\n    data, perm = aten.sort.default(tensor._data, dim, descending)\n\n    out = Index(data)\n    out._dim_size = tensor._dim_size\n\n    if not descending:\n        out._is_sorted = True\n\n    return out, perm\n\n\n@implements(aten.sort.stable)\ndef _sort_stable(\n    tensor: Index,\n    *,\n    stable: bool = False,\n    dim: int = -1,\n    descending: bool = False,\n) -> Tuple[Index, Tensor]:\n\n    if tensor.is_sorted and not descending:\n        return tensor, torch.arange(tensor._data.numel(),\n                                    device=tensor._data.device)\n\n    data, perm = aten.sort.stable(tensor._data, stable=stable, dim=dim,\n                                  descending=descending)\n\n    out = Index(data)\n    out._dim_size = tensor._dim_size\n\n    if not descending:\n        out._is_sorted = True\n\n    return out, perm\n\n\n@implements(aten.cat.default)\ndef _cat(\n    tensors: List[Union[Index, Tensor]],\n    dim: int = 0,\n) -> Union[Index, Tensor]:\n\n    data_list = pytree.tree_map_only(Index, lambda x: x._data, tensors)\n    data = aten.cat.default(data_list, dim=dim)\n\n    if any([not isinstance(tensor, Index) for tensor in tensors]):\n        return data\n\n    out = Index(data)\n\n    nnz_list = [t.numel() for t in tensors]\n    dim_size_list = [t.dim_size for t in tensors]  # type: ignore\n    is_sorted_list = [t.is_sorted for t in tensors]  # type: ignore\n\n    # Post-process `dim_size`:\n    total_dim_size: Optional[int] = 0\n    for dim_size in dim_size_list:\n        if dim_size is None:\n            total_dim_size = None\n            break\n        assert isinstance(total_dim_size, int)\n        total_dim_size = max(dim_size, total_dim_size)\n\n    out._dim_size = total_dim_size\n\n    out._cat_metadata = CatMetadata(\n        nnz=nnz_list,\n        dim_size=dim_size_list,\n        is_sorted=is_sorted_list,\n    )\n\n    return out\n\n\n@implements(aten.flip.default)\ndef _flip(\n    input: Index,\n    dims: Union[List[int], Tuple[int, ...]],\n) -> Index:\n\n    data = aten.flip.default(input._data, dims)\n\n    out = Index(data)\n    out._dim_size = input.dim_size\n\n    return out\n\n\n@implements(aten.index_select.default)\ndef _index_select(\n    input: Union[Index, Tensor],\n    dim: int,\n    index: Union[Index, Tensor],\n) -> Union[Index, Tensor]:\n\n    out = aten.index_select.default(\n        input._data if isinstance(input, Index) else input,\n        dim,\n        index._data if isinstance(index, Index) else index,\n    )\n\n    if isinstance(input, Index):\n        out = Index(out)\n        out._dim_size = input.dim_size\n\n    return out\n\n\n@implements(aten.slice.Tensor)\ndef _slice(\n    input: Index,\n    dim: int,\n    start: Optional[int] = None,\n    end: Optional[int] = None,\n    step: int = 1,\n) -> Index:\n\n    if ((start is None or start <= 0 or start <= -input.size(dim))\n            and (end is None or end > input.size(dim)) and step == 1):\n        return input._shallow_copy()  # No-op.\n\n    data = aten.slice.Tensor(input._data, dim, start, end, step)\n\n    if step != 1:\n        data = data.contiguous()\n\n    out = Index(data)\n    out._dim_size = input.dim_size\n    # NOTE We could potentially maintain the `indptr` attribute here,\n    # but it is not really clear if this is worth it. The most important\n    # information `is_sorted` needs to be maintained though:\n    if step >= 0:\n        out._is_sorted = input.is_sorted\n\n    return out\n\n\n@implements(aten.index.Tensor)\ndef _index(\n    input: Union[Index, Tensor],\n    indices: List[Optional[Union[Tensor, Index]]],\n) -> Union[Index, Tensor]:\n\n    if not isinstance(input, Index):\n        indices = pytree.tree_map_only(Index, lambda x: x._data, indices)\n        return aten.index.Tensor(input, indices)\n\n    data = aten.index.Tensor(input._data, indices)\n\n    if data.dim() != 1:\n        return data\n\n    assert len(indices) == 1\n    index = indices[0]\n    assert index is not None\n\n    out = Index(data)\n\n    if index.dtype in (torch.bool, torch.uint8):  # 1. `index[mask]`.\n        out._dim_size = input.dim_size\n        out._is_sorted = input.is_sorted\n\n    else:  # 2. `index[index]`.\n        out._dim_size = input.dim_size\n\n    return out\n\n\n@implements(aten.add.Tensor)\ndef _add(\n    input: Union[int, Tensor, Index],\n    other: Union[int, Tensor, Index],\n    *,\n    alpha: int = 1,\n) -> Union[Index, Tensor]:\n\n    data = aten.add.Tensor(\n        input._data if isinstance(input, Index) else input,\n        other._data if isinstance(other, Index) else other,\n        alpha=alpha,\n    )\n\n    if data.dtype not in INDEX_DTYPES:\n        return data\n    if data.dim() != 1:\n        return data\n\n    out = Index(data)\n\n    if isinstance(input, Tensor) and input.numel() <= 1:\n        input = int(input)\n\n    if isinstance(other, Tensor) and other.numel() <= 1:\n        other = int(other)\n\n    if isinstance(other, int):\n        assert isinstance(input, Index)\n        if input.dim_size is not None:\n            out._dim_size = input.dim_size + alpha * other\n        out._is_sorted = input.is_sorted\n\n    elif isinstance(input, int):\n        assert isinstance(other, Index)\n        if other.dim_size is not None:\n            out._dim_size = input + alpha * other.dim_size\n        out._is_sorted = other.is_sorted\n\n    elif isinstance(input, Index) and isinstance(other, Index):\n        if input.dim_size is not None and other.dim_size is not None:\n            out._dim_size = input.dim_size + alpha * other.dim_size\n\n    return out\n\n\n@implements(aten.add_.Tensor)\ndef add_(\n    input: Index,\n    other: Union[int, Tensor, Index],\n    *,\n    alpha: int = 1,\n) -> Index:\n\n    dim_size = input.dim_size\n    is_sorted = input.is_sorted\n    input._clear_metadata()\n\n    aten.add_.Tensor(\n        input._data,\n        other._data if isinstance(other, Index) else other,\n        alpha=alpha,\n    )\n\n    if isinstance(other, Tensor) and other.numel() <= 1:\n        other = int(other)\n\n    if isinstance(other, int):\n        if dim_size is not None:\n            input._dim_size = dim_size + alpha * other\n        input._is_sorted = is_sorted\n\n    elif isinstance(other, Index):\n        if dim_size is not None and other.dim_size is not None:\n            input._dim_size = dim_size + alpha * other.dim_size\n\n    return input\n\n\n@implements(aten.sub.Tensor)\ndef _sub(\n    input: Union[int, Tensor, Index],\n    other: Union[int, Tensor, Index],\n    *,\n    alpha: int = 1,\n) -> Union[Index, Tensor]:\n\n    data = aten.sub.Tensor(\n        input._data if isinstance(input, Index) else input,\n        other._data if isinstance(other, Index) else other,\n        alpha=alpha,\n    )\n\n    if data.dtype not in INDEX_DTYPES:\n        return data\n    if data.dim() != 1:\n        return data\n\n    out = Index(data)\n\n    if not isinstance(input, Index):\n        return out\n\n    if isinstance(other, Tensor) and other.numel() <= 1:\n        other = int(other)\n\n    if isinstance(other, int):\n        if input.dim_size is not None:\n            out._dim_size = input.dim_size - alpha * other\n        out._is_sorted = input.is_sorted\n\n    return out\n\n\n@implements(aten.sub_.Tensor)\ndef sub_(\n    input: Index,\n    other: Union[int, Tensor, Index],\n    *,\n    alpha: int = 1,\n) -> Index:\n\n    dim_size = input.dim_size\n    is_sorted = input.is_sorted\n    input._clear_metadata()\n\n    aten.sub_.Tensor(\n        input._data,\n        other._data if isinstance(other, Index) else other,\n        alpha=alpha,\n    )\n\n    if isinstance(other, Tensor) and other.numel() <= 1:\n        other = int(other)\n\n    if isinstance(other, int):\n        if dim_size is not None:\n            input._dim_size = dim_size - alpha * other\n        input._is_sorted = is_sorted\n\n    return input\n"
  },
  {
    "path": "torch_geometric/inspector.py",
    "content": "import inspect\nimport re\nimport sys\nimport typing\nfrom typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union\n\nimport torch\nfrom torch import Tensor\n\n\nclass Parameter(NamedTuple):\n    name: str\n    type: Type\n    type_repr: str\n    default: Any\n\n\nclass Signature(NamedTuple):\n    param_dict: Dict[str, Parameter]\n    return_type: Type\n    return_type_repr: str\n\n\nclass Inspector:\n    r\"\"\"Inspects a given class and collects information about its instance\n    methods.\n\n    Args:\n        cls (Type): The class to inspect.\n    \"\"\"\n    def __init__(self, cls: Type):\n        self._cls = cls\n        self._signature_dict: Dict[str, Signature] = {}\n        self._source_dict: Dict[str, str] = {}\n\n    def _get_modules(self, cls: Type) -> List[str]:\n        from torch_geometric.nn import MessagePassing\n\n        modules: List[str] = []\n        for base_cls in cls.__bases__:\n            if base_cls not in {object, torch.nn.Module, MessagePassing}:\n                modules.extend(self._get_modules(base_cls))\n\n        modules.append(cls.__module__)\n        return modules\n\n    @property\n    def _modules(self) -> List[str]:\n        return self._get_modules(self._cls)\n\n    @property\n    def _globals(self) -> Dict[str, Any]:\n        out: Dict[str, Any] = {}\n        for module in self._modules:\n            out.update(sys.modules[module].__dict__)\n        return out\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self._cls.__name__})'\n\n    def eval_type(self, value: Any) -> Type:\n        r\"\"\"Returns the type hint of a string.\"\"\"\n        return eval_type(value, self._globals)\n\n    def type_repr(self, obj: Any) -> str:\n        r\"\"\"Returns the type hint representation of an object.\"\"\"\n        return type_repr(obj, self._globals)\n\n    def implements(self, func_name: str) -> bool:\n        r\"\"\"Returns :obj:`True` in case the inspected class implements the\n        :obj:`func_name` method.\n\n        Args:\n            func_name (str): The function name to check for existence.\n        \"\"\"\n        func = getattr(self._cls, func_name, None)\n        if not callable(func):\n            return False\n        return not getattr(func, '__isabstractmethod__', False)\n\n    # Inspecting Method Signatures ############################################\n\n    def inspect_signature(\n        self,\n        func: Union[Callable, str],\n        exclude: Optional[List[Union[str, int]]] = None,\n    ) -> Signature:\n        r\"\"\"Inspects the function signature of :obj:`func` and returns a tuple\n        of parameter types and return type.\n\n        Args:\n            func (callabel or str): The function.\n            exclude (list[int or str]): A list of parameters to exclude, either\n                given by their name or index. (default: :obj:`None`)\n        \"\"\"\n        if isinstance(func, str):\n            func = getattr(self._cls, func)\n        assert callable(func)\n\n        if func.__name__ in self._signature_dict:\n            return self._signature_dict[func.__name__]\n\n        signature = inspect.signature(func)\n        params = [p for p in signature.parameters.values() if p.name != 'self']\n\n        param_dict: Dict[str, Parameter] = {}\n        for i, param in enumerate(params):\n            if exclude is not None and (i in exclude or param.name in exclude):\n                continue\n\n            param_type = param.annotation\n            # Mimic TorchScript to auto-infer `Tensor` on non-present types:\n            param_type = Tensor if param_type is inspect._empty else param_type\n\n            param_dict[param.name] = Parameter(\n                name=param.name,\n                type=self.eval_type(param_type),\n                type_repr=self.type_repr(param_type),\n                default=param.default,\n            )\n\n        return_type = signature.return_annotation\n        # Mimic TorchScript to auto-infer `Tensor` on non-present types:\n        return_type = Tensor if return_type is inspect._empty else return_type\n\n        self._signature_dict[func.__name__] = Signature(\n            param_dict=param_dict,\n            return_type=self.eval_type(return_type),\n            return_type_repr=self.type_repr(return_type),\n        )\n\n        return self._signature_dict[func.__name__]\n\n    def get_signature(\n        self,\n        func: Union[Callable, str],\n        exclude: Optional[List[str]] = None,\n    ) -> Signature:\n        r\"\"\"Returns the function signature of the inspected function\n        :obj:`func`.\n\n        Args:\n            func (callabel or str): The function.\n            exclude (list[str], optional): The parameter names to exclude.\n                (default: :obj:`None`)\n        \"\"\"\n        func_name = func if isinstance(func, str) else func.__name__\n        signature = self._signature_dict.get(func_name)\n        if signature is None:\n            raise IndexError(f\"Could not access signature for function \"\n                             f\"'{func_name}'. Did you forget to inspect it?\")\n\n        if exclude is None:\n            return signature\n\n        param_dict = {\n            name: param\n            for name, param in signature.param_dict.items()\n            if name not in exclude\n        }\n        return Signature(\n            param_dict=param_dict,\n            return_type=signature.return_type,\n            return_type_repr=signature.return_type_repr,\n        )\n\n    def remove_signature(\n        self,\n        func: Union[Callable, str],\n    ) -> Optional[Signature]:\n        r\"\"\"Removes the inspected function signature :obj:`func`.\n\n        Args:\n            func (callabel or str): The function.\n        \"\"\"\n        func_name = func if isinstance(func, str) else func.__name__\n        return self._signature_dict.pop(func_name, None)\n\n    def get_param_dict(\n        self,\n        func: Union[Callable, str],\n        exclude: Optional[List[str]] = None,\n    ) -> Dict[str, Parameter]:\n        r\"\"\"Returns the parameters of the inspected function :obj:`func`.\n\n        Args:\n            func (str or callable): The function.\n            exclude (list[str], optional): The parameter names to exclude.\n                (default: :obj:`None`)\n        \"\"\"\n        return self.get_signature(func, exclude).param_dict\n\n    def get_params(\n        self,\n        func: Union[Callable, str],\n        exclude: Optional[List[str]] = None,\n    ) -> List[Parameter]:\n        r\"\"\"Returns the parameters of the inspected function :obj:`func`.\n\n        Args:\n            func (str or callable): The function.\n            exclude (list[str], optional): The parameter names to exclude.\n                (default: :obj:`None`)\n        \"\"\"\n        return list(self.get_param_dict(func, exclude).values())\n\n    def get_flat_param_dict(\n        self,\n        funcs: List[Union[Callable, str]],\n        exclude: Optional[List[str]] = None,\n    ) -> Dict[str, Parameter]:\n        r\"\"\"Returns the union of parameters of all inspected functions in\n        :obj:`funcs`.\n\n        Args:\n            funcs (list[str or callable]): The functions.\n            exclude (list[str], optional): The parameter names to exclude.\n                (default: :obj:`None`)\n        \"\"\"\n        param_dict: Dict[str, Parameter] = {}\n        for func in funcs:\n            params = self.get_params(func, exclude)\n            for param in params:\n                expected = param_dict.get(param.name)\n                if expected is not None and param.type != expected.type:\n                    raise ValueError(f\"Found inconsistent types for argument \"\n                                     f\"'{param.name}'. Expected type \"\n                                     f\"'{expected.type}' but found type \"\n                                     f\"'{param.type}'.\")\n\n                if expected is not None and param.default != expected.default:\n                    if (param.default is not inspect._empty\n                            and expected.default is not inspect._empty):\n                        raise ValueError(f\"Found inconsistent defaults for \"\n                                         f\"argument '{param.name}'. Expected \"\n                                         f\"'{expected.default}'  but found \"\n                                         f\"'{param.default}'.\")\n\n                    default = expected.default\n                    if default is inspect._empty:\n                        default = param.default\n\n                    param_dict[param.name] = Parameter(\n                        name=param.name,\n                        type=param.type,\n                        type_repr=param.type_repr,\n                        default=default,\n                    )\n\n                if expected is None:\n                    param_dict[param.name] = param\n\n        return param_dict\n\n    def get_flat_params(\n        self,\n        funcs: List[Union[Callable, str]],\n        exclude: Optional[List[str]] = None,\n    ) -> List[Parameter]:\n        r\"\"\"Returns the union of parameters of all inspected functions in\n        :obj:`funcs`.\n\n        Args:\n            funcs (list[str or callable]): The functions.\n            exclude (list[str], optional): The parameter names to exclude.\n                (default: :obj:`None`)\n        \"\"\"\n        return list(self.get_flat_param_dict(funcs, exclude).values())\n\n    def get_param_names(\n        self,\n        func: Union[Callable, str],\n        exclude: Optional[List[str]] = None,\n    ) -> List[str]:\n        r\"\"\"Returns the parameter names of the inspected function :obj:`func`.\n\n        Args:\n            func (str or callable): The function.\n            exclude (list[str], optional): The parameter names to exclude.\n                (default: :obj:`None`)\n        \"\"\"\n        return list(self.get_param_dict(func, exclude).keys())\n\n    def get_flat_param_names(\n        self,\n        funcs: List[Union[Callable, str]],\n        exclude: Optional[List[str]] = None,\n    ) -> List[str]:\n        r\"\"\"Returns the union of parameter names of all inspected functions in\n        :obj:`funcs`.\n\n        Args:\n            funcs (list[str or callable]): The functions.\n            exclude (list[str], optional): The parameter names to exclude.\n                (default: :obj:`None`)\n        \"\"\"\n        return list(self.get_flat_param_dict(funcs, exclude).keys())\n\n    def collect_param_data(\n        self,\n        func: Union[Callable, str],\n        kwargs: Dict[str, Any],\n    ) -> Dict[str, Any]:\n        r\"\"\"Collects the input data of the inspected function :obj:`func`\n        according to its function signature from a data blob.\n\n        Args:\n            func (callable or str): The function.\n            kwargs (dict[str, Any]): The data blob which may serve as inputs.\n        \"\"\"\n        out_dict: Dict[str, Any] = {}\n        for param in self.get_params(func):\n            if param.name not in kwargs:\n                if param.default is inspect._empty:\n                    raise TypeError(f\"Parameter '{param.name}' is required\")\n                out_dict[param.name] = param.default\n            else:\n                out_dict[param.name] = kwargs[param.name]\n        return out_dict\n\n    # Inspecting Method Bodies ################################################\n\n    def get_source(self, cls: Optional[Type] = None) -> str:\n        r\"\"\"Returns the source code of :obj:`cls`.\"\"\"\n        from torch_geometric.nn import MessagePassing\n\n        cls = cls or self._cls\n        if cls.__name__ in self._source_dict:\n            return self._source_dict[cls.__name__]\n        if cls in {object, torch.nn.Module, MessagePassing}:\n            return ''\n        source = inspect.getsource(cls)\n        self._source_dict[cls.__name__] = source\n        return source\n\n    def get_params_from_method_call(\n        self,\n        func: Union[Callable, str],\n        exclude: Optional[List[Union[int, str]]] = None,\n    ) -> Dict[str, Parameter]:\n        r\"\"\"Parses a method call of :obj:`func` and returns its keyword\n        arguments.\n\n        .. note::\n            The method is required to be called via keyword arguments in case\n            type annotations are not found.\n\n        Args:\n            func (callable or str): The function.\n            exclude (list[int or str]): A list of parameters to exclude, either\n                given by their name or index. (default: :obj:`None`)\n        \"\"\"\n        func_name = func if isinstance(func, str) else func.__name__\n        param_dict: Dict[str, Parameter] = {}\n\n        # Three ways to specify the parameters of an unknown function header:\n        # 1. Defined as class attributes in `{func_name}_type`.\n        # 2. Defined via type annotations in `# {func_name}_type: (...)`.\n        # 3. Defined via parsing of the function call.\n\n        # (1) Find class attribute:\n        if hasattr(self._cls, f'{func_name}_type'):\n            type_dict = getattr(self._cls, f'{func_name}_type')\n            if not isinstance(type_dict, dict):\n                raise ValueError(f\"'{func_name}_type' is expected to be a \"\n                                 f\"dictionary (got '{type(type_dict)}')\")\n\n            for name, param_type in type_dict.items():\n                param_dict[name] = Parameter(\n                    name=name,\n                    type=self.eval_type(param_type),\n                    type_repr=self.type_repr(param_type),\n                    default=inspect._empty,\n                )\n            return param_dict\n\n        # (2) Find type annotation:\n        for cls in self._cls.__mro__:\n            source = self.get_source(cls)\n            match = find_parenthesis_content(source, f'{func_name}_type:')\n            if match is not None:\n                for arg in split(match, sep=','):\n                    name_and_type_repr = re.split(r'\\s*:\\s*', arg)\n                    if len(name_and_type_repr) != 2:\n                        raise ValueError(f\"Could not parse argument '{arg}' \"\n                                         f\"of '{func_name}_type' annotation\")\n\n                    name, type_repr = name_and_type_repr\n                    param_dict[name] = Parameter(\n                        name=name,\n                        type=self.eval_type(type_repr),\n                        type_repr=type_repr,\n                        default=inspect._empty,\n                    )\n                return param_dict\n\n        # (3) Parse the function call:\n        for cls in self._cls.__mro__:\n            source = self.get_source(cls)\n            source = remove_comments(source)\n            match = find_parenthesis_content(source, f'self.{func_name}')\n            if match is not None:\n                for i, kwarg in enumerate(split(match, sep=',')):\n                    if ('=' not in kwarg and exclude is not None\n                            and i in exclude):\n                        continue\n\n                    name_and_content = re.split(r'\\s*=\\s*', kwarg)\n                    if len(name_and_content) != 2:\n                        raise ValueError(f\"Could not parse keyword argument \"\n                                         f\"'{kwarg}' in 'self.{func_name}()'\")\n\n                    name, _ = name_and_content\n\n                    if exclude is not None and name in exclude:\n                        continue\n\n                    param_dict[name] = Parameter(\n                        name=name,\n                        type=Tensor,\n                        type_repr=self.type_repr(Tensor),\n                        default=inspect._empty,\n                    )\n                return param_dict\n\n        return {}  # (4) No function call found:\n\n\ndef eval_type(value: Any, _globals: Dict[str, Any]) -> Type:\n    r\"\"\"Returns the type hint of a string.\"\"\"\n    if isinstance(value, str):\n        value = typing.ForwardRef(value)\n    return typing._eval_type(value, _globals, None)  # type: ignore\n\n\ndef type_repr(obj: Any, _globals: Dict[str, Any]) -> str:\n    r\"\"\"Returns the type hint representation of an object.\"\"\"\n    def _get_name(name: str, module: str) -> str:\n        return name if name in _globals else f'{module}.{name}'\n\n    if isinstance(obj, str):\n        return obj\n\n    if obj is type(None):\n        return 'None'\n\n    if obj is ...:\n        return '...'\n\n    if obj.__module__ == 'typing':  # Special logic for `typing.*` types:\n\n        if not hasattr(obj, '_name'):\n            return repr(obj)\n\n        name = obj._name\n        if name is None:  # In some cases, `_name` is not populated.\n            name = str(obj.__origin__).split('.')[-1]\n\n        args = getattr(obj, '__args__', None)\n        if args is None or len(args) == 0:\n            return _get_name(name, obj.__module__)\n        if all(isinstance(arg, typing.TypeVar) for arg in args):\n            return _get_name(name, obj.__module__)\n\n        # Convert `Union[*, None]` to `Optional[*]`.\n        # This is only necessary for old Python versions, e.g. 3.8.\n        # TODO Only convert to `Optional` if `Optional` is importable.\n        if (name == 'Union' and len(args) == 2\n                and any([arg is type(None) for arg in args])):\n            name = 'Optional'\n\n        if name == 'Optional':  # Remove `None` from `Optional` arguments:\n            args = [arg for arg in obj.__args__ if arg is not type(None)]\n\n        args_repr = ', '.join([type_repr(arg, _globals) for arg in args])\n        return f'{_get_name(name, obj.__module__)}[{args_repr}]'\n\n    if obj.__module__ == 'builtins':\n        return obj.__qualname__\n\n    return _get_name(obj.__qualname__, obj.__module__)\n\n\ndef find_parenthesis_content(source: str, prefix: str) -> Optional[str]:\n    r\"\"\"Returns the content of :obj:`{prefix}.*(...)` within :obj:`source`.\"\"\"\n    match = re.search(prefix, source)\n    if match is None:\n        return None\n\n    offset = source[match.start():].find('(')\n    if offset < 0:\n        return None\n\n    source = source[match.start() + offset:]\n\n    depth = 0\n    for end, char in enumerate(source):\n        if char == '(':\n            depth += 1\n        if char == ')':\n            depth -= 1\n        if depth == 0:\n            content = source[1:end]\n            # Properly handle line breaks and multiple white-spaces:\n            content = content.replace('\\n', ' ')\n            content = content.replace('#', ' ')\n            content = re.sub(' +', ' ', content)\n            content = content.strip()\n            return content\n\n    return None\n\n\ndef split(content: str, sep: str) -> List[str]:\n    r\"\"\"Splits :obj:`content` based on :obj:`sep`.\n    :obj:`sep` inside parentheses or square brackets are ignored.\n    \"\"\"\n    assert len(sep) == 1\n    outs: List[str] = []\n\n    start = depth = 0\n    for end, char in enumerate(content):\n        if char == '[' or char == '(':\n            depth += 1\n        elif char == ']' or char == ')':\n            depth -= 1\n        elif char == sep and depth == 0:\n            outs.append(content[start:end].strip())\n            start = end + 1\n    if start != len(content):  # Respect dangling `sep`:\n        outs.append(content[start:].strip())\n    return outs\n\n\ndef remove_comments(content: str) -> str:\n    content = re.sub(r'\\s*#.*', '', content)\n    content = re.sub(re.compile(r'r\"\"\"(.*?)\"\"\"', re.DOTALL), '', content)\n    content = re.sub(re.compile(r'\"\"\"(.*?)\"\"\"', re.DOTALL), '', content)\n    content = re.sub(re.compile(r\"r'''(.*?)'''\", re.DOTALL), '', content)\n    content = re.sub(re.compile(r\"'''(.*?)'''\", re.DOTALL), '', content)\n    return content\n"
  },
  {
    "path": "torch_geometric/io/__init__.py",
    "content": "from .txt_array import parse_txt_array, read_txt_array\nfrom .tu import read_tu_data\nfrom .planetoid import read_planetoid_data\nfrom .ply import read_ply\nfrom .obj import read_obj\nfrom .sdf import read_sdf, parse_sdf\nfrom .off import read_off, write_off\nfrom .npz import read_npz, parse_npz\n\n__all__ = [\n    'read_off',\n    'write_off',\n    'parse_txt_array',\n    'read_txt_array',\n    'read_tu_data',\n    'read_planetoid_data',\n    'read_ply',\n    'read_obj',\n    'read_sdf',\n    'parse_sdf',\n    'read_npz',\n    'parse_npz',\n]\n"
  },
  {
    "path": "torch_geometric/io/fs.py",
    "content": "import io\nimport os\nimport os.path as osp\nimport pickle\nimport re\nimport sys\nimport warnings\nfrom typing import Any, Dict, List, Literal, Optional, Union, overload\nfrom uuid import uuid4\n\nimport fsspec\nimport torch\n\nimport torch_geometric\n\nDEFAULT_CACHE_PATH = '/tmp/pyg_simplecache'\n\n\ndef get_fs(path: str) -> fsspec.AbstractFileSystem:\n    r\"\"\"Get filesystem backend given a path URI to the resource.\n\n    Here are some common example paths and dispatch result:\n\n    * :obj:`\"/home/file\"` ->\n      :class:`fsspec.implementations.local.LocalFileSystem`\n    * :obj:`\"memory://home/file\"` ->\n      :class:`fsspec.implementations.memory.MemoryFileSystem`\n    * :obj:`\"https://home/file\"` ->\n      :class:`fsspec.implementations.http.HTTPFileSystem`\n    * :obj:`\"gs://home/file\"` -> :class:`gcsfs.GCSFileSystem`\n    * :obj:`\"s3://home/file\"` -> :class:`s3fs.S3FileSystem`\n\n    A full list of supported backend implementations of :class:`fsspec` can be\n    found `here <https://github.com/fsspec/filesystem_spec/blob/master/fsspec/\n    registry.py#L62>`_.\n\n    The backend dispatch logic can be updated with custom backends following\n    `this tutorial <https://filesystem-spec.readthedocs.io/en/latest/\n    developer.html#implementing-a-backend>`_.\n\n    Args:\n        path (str): The URI to the filesystem location, *e.g.*,\n            :obj:`\"gs://home/me/file\"`, :obj:`\"s3://...\"`.\n    \"\"\"\n    return fsspec.core.url_to_fs(path)[0]\n\n\ndef normpath(path: str) -> str:\n    if isdisk(path):\n        return osp.normpath(path)\n    return path\n\n\ndef exists(path: str) -> bool:\n    return get_fs(path).exists(path)\n\n\ndef makedirs(path: str, exist_ok: bool = True) -> None:\n    return get_fs(path).makedirs(path, exist_ok)\n\n\ndef isdir(path: str) -> bool:\n    return get_fs(path).isdir(path)\n\n\ndef isfile(path: str) -> bool:\n    return get_fs(path).isfile(path)\n\n\ndef isdisk(path: str) -> bool:\n    return 'file' in get_fs(path).protocol\n\n\ndef islocal(path: str) -> bool:\n    return isdisk(path) or 'memory' in get_fs(path).protocol\n\n\n@overload\ndef ls(path: str, detail: Literal[False] = False) -> List[str]:\n    pass\n\n\n@overload\ndef ls(path: str, detail: Literal[True]) -> List[Dict[str, Any]]:\n    pass\n\n\ndef ls(\n    path: str,\n    detail: bool = False,\n) -> Union[List[str], List[Dict[str, Any]]]:\n    fs = get_fs(path)\n    outputs = fs.ls(path, detail=detail)\n\n    if not isdisk(path):\n        if detail:\n            for output in outputs:\n                output['name'] = fs.unstrip_protocol(output['name'])\n        else:\n            outputs = [fs.unstrip_protocol(output) for output in outputs]\n\n    return outputs\n\n\ndef cp(\n    path1: str,\n    path2: str,\n    extract: bool = False,\n    log: bool = True,\n    use_cache: bool = True,\n    clear_cache: bool = True,\n) -> None:\n    kwargs: Dict[str, Any] = {}\n\n    is_path1_dir = isdir(path1)\n    is_path2_dir = isdir(path2)\n\n    # Cache result if the protocol is not local:\n    cache_dir: Optional[str] = None\n    if not islocal(path1):\n        if log and 'PYTEST_CURRENT_TEST' not in os.environ:\n            print(f'Downloading {path1}', file=sys.stderr)\n\n        if extract and use_cache:  # Cache seems to confuse the gcs filesystem.\n            home_dir = torch_geometric.get_home_dir()\n            cache_dir = osp.join(home_dir, 'simplecache', uuid4().hex)\n            kwargs.setdefault('simplecache', dict(cache_storage=cache_dir))\n            path1 = f'simplecache::{path1}'\n\n    # Handle automatic extraction:\n    multiple_files = False\n    if extract and path1.endswith('.tar.gz'):\n        kwargs.setdefault('tar', dict(compression='gzip'))\n        path1 = f'tar://**::{path1}'\n        multiple_files = True\n    elif extract and path1.endswith('.zip'):\n        path1 = f'zip://**::{path1}'\n        multiple_files = True\n    elif extract and path1.endswith('.gz'):\n        kwargs.setdefault('compression', 'infer')\n    elif extract:\n        raise NotImplementedError(\n            f\"Automatic extraction of '{path1}' not yet supported\")\n\n    # If the source path points to a directory, we need to make sure to\n    # recursively copy all files within this directory. Additionally, if the\n    # destination folder does not yet exist, we inherit the basename from the\n    # source folder.\n    if is_path1_dir:\n        if exists(path2):\n            path2 = osp.join(path2, osp.basename(path1))\n        path1 = osp.join(path1, '**')\n        multiple_files = True\n\n    # Perform the copy:\n    for open_file in fsspec.open_files(path1, **kwargs):\n        with open_file as f_from:\n            if not multiple_files:\n                if is_path2_dir:\n                    basename = osp.basename(path1)\n                    if extract and path1.endswith('.gz'):\n                        basename = '.'.join(basename.split('.')[:-1])\n                    to_path = osp.join(path2, basename)\n                else:\n                    to_path = path2\n            else:\n                # Open file has protocol stripped.\n                common_path = osp.commonprefix(\n                    [fsspec.core.strip_protocol(path1), open_file.path])\n                to_path = osp.join(path2, open_file.path[len(common_path):])\n            with fsspec.open(to_path, 'wb') as f_to:\n                while True:\n                    chunk = f_from.read(10 * 1024 * 1024)\n                    if not chunk:\n                        break\n                    f_to.write(chunk)\n\n    if use_cache and clear_cache and cache_dir is not None:\n        try:\n            rm(cache_dir)\n        except Exception:  # FIXME\n            # Windows test yield \"PermissionError: The process cannot access\n            # the file because it is being used by another process\".\n            # Users may also observe \"OSError: Directory not empty\".\n            # This is a quick workaround until we figure out the deeper issue.\n            pass\n\n\ndef rm(path: str, recursive: bool = True) -> None:\n    get_fs(path).rm(path, recursive)\n\n\ndef mv(path1: str, path2: str) -> None:\n    fs1 = get_fs(path1)\n    fs2 = get_fs(path2)\n    assert fs1.protocol == fs2.protocol\n    fs1.mv(path1, path2)\n\n\ndef glob(path: str) -> List[str]:\n    fs = get_fs(path)\n    paths = fs.glob(path)\n\n    if not isdisk(path):\n        paths = [fs.unstrip_protocol(path) for path in paths]\n\n    return paths\n\n\ndef torch_save(data: Any, path: str) -> None:\n    buffer = io.BytesIO()\n    torch.save(data, buffer)\n    with fsspec.open(path, 'wb') as f:\n        f.write(buffer.getvalue())\n\n\ndef torch_load(path: str, map_location: Any = None) -> Any:\n    if torch_geometric.typing.WITH_PT24:\n        try:\n            with fsspec.open(path, 'rb') as f:\n                return torch.load(f, map_location, weights_only=True)\n        except pickle.UnpicklingError as e:\n            error_msg = str(e)\n            if \"add_safe_globals\" in error_msg:\n                warn_msg = (\"Weights only load failed. Please file an issue \"\n                            \"to make `torch.load(weights_only=True)` \"\n                            \"compatible in your case.\")\n                match = re.search(r'add_safe_globals\\(.*?\\)', error_msg)\n                if match is not None:\n                    warnings.warn(\n                        f\"{warn_msg} Please use \"\n                        f\"`torch.serialization.{match.group()}` to \"\n                        f\"allowlist this global.\", stacklevel=2)\n                else:\n                    warnings.warn(warn_msg, stacklevel=2)\n\n                with fsspec.open(path, 'rb') as f:\n                    return torch.load(f, map_location, weights_only=False)\n            else:\n                raise e\n\n    with fsspec.open(path, 'rb') as f:\n        return torch.load(f, map_location)\n"
  },
  {
    "path": "torch_geometric/io/npz.py",
    "content": "from typing import Any, Dict\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.utils import remove_self_loops\nfrom torch_geometric.utils import to_undirected as to_undirected_fn\n\n\ndef read_npz(path: str, to_undirected: bool = True) -> Data:\n    with np.load(path) as f:\n        return parse_npz(f, to_undirected=to_undirected)\n\n\ndef parse_npz(f: Dict[str, Any], to_undirected: bool = True) -> Data:\n    import scipy.sparse as sp\n\n    x = sp.csr_matrix((f['attr_data'], f['attr_indices'], f['attr_indptr']),\n                      f['attr_shape']).todense()\n    x = torch.from_numpy(x).to(torch.float)\n    x[x > 0] = 1\n\n    adj = sp.csr_matrix((f['adj_data'], f['adj_indices'], f['adj_indptr']),\n                        f['adj_shape']).tocoo()\n    row = torch.from_numpy(adj.row).to(torch.long)\n    col = torch.from_numpy(adj.col).to(torch.long)\n    edge_index = torch.stack([row, col], dim=0)\n    edge_index, _ = remove_self_loops(edge_index)\n    if to_undirected:\n        edge_index = to_undirected_fn(edge_index, num_nodes=x.size(0))\n\n    y = torch.from_numpy(f['labels']).to(torch.long)\n\n    return Data(x=x, edge_index=edge_index, y=y)\n"
  },
  {
    "path": "torch_geometric/io/obj.py",
    "content": "from typing import Iterator, List, Optional, Tuple, Union\n\nimport torch\n\nfrom torch_geometric.data import Data\n\n\ndef yield_file(in_file: str) -> Iterator[Tuple[str, List[Union[int, float]]]]:\n\n    f = open(in_file)\n    buf = f.read()\n    f.close()\n    for b in buf.split('\\n'):\n        if b.startswith('v '):\n            yield 'v', [float(x) for x in b.split(\" \")[1:]]\n        elif b.startswith('f '):\n            triangles = b.split(' ')[1:]\n            # -1 as .obj is base 1 but the Data class expects base 0 indices\n            yield 'f', [int(t.split(\"/\")[0]) - 1 for t in triangles]\n        else:\n            yield '', []\n\n\ndef read_obj(in_file: str) -> Optional[Data]:\n    vertices = []\n    faces = []\n\n    for k, v in yield_file(in_file):\n        if k == 'v':\n            vertices.append(v)\n        elif k == 'f':\n            faces.append(v)\n\n    if not len(faces) or not len(vertices):\n        return None\n\n    pos = torch.tensor(vertices, dtype=torch.float)\n    face = torch.tensor(faces, dtype=torch.long).t().contiguous()\n\n    data = Data(pos=pos, face=face)\n\n    return data\n"
  },
  {
    "path": "torch_geometric/io/off.py",
    "content": "import re\nfrom typing import List\n\nimport torch\nfrom torch import Tensor\nfrom torch._tensor_str import PRINT_OPTS, _tensor_str\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.io import parse_txt_array\n\n\ndef parse_off(src: List[str]) -> Data:\n    # Some files may contain a bug and do not have a carriage return after OFF.\n    if src[0] == 'OFF':\n        src = src[1:]\n    else:\n        src[0] = src[0][3:]\n\n    num_nodes, num_faces = (int(item) for item in src[0].split()[:2])\n\n    pos = parse_txt_array(src[1:1 + num_nodes])\n\n    face = face_to_tri(src[1 + num_nodes:1 + num_nodes + num_faces])\n\n    data = Data(pos=pos)\n    data.face = face\n\n    return data\n\n\ndef face_to_tri(face: List[str]) -> Tensor:\n    face_index = [[int(x) for x in line.strip().split()] for line in face]\n\n    triangle = torch.tensor([line[1:] for line in face_index if line[0] == 3])\n    triangle = triangle.to(torch.int64)\n\n    rect = torch.tensor([line[1:] for line in face_index if line[0] == 4])\n    rect = rect.to(torch.int64)\n\n    if rect.numel() > 0:\n        first, second = rect[:, [0, 1, 2]], rect[:, [0, 2, 3]]\n        return torch.cat([triangle, first, second], dim=0).t().contiguous()\n\n    return triangle.t().contiguous()\n\n\ndef read_off(path: str) -> Data:\n    r\"\"\"Reads an OFF (Object File Format) file, returning both the position of\n    nodes and their connectivity in a :class:`torch_geometric.data.Data`\n    object.\n\n    Args:\n        path (str): The path to the file.\n    \"\"\"\n    with open(path) as f:\n        src = f.read().split('\\n')[:-1]\n    return parse_off(src)\n\n\ndef write_off(data: Data, path: str) -> None:\n    r\"\"\"Writes a :class:`torch_geometric.data.Data` object to an OFF (Object\n    File Format) file.\n\n    Args:\n        data (:class:`torch_geometric.data.Data`): The data object.\n        path (str): The path to the file.\n    \"\"\"\n    assert data.pos is not None\n    assert data.face is not None\n\n    num_nodes, num_faces = data.pos.size(0), data.face.size(1)\n\n    pos = data.pos.to(torch.float)\n    face = data.face.t()\n    num_vertices = torch.full((num_faces, 1), face.size(1), dtype=torch.long)\n    face = torch.cat([num_vertices, face], dim=-1)\n\n    threshold = PRINT_OPTS.threshold\n    torch.set_printoptions(threshold=float('inf'))\n\n    pos_repr = re.sub(',', '', _tensor_str(pos, indent=0))\n    pos_repr = '\\n'.join([x[2:-1] for x in pos_repr.split('\\n')])[:-1]\n\n    face_repr = re.sub(',', '', _tensor_str(face, indent=0))\n    face_repr = '\\n'.join([x[2:-1] for x in face_repr.split('\\n')])[:-1]\n\n    with open(path, 'w') as f:\n        f.write(f'OFF\\n{num_nodes} {num_faces} 0\\n')\n        f.write(pos_repr)\n        f.write('\\n')\n        f.write(face_repr)\n        f.write('\\n')\n    torch.set_printoptions(threshold=threshold)\n"
  },
  {
    "path": "torch_geometric/io/planetoid.py",
    "content": "import os.path as osp\nimport warnings\nfrom itertools import repeat\nfrom typing import Dict, List, Optional\n\nimport fsspec\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.io import read_txt_array\nfrom torch_geometric.utils import (\n    coalesce,\n    index_to_mask,\n    remove_self_loops,\n    to_torch_csr_tensor,\n)\n\ntry:\n    import cPickle as pickle\nexcept ImportError:\n    import pickle\n\n\ndef read_planetoid_data(folder: str, prefix: str) -> Data:\n    names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']\n    items = [read_file(folder, prefix, name) for name in names]\n    x, tx, allx, y, ty, ally, graph, test_index = items\n    train_index = torch.arange(y.size(0), dtype=torch.long)\n    val_index = torch.arange(y.size(0), y.size(0) + 500, dtype=torch.long)\n    sorted_test_index = test_index.sort()[0]\n\n    if prefix.lower() == 'citeseer':\n        # There are some isolated nodes in the Citeseer graph, resulting in\n        # none consecutive test indices. We need to identify them and add them\n        # as zero vectors to `tx` and `ty`.\n        len_test_indices = int(test_index.max() - test_index.min()) + 1\n\n        tx_ext = torch.zeros(len_test_indices, tx.size(1), dtype=tx.dtype)\n        tx_ext[sorted_test_index - test_index.min(), :] = tx\n        ty_ext = torch.zeros(len_test_indices, ty.size(1), dtype=ty.dtype)\n        ty_ext[sorted_test_index - test_index.min(), :] = ty\n\n        tx, ty = tx_ext, ty_ext\n\n    if prefix.lower() == 'nell.0.001':\n        tx_ext = torch.zeros(len(graph) - allx.size(0), x.size(1))\n        tx_ext[sorted_test_index - allx.size(0)] = tx\n\n        ty_ext = torch.zeros(len(graph) - ally.size(0), y.size(1))\n        ty_ext[sorted_test_index - ally.size(0)] = ty\n\n        tx, ty = tx_ext, ty_ext\n\n        x = torch.cat([allx, tx], dim=0)\n        x[test_index] = x[sorted_test_index]\n\n        # Creating feature vectors for relations.\n        row, col = x.nonzero(as_tuple=True)\n        value = x[row, col]\n\n        mask = ~index_to_mask(test_index, size=len(graph))\n        mask[:allx.size(0)] = False\n        isolated_idx = mask.nonzero().view(-1)\n\n        row = torch.cat([row, isolated_idx])\n        col = torch.cat([col, torch.arange(isolated_idx.size(0)) + x.size(1)])\n        value = torch.cat([value, value.new_ones(isolated_idx.size(0))])\n\n        x = to_torch_csr_tensor(\n            edge_index=torch.stack([row, col], dim=0),\n            edge_attr=value,\n            size=(x.size(0), isolated_idx.size(0) + x.size(1)),\n        )\n    else:\n        x = torch.cat([allx, tx], dim=0)\n        x[test_index] = x[sorted_test_index]\n\n    y = torch.cat([ally, ty], dim=0).max(dim=1)[1]\n    y[test_index] = y[sorted_test_index]\n\n    train_mask = index_to_mask(train_index, size=y.size(0))\n    val_mask = index_to_mask(val_index, size=y.size(0))\n    test_mask = index_to_mask(test_index, size=y.size(0))\n\n    edge_index = edge_index_from_dict(\n        graph_dict=graph,  # type: ignore\n        num_nodes=y.size(0),\n    )\n\n    data = Data(x=x, edge_index=edge_index, y=y)\n    data.train_mask = train_mask\n    data.val_mask = val_mask\n    data.test_mask = test_mask\n\n    return data\n\n\ndef read_file(folder: str, prefix: str, name: str) -> Tensor:\n    path = osp.join(folder, f'ind.{prefix.lower()}.{name}')\n\n    if name == 'test.index':\n        return read_txt_array(path, dtype=torch.long)\n\n    with fsspec.open(path, 'rb') as f:\n        warnings.filterwarnings('ignore', '.*`scipy.sparse.csr` name.*')\n        out = pickle.load(f, encoding='latin1')\n\n    if name == 'graph':\n        return out\n\n    out = out.todense() if hasattr(out, 'todense') else out\n    out = torch.from_numpy(out).to(torch.float)\n    return out\n\n\ndef edge_index_from_dict(\n    graph_dict: Dict[int, List[int]],\n    num_nodes: Optional[int] = None,\n) -> Tensor:\n    rows: List[int] = []\n    cols: List[int] = []\n    for key, value in graph_dict.items():\n        rows += repeat(key, len(value))\n        cols += value\n    row = torch.tensor(rows)\n    col = torch.tensor(cols)\n    edge_index = torch.stack([row, col], dim=0)\n\n    # `torch.compile` is not yet ready for `EdgeIndex` :(\n    # from torch_geometric import EdgeIndex\n    # edge_index: Union[EdgeIndex, Tensor] = EdgeIndex(\n    #     torch.stack([row, col], dim=0),\n    #     is_undirected=True,\n    #     sparse_size=(num_nodes, num_nodes),\n    # )\n\n    # NOTE: There are some duplicated edges and self loops in the datasets.\n    #       Other implementations do not remove them!\n    edge_index, _ = remove_self_loops(edge_index)\n    edge_index = coalesce(edge_index, num_nodes=num_nodes, sort_by_row=False)\n\n    return edge_index\n"
  },
  {
    "path": "torch_geometric/io/ply.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\n\ntry:\n    import openmesh\nexcept ImportError:\n    openmesh = None\n\n\ndef read_ply(path: str) -> Data:\n    if openmesh is None:\n        raise ImportError('`read_ply` requires the `openmesh` package.')\n\n    mesh = openmesh.read_trimesh(path)\n    pos = torch.from_numpy(mesh.points()).to(torch.float)\n    face = torch.from_numpy(mesh.face_vertex_indices())\n    face = face.t().to(torch.long).contiguous()\n    return Data(pos=pos, face=face)\n"
  },
  {
    "path": "torch_geometric/io/sdf.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.io import parse_txt_array\nfrom torch_geometric.utils import coalesce, one_hot\n\nelems = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}\n\n\ndef parse_sdf(src: str) -> Data:\n    lines = src.split('\\n')[3:]\n    num_atoms, num_bonds = (int(item) for item in lines[0].split()[:2])\n\n    atom_block = lines[1:num_atoms + 1]\n    pos = parse_txt_array(atom_block, end=3)\n    x = torch.tensor([elems[item.split()[3]] for item in atom_block])\n    x = one_hot(x, num_classes=len(elems))\n\n    bond_block = lines[1 + num_atoms:1 + num_atoms + num_bonds]\n    row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1\n    row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)\n    edge_index = torch.stack([row, col], dim=0)\n    edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1\n    edge_attr = torch.cat([edge_attr, edge_attr], dim=0)\n    edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms)\n\n    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)\n\n\ndef read_sdf(path: str) -> Data:\n    with open(path) as f:\n        return parse_sdf(f.read())\n"
  },
  {
    "path": "torch_geometric/io/tu.py",
    "content": "import os.path as osp\nfrom typing import Dict, List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.io import fs, read_txt_array\nfrom torch_geometric.utils import coalesce, cumsum, one_hot, remove_self_loops\n\nnames = [\n    'A', 'graph_indicator', 'node_labels', 'node_attributes'\n    'edge_labels', 'edge_attributes', 'graph_labels', 'graph_attributes'\n]\n\n\ndef read_tu_data(\n    folder: str,\n    prefix: str,\n) -> Tuple[Data, Dict[str, Tensor], Dict[str, int]]:\n    files = fs.glob(osp.join(folder, f'{prefix}_*.txt'))\n    names = [osp.basename(f)[len(prefix) + 1:-4] for f in files]\n\n    edge_index = read_file(folder, prefix, 'A', torch.long).t() - 1\n    batch = read_file(folder, prefix, 'graph_indicator', torch.long) - 1\n\n    node_attribute = torch.empty((batch.size(0), 0))\n    if 'node_attributes' in names:\n        node_attribute = read_file(folder, prefix, 'node_attributes')\n        if node_attribute.dim() == 1:\n            node_attribute = node_attribute.unsqueeze(-1)\n\n    node_label = torch.empty((batch.size(0), 0))\n    if 'node_labels' in names:\n        node_label = read_file(folder, prefix, 'node_labels', torch.long)\n        if node_label.dim() == 1:\n            node_label = node_label.unsqueeze(-1)\n        node_label = node_label - node_label.min(dim=0)[0]\n        node_labels = list(node_label.unbind(dim=-1))\n        node_labels = [one_hot(x) for x in node_labels]\n        if len(node_labels) == 1:\n            node_label = node_labels[0]\n        else:\n            node_label = torch.cat(node_labels, dim=-1)\n\n    edge_attribute = torch.empty((edge_index.size(1), 0))\n    if 'edge_attributes' in names:\n        edge_attribute = read_file(folder, prefix, 'edge_attributes')\n        if edge_attribute.dim() == 1:\n            edge_attribute = edge_attribute.unsqueeze(-1)\n\n    edge_label = torch.empty((edge_index.size(1), 0))\n    if 'edge_labels' in names:\n        edge_label = read_file(folder, prefix, 'edge_labels', torch.long)\n        if edge_label.dim() == 1:\n            edge_label = edge_label.unsqueeze(-1)\n        edge_label = edge_label - edge_label.min(dim=0)[0]\n        edge_labels = list(edge_label.unbind(dim=-1))\n        edge_labels = [one_hot(e) for e in edge_labels]\n        if len(edge_labels) == 1:\n            edge_label = edge_labels[0]\n        else:\n            edge_label = torch.cat(edge_labels, dim=-1)\n\n    x = cat([node_attribute, node_label])\n    edge_attr = cat([edge_attribute, edge_label])\n\n    y = None\n    if 'graph_attributes' in names:  # Regression problem.\n        y = read_file(folder, prefix, 'graph_attributes')\n    elif 'graph_labels' in names:  # Classification problem.\n        y = read_file(folder, prefix, 'graph_labels', torch.long)\n        _, y = y.unique(sorted=True, return_inverse=True)\n\n    num_nodes = int(edge_index.max()) + 1 if x is None else x.size(0)\n    edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)\n    edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes)\n\n    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)\n    data, slices = split(data, batch)\n\n    sizes = {\n        'num_node_attributes': node_attribute.size(-1),\n        'num_node_labels': node_label.size(-1),\n        'num_edge_attributes': edge_attribute.size(-1),\n        'num_edge_labels': edge_label.size(-1),\n    }\n\n    return data, slices, sizes\n\n\ndef read_file(\n    folder: str,\n    prefix: str,\n    name: str,\n    dtype: Optional[torch.dtype] = None,\n) -> Tensor:\n    path = osp.join(folder, f'{prefix}_{name}.txt')\n    return read_txt_array(path, sep=',', dtype=dtype)\n\n\ndef cat(seq: List[Optional[Tensor]]) -> Optional[Tensor]:\n    values = [v for v in seq if v is not None]\n    values = [v for v in values if v.numel() > 0]\n    values = [v.unsqueeze(-1) if v.dim() == 1 else v for v in values]\n    return torch.cat(values, dim=-1) if len(values) > 0 else None\n\n\ndef split(data: Data, batch: Tensor) -> Tuple[Data, Dict[str, Tensor]]:\n    node_slice = cumsum(torch.bincount(batch))\n\n    assert data.edge_index is not None\n    row, _ = data.edge_index\n    edge_slice = cumsum(torch.bincount(batch[row]))\n\n    # Edge indices should start at zero for every graph.\n    data.edge_index -= node_slice[batch[row]].unsqueeze(0)\n\n    slices = {'edge_index': edge_slice}\n    if data.x is not None:\n        slices['x'] = node_slice\n    else:\n        # Imitate `collate` functionality:\n        data._num_nodes = torch.bincount(batch).tolist()\n        data.num_nodes = batch.numel()\n    if data.edge_attr is not None:\n        slices['edge_attr'] = edge_slice\n    if data.y is not None:\n        assert isinstance(data.y, Tensor)\n        if data.y.size(0) == batch.size(0):\n            slices['y'] = node_slice\n        else:\n            slices['y'] = torch.arange(0, int(batch[-1]) + 2, dtype=torch.long)\n\n    return data, slices\n"
  },
  {
    "path": "torch_geometric/io/txt_array.py",
    "content": "from typing import List, Optional\n\nimport fsspec\nimport torch\nfrom torch import Tensor\n\n\ndef parse_txt_array(\n    src: List[str],\n    sep: Optional[str] = None,\n    start: int = 0,\n    end: Optional[int] = None,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n) -> Tensor:\n    empty = torch.empty(0, dtype=dtype)\n    to_number = float if empty.is_floating_point() else int\n\n    return torch.tensor([[to_number(x) for x in line.split(sep)[start:end]]\n                         for line in src], dtype=dtype).squeeze()\n\n\ndef read_txt_array(\n    path: str,\n    sep: Optional[str] = None,\n    start: int = 0,\n    end: Optional[int] = None,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n) -> Tensor:\n    with fsspec.open(path, 'r') as f:\n        src = f.read().split('\\n')[:-1]\n    return parse_txt_array(src, sep, start, end, dtype, device)\n"
  },
  {
    "path": "torch_geometric/isinstance.py",
    "content": "from typing import Any, Tuple, Type, Union\n\nimport torch\n\nimport torch_geometric.typing\n\nif torch_geometric.typing.WITH_PT20:\n    import torch._dynamo\n\n\ndef is_torch_instance(obj: Any, cls: Union[Type, Tuple[Type]]) -> bool:\n    r\"\"\"Checks if the :obj:`obj` is an instance of a :obj:`cls`.\n\n    This function extends :meth:`isinstance` to be applicable during\n    :meth:`torch.compile` usage by checking against the original class of\n    compiled models.\n    \"\"\"\n    # `torch.compile` removes the model inheritance and converts the model to\n    # a `torch._dynamo.OptimizedModule` instance, leading to `isinstance` being\n    # unable to check the model's inheritance. This function unwraps the\n    # compiled model before evaluating via `isinstance`.\n    if (torch_geometric.typing.WITH_PT20\n            and isinstance(obj, torch._dynamo.OptimizedModule)):\n        return isinstance(obj._orig_mod, cls)\n    return isinstance(obj, cls)\n"
  },
  {
    "path": "torch_geometric/lazy_loader.py",
    "content": "from importlib import import_module\nfrom types import ModuleType\nfrom typing import Any, Dict, List\n\n\n# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/\n# python/util/lazy_loader.py\nclass LazyLoader(ModuleType):\n    def __init__(\n        self,\n        local_name: str,\n        parent_module_globals: Dict[str, Any],\n        name: str,\n    ) -> None:\n        self._local_name = local_name\n        self._parent_module_globals = parent_module_globals\n        super().__init__(name)\n\n    def _load(self) -> Any:\n        module = import_module(self.__name__)\n        self._parent_module_globals[self._local_name] = module\n        self.__dict__.update(module.__dict__)\n        return module\n\n    def __getattr__(self, item: str) -> Any:\n        module = self._load()\n        return getattr(module, item)\n\n    def __dir__(self) -> List[str]:\n        module = self._load()\n        return dir(module)\n"
  },
  {
    "path": "torch_geometric/llm/__init__.py",
    "content": "from .large_graph_indexer import LargeGraphIndexer\nfrom .rag_loader import RAGQueryLoader\nfrom .utils import *  # noqa\nfrom .models import *  # noqa\n\n__all__ = classes = [\n    'LargeGraphIndexer',\n    'RAGQueryLoader',\n]\n"
  },
  {
    "path": "torch_geometric/llm/large_graph_indexer.py",
    "content": "import os\nimport pickle as pkl\nimport shutil\nfrom dataclasses import dataclass\nfrom itertools import chain, islice, tee\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    Iterator,\n    List,\n    Optional,\n    Sequence,\n    Set,\n    Tuple,\n    Union,\n)\n\nimport torch\nfrom torch import Tensor\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.io import fs\nfrom torch_geometric.typing import WITH_PT24\n\n# Could be any hashable type\nTripletLike = Tuple[str, str, str]\n\nKnowledgeGraphLike = Iterable[TripletLike]\n\n\ndef ordered_set(values: Iterable[str]) -> List[str]:\n    return list(dict.fromkeys(values))\n\n\n# TODO: Refactor Node and Edge funcs and attrs to be accessible via an Enum?\n\nNODE_PID = \"pid\"  # Encodes node id\n\nNODE_KEYS = {NODE_PID}\n\nEDGE_PID = \"e_pid\"  # Encodes source node, relation, destination node\nEDGE_HEAD = \"h\"  # Encodes source node\nEDGE_RELATION = \"r\"  # Encodes relation\nEDGE_TAIL = \"t\"  # Encodes destination node\nEDGE_INDEX = \"edge_idx\"  # Encodes source node, destination node\n\nEDGE_KEYS = {EDGE_PID, EDGE_HEAD, EDGE_RELATION, EDGE_TAIL, EDGE_INDEX}\n\nFeatureValueType = Union[Sequence[Any], Tensor]\n\n\n@dataclass\nclass MappedFeature:\n    name: str\n    values: FeatureValueType\n\n    def __eq__(self, value: \"MappedFeature\") -> bool:\n        eq = self.name == value.name\n        if isinstance(self.values, torch.Tensor):\n            eq &= torch.equal(self.values, value.values)\n        else:\n            eq &= self.values == value.values\n        return eq\n\n\nif WITH_PT24:\n    torch.serialization.add_safe_globals([MappedFeature])\n\n\nclass LargeGraphIndexer:\n    \"\"\"For a dataset that consists of multiple subgraphs that are assumed to\n    be part of a much larger graph, collate the values into a large graph store\n    to save resources.\n    \"\"\"\n    def __init__(\n        self,\n        nodes: Iterable[str],\n        edges: KnowledgeGraphLike,\n        node_attr: Optional[Dict[str, List[Any]]] = None,\n        edge_attr: Optional[Dict[str, List[Any]]] = None,\n    ) -> None:\n        r\"\"\"Constructs a new index that uniquely catalogs each node and edge\n        by id. Not meant to be used directly.\n\n        Args:\n            nodes (Iterable[str]): Node ids in the graph.\n            edges (KnowledgeGraphLike): Edge ids in the graph.\n                Example: [(\"cats\", \"eat\", \"dogs\")]\n            node_attr (Optional[Dict[str, List[Any]]], optional): Mapping node\n                attribute name and list of their values in order of unique node\n                ids. Defaults to None.\n            edge_attr (Optional[Dict[str, List[Any]]], optional): Mapping edge\n                attribute name and list of their values in order of unique edge\n                ids. Defaults to None.\n        \"\"\"\n        self._nodes: Dict[str, int] = dict()\n        self._edges: Dict[TripletLike, int] = dict()\n\n        self._mapped_node_features: Set[str] = set()\n        self._mapped_edge_features: Set[str] = set()\n\n        if len(nodes) != len(set(nodes)):\n            raise AttributeError(\"Nodes need to be unique\")\n        if len(edges) != len(set(edges)):\n            raise AttributeError(\"Edges need to be unique\")\n\n        if node_attr is not None:\n            # TODO: Validity checks btw nodes and node_attr\n            self.node_attr = node_attr\n            if NODE_KEYS & set(self.node_attr.keys()) != NODE_KEYS:\n                raise AttributeError(\n                    \"Invalid node_attr object. Missing \" +\n                    f\"{NODE_KEYS - set(self.node_attr.keys())}\")\n            elif self.node_attr[NODE_PID] != nodes:\n                raise AttributeError(\n                    \"Nodes provided do not match those in node_attr\")\n        else:\n            self.node_attr = dict()\n            self.node_attr[NODE_PID] = nodes\n\n        for i, node in enumerate(self.node_attr[NODE_PID]):\n            self._nodes[node] = i\n\n        if edge_attr is not None:\n            # TODO: Validity checks btw edges and edge_attr\n            self.edge_attr = edge_attr\n\n            if EDGE_KEYS & set(self.edge_attr.keys()) != EDGE_KEYS:\n                raise AttributeError(\n                    \"Invalid edge_attr object. Missing \" +\n                    f\"{EDGE_KEYS - set(self.edge_attr.keys())}\")\n            elif self.node_attr[EDGE_PID] != edges:\n                raise AttributeError(\n                    \"Edges provided do not match those in edge_attr\")\n\n        else:\n            self.edge_attr = dict()\n            for default_key in EDGE_KEYS:\n                self.edge_attr[default_key] = list()\n            self.edge_attr[EDGE_PID] = edges\n\n            for tup in edges:\n                h, r, t = tup\n                self.edge_attr[EDGE_HEAD].append(h)\n                self.edge_attr[EDGE_RELATION].append(r)\n                self.edge_attr[EDGE_TAIL].append(t)\n                self.edge_attr[EDGE_INDEX].append(\n                    (self._nodes[h], self._nodes[t]))\n        for i, tup in enumerate(edges):\n            self._edges[tup] = i\n\n    @classmethod\n    def from_triplets(\n        cls,\n        triplets: KnowledgeGraphLike,\n        pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,\n    ) -> \"LargeGraphIndexer\":\n        r\"\"\"Generate a new index from a series of triplets that represent edge\n        relations between nodes.\n        Formatted like (source_node, edge, dest_node).\n\n        Args:\n            triplets (KnowledgeGraphLike): Series of triplets representing\n                knowledge graph relations. Example: [(\"cats\", \"eat\", dogs\")].\n                Note: Please ensure triplets are unique.\n            pre_transform (Optional[Callable[[TripletLike], TripletLike]]):\n                Optional preprocessing function to apply to triplets.\n                Defaults to None.\n\n        Returns:\n            LargeGraphIndexer: Index of unique nodes and edges.\n        \"\"\"\n        # NOTE: Right now assumes that all trips can be loaded into memory\n        nodes = []\n        edges = []\n\n        if pre_transform is not None:\n\n            def apply_transform(\n                    trips: KnowledgeGraphLike) -> Iterator[TripletLike]:\n                for trip in trips:\n                    yield pre_transform(trip)\n\n            triplets = list(apply_transform(triplets))\n\n        for h, r, t in triplets:\n\n            for node in (h, t):\n                nodes.append(node)\n\n            edge_idx = (h, r, t)\n            edges.append(edge_idx)\n        nodes = ordered_set(nodes)\n        edges = ordered_set(edges)\n        return cls(list(nodes), list(edges))\n\n    @classmethod\n    def collate(cls,\n                graphs: Iterable[\"LargeGraphIndexer\"]) -> \"LargeGraphIndexer\":\n        r\"\"\"Combines a series of large graph indexes into a single large graph\n        index.\n\n        Args:\n            graphs (Iterable[LargeGraphIndexer]): Indices to be\n                combined.\n\n        Returns:\n            LargeGraphIndexer: Singular unique index for all nodes and edges\n                in input indices.\n        \"\"\"\n        # FIXME Needs to merge node attrs and edge attrs?\n        trips = chain.from_iterable([graph.to_triplets() for graph in graphs])\n        return cls.from_triplets(trips)\n\n    def get_unique_node_features(self,\n                                 feature_name: str = NODE_PID) -> List[str]:\n        r\"\"\"Get all the unique values for a specific node attribute.\n\n        Args:\n            feature_name (str, optional): Name of feature to get.\n                Defaults to NODE_PID.\n\n        Returns:\n            List[str]: List of unique values for the specified feature.\n        \"\"\"\n        try:\n            if feature_name in self._mapped_node_features:\n                raise IndexError(\n                    \"Only non-mapped features can be retrieved uniquely.\")\n            return ordered_set(self.get_node_features(feature_name))\n\n        except KeyError as e:\n            raise AttributeError(\n                f\"Nodes do not have a feature called {feature_name}\") from e\n\n    def add_node_feature(\n        self,\n        new_feature_name: str,\n        new_feature_vals: FeatureValueType,\n        map_from_feature: str = NODE_PID,\n    ) -> None:\n        r\"\"\"Adds a new feature that corresponds to each unique node in\n            the graph.\n\n        Args:\n            new_feature_name (str): Name to call the new feature.\n            new_feature_vals (FeatureValueType): Values to map for that\n                new feature.\n            map_from_feature (str, optional): Key of feature to map from.\n                Size must match the number of feature values.\n                Defaults to NODE_PID.\n        \"\"\"\n        if new_feature_name in self.node_attr:\n            raise AttributeError(\"Features cannot be overridden once created\")\n        if map_from_feature in self._mapped_node_features:\n            raise AttributeError(\n                f\"{map_from_feature} is already a feature mapping.\")\n\n        feature_keys = self.get_unique_node_features(map_from_feature)\n        if len(feature_keys) != len(new_feature_vals):\n            raise AttributeError(\n                \"Expected encodings for {len(feature_keys)} unique features,\" +\n                f\" but got {len(new_feature_vals)} encodings.\")\n\n        if map_from_feature == NODE_PID:\n            self.node_attr[new_feature_name] = new_feature_vals\n        else:\n            self.node_attr[new_feature_name] = MappedFeature(\n                name=map_from_feature, values=new_feature_vals)\n            self._mapped_node_features.add(new_feature_name)\n\n    def get_node_features(\n        self,\n        feature_name: str = NODE_PID,\n        pids: Optional[Iterable[str]] = None,\n    ) -> List[Any]:\n        r\"\"\"Get node feature values for a given set of unique node ids.\n            Returned values are not necessarily unique.\n\n        Args:\n            feature_name (str, optional): Name of feature to fetch. Defaults\n                to NODE_PID.\n            pids (Optional[Iterable[str]], optional): Node ids to fetch\n                for. Defaults to None, which fetches all nodes.\n\n        Returns:\n            List[Any]: Node features corresponding to the specified ids.\n        \"\"\"\n        if feature_name in self._mapped_node_features:\n            values = self.node_attr[feature_name].values\n        else:\n            values = self.node_attr[feature_name]\n        # TODO: torch_geometric.utils.select\n        if isinstance(values, torch.Tensor):\n            idxs = list(\n                self.get_node_features_iter(feature_name, pids,\n                                            index_only=True))\n            return values[torch.tensor(idxs).long()]\n        return list(self.get_node_features_iter(feature_name, pids))\n\n    def get_node_features_iter(\n        self,\n        feature_name: str = NODE_PID,\n        pids: Optional[Iterable[str]] = None,\n        index_only: bool = False,\n    ) -> Iterator[Any]:\n        \"\"\"Iterator version of get_node_features. If index_only is True,\n        yields indices instead of values.\n        \"\"\"\n        if pids is None:\n            pids = self.node_attr[NODE_PID]\n\n        if feature_name in self._mapped_node_features:\n            feature_map_info = self.node_attr[feature_name]\n            from_feature_name, to_feature_vals = (\n                feature_map_info.name,\n                feature_map_info.values,\n            )\n            from_feature_vals = self.get_unique_node_features(\n                from_feature_name)\n            feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}\n\n            for pid in pids:\n                idx = self._nodes[pid]\n                from_feature_val = self.node_attr[from_feature_name][idx]\n                to_feature_idx = feature_mapping[from_feature_val]\n                if index_only:\n                    yield to_feature_idx\n                else:\n                    yield to_feature_vals[to_feature_idx]\n        else:\n            for pid in pids:\n                idx = self._nodes[pid]\n                if index_only:\n                    yield idx\n                else:\n                    yield self.node_attr[feature_name][idx]\n\n    def get_unique_edge_features(self,\n                                 feature_name: str = EDGE_PID) -> List[str]:\n        r\"\"\"Get all the unique values for a specific edge attribute.\n\n        Args:\n            feature_name (str, optional): Name of feature to get.\n                Defaults to EDGE_PID.\n\n        Returns:\n            List[str]: List of unique values for the specified feature.\n        \"\"\"\n        try:\n            if feature_name in self._mapped_edge_features:\n                raise IndexError(\n                    \"Only non-mapped features can be retrieved uniquely.\")\n            return ordered_set(self.get_edge_features(feature_name))\n        except KeyError as e:\n            raise AttributeError(\n                f\"Edges do not have a feature called {feature_name}\") from e\n\n    def add_edge_feature(\n        self,\n        new_feature_name: str,\n        new_feature_vals: FeatureValueType,\n        map_from_feature: str = EDGE_PID,\n    ) -> None:\n        r\"\"\"Adds a new feature that corresponds to each unique edge in\n        the graph.\n\n        Args:\n            new_feature_name (str): Name to call the new feature.\n            new_feature_vals (FeatureValueType): Values to map for that new\n                feature.\n            map_from_feature (str, optional): Key of feature to map from.\n                Size must match the number of feature values.\n                Defaults to EDGE_PID.\n        \"\"\"\n        if new_feature_name in self.edge_attr:\n            raise AttributeError(\"Features cannot be overridden once created\")\n        if map_from_feature in self._mapped_edge_features:\n            raise AttributeError(\n                f\"{map_from_feature} is already a feature mapping.\")\n\n        feature_keys = self.get_unique_edge_features(map_from_feature)\n        if len(feature_keys) != len(new_feature_vals):\n            raise AttributeError(\n                f\"Expected encodings for {len(feature_keys)} unique features, \"\n                + f\"but got {len(new_feature_vals)} encodings.\")\n\n        if map_from_feature == EDGE_PID:\n            self.edge_attr[new_feature_name] = new_feature_vals\n        else:\n            self.edge_attr[new_feature_name] = MappedFeature(\n                name=map_from_feature, values=new_feature_vals)\n            self._mapped_edge_features.add(new_feature_name)\n\n    def get_edge_features(\n        self,\n        feature_name: str = EDGE_PID,\n        pids: Optional[Iterable[str]] = None,\n    ) -> List[Any]:\n        r\"\"\"Get edge feature values for a given set of unique edge ids.\n            Returned values are not necessarily unique.\n\n        Args:\n            feature_name (str, optional): Name of feature to fetch.\n                Defaults to EDGE_PID.\n            pids (Optional[Iterable[str]], optional): Edge ids to fetch\n                for. Defaults to None, which fetches all edges.\n\n        Returns:\n            List[Any]: Node features corresponding to the specified ids.\n        \"\"\"\n        if feature_name in self._mapped_edge_features:\n            values = self.edge_attr[feature_name].values\n        else:\n            values = self.edge_attr[feature_name]\n\n        # TODO: torch_geometric.utils.select\n        if isinstance(values, torch.Tensor):\n            idxs = list(\n                self.get_edge_features_iter(feature_name, pids,\n                                            index_only=True))\n            return values[torch.tensor(idxs).long()]\n        return list(self.get_edge_features_iter(feature_name, pids))\n\n    def get_edge_features_iter(\n        self,\n        feature_name: str = EDGE_PID,\n        pids: Optional[KnowledgeGraphLike] = None,\n        index_only: bool = False,\n    ) -> Iterator[Any]:\n        \"\"\"Iterator version of get_edge_features. If index_only is True,\n        yields indices instead of values.\n        \"\"\"\n        if pids is None:\n            pids = self.edge_attr[EDGE_PID]\n\n        if feature_name in self._mapped_edge_features:\n            feature_map_info = self.edge_attr[feature_name]\n            from_feature_name, to_feature_vals = (\n                feature_map_info.name,\n                feature_map_info.values,\n            )\n            from_feature_vals = self.get_unique_edge_features(\n                from_feature_name)\n            feature_mapping = {k: i for i, k in enumerate(from_feature_vals)}\n\n            for pid in pids:\n                idx = self._edges[pid]\n                from_feature_val = self.edge_attr[from_feature_name][idx]\n                to_feature_idx = feature_mapping[from_feature_val]\n                if index_only:\n                    yield to_feature_idx\n                else:\n                    yield to_feature_vals[to_feature_idx]\n        else:\n            for pid in pids:\n                idx = self._edges[pid]\n                if index_only:\n                    yield idx\n                else:\n                    yield self.edge_attr[feature_name][idx]\n\n    def to_triplets(self) -> Iterator[TripletLike]:\n        return iter(self.edge_attr[EDGE_PID])\n\n    def save(self, path: str) -> None:\n        if os.path.exists(path):\n            shutil.rmtree(path)\n        os.makedirs(path, exist_ok=True)\n        with open(path + \"/edges\", \"wb\") as f:\n            pkl.dump(self._edges, f)\n        with open(path + \"/nodes\", \"wb\") as f:\n            pkl.dump(self._nodes, f)\n\n        with open(path + \"/mapped_edges\", \"wb\") as f:\n            pkl.dump(self._mapped_edge_features, f)\n        with open(path + \"/mapped_nodes\", \"wb\") as f:\n            pkl.dump(self._mapped_node_features, f)\n\n        node_attr_path = path + \"/node_attr\"\n        os.makedirs(node_attr_path, exist_ok=True)\n        for attr_name, vals in self.node_attr.items():\n            torch.save(vals, node_attr_path + f\"/{attr_name}.pt\")\n\n        edge_attr_path = path + \"/edge_attr\"\n        os.makedirs(edge_attr_path, exist_ok=True)\n        for attr_name, vals in self.edge_attr.items():\n            torch.save(vals, edge_attr_path + f\"/{attr_name}.pt\")\n\n    @classmethod\n    def from_disk(cls, path: str) -> \"LargeGraphIndexer\":\n        indexer = cls(list(), list())\n        with open(path + \"/edges\", \"rb\") as f:\n            indexer._edges = pkl.load(f)\n        with open(path + \"/nodes\", \"rb\") as f:\n            indexer._nodes = pkl.load(f)\n\n        with open(path + \"/mapped_edges\", \"rb\") as f:\n            indexer._mapped_edge_features = pkl.load(f)\n        with open(path + \"/mapped_nodes\", \"rb\") as f:\n            indexer._mapped_node_features = pkl.load(f)\n\n        node_attr_path = path + \"/node_attr\"\n        for fname in os.listdir(node_attr_path):\n            full_fname = f\"{node_attr_path}/{fname}\"\n            key = fname.split(\".\")[0]\n            indexer.node_attr[key] = fs.torch_load(full_fname)\n\n        edge_attr_path = path + \"/edge_attr\"\n        for fname in os.listdir(edge_attr_path):\n            full_fname = f\"{edge_attr_path}/{fname}\"\n            key = fname.split(\".\")[0]\n            indexer.edge_attr[key] = fs.torch_load(full_fname)\n\n        return indexer\n\n    def to_data(self, node_feature_name: str,\n                edge_feature_name: Optional[str] = None) -> Data:\n        \"\"\"Return a Data object containing all the specified node and\n            edge features and the graph.\n\n        Args:\n            node_feature_name (str): Feature to use for nodes\n            edge_feature_name (Optional[str], optional): Feature to use for\n                edges. Defaults to None.\n\n        Returns:\n            Data: Data object containing the specified node and\n                edge features and the graph.\n        \"\"\"\n        x = torch.Tensor(self.get_node_features(node_feature_name))\n        node_id = torch.LongTensor(range(len(x)))\n        edge_index = torch.t(\n            torch.LongTensor(self.get_edge_features(EDGE_INDEX)))\n\n        edge_attr = (self.get_edge_features(edge_feature_name)\n                     if edge_feature_name is not None else None)\n        edge_id = torch.LongTensor(range(len(edge_attr)))\n\n        return Data(x=x, edge_index=edge_index, edge_attr=edge_attr,\n                    edge_id=edge_id, node_id=node_id)\n\n    def __eq__(self, value: \"LargeGraphIndexer\") -> bool:\n        eq = True\n        eq &= self._nodes == value._nodes\n        eq &= self._edges == value._edges\n        eq &= self.node_attr.keys() == value.node_attr.keys()\n        eq &= self.edge_attr.keys() == value.edge_attr.keys()\n        eq &= self._mapped_node_features == value._mapped_node_features\n        eq &= self._mapped_edge_features == value._mapped_edge_features\n\n        for k in self.node_attr:\n            eq &= isinstance(self.node_attr[k], type(value.node_attr[k]))\n            if isinstance(self.node_attr[k], torch.Tensor):\n                eq &= torch.equal(self.node_attr[k], value.node_attr[k])\n            else:\n                eq &= self.node_attr[k] == value.node_attr[k]\n        for k in self.edge_attr:\n            eq &= isinstance(self.edge_attr[k], type(value.edge_attr[k]))\n            if isinstance(self.edge_attr[k], torch.Tensor):\n                eq &= torch.equal(self.edge_attr[k], value.edge_attr[k])\n            else:\n                eq &= self.edge_attr[k] == value.edge_attr[k]\n        return eq\n\n\ndef get_features_for_triplets_groups(\n    indexer: LargeGraphIndexer,\n    triplet_groups: Iterable[KnowledgeGraphLike],\n    node_feature_name: str = \"x\",\n    edge_feature_name: str = \"edge_attr\",\n    pre_transform: Callable[[TripletLike], TripletLike] = lambda trip: trip,\n    verbose: bool = False,\n    max_batch_size: int = 250,\n    num_workers: Optional[int] = None,\n) -> Iterator[Data]:\n    \"\"\"Given an indexer and a series of triplet groups (like a dataset),\n    retrieve the specified node and edge features for each triplet from the\n    index.\n\n    Args:\n        indexer (LargeGraphIndexer): Indexer containing desired features\n        triplet_groups (Iterable[KnowledgeGraphLike]): List of lists of\n            triplets to fetch features for\n        node_feature_name (str, optional): Node feature to fetch.\n            Defaults to \"x\".\n        edge_feature_name (str, optional): edge feature to fetch.\n            Defaults to \"edge_attr\".\n        pre_transform (Callable[[TripletLike], TripletLike]):\n            Optional preprocessing to perform on triplets.\n            Defaults to None.\n        verbose (bool, optional): Whether to print progress.\n            Defaults to False.\n        max_batch_size (int, optional):\n            Maximum batch size for fetching features.\n            Defaults to 250.\n        num_workers (int, optional):\n            Number of workers to use for fetching features.\n            Defaults to None (all available).\n\n    Yields:\n        Iterator[Data]: For each triplet group, yield a data object containing\n            the unique graph and features from the index.\n    \"\"\"\n    def apply_transform(trips: Iterable[TripletLike]) -> Iterator[TripletLike]:\n        for trip in trips:\n            yield pre_transform(tuple(trip))\n\n    # Carefully trying to avoid loading all triplets into memory at once\n    # While also still tracking the number of elements for tqdm\n    triplet_groups: List[Iterator[TripletLike]] = [\n        apply_transform(triplets) for triplets in triplet_groups\n    ]\n\n    node_keys = []\n    edge_keys = []\n    edge_index = []\n    \"\"\"\n    For each KG, we gather the node_indices, edge_keys,\n    and edge_indices needed to construct each Data object\n    \"\"\"\n\n    for kg_triplets in tqdm(triplet_groups, disable=not verbose):\n        kg_triplets_nodes, kg_triplets_edge_keys, kg_triplets_edge_index = tee(\n            kg_triplets, 3)\n        \"\"\"\n        Don't apply pre_transform here,\n        because it has already been applied on the triplet groups/\n        \"\"\"\n        small_graph_indexer = LargeGraphIndexer.from_triplets(\n            kg_triplets_nodes)\n\n        node_keys.append(small_graph_indexer.get_node_features())\n        edge_keys.append(\n            small_graph_indexer.get_edge_features(pids=kg_triplets_edge_keys))\n        edge_index.append(\n            small_graph_indexer.get_edge_features(\n                EDGE_INDEX,\n                kg_triplets_edge_index,\n            ))\n    \"\"\"\n    We get the embeddings for each node and edge key in the KG,\n    but we need to do so in batches.\n    Batches that are too small waste compute time,\n    as each call to get features has an upfront cost.\n    Batches that are too large waste memory,\n    as we need to store all the result embeddings in memory.\n    \"\"\"\n\n    def _fetch_feature_batch(batches):\n        node_key_batch, edge_key_batch, edge_index_batch = batches\n        node_feats = indexer.get_node_features(\n            feature_name=node_feature_name,\n            pids=chain.from_iterable(node_key_batch))\n        edge_feats = indexer.get_edge_features(\n            feature_name=edge_feature_name,\n            pids=chain.from_iterable(edge_key_batch))\n\n        last_node_idx, last_edge_idx = 0, 0\n        for (nkeys, ekeys, eidx) in zip(node_key_batch, edge_key_batch,\n                                        edge_index_batch):\n            nlen, elen = len(nkeys), len(ekeys)\n            x = torch.Tensor(node_feats[last_node_idx:last_node_idx + nlen])\n            last_node_idx += len(nkeys)\n\n            edge_attr = torch.Tensor(edge_feats[last_edge_idx:last_edge_idx +\n                                                elen])\n            last_edge_idx += len(ekeys)\n\n            edge_idx = torch.LongTensor(eidx).T\n\n            data_obj = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)\n            data_obj[NODE_PID] = node_keys\n            data_obj[EDGE_PID] = edge_keys\n            data_obj[\"node_idx\"] = [indexer._nodes[k] for k in nkeys]\n            data_obj[\"edge_idx\"] = [indexer._edges[e] for e in ekeys]\n\n            yield data_obj\n\n    # NOTE: Backport of itertools.batched from Python 3.12\n    def batched(iterable, n, *, strict=False):\n        # batched('ABCDEFG', 3) → ABC DEF G\n        if n < 1:\n            raise ValueError('n must be at least one')\n        iterator = iter(iterable)\n        while batch := tuple(islice(iterator, n)):\n            if strict and len(batch) != n:\n                raise ValueError('batched(): incomplete batch')\n            yield batch\n\n    import multiprocessing as mp\n    import multiprocessing.pool as mpp\n    num_workers = num_workers if num_workers is not None else mp.cpu_count()\n    ideal_batch_size = min(max_batch_size,\n                           max(1,\n                               len(triplet_groups) // num_workers))\n\n    node_key_batches = batched(node_keys, ideal_batch_size)\n    edge_key_batches = batched(edge_keys, ideal_batch_size)\n    edge_index_batches = batched(edge_index, ideal_batch_size)\n    batches = zip(node_key_batches, edge_key_batches, edge_index_batches)\n\n    with mpp.ThreadPool() as pool:\n        result = pool.map(_fetch_feature_batch, batches)\n    yield from chain.from_iterable(result)\n\n\ndef get_features_for_triplets(\n    indexer: LargeGraphIndexer,\n    triplets: KnowledgeGraphLike,\n    node_feature_name: str = \"x\",\n    edge_feature_name: str = \"edge_attr\",\n    pre_transform: Callable[[TripletLike], TripletLike] = lambda trip: trip,\n    verbose: bool = False,\n) -> Data:\n    \"\"\"For a given set of triplets retrieve a Data object containing the\n        unique graph and features from the index.\n\n    Args:\n        indexer (LargeGraphIndexer): Indexer containing desired features\n        triplets (KnowledgeGraphLike): Triplets to fetch features for\n        node_feature_name (str, optional): Feature to use for node features.\n            Defaults to \"x\".\n        edge_feature_name (str, optional): Feature to use for edge features.\n            Defaults to \"edge_attr\".\n        pre_transform (Callable[[TripletLike], TripletLike]):\n            Optional preprocessing function for triplets. Defaults to None.\n        verbose (bool, optional): Whether to print progress. Defaults to False.\n\n    Returns:\n        Data: Data object containing the unique graph and features from the\n            index for the given triplets.\n    \"\"\"\n    gen = get_features_for_triplets_groups(indexer, [triplets],\n                                           node_feature_name,\n                                           edge_feature_name, pre_transform,\n                                           verbose, max_batch_size=1)\n    return next(gen)\n"
  },
  {
    "path": "torch_geometric/llm/models/__init__.py",
    "content": "from .sentence_transformer import SentenceTransformer\nfrom .vision_transformer import VisionTransformer\nfrom .llm import LLM\nfrom .txt2kg import TXT2KG\nfrom .llm_judge import LLMJudge\nfrom .g_retriever import GRetriever\nfrom .molecule_gpt import MoleculeGPT\nfrom .glem import GLEM\nfrom .protein_mpnn import ProteinMPNN\nfrom .git_mol import GITMol\n\n__all__ = classes = [\n    'SentenceTransformer',\n    'VisionTransformer',\n    'LLM',\n    'LLMJudge',\n    'TXT2KG',\n    'GRetriever',\n    'MoleculeGPT',\n    'GLEM',\n    'ProteinMPNN',\n    'GITMol',\n]\n"
  },
  {
    "path": "torch_geometric/llm/models/g_retriever.py",
    "content": "from typing import List, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.llm.models.llm import LLM, MAX_NEW_TOKENS\nfrom torch_geometric.utils import scatter\n\n\nclass GRetriever(torch.nn.Module):\n    r\"\"\"The G-Retriever model from the `\"G-Retriever: Retrieval-Augmented\n    Generation for Textual Graph Understanding and Question Answering\"\n    <https://arxiv.org/abs/2402.07630>`_ paper.\n\n    Args:\n        llm (LLM): The LLM to use.\n        gnn (torch.nn.Module): The GNN to use.\n        use_lora (bool, optional): If set to :obj:`True`, will use LORA from\n            :obj:`peft` for training the LLM, see\n            `here <https://huggingface.co/docs/peft/en/index>`_ for details.\n            (default: :obj:`False`)\n        mlp_out_tokens (int, optional): Number of LLM prefix tokens to\n            reserve for GNN output. (default: :obj:`1`)\n\n    .. warning::\n        This module has been tested with the following HuggingFace models\n        * :obj:`llm_to_use=\"meta-llama/Meta-Llama-3.1-8B-Instruct\"`\n        * :obj:`llm_to_use=\"Qwen/Qwen3-0.6B\"`\n\n\n        This module should work with any HuggingFace model.\n        See other models at `HuggingFace\n        Models <https://huggingface.co/models>`_\n        and let us know if you\n        encounter any issues.\n\n    .. note::\n        For an example of using :class:`GRetriever`, see\n        `examples/llm/g_retriever.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/llm/g_retriever.py>`_.\n    \"\"\"\n    def __init__(\n        self,\n        llm: LLM,\n        gnn: torch.nn.Module = None,\n        use_lora: bool = False,\n        mlp_out_tokens: int = 1,\n    ) -> None:\n        super().__init__()\n\n        self.llm = llm\n        self.gnn = gnn.to(self.llm.device) if gnn is not None else None\n\n        self.word_embedding = self.llm.word_embedding\n        self.llm_generator = self.llm.llm\n        if use_lora:\n            from peft import (\n                LoraConfig,\n                get_peft_model,\n                prepare_model_for_kbit_training,\n            )\n            self.llm_generator = prepare_model_for_kbit_training(\n                self.llm_generator)\n            lora_r: int = 8\n            lora_alpha: int = 16\n            lora_dropout: float = 0.05\n            lora_target_modules = ['q_proj', 'v_proj']\n            config = LoraConfig(\n                r=lora_r,\n                lora_alpha=lora_alpha,\n                target_modules=lora_target_modules,\n                lora_dropout=lora_dropout,\n                bias='none',\n                task_type='CAUSAL_LM',\n            )\n            self.llm_generator = get_peft_model(self.llm_generator, config)\n\n        if self.gnn is not None:\n            mlp_out_channels = llm.word_embedding.embedding_dim\n            mlp_hidden_channels = self.gnn.out_channels\n            self.projector = torch.nn.Sequential(\n                torch.nn.Linear(mlp_hidden_channels, mlp_hidden_channels),\n                torch.nn.Sigmoid(),\n                torch.nn.Linear(mlp_hidden_channels,\n                                mlp_out_channels * mlp_out_tokens),\n                torch.nn.Unflatten(-1, (mlp_out_tokens, mlp_out_channels)),\n            ).to(self.llm.device)\n\n        self.seq_length_stats = []\n\n    def encode(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n        edge_attr: Optional[Tensor],\n    ) -> Tensor:\n        x = x.to(self.llm.device)\n        edge_index = edge_index.to(self.llm.device)\n        if edge_attr is not None:\n            edge_attr = edge_attr.to(self.llm.device)\n        batch = batch.to(self.llm.device)\n\n        model_specific_kwargs = {}\n\n        # duck typing for SGFormer to get around circular import\n        if (hasattr(self.gnn, 'trans_conv')\n                and hasattr(self.gnn, 'graph_conv')):\n            model_specific_kwargs['batch'] = batch\n        else:\n            model_specific_kwargs['edge_attr'] = edge_attr\n\n        out = self.gnn(x, edge_index, **model_specific_kwargs)\n        return scatter(out, batch, dim=0, reduce='mean')\n\n    def forward(\n        self,\n        question: List[str],\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n        label: List[str],\n        edge_attr: Optional[Tensor] = None,\n        additional_text_context: Optional[List[str]] = None,\n    ):\n        r\"\"\"The forward pass.\n\n        Args:\n            question (List[str]): The questions/prompts.\n            x (torch.Tensor): The input node features.\n            edge_index (torch.Tensor): The edge indices.\n            batch (torch.Tensor): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example.\n            label (List[str]): The answers/labels.\n            edge_attr (torch.Tensor, optional): The edge features (if supported\n                by the GNN). (default: :obj:`None`)\n            additional_text_context (List[str], optional): Additional context\n                to give to the LLM, such as textified knowledge graphs.\n                (default: :obj:`None`)\n        \"\"\"\n        xs = None\n        if self.gnn is not None:\n            x = self.encode(x, edge_index, batch, edge_attr)\n            x = self.projector(x)\n            x = self._align_dtype(x, self.llm_generator)\n            xs = x.split(1, dim=0)\n\n            # Handle case where theres more than one embedding for each sample\n            xs = [x.squeeze(0) for x in xs]\n\n            # Handle questions without node features:\n            batch_unique = batch.unique()\n            batch_size = len(question)\n            if len(batch_unique) < batch_size:\n                xs = [\n                    xs[i] if i in batch_unique else None\n                    for i in range(batch_size)\n                ]\n        (\n            inputs_embeds,\n            attention_mask,\n            label_input_ids,\n        ) = self.llm._get_embeds(question, additional_text_context, xs, label)\n\n        max_seq_len = inputs_embeds.size(1)\n        self.seq_length_stats.append(max_seq_len)\n\n        with self.llm.autocast_context:\n            outputs = self.llm_generator(\n                inputs_embeds=inputs_embeds,\n                attention_mask=attention_mask,\n                return_dict=True,\n                labels=label_input_ids,\n            )\n\n        return outputs.loss\n\n    @torch.no_grad()\n    def inference(\n        self,\n        question: List[str],\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n        edge_attr: Optional[Tensor] = None,\n        additional_text_context: Optional[List[str]] = None,\n        max_out_tokens: Optional[int] = MAX_NEW_TOKENS,\n    ):\n        r\"\"\"The inference pass.\n\n        Args:\n            question (List[str]): The questions/prompts.\n            x (torch.Tensor): The input node features.\n            edge_index (torch.Tensor): The edge indices.\n            batch (torch.Tensor): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example.\n            edge_attr (torch.Tensor, optional): The edge features (if supported\n                by the GNN). (default: :obj:`None`)\n            additional_text_context (List[str], optional): Additional context\n                to give to the LLM, such as textified knowledge graphs.\n                (default: :obj:`None`)\n            max_out_tokens (int, optional): How many tokens for the LLM to\n                generate. (default: :obj:`32`)\n        \"\"\"\n        xs = None\n        if self.gnn is not None:\n            x = self.encode(x, edge_index, batch, edge_attr)\n            x = self.projector(x)\n            xs = x.split(1, dim=0)\n\n            # Handle case where theres more than one embedding for each sample\n            xs = [x.squeeze(0) for x in xs]\n\n            # Handle questions without node features:\n            batch_unique = batch.unique()\n            batch_size = len(question)\n            if len(batch_unique) < batch_size:\n                xs = [\n                    xs[i] if i in batch_unique else None\n                    for i in range(batch_size)\n                ]\n\n        inputs_embeds, attention_mask, _ = self.llm._get_embeds(\n            question, additional_text_context, xs)\n\n        # bos_token = self.llm.tokenizer(\n        #     self.llm.tokenizer.bos_token_id,\n        #     add_special_tokens=False,\n        # ).input_ids[0]\n\n        with self.llm.autocast_context:\n            outputs = self.llm_generator.generate(\n                inputs_embeds=inputs_embeds,\n                max_new_tokens=max_out_tokens,\n                attention_mask=attention_mask,\n                bos_token_id=self.llm.tokenizer.bos_token_id,\n                pad_token_id=self.llm.tokenizer.eos_token_id,\n                use_cache=True  # Important to set!\n            )\n\n        return self.llm.tokenizer.batch_decode(\n            outputs,\n            skip_special_tokens=True,\n        )\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(\\n'\n                f'  llm={self.llm},\\n'\n                f'  gnn={self.gnn},\\n'\n                f')')\n\n    def _align_dtype(\n        self,\n        x: torch.Tensor,\n        llm_generator: torch.nn.Module,\n    ) -> torch.Tensor:\n        llm_dtype = next(iter(llm_generator.parameters())).dtype\n        if x.dtype != llm_dtype:\n            x = x.to(llm_dtype)\n\n        return x\n"
  },
  {
    "path": "torch_geometric/llm/models/git_mol.py",
    "content": "from typing import List, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import BatchNorm1d, LayerNorm, Linear, ReLU, Sequential\n\nfrom torch_geometric.llm.models import SentenceTransformer, VisionTransformer\nfrom torch_geometric.nn import GINEConv\nfrom torch_geometric.utils import add_self_loops, to_dense_batch\n\n\nclass GraphEncoder(torch.nn.Module):\n    def __init__(\n        self,\n        num_layers: int,\n        in_channels: int,\n        dropout: float = 0.,\n        num_atom_type: int = 120,\n        num_chirality_tag: int = 3,\n        num_bond_type: int = 6,\n        num_bond_direction: int = 3,\n    ) -> None:\n        super().__init__()\n\n        self.num_layers = num_layers\n        self.dropout = dropout\n\n        self.x_embed1 = torch.nn.Embedding(num_atom_type, in_channels)\n        self.x_embed2 = torch.nn.Embedding(num_chirality_tag, in_channels)\n        self.edge_embed1 = torch.nn.Embedding(num_bond_type, in_channels)\n        self.edge_embed2 = torch.nn.Embedding(num_bond_direction, in_channels)\n\n        self.gnns = torch.nn.ModuleList()\n        self.batch_norms = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            self.gnns.append(\n                GINEConv(\n                    nn=Sequential(\n                        Linear(in_channels, in_channels * 2),\n                        ReLU(),\n                        Linear(in_channels * 2, in_channels),\n                    ),\n                    train_eps=True,\n                    edge_dim=in_channels,\n                ))\n            self.batch_norms.append(BatchNorm1d(in_channels))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.x_embed1.weight.data)\n        torch.nn.init.xavier_uniform_(self.x_embed2.weight.data)\n        torch.nn.init.xavier_uniform_(self.edge_embed1.weight.data)\n        torch.nn.init.xavier_uniform_(self.edge_embed2.weight.data)\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n        edge_attr: Tensor,\n    ) -> Tensor:\n        x = self.x_embed1(x[:, 0].long()) + self.x_embed2(x[:, 1].long())\n        edge_index, edge_attr = add_self_loops(\n            edge_index,\n            edge_attr,\n            fill_value=0,\n            num_nodes=x.size(0),\n        )\n        edge_attr = self.edge_embed1(edge_attr[:, 0]) + self.edge_embed2(\n            edge_attr[:, 1])\n        for i, (gnn, bn) in enumerate(zip(self.gnns, self.batch_norms)):\n            x = gnn(x, edge_index, edge_attr)\n            x = bn(x)\n            if i < self.num_layers - 1:\n                x = F.relu(x)\n            x = F.dropout(x, self.dropout, training=self.training)\n\n        x, mask = to_dense_batch(x, batch)\n        return x, mask\n\n\nclass GITFormer(torch.nn.Module):\n    def __init__(\n        self,\n        num_query_token: int,\n        vision_graph_width: int,\n        cross_attention_freq: int = 2,\n    ):\n        super().__init__()\n        from transformers import AutoConfig, AutoModel\n\n        config = AutoConfig.from_pretrained(\"allenai/scibert_scivocab_uncased\")\n        config.encoder_width = vision_graph_width\n        # insert cross-attention layer every other block\n        config.add_cross_attention = True\n        config.is_decoder = True\n        config.cross_attention_freq = cross_attention_freq\n        config.query_length = num_query_token\n        self.Qformer = AutoModel.from_pretrained(\n            \"allenai/scibert_scivocab_uncased\", config=config)\n        self.query_tokens = torch.nn.Parameter(\n            torch.zeros(1, num_query_token, config.hidden_size))\n        self.query_tokens.data.normal_(mean=0.0, std=config.initializer_range)\n\n\nclass GITMol(torch.nn.Module):\n    r\"\"\"The GITMol model from the `\"GIT-Mol: A Multi-modal Large Language\n    Model for Molecular Science with Graph, Image, and Text\"\n    <https://arxiv.org/pdf/2308.06911>`_ paper.\n\n    .. note::\n        For an example of using :class:`GITMol`, see\n        `examples/llm/git_mol.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/llm/git_mol.py>`_.\n    \"\"\"\n    def __init__(self) -> None:\n        super().__init__()\n        # graph\n        self.graph_encoder = GraphEncoder(num_layers=2, in_channels=16)\n        self.graph_proj = Linear(16, 768)\n        self.ln_graph = LayerNorm(768)\n        # text\n        self.text_encoder = SentenceTransformer(\n            model_name='allenai/scibert_scivocab_uncased',\n            pooling_strategy='last_hidden_state',\n        )\n        self.text_proj = Linear(768, 768)\n        self.ln_text = LayerNorm(768)\n        # vision\n        self.vision_encoder = VisionTransformer(\n            model_name='microsoft/swin-base-patch4-window7-224', )\n        self.vision_proj = Linear(1024, 768)\n        self.ln_vision = LayerNorm(768)\n        # cross-attention\n        self.gitformer = GITFormer(384, 768)\n\n        self.xtm_head = torch.nn.ModuleDict({\n            'image':\n            Linear(self.gitformer.Qformer.config.hidden_size, 2),\n            'graph':\n            Linear(self.gitformer.Qformer.config.hidden_size, 2),\n            'cs_text':\n            Linear(self.gitformer.Qformer.config.hidden_size, 2),\n        })\n\n        self.xtc_proj = torch.nn.ModuleDict({\n            'image':\n            Linear(self.gitformer.Qformer.config.hidden_size, 768),\n            'graph':\n            Linear(self.gitformer.Qformer.config.hidden_size, 768),\n            'cs_text':\n            Linear(self.gitformer.Qformer.config.hidden_size, 768),\n        })\n        self.temp = torch.nn.Parameter(0.07 * torch.ones([]))\n        self.model_freeze()\n\n    def model_freeze(self) -> None:\n        for param in self.graph_encoder.parameters():\n            param.requires_grad = False\n\n        for param in self.vision_encoder.parameters():\n            param.requires_grad = False\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n        edge_attr: Optional[Tensor],\n        smiles: List[str],\n        images: Tensor,\n        captions: List[str],\n    ) -> Tensor:\n        batch_size = len(smiles)\n\n        x_vision = self.vision_encoder(images)\n        x_vision = self.vision_proj(x_vision)\n        x_vision = self.ln_vision(x_vision)  # [bs, patch_len, d]\n        vision_atts = torch.ones(x_vision.size()[:-1],\n                                 dtype=torch.long).to(x_vision.device)\n        vision_targets = torch.arange(batch_size).to(x_vision.device)\n\n        x_graph, graph_atts = self.graph_encoder(x, edge_index, batch,\n                                                 edge_attr)\n        x_graph = self.graph_proj(x_graph)\n        x_graph = self.ln_graph(x_graph)  # [bs, node_len, d]\n        graph_targets = torch.arange(batch_size).to(x_graph.device)\n\n        x_smiles = self.text_encoder.encode(smiles)  # [bs, seq_len, d]\n        smiles_atts = torch.ones(x_smiles.size()[:-1],\n                                 dtype=torch.long).to(x_smiles.device)\n        smiles_targets = torch.arange(batch_size).to(x_smiles.device)\n\n        caption_input_ids, caption_attention_masks = self.text_encoder.get_input_ids(  # noqa: E501\n            captions)\n\n        text_output = self.gitformer.Qformer(\n            caption_input_ids,\n            attention_mask=caption_attention_masks,\n            return_dict=True,\n        )\n        text_feat = F.normalize(\n            self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)\n\n        loss = 0\n        for x_embed, x_atts, x_targets, modal in zip(\n            [x_graph, x_smiles, x_vision],\n            [graph_atts, smiles_atts, vision_atts],\n            [graph_targets, smiles_targets, vision_targets],\n            ['graph', 'cs_text', 'image'],\n        ):\n            loss += self._calc_xtc_loss(x_embed, x_atts, x_targets, text_feat,\n                                        modal)\n            loss += self._calc_xtm_loss(x_embed, caption_input_ids,\n                                        caption_attention_masks, modal)\n\n        return loss / 6\n\n    def _calc_xtm_loss(\n        self,\n        x_embeds: Tensor,\n        input_ids: Tensor,\n        attention_mask: Tensor,\n        modal: str,\n    ) -> Tensor:\n        # Initializing lists to hold the original and negative samples\n        x_embeds_list = []\n        text_input_ids_list = []\n        text_attention_mask_list = []\n\n        batch_size = x_embeds.size(0)\n        for i in range(batch_size):\n            # Original samples\n            x_embeds_list.append(x_embeds[i])\n            text_input_ids_list.append(input_ids[i, :])\n            text_attention_mask_list.append(attention_mask[i, :])\n\n            if batch_size > 1:\n                # Negative samples (neg_text_input_ids corresponds to x_embeds)\n                neg_text_input_ids = input_ids[i - 1 if i == batch_size -\n                                               1 else i + 1, :]\n                neg_text_attention_mask = attention_mask[i -\n                                                         1 if i == batch_size -\n                                                         1 else i + 1, :]\n                text_input_ids_list.append(neg_text_input_ids)\n                text_attention_mask_list.append(neg_text_attention_mask)\n                x_embeds_list.append(x_embeds[i, :])\n\n                # Negative samples (text_input_ids corresponds to neg_x_embeds)\n                neg_x_embeds = x_embeds[i - 1 if i == batch_size - 1 else i +\n                                        1, :]\n                x_embeds_list.append(neg_x_embeds)\n                text_input_ids_list.append(input_ids[i, :])\n                text_attention_mask_list.append(attention_mask[i, :])\n\n        # Stack all samples into two large tensors\n        x_embeds_all = torch.stack(x_embeds_list, dim=1) \\\n            .reshape(-1, x_embeds.size(1), x_embeds.size(2))\n        text_input_ids_all = torch.stack(text_input_ids_list, dim=1) \\\n            .reshape(-1, input_ids.size(1))\n        # Create image attention masks for the concatenated tensor\n        image_attns_all = torch.ones(x_embeds_all.size()[:-1],\n                                     dtype=torch.long).to(x_embeds_all.device)\n        query_tokens_xtm = self.gitformer.query_tokens.expand(\n            text_input_ids_all.shape[0], -1, -1)\n        query_attns_xtm = torch.ones(query_tokens_xtm.size()[:-1],\n                                     dtype=torch.long).to(x_embeds_all.device)\n\n        output_xtm = self.gitformer.Qformer(\n            inputs_embeds=query_tokens_xtm,\n            attention_mask=query_attns_xtm,\n            encoder_hidden_states=x_embeds_all,\n            encoder_attention_mask=image_attns_all,\n            return_dict=True,\n        ).last_hidden_state\n\n        xtm_embeddings = output_xtm[:, :query_tokens_xtm.size(1), :]\n\n        xtm_logit = self.xtm_head[modal](xtm_embeddings).mean(dim=1)\n        # Create labels: 1 for the original samples, 0 for the negative samples\n        if batch_size > 1:\n            labels = torch.cat(\n                [torch.ones(batch_size),\n                 torch.zeros(batch_size * 2)], dim=0)\n        else:\n            labels = torch.ones(batch_size)\n        labels = labels.long().to(xtm_logit.device)\n\n        # Calculate cross entropy loss\n        return F.cross_entropy(xtm_logit, labels)\n\n    def _calc_xtc_loss(\n        self,\n        x_embeds: Tensor,\n        x_atts: Tensor,\n        x_targets: Tensor,\n        text_feat: Tensor,\n        modal: str,\n    ) -> Tensor:\n        query_tokens = self.gitformer.query_tokens.expand(\n            x_embeds.shape[0], -1, -1)\n\n        query_output = self.gitformer.Qformer(\n            inputs_embeds=query_tokens,\n            encoder_hidden_states=x_embeds,\n            encoder_attention_mask=x_atts,\n            return_dict=True,\n        ).last_hidden_state\n\n        x_feats = F.normalize(self.xtc_proj[modal](query_output), dim=-1)\n\n        sim_q2t = torch.matmul(\n            x_feats.unsqueeze(1),\n            text_feat.unsqueeze(-1),\n        ).squeeze(-1)\n\n        # modal-text similarity: aggregate across all query tokens\n        sim_x2t, _ = sim_q2t.max(-1)\n        sim_x2t = sim_x2t / self.temp\n\n        # text-query similarity\n        sim_t2q = torch.matmul(\n            text_feat.unsqueeze(1).unsqueeze(1),\n            x_feats.permute(0, 2, 1),\n        ).squeeze(-2)\n\n        # text-modal similarity: aggregate across all query tokens\n        sim_t2x, _ = sim_t2q.max(-1)\n        sim_t2x = sim_t2x / self.temp\n\n        loss_itc = (\n            F.cross_entropy(sim_x2t, x_targets, label_smoothing=0.1) +\n            F.cross_entropy(sim_t2x, x_targets, label_smoothing=0.1)) / 2\n\n        return loss_itc\n"
  },
  {
    "path": "torch_geometric/llm/models/glem.py",
    "content": "from typing import List, Optional, Union\n\nimport torch\nimport torch.nn as nn\nfrom tqdm import tqdm\n\nfrom torch_geometric.loader import DataLoader, NeighborLoader\nfrom torch_geometric.nn.models import GraphSAGE, basic_gnn\n\n\ndef deal_nan(x):\n    if isinstance(x, torch.Tensor):\n        x = x.clone()\n        x[torch.isnan(x)] = 0.0\n    return x\n\n\nclass GLEM(torch.nn.Module):\n    r\"\"\"This GNN+LM co-training model is based on GLEM from the `\"Learning on\n    Large-scale Text-attributed Graphs via Variational Inference\"\n    <https://arxiv.org/abs/2210.14709>`_ paper.\n\n    Args:\n        lm_to_use (str): A TextEncoder from huggingface model repo\n                with a classifier(default: TinyBERT)\n        gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE)\n        out_channels (int): output channels for LM and GNN, should be same\n        num_gnn_heads Optional[int]: Number of heads for attention, if needed\n        num_gnn_layers (int): number of gnn layers\n        gnn_loss: loss function for gnn, (default: CrossEntropyLoss)\n        lm_loss: loss function for Language Model, (default: CrossEntropyLoss)\n        alpha (float): pseudo label weight of E-step, LM optimization,\n            (default: 0.5)\n        beta (float): pseudo label weight of M-step, GNN optimization,\n            (default: 0.5)\n        lm_dtype (torch.dtype): the data type once you load LM into memory,\n            (default: torch.bfloat16)\n        lm_use_lora (bool): choose if LM use Lora peft for fine tune,\n            (default: True)\n        lora_target_modules: The names of the target modules to apply the lora\n            adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None)\n\n    .. note::\n        See `examples/llm_plus_gnn/glem.py` for example usage.\n    \"\"\"\n    def __init__(\n        self,\n        lm_to_use: str = 'prajjwal1/bert-tiny',\n        gnn_to_use: basic_gnn = GraphSAGE,\n        out_channels: int = 47,\n        gnn_loss: Optional[nn.Module] = None,\n        lm_loss: Optional[nn.Module] = None,\n        alpha: float = 0.5,\n        beta: float = 0.5,\n        lm_dtype: torch.dtype = torch.bfloat16,\n        lm_use_lora: bool = True,\n        lora_target_modules: Optional[Union[List[str], str]] = None,\n        device: Optional[Union[str, torch.device]] = None,\n    ):\n        super().__init__()\n\n        if gnn_loss is None:\n            gnn_loss = nn.CrossEntropyLoss(reduction='mean')\n        if lm_loss is None:\n            lm_loss = nn.CrossEntropyLoss(reduction='mean')\n        if device is None:\n            device = torch.device('cpu')\n\n        self.device = device\n        self.lm_loss = lm_loss\n        self.gnn = gnn_to_use\n        self.gnn_loss = gnn_loss\n        self.alpha = alpha\n        self.beta = beta\n        self.gnn_loss = gnn_loss\n        self.lm = lm_to_use\n        from transformers import AutoModelForSequenceClassification\n        self.lm = AutoModelForSequenceClassification.from_pretrained(\n            lm_to_use, num_labels=out_channels, dtype=lm_dtype,\n            offload_folder=\"offload\", trust_remote_code=True)\n        if lm_use_lora:\n            from peft import (\n                LoraConfig,\n                TaskType,\n                get_peft_model,\n                prepare_model_for_kbit_training,\n            )\n            print(\"Training LM with LORA!\")\n            self.lm = prepare_model_for_kbit_training(self.lm)\n            config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16,\n                                lora_alpha=16, lora_dropout=0.05, bias=\"none\",\n                                target_modules=lora_target_modules)\n            self.lm = get_peft_model(self.lm, config)\n            self.lm.print_trainable_parameters()\n        self.lm.config.pad_token_id = self.lm.config.eos_token_id\n        self.lm_device = self.lm.device\n\n        if self.lm.num_labels != self.gnn.out_channels:\n            raise ValueError('''The output channel of language model \\\n                             and gnn should be the same''')\n\n    def pre_train_gnn(self, train_loader: NeighborLoader,\n                      optimizer: torch.optim.Optimizer, num_epochs: int,\n                      patience: int, ext_pseudo_labels: torch.Tensor = None,\n                      is_augmented: bool = False, verbose: bool = True):\n        # Pretrain GNN, optional steps if you do not have pseudo labels.\n        best_acc = 0\n        early_stopping = 0\n        # training only based on gold data\n        for epoch in range(0, num_epochs):\n            acc, loss = self.train_gnn(train_loader, optimizer, epoch,\n                                       ext_pseudo_labels, is_augmented,\n                                       verbose)\n            if acc < best_acc:\n                early_stopping += 1\n                if early_stopping > patience:\n                    print(f'Early stopped by Epoch: {epoch}, '\n                          f'Best acc: {best_acc}')\n                    break\n            best_acc = max(best_acc, acc)\n\n    def pre_train_lm(self, train_loader: DataLoader,\n                     optimizer: torch.optim.Optimizer, num_epochs: int,\n                     patience: int, ext_pseudo_labels: torch.Tensor = None,\n                     is_augmented: bool = False, verbose: bool = True):\n        # Pretrain language model\n        best_acc = 0\n        early_stopping = 0\n        for epoch in range(1, num_epochs + 1):\n            acc, loss = self.train_lm(train_loader, optimizer, epoch,\n                                      ext_pseudo_labels, is_augmented, verbose)\n            if acc < best_acc:\n                early_stopping += 1\n                if early_stopping > patience:\n                    print(f'Early stopped by Epoch: {epoch}, '\n                          f'Best acc: {best_acc}')\n                    break\n            best_acc = max(best_acc, acc)\n\n    def train(self, em_phase: str, train_loader: Union[DataLoader,\n                                                       NeighborLoader],\n              optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor,\n              epoch: int, is_augmented: bool = False, verbose: bool = False):\n        r\"\"\"GLEM training step, EM steps.\n\n        Args:\n            em_phase(str): 'gnn' or 'lm' choose which phase you are training on\n            train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for\n                lm training, include tokenized data, labels is_gold mask.\n                use NeighborLoader for gnn training, include x, edge_index.\n            optimizer (torch.optim.Optimizer): optimizer for training\n            pseudo_labels(torch.Tensor): the predicted labels used as pseudo\n                labels\n            epoch (int): current epoch\n            is_augmented (bool): will use pseudo_labels or not\n            verbose (bool): print training progress bar or not\n\n        Returns:\n            acc (float): training accuracy\n            loss (float): loss value\n        \"\"\"\n        if pseudo_labels is not None:\n            pseudo_labels = pseudo_labels.to(self.device)\n        if em_phase == 'gnn':\n            acc, loss = self.train_gnn(train_loader, optimizer, epoch,\n                                       pseudo_labels, is_augmented, verbose)\n        if em_phase == 'lm':\n            acc, loss = self.train_lm(train_loader, optimizer, epoch,\n                                      pseudo_labels, is_augmented, verbose)\n        return acc, loss\n\n    def train_lm(self, train_loader: DataLoader,\n                 optimizer: torch.optim.Optimizer, epoch: int,\n                 pseudo_labels: torch.Tensor = None,\n                 is_augmented: bool = False, verbose: bool = True):\n        r\"\"\"Language model Training in every epoch.\n\n        Args:\n            train_loader (loader.dataloader.DataLoader): text token dataloader\n            optimizer (torch.optim.Optimizer): model optimizer\n            epoch (int): current train epoch\n            pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn\n            is_augmented (bool): train with pseudo labels or not\n            verbose (bool): print training progress bar or not\n\n        Returns:\n            approx_acc (torch.tensor): training accuracy\n            loss (torch.float): loss value\n\n        \"\"\"\n        all_out = []\n        total_loss = total_correct = 0\n        num_nodes = train_loader.dataset.indices.size(0)\n        self.lm.train()\n        if verbose:\n            pbar = tqdm(total=num_nodes)\n            pbar.set_description(f'Epoch {epoch:02d}')\n        for batch in train_loader:\n            inputs = {k: v.to(self.device) for k, v in batch['input'].items()}\n            out = self.lm(**inputs).logits\n            labels = batch['labels'].to(self.device).squeeze()\n            # training with pseudo labels or not\n            if is_augmented:\n                pl_batch = pseudo_labels[batch['n_id']].to(self.device)\n            else:\n                pl_batch = None\n            loss = self.loss(out, labels, self.lm_loss,\n                             batch['is_gold'].to(self.device), pl_batch,\n                             self.alpha, is_augmented)\n            loss.backward()\n            optimizer.step()\n            optimizer.zero_grad()\n            all_out.append(out)\n            total_correct += int(out.argmax(dim=-1).eq(labels).sum())\n            total_loss += float(loss.detach())\n            if verbose:\n                pbar.update(batch['n_id'].size(0))\n\n        all_out = torch.cat(all_out, dim=0)\n        approx_acc = total_correct / num_nodes\n        loss = total_loss / len(train_loader)\n        if verbose:\n            pbar.close()\n        print(f'Epoch {epoch:02d} Loss: {loss:.4f} '\n              f'Approx. Train: {approx_acc:.4f}')\n        return approx_acc, loss\n\n    def train_gnn(self, train_loader: NeighborLoader,\n                  optimizer: torch.optim.Optimizer, epoch: int,\n                  pseudo_labels: torch.Tensor = None,\n                  is_augmented: bool = False, verbose: bool = True):\n        r\"\"\"GNN training step in every epoch.\n\n        Args:\n            train_loader (loader.NeighborLoader): gnn Neighbor node loader\n            optimizer (torch.optim.Optimizer): model optimizer\n            epoch (int): current train epoch\n            pseudo_labels(torch.tensor): 1-D tensor, predictions from lm\n            is_augmented(bool): use pseudo labeled node or not\n            verbose (bool): print training progress or not\n\n        Returns:\n            approx_acc (torch.tensor): training accuracy\n            loss (torch.float): loss value\n        \"\"\"\n        self.gnn.train()\n        num_nodes = train_loader.input_nodes.size(0)\n        if verbose:\n            pbar = tqdm(total=num_nodes)\n            pbar.set_description(f'Epoch {epoch:02d}')\n        total_loss = total_correct = 0\n        all_out = []\n        for batch in train_loader:\n            batch = batch.to(self.device)\n            out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]\n            all_out.append(out)\n            labels = batch.y[:batch.batch_size].squeeze()\n            is_gold_batch = batch.is_gold[:batch.batch_size].squeeze()\n            # training with pseudo labels or not\n            if is_augmented and pseudo_labels is not None:\n                pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]]\n            else:\n                pl_batch = None\n            loss = self.loss(out, labels, self.gnn_loss, is_gold_batch,\n                             pl_batch, self.beta, is_augmented)\n            loss.backward()\n            optimizer.step()\n            optimizer.zero_grad()\n            total_loss += float(loss.detach())\n            total_correct += int(out.argmax(dim=-1).eq(labels).sum())\n            if verbose:\n                pbar.update(batch.batch_size)\n\n        all_out = torch.cat(all_out, dim=0)\n        loss = total_loss / len(train_loader)\n        approx_acc = total_correct / num_nodes\n        if verbose:\n            pbar.close()\n        print(f'Epoch: {epoch:02d} Loss: {loss:.4f} '\n              f'Approx. Train: {approx_acc:.4f}')\n        return approx_acc, loss\n\n    @torch.no_grad()\n    def inference(self, em_phase: str, data_loader: Union[NeighborLoader,\n                                                          DataLoader],\n                  verbose: bool = False):\n        r\"\"\"GLEM inference step.\n\n        Args:\n            em_phase(str): 'gnn' or 'lm'\n            data_loader(dataloader or Neighborloader):\n                dataloader: for lm training, include tokenized data\n                nodeloader: for gnn training, include x, edge_index\n            verbose(bool): print inference progress or not\n\n        Returns:\n            out (torch.Tensor): n * m tensor, m is number of classes,\n                n is number of nodes\n        \"\"\"\n        out = None\n        if em_phase == 'gnn':\n            self.gnn.eval()\n            out = self.inference_gnn(data_loader, verbose)\n        elif em_phase == 'lm':\n            self.lm.eval()\n            out = self.inference_lm(data_loader, verbose)\n        return out\n\n    @torch.no_grad()\n    def inference_lm(self, data_loader: DataLoader, verbose: bool = True):\n        r\"\"\"LM inference step.\n\n        Args:\n            data_loader (Dataloader): include token, labels, and gold mask\n            verbose (bool): print progress bar or not\n\n        Returns:\n            preds (tensor): prediction from GNN, convert to pseudo labels\n                by preds.argmax(dim=-1).unsqueeze(1)\n        \"\"\"\n        if verbose:\n            pbar = tqdm(total=data_loader.dataset._data.num_nodes)\n            pbar.set_description('LM inference stage')\n        self.lm.eval()\n        preds = []\n        for batch in data_loader:\n            inputs = {k: v.to(self.device) for k, v in batch['input'].items()}\n            logits = self.lm(**inputs).logits\n            preds.append(logits)\n            if verbose:\n                pbar.update(batch['n_id'].size(0))\n        if verbose:\n            pbar.close()\n        preds = torch.cat(preds)\n        return preds\n\n    @torch.no_grad()\n    def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True):\n        r\"\"\"GNN inference step.\n\n        Args:\n            data_loader(NeighborLoader): include x, edge_index,\n            verbose (bool): print progress bar or not\n\n        Returns:\n            preds (tensor): prediction from GNN,\n                convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)\n        \"\"\"\n        if verbose:\n            pbar = tqdm(total=data_loader.data.num_nodes)\n            pbar.set_description('GNN inference stage')\n        preds = []\n        self.gnn.eval()\n        for batch in data_loader:\n            batch = batch.to(self.device)\n            out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size]\n            preds.append(out)\n            if verbose:\n                pbar.update(batch.batch_size)\n        if verbose:\n            pbar.close()\n        preds = torch.cat(preds, dim=0)\n        return preds\n\n    def loss(self, logits: torch.Tensor, labels: torch.Tensor,\n             loss_func: torch.nn.functional, is_gold: torch.Tensor,\n             pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5,\n             is_augmented: bool = True):\n        r\"\"\"Core function of variational EM inference, this function is aming\n        on combining loss value on gold(original train) and loss value on\n        pseudo labels.\n\n        Reference:\n        <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa\n\n        Args:\n            logits(torch.tensor): predict results from LM or GNN\n            labels(torch.tensor): combined node labels from ground truth and\n                pseudo labels(if provided)\n            loss_func(torch.nn.modules.loss): loss function for classification\n            is_gold(tensor): a tensor with bool value that mask ground truth\n                    label and during training, thus ~is_gold mask pseudo labels\n            pseudo_labels(torch.tensor): predictions from other model\n            pl_weight: the pseudo labels used in E-step and M-step optimization\n                        alpha in E-step, beta in M-step respectively\n            is_augmented: use EM or just train GNN and LM with gold data\n\n        \"\"\"\n        if is_augmented and (sum(~is_gold) > 0):\n            mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold]))\n            # all other labels beside from ground truth(gold labels)\n            pseudo_label_loss = deal_nan(\n                loss_func(logits[~is_gold], pseudo_labels[~is_gold]))\n            loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss\n        else:\n            loss = loss_func(logits, labels)\n        return loss\n"
  },
  {
    "path": "torch_geometric/llm/models/llm.py",
    "content": "import warnings\nfrom contextlib import nullcontext\nfrom typing import Any, Dict, List, Optional\n\nimport torch\nfrom torch import Tensor\n\ntry:\n    from transformers.tokenization_utils_base import BatchEncoding\nexcept ImportError:\n    BatchEncoding = Dict\n\nIGNORE_INDEX = -100\nMAX_TXT_LEN = 512\nMAX_NEW_TOKENS = 128\nPAD_TOKEN_ID = 0\nPADDING_SIDE = 'left'\n\n# legacy constants - used for Llama 2 style prompting\nBOS = '<s>[INST]'\nEOS_USER = '[/INST]'\nEOS = '[/s]'\n\n\ndef get_llm_kwargs(required_memory: int, dtype=torch.dtype) -> Dict[str, Any]:\n    torch.cuda.empty_cache()\n\n    gpu_memory: List[int] = []\n    for i in range(torch.cuda.device_count()):\n        gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3)\n        # Use the minimum number of GPUs to fit the LLM on.\n        if sum(gpu_memory) >= required_memory:\n            break\n\n    if sum(gpu_memory) < required_memory:\n        gpu_memory = []  # If not enough VRAM, use pure CPU.\n\n    kwargs = dict(revision='main')\n    if len(gpu_memory) > 0:\n        kwargs['max_memory'] = {\n            i: f'{memory}GiB'\n            for i, memory in enumerate(gpu_memory)\n        }\n        kwargs['low_cpu_mem_usage'] = True\n        kwargs['device_map'] = 'auto'\n        kwargs['dtype'] = dtype\n\n    return kwargs\n\n\nclass LLM(torch.nn.Module):\n    r\"\"\"A wrapper around a Large Language Model (LLM) from HuggingFace.\n\n    Args:\n        model_name (str): The HuggingFace model name\n        num_params (float, optional): An integer representing how many params\n            the HuggingFace model has, in billions. This is used to\n            automatically allocate the correct number of GPUs needed (using a\n            rough heuristic), given the available GPU memory of your GPUs.  If\n            not specified, the number of parameters is determined using the\n            `huggingface_hub` module.\n        n_gpus (int, optional): Number of GPUs to use. Designed for advanced\n            users to select how many GPU's they want to set this manually and\n            override the automatic set up mechanism.\n        dtype (torch.dtype, optional): The data type to use for the LLM.\n            (default :obj: `torch.bfloat16`)\n        sys_prompt (str, optional): A system prompt to use for the LLM.\n            (default: :obj: `None`)\n    \"\"\"\n    def __init__(\n        self,\n        model_name: str,\n        num_params: Optional[float] = None,\n        n_gpus: Optional[int] = None,\n        dtype: Optional[torch.dtype] = torch.bfloat16,\n        sys_prompt: Optional[str] = None,\n    ) -> None:\n        super().__init__()\n\n        self.model_name = model_name\n\n        from transformers import AutoModelForCausalLM, AutoTokenizer\n        if n_gpus is None:\n            if num_params is None:\n                from huggingface_hub import get_safetensors_metadata\n                safetensors_metadata = get_safetensors_metadata(model_name)\n                param_count = safetensors_metadata.parameter_count\n                num_params = float(list(param_count.values())[0] // 10**9)\n\n            # A rough heuristic on GPU memory requirements, e.g., we found that\n            # LLAMA3 (8B parameters) fits on a 96GB GPU.\n            required_memory = 96.0 * num_params / 8.0\n            kwargs = get_llm_kwargs(required_memory, dtype)\n        else:\n            gpu_memory: List[int] = []\n            for i in range(n_gpus):\n                gpu_memory.append(torch.cuda.mem_get_info(i)[0] // 1024**3)\n            kwargs = dict(revision='main')\n            kwargs['max_memory'] = {\n                i: f'{memory}GiB'\n                for i, memory in enumerate(gpu_memory)\n            }\n            kwargs['low_cpu_mem_usage'] = True\n            kwargs['device_map'] = 'auto'\n            kwargs['dtype'] = dtype\n\n        print(f\"Setting up '{model_name}' with configuration: {kwargs}\")\n        self.tokenizer = AutoTokenizer.from_pretrained(\n            model_name,\n            use_fast=False,\n        )\n        if self.tokenizer.chat_template and self.tokenizer.bos_token is None:\n            dummy_convo = [\n                {\n                    \"role\": \"system\",\n                    \"content\": \"dummy\"\n                },\n                {\n                    \"role\": \"user\",\n                    \"content\": \"convo\"\n                },\n            ]\n            text = self.tokenizer.apply_chat_template(\n                dummy_convo,\n                tokenize=True,\n            )\n            self.tokenizer.bos_token = self._safe_decode(self.tokenizer, text)\n        if self.tokenizer.pad_token_id is None:\n            self.tokenizer.pad_token_id = PAD_TOKEN_ID\n        if self.tokenizer.padding_side is None:\n            self.tokenizer.padding_side = PADDING_SIDE\n        self.llm = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)\n        self.llm = self.llm.to(dtype)\n        self.word_embedding = self.llm.model.get_input_embeddings()\n        if sys_prompt is not None:\n            self.sys_prompt = sys_prompt\n        else:\n            self.sys_prompt = \"\"\n        if 'max_memory' not in kwargs:  # Pure CPU:\n            warnings.warn(\n                \"LLM is being used on CPU, which may be slow. This decision \"\n                \"was made by a rough hueristic that assumes your GPU set up \"\n                \"does not have enough GPU RAM. This is done to avoid GPU OOM \"\n                \"errors. If you think this is a mistake, please initialize \"\n                \"your LLM with the n_gpus param to dictate how many gpus to \"\n                \"use for the LLM.\", stacklevel=2)\n            self.device = torch.device('cpu')\n            self.autocast_context = nullcontext()\n        else:\n            self.device = self.llm.device\n            if dtype == torch.float32:\n                self.autocast_context = nullcontext()\n            else:\n                self.autocast_context = torch.amp.autocast('cuda', dtype=dtype)\n\n    @staticmethod\n    def _safe_decode(tokenizer, tokens) -> str:\n        \"\"\"Decode token IDs from various Hugging Face tokenizer outputs.\n\n        Supports:\n            - list[int]\n            - list[list[int]]\n            - BatchEncoding\n            - tokenizers.Encoding\n        \"\"\"\n        if isinstance(tokens, dict):\n            tokens = tokens.get(\"input_ids\", tokens)\n\n        if hasattr(tokens, \"ids\"):\n            tokens = tokens.ids\n\n        if isinstance(tokens, list) and tokens and isinstance(tokens[0], list):\n            tokens = tokens[0]\n\n        return tokenizer.decode(tokens)\n\n    # legacy function - used for Llama 2 style prompting\n    def _encode_inputs(\n        self,\n        question: List[str],\n        context: Optional[List[str]] = None,\n    ) -> tuple:\n        batch_size = len(question)\n        questions = self.tokenizer(question, add_special_tokens=False)\n        if context is not None:\n            context = self.tokenizer(context, add_special_tokens=False)\n\n        eos_user_tokens = self.tokenizer(EOS_USER, add_special_tokens=False)\n        bos_token = self.tokenizer(\n            BOS,\n            add_special_tokens=False,\n            return_tensors='pt',\n        ).input_ids[0].to(self.device)\n        bos_embeds = self.word_embedding(bos_token)\n        pad_token = torch.tensor(self.tokenizer.pad_token_id,\n                                 device=self.device)\n        pad_embeds = self.word_embedding(pad_token).unsqueeze(0)\n        return (batch_size, questions, context, eos_user_tokens, bos_embeds,\n                pad_embeds)\n\n    def _label_input_ids(\n        self,\n        i: int,\n        label: BatchEncoding,\n        eos_tokens: BatchEncoding,\n    ) -> List[int]:\n        label_input_ids = label.input_ids[i][:MAX_NEW_TOKENS]\n        label_input_ids = label_input_ids + eos_tokens.input_ids\n        return label_input_ids\n\n    # legacy function - used for Llama 2 style prompting\n    def _input_ids(\n        self,\n        i: int,\n        context: BatchEncoding,\n        question: BatchEncoding,\n        eos_user_tokens: BatchEncoding,\n    ) -> List[int]:\n        input_ids: List[int] = []\n        if context is not None:\n            input_ids += context.input_ids[i][:MAX_TXT_LEN]\n        input_ids += question.input_ids[i]\n        input_ids += eos_user_tokens.input_ids\n        return input_ids\n\n    # legacy function - used for Llama 2 style prompting\n    def _inputs_embeds(\n        self,\n        i: int,\n        input_ids: List[int],\n        bos_embeds: Tensor,\n        embedding: Optional[List[Tensor]] = None,\n    ) -> Tensor:\n        inputs_embeds = self.word_embedding(\n            torch.tensor(input_ids, device=self.device))\n\n        to_cat = [bos_embeds]\n        if embedding is not None and embedding[i] is not None:\n            to_cat.append(embedding[i])\n        to_cat.append(inputs_embeds)\n        return torch.cat(to_cat, dim=0).to(self.device)\n\n    def _append_embeds(\n        self,\n        inputs_embeds: Tensor,\n        batch_inputs_embeds: List[Tensor],\n        batch_attention_mask: List[List[int]],\n        label_input_ids: List[int] = None,\n        batch_label_input_ids: Optional[List[List[int]]] = None,\n    ) -> tuple:\n        batch_inputs_embeds.append(inputs_embeds)\n        batch_attention_mask.append([1] * inputs_embeds.size(0))\n        if label_input_ids is not None:\n            pad = inputs_embeds.size(0) - len(label_input_ids)\n            label_input_ids = [IGNORE_INDEX] * pad + label_input_ids\n            batch_label_input_ids.append(label_input_ids)\n        return batch_inputs_embeds, batch_attention_mask, batch_label_input_ids\n\n    def _pad_embeds(\n        self,\n        pad_embeds: Tensor,\n        batch_inputs_embeds: List[Tensor],\n        batch_attention_mask: List[List[int]],\n        batch_label_input_ids: Optional[List[List[int]]] = None,\n    ) -> tuple:\n        max_length = max([x.size(0) for x in batch_inputs_embeds])\n        batch_size = len(batch_inputs_embeds)\n        for i in range(batch_size):\n            pad = max_length - batch_inputs_embeds[i].size(0)\n            batch_inputs_embeds[i] = torch.cat([\n                pad_embeds.repeat(pad, 1),\n                batch_inputs_embeds[i],\n            ])\n            batch_attention_mask[i] = [0] * pad + batch_attention_mask[i]\n            if batch_label_input_ids is not None:\n                tmp = [IGNORE_INDEX] * pad + batch_label_input_ids[i]\n                batch_label_input_ids[i] = tmp\n        inputs_embeds = torch.stack(batch_inputs_embeds, dim=0)\n        attention_mask = torch.tensor(batch_attention_mask, device=self.device)\n        label_input_ids = None\n        if batch_label_input_ids is not None:\n            label_input_ids = torch.tensor(batch_label_input_ids,\n                                           device=self.device)\n        return inputs_embeds, attention_mask, label_input_ids\n\n    # legacy function - used for Llama 2 style prompting\n    def _get_embeds_old(\n        self,\n        question: List[str],\n        context: Optional[List[str]] = None,\n        embedding: Optional[List[Tensor]] = None,\n        answer: Optional[List[str]] = None,\n    ) -> tuple:\n        (batch_size, question, context, eos_user_tokens, bos_embeds,\n         pad_embeds) = self._encode_inputs(question, context)\n\n        batch_label_input_ids = None\n        if answer is not None:\n            label = self.tokenizer(answer, add_special_tokens=False)\n            eos_tokens = self.tokenizer(EOS, add_special_tokens=False)\n            batch_label_input_ids = []\n\n        batch_inputs_embeds = []\n        batch_attention_mask = []\n        for i in range(batch_size):\n            input_ids = self._input_ids(i, context, question, eos_user_tokens)\n            if answer is not None:\n                label_input_ids = self._label_input_ids(i, label, eos_tokens)\n                input_ids += label_input_ids\n            else:\n                label_input_ids = None\n\n            inputs_embeds = self._inputs_embeds(i, input_ids, bos_embeds,\n                                                embedding)\n\n            (\n                batch_inputs_embeds,\n                batch_attention_mask,\n                batch_label_input_ids,\n            ) = self._append_embeds(\n                inputs_embeds,\n                batch_inputs_embeds,\n                batch_attention_mask,\n                label_input_ids,\n                batch_label_input_ids,\n            )\n\n        inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(\n            pad_embeds, batch_inputs_embeds, batch_attention_mask,\n            batch_label_input_ids)\n\n        return inputs_embeds, attention_mask, label_input_ids\n\n    def _get_embeds(\n        self,\n        question: List[str],\n        context: Optional[List[str]] = None,\n        embedding: Optional[List[Tensor]] = None,\n        answer: Optional[List[str]] = None,\n    ) -> tuple:\n        if not self.tokenizer.chat_template or not self.sys_prompt:\n            warnings.warn(\n                f\"HuggingFace model {self.model_name} is not using a \"\n                \"chat template, using Llama 2 style prompting. Please \"\n                \"consider using a more recent model and initialize the \"\n                \"LLM with `sys_prompt`.\", stacklevel=2)\n            return self._get_embeds_old(question, context, embedding, answer)\n        batch_label_input_ids = None\n        if answer is not None:\n            label = self.tokenizer(answer, add_special_tokens=False)\n            eos_tokens = self.tokenizer(self.tokenizer.eos_token,\n                                        add_special_tokens=False)\n            batch_label_input_ids = []\n\n        batch_inputs_embeds = []\n        batch_attention_mask = []\n        for i in range(len(question)):\n            ctx = f\"{context[i]} - \" if context else \"\"\n            messages = [\n                {\n                    \"role\": \"system\",\n                    \"content\": self.sys_prompt\n                },\n                {\n                    \"role\": \"user\",\n                    \"content\": f\"{ctx} - {question[i]}\"\n                },\n            ]\n            text = self.tokenizer.apply_chat_template(\n                messages,\n                tokenize=False,\n                add_generation_prompt=True,\n                enable_thinking=True,\n            )\n            text = text[len(self.tokenizer.bos_token):]\n            input_ids = self.tokenizer(text,\n                                       add_special_tokens=False).input_ids\n            if answer is not None:\n                label_input_ids = self._label_input_ids(i, label, eos_tokens)\n                input_ids += label_input_ids\n            else:\n                label_input_ids = None\n\n            bos_token = self.tokenizer(\n                self.tokenizer.bos_token,\n                add_special_tokens=False,\n                return_tensors='pt',\n            ).input_ids[0].to(self.device)\n\n            bos_embeds = self.word_embedding(bos_token)\n\n            inputs_embeds = self.word_embedding(\n                torch.tensor(input_ids, device=self.device))\n\n            to_cat = [bos_embeds]\n            if embedding is not None and embedding[i] is not None:\n                to_cat.append(embedding[i])\n            to_cat.append(inputs_embeds)\n            inputs_embeds = torch.cat(to_cat, dim=0).to(self.device)\n\n            (\n                batch_inputs_embeds,\n                batch_attention_mask,\n                batch_label_input_ids,\n            ) = self._append_embeds(\n                inputs_embeds,\n                batch_inputs_embeds,\n                batch_attention_mask,\n                label_input_ids,\n                batch_label_input_ids,\n            )\n\n        pad_token = torch.tensor(self.tokenizer.pad_token_id,\n                                 device=self.device)\n        pad_embeds = self.word_embedding(pad_token).unsqueeze(0)\n\n        inputs_embeds, attention_mask, label_input_ids = self._pad_embeds(\n            pad_embeds, batch_inputs_embeds, batch_attention_mask,\n            batch_label_input_ids)\n\n        return inputs_embeds, attention_mask, label_input_ids\n\n    def forward(\n        self,\n        question: List[str],\n        answer: List[str],\n        context: Optional[List[str]] = None,\n        embedding: Optional[List[Tensor]] = None,\n    ) -> Tensor:\n        r\"\"\"The forward pass.\n\n        Args:\n            question (list[str]): The questions/prompts.\n            answer (list[str]): The answers/labels.\n            context (list[str], optional): Additional context to give to the\n                LLM, such as textified knowledge graphs. (default: :obj:`None`)\n            embedding (list[torch.Tensor], optional): RAG embedding\n                tensors, *i.e.* the embedded form of :obj:`context`. Either\n                :obj:`context` or :obj:`embedding` should be used, not\n                both. (default: :obj:`None`)\n        \"\"\"\n        inputs_embeds, attention_mask, label_input_ids = self._get_embeds(\n            question, context, embedding, answer)\n\n        with self.autocast_context:\n            outputs = self.llm(\n                inputs_embeds=inputs_embeds,\n                attention_mask=attention_mask,\n                return_dict=True,\n                labels=label_input_ids,\n            )\n        return outputs.loss\n\n    @torch.no_grad()\n    def inference(\n        self,\n        question: List[str],\n        context: Optional[List[str]] = None,\n        embedding: Optional[List[Tensor]] = None,\n        max_tokens: Optional[int] = MAX_NEW_TOKENS,\n    ) -> List[str]:\n        r\"\"\"The inference pass.\n\n        Args:\n            question (list[str]): The questions/prompts.\n            answer (list[str]): The answers/labels.\n            context (list[str], optional): Additional context to give to the\n                LLM, such as textified knowledge graphs. (default: :obj:`None`)\n            embedding (list[torch.Tensor], optional): RAG embedding\n                tensors, *i.e.* the embedded form of :obj:`context`. Either\n                :obj:`context` or :obj:`embedding` should be used, not\n                both. (default: :obj:`None`)\n            max_tokens (int, optional): How many tokens for the LLM to\n                generate. (default: :obj:`32`)\n        \"\"\"\n        inputs_embeds, attention_mask, _ = self._get_embeds(\n            question, context, embedding)\n\n        with self.autocast_context:\n            outputs = self.llm.generate(\n                inputs_embeds=inputs_embeds,\n                bos_token_id=self.tokenizer.bos_token_id,\n                max_new_tokens=max_tokens,\n                attention_mask=attention_mask,\n                pad_token_id=self.tokenizer.eos_token_id,\n                use_cache=True,\n            )\n\n        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.model_name})'\n"
  },
  {
    "path": "torch_geometric/llm/models/llm_judge.py",
    "content": "from math import isnan\nfrom typing import Optional\n\nfrom torch_geometric.llm.models.txt2kg import \\\n    _chunk_to_triples_str_cloud as call_NIM\n\n# Credit for original \"Marlin Accuracy\" system goes to:\n# Gilberto Titericz (NVIDIA)\n# This work is an adaptation of his for PyG\nSYSTEM_PROMPT_1 = (\n    \"Instruction: You are a world class state of the art \" +\n    \"assistant for rating \" +\n    \"a User Answer given a Question. The Question is completely\" +\n    \" answered by the Reference Answer.\\n\" +\n    \"Say 4, if User Answer is full contained and equivalent to\" +\n    \" Reference Answer\" +\n    \"in all terms, topics, numbers, metrics, dates and units.\\n\" +\n    \"Say 2, if User Answer is partially contained and almost \" +\n    \"equivalent to Reference Answer\" +\n    \"in all terms, topics, numbers, metrics, dates and units.\\n\" +\n    \"Say 0, if User Answer is not contained in Reference Answer\" +\n    \" or not accurate in all terms, topics,\" +\n    \"numbers, metrics, dates and units or the User Answer do not\" +\n    \" answer the question.\\n\" +\n    \"Do not explain or justify your rating. Your rating must be \" +\n    \"only 4, 2 or 0 according to the instructions above.\\n\" +\n    \"### Question: \\\"{question}\\\"\\n\" + \"### User Answer: \\\"{model_pred}\\\"\\n\" +\n    \"### Reference Answer: \\\"{correct_answer}\\\"\\n\" + \"The rating is:\\n\")\n\nSYSTEM_PROMPT_2 = (\n    \"I will rate the User Answer in comparison to the Reference \" +\n    \"Answer for a given Question.\\n\" +\n    \"A rating of 4 indicates that the User Answer is entirely \" +\n    \"consistent with the Reference Answer, covering all aspects,\" +\n    \" topics, numbers, metrics, dates, and units.\\n\" +\n    \"A rating of 2 signifies that the User Answer is mostly \" +\n    \"aligned with the Reference Answer, with minor discrepancies\" +\n    \" in some areas.\\n\" +\n    \"A rating of 0 means that the User Answer is either \" +\n    \"inaccurate, incomplete, or unrelated to the Reference \" +\n    \"Answer, or it fails to address the Question.\\n\" +\n    \"I will provide the rating without any explanation or \" +\n    \"justification, adhering to the following scale: \" +\n    \"0 (no match), 2 (partial match), 4 (exact match).\\n\" +\n    \"Do not explain or justify my rating. My rating must\" +\n    \" be only 4, 2 or 0 only.\\n\\n\" + \"Question: \\\"{question}\\\"\\n\\n\" +\n    \"Reference Answer: \\\"{model_pred}\\\"\\n\\n\" +\n    \"User Answer: \\\"{correct_answer}\\\"\\n\\n\" + \"Rating: \")\n\n\n# TODO: add support for Local LM\n# TODO: add multiproc support like txt2kg\nclass LLMJudge():\n    \"\"\"Uses NIMs to score a triple of (question, model_pred, correct_answer)\n    This whole class is an adaptation of Gilberto's work for PyG.\n\n    Args:\n        NVIDIA_NIM_MODEL : (str, optional)\n            The name of the NVIDIA NIM model to use.\n            (default: \"nvidia/llama-3.1-nemotron-70b-instruct\").\n        NVIDIA_API_KEY : (str, optional)\n            The API key for accessing NVIDIA's NIM models.\n            (default: \"\").\n        ENDPOINT_URL : (str, optional)\n            The URL hosting your model, in case you are not using\n            the public NIM.\n            (default: \"https://integrate.api.nvidia.com/v1\").\n    \"\"\"\n    def __init__(\n        self,\n        NVIDIA_NIM_MODEL: Optional[\n            str] = \"nvidia/llama-3.1-nemotron-70b-instruct\",\n        NVIDIA_API_KEY: Optional[str] = \"\",\n        ENDPOINT_URL: Optional[str] = \"https://integrate.api.nvidia.com/v1\",\n    ) -> None:\n        self.NVIDIA_API_KEY = NVIDIA_API_KEY\n        self.NIM_MODEL = NVIDIA_NIM_MODEL\n        self.ENDPOINT_URL = ENDPOINT_URL\n\n    def _process_score(self, response: str) -> float:\n        \"\"\"Uses 3 and 1 even though prompt says only 0, 2, 4.\n        This is because LLMs don't always follow instructions.\n        Credit to Gilberto.\n        \"\"\"\n        for i in [4, 3, 2, 1, 0]:\n            if str(i) in response:\n                return i / 4\n        return float(\"nan\")\n\n    def _average_scores(self, score0: float, score1: float):\n        \"\"\"Take the average of score0 and score1.\n        Sometimes the LLM fail to respond or have no score in the response.\n        In those cases the failed score is discarded.\n        Credit to Gilberto.\n\n        Args:\n         score0 (float): judge accuracy score.\n         score1 (float): judge accuracy score by permuting agent answer and\n         ground truth.\n\n        Returns:\n            (float) average of score0 and score1 of both contains scores,\n            otherwise pick the max.\n        \"\"\"\n        score = float(\"nan\")\n        if score0 >= 0 and score1 >= 0:\n            score = (score0 + score1) / 2\n        else:\n            score = max(score0, score1)\n        return score\n\n    def score(\n        self,\n        question: str,\n        model_pred: str,\n        correct_answer: str,\n    ) -> float:\n        \"\"\"Args:\n            question (str): The original question asked to the model.\n            model_pred (str): The prediction made by the model.\n            correct_answer (str): The actual correct answer to the question.\n\n        Returns:\n            score (float): score of 0-1, may be nan due to LLM judge failure.\n                Evals should skip nan's when aggregating score.\n        \"\"\"\n        prompt1 = SYSTEM_PROMPT_1.format(question=question,\n                                         model_pred=model_pred,\n                                         correct_answer=correct_answer)\n        prompt2 = SYSTEM_PROMPT_2.format(question=question,\n                                         model_pred=model_pred,\n                                         correct_answer=correct_answer)\n        score1 = float(\"nan\")\n        score2 = float(\"nan\")\n        for _retry in range(200):\n            try:\n                score1 = self._process_score(\n                    call_NIM(prompt1, self.NVIDIA_API_KEY, self.NIM_MODEL,\n                             self.ENDPOINT_URL, post_text=\"\"))\n                if not isnan(score1):\n                    break\n            except ImportError:\n                raise\n            except:  # noqa\n                pass\n        for _retry in range(20):\n            try:\n                score2 = self._process_score(\n                    call_NIM(prompt2, self.NVIDIA_API_KEY, self.NIM_MODEL,\n                             self.ENDPOINT_URL, post_text=\"\"))\n                if not isnan(score2):\n                    break\n            except ImportError:\n                raise\n            except:  # noqa\n                pass\n\n        return self._average_scores(score1, score2)\n"
  },
  {
    "path": "torch_geometric/llm/models/molecule_gpt.py",
    "content": "from typing import List, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.llm.models.llm import BOS, LLM, MAX_NEW_TOKENS\nfrom torch_geometric.nn.attention import QFormer\nfrom torch_geometric.utils import to_dense_batch\n\n\ndef pad_or_truncate(embeddings: Tensor, max_seq_len: int,\n                    padding_value: int = 0) -> Tensor:\n    batch_size, current_seq_len, d = embeddings.size()\n\n    if current_seq_len > max_seq_len:\n        return embeddings[:, :max_seq_len, :]\n    elif current_seq_len < max_seq_len:\n        pad_tensor = torch.full((batch_size, max_seq_len - current_seq_len, d),\n                                padding_value, dtype=embeddings.dtype,\n                                device=embeddings.device)\n        return torch.cat([embeddings, pad_tensor], dim=1)\n    else:\n        return embeddings\n\n\nclass MoleculeGPT(torch.nn.Module):\n    r\"\"\"The MoleculeGPT model from the `\"MoleculeGPT: Instruction\n    Following Large Language Models for Molecular Property Prediction\"\n    <https://ai4d3.github.io/papers/34.pdf>`_ paper.\n\n    Args:\n        llm (LLM): The LLM to use.\n        graph_encoder (torch.nn.Module): Encode 2D molecule graph.\n        smiles_encoder (torch.nn.Module): Encode 1D SMILES.\n        mlp_out_channels (int, optional): The size of each embedding\n            after qformer encoding. (default: :obj:`32`)\n        max_tokens (int, optional): Max output tokens of 1D/2D encoder.\n            (default: :obj:`20`)\n\n    .. warning::\n        This module has been tested with the following HuggingFace models\n\n        * :obj:`llm_to_use=\"lmsys/vicuna-7b-v1.5\"`\n\n        and may not work with other models. See other models at `HuggingFace\n        Models <https://huggingface.co/models>`_ and let us know if you\n        encounter any issues.\n\n    .. note::\n        For an example of using :class:`MoleculeGPT`, see\n        `examples/llm/molecule_gpt.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/llm/molecule_gpt.py>`_.\n    \"\"\"\n    def __init__(\n        self,\n        llm: LLM,\n        graph_encoder: torch.nn.Module,\n        smiles_encoder: torch.nn.Module,\n        mlp_out_channels: int = 32,\n        max_tokens: Optional[int] = 20,\n    ) -> None:\n        super().__init__()\n        self.llm = llm\n        self.graph_encoder = graph_encoder.to(self.llm.device)\n        self.smiles_encoder = smiles_encoder.to(self.llm.device)\n\n        self.graph_qformer = QFormer(\n            input_dim=self.graph_encoder.nn[-1].out_features,\n            hidden_dim=mlp_out_channels,\n            output_dim=mlp_out_channels,\n            num_heads=4,\n            num_layers=2,\n        ).to(self.llm.device)\n\n        self.smiles_qformer = QFormer(\n            input_dim=self.smiles_encoder.model.pooler.dense.out_features,\n            hidden_dim=mlp_out_channels,\n            output_dim=mlp_out_channels,\n            num_heads=4,\n            num_layers=2,\n        ).to(self.llm.device)\n\n        self.max_tokens = max_tokens\n\n        self.word_embedding = self.llm.word_embedding\n        self.llm_generator = self.llm.llm\n\n        # LLMs\n        in_dim = 2 * mlp_out_channels * max_tokens\n        out_dim = self.llm.llm.model.embed_tokens.embedding_dim\n        self.projector = torch.nn.Sequential(\n            torch.nn.Linear(in_dim, in_dim),\n            torch.nn.Sigmoid(),\n            torch.nn.Linear(in_dim, out_dim),\n        ).to(self.llm.device)\n\n    def encode(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n        edge_attr: Optional[Tensor],\n        smiles: List[str],\n    ) -> Tensor:\n        batch_size = len(smiles)\n        # 2D Graph Branch: [bs, node_len, d]\n        x = x.to(self.llm.device)\n        edge_index = edge_index.to(self.llm.device)\n        if edge_attr is not None:\n            edge_attr = edge_attr.to(self.llm.device)\n        batch = batch.to(self.llm.device)\n\n        x_graph = self.graph_encoder(x, edge_index, edge_attr=edge_attr)\n        x_graph = to_dense_batch(x_graph, batch)[0]\n        out_graph = self.graph_qformer(x_graph)\n        out_graph = pad_or_truncate(out_graph, max_seq_len=self.max_tokens,\n                                    padding_value=0)\n        out_graph = out_graph.view(batch_size, -1)\n\n        # 1D SMILES Branch: [bs, seq_len, d]\n        x_smiles = self.smiles_encoder.encode(smiles,\n                                              output_device=self.llm.device)\n        out_smiles = self.smiles_qformer(x_smiles)\n        out_smiles = pad_or_truncate(out_smiles, max_seq_len=self.max_tokens,\n                                     padding_value=0)\n        out_smiles = out_smiles.view(batch_size, -1)\n\n        # Merge into LLMs\n        x_cat = torch.cat([out_graph, out_smiles], dim=1)\n        return x_cat\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n        edge_attr: Optional[Tensor],\n        smiles: List[str],\n        instructions: List[str],\n        label: List[str],\n        additional_text_context: Optional[List[str]] = None,\n    ):\n        x = self.encode(x, edge_index, batch, edge_attr, smiles)\n        x = self.projector(x)\n        xs = x.split(1, dim=0)\n\n        batch_unique = batch.unique()\n        batch_size = len(instructions)\n        if len(batch_unique) < batch_size:\n            xs = [\n                xs[i] if i in batch_unique else None for i in range(batch_size)\n            ]\n\n        (\n            inputs_embeds,\n            attention_mask,\n            label_input_ids,\n        ) = self.llm._get_embeds(instructions, additional_text_context, xs,\n                                 label)\n\n        with self.llm.autocast_context:\n            outputs = self.llm_generator(\n                inputs_embeds=inputs_embeds,\n                attention_mask=attention_mask,\n                return_dict=True,\n                labels=label_input_ids,\n            )\n\n        return outputs.loss\n\n    @torch.no_grad()\n    def inference(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n        edge_attr: Optional[Tensor],\n        smiles: List[str],\n        instructions: List[str],\n        additional_text_context: Optional[List[str]] = None,\n        max_out_tokens: Optional[int] = MAX_NEW_TOKENS,\n    ):\n        x = self.encode(x, edge_index, batch, edge_attr, smiles)\n        x = self.projector(x)\n        xs = x.split(1, dim=0)\n\n        # Handle questions without node features:\n        batch_unique = batch.unique()\n        batch_size = len(instructions)\n        if len(batch_unique) < batch_size:\n            xs = [\n                xs[i] if i in batch_unique else None for i in range(batch_size)\n            ]\n\n        inputs_embeds, attention_mask, _ = self.llm._get_embeds(\n            instructions, additional_text_context, xs)\n\n        bos_token = self.llm.tokenizer(\n            BOS,\n            add_special_tokens=False,\n        ).input_ids[0]\n\n        with self.llm.autocast_context:\n            outputs = self.llm_generator.generate(\n                inputs_embeds=inputs_embeds,\n                max_new_tokens=max_out_tokens,\n                attention_mask=attention_mask,\n                bos_token_id=bos_token,\n                use_cache=True  # Important to set!\n            )\n\n        return self.llm.tokenizer.batch_decode(\n            outputs,\n            skip_special_tokens=True,\n        )\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(\\n'\n                f'  llm={self.llm},\\n'\n                f'  graph={self.graph_encoder.__class__.__name__},\\n'\n                f'  smiles={self.smiles_encoder},\\n'\n                f')')\n"
  },
  {
    "path": "torch_geometric/llm/models/protein_mpnn.py",
    "content": "from itertools import product\nfrom typing import Tuple\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.nn import knn_graph\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.utils import to_dense_adj, to_dense_batch\n\n\nclass PositionWiseFeedForward(torch.nn.Module):\n    def __init__(self, in_channels: int, hidden_channels: int) -> None:\n        super().__init__()\n        self.out = torch.nn.Sequential(\n            torch.nn.Linear(in_channels, hidden_channels),\n            torch.nn.GELU(),\n            torch.nn.Linear(hidden_channels, in_channels),\n        )\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        return self.out(x)\n\n\nclass PositionalEncoding(torch.nn.Module):\n    def __init__(self, hidden_channels: int,\n                 max_relative_feature: int = 32) -> None:\n        super().__init__()\n        self.max_relative_feature = max_relative_feature\n        self.emb = torch.nn.Embedding(2 * max_relative_feature + 2,\n                                      hidden_channels)\n\n    def forward(self, offset, mask) -> torch.Tensor:\n        d = torch.clip(offset + self.max_relative_feature, 0,\n                       2 * self.max_relative_feature) * mask + (1 - mask) * (\n                           2 * self.max_relative_feature + 1)  # noqa: E501\n        return self.emb(d.long())\n\n\nclass Encoder(MessagePassing):\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        dropout: float = 0.1,\n        scale: float = 30,\n    ) -> None:\n        super().__init__()\n        self.out_v = torch.nn.Sequential(\n            torch.nn.Linear(in_channels, hidden_channels),\n            torch.nn.GELU(),\n            torch.nn.Linear(hidden_channels, hidden_channels),\n            torch.nn.GELU(),\n            torch.nn.Linear(hidden_channels, hidden_channels),\n        )\n        self.out_e = torch.nn.Sequential(\n            torch.nn.Linear(in_channels, hidden_channels),\n            torch.nn.GELU(),\n            torch.nn.Linear(hidden_channels, hidden_channels),\n            torch.nn.GELU(),\n            torch.nn.Linear(hidden_channels, hidden_channels),\n        )\n        self.dropout1 = torch.nn.Dropout(dropout)\n        self.dropout2 = torch.nn.Dropout(dropout)\n        self.dropout3 = torch.nn.Dropout(dropout)\n        self.norm1 = torch.nn.LayerNorm(hidden_channels)\n        self.norm2 = torch.nn.LayerNorm(hidden_channels)\n        self.norm3 = torch.nn.LayerNorm(hidden_channels)\n        self.scale = scale\n        self.dense = PositionWiseFeedForward(hidden_channels,\n                                             hidden_channels * 4)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        edge_index: torch.Tensor,\n        edge_attr: torch.Tensor,\n    ) -> torch.Tensor:\n        # x: [N, d_v]\n        # edge_index: [2, E]\n        # edge_attr: [E, d_e]\n        # update node features\n        h_message = self.propagate(x=x, edge_index=edge_index,\n                                   edge_attr=edge_attr)\n        dh = h_message / self.scale\n        x = self.norm1(x + self.dropout1(dh))\n        dh = self.dense(x)\n        x = self.norm2(x + self.dropout2(dh))\n        # update edge features\n        row, col = edge_index\n        x_i, x_j = x[row], x[col]\n        h_e = torch.cat([x_i, x_j, edge_attr], dim=-1)\n        h_e = self.out_e(h_e)\n        edge_attr = self.norm3(edge_attr + self.dropout3(h_e))\n        return x, edge_attr\n\n    def message(self, x_i: torch.Tensor, x_j: torch.Tensor,\n                edge_attr: torch.Tensor) -> torch.Tensor:\n        h = torch.cat([x_i, x_j, edge_attr], dim=-1)  # [E, 2*d_v + d_e]\n        h = self.out_e(h)  # [E, d_e]\n        return h\n\n\nclass Decoder(MessagePassing):\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        dropout: float = 0.1,\n        scale: float = 30,\n    ) -> None:\n        super().__init__()\n        self.out_v = torch.nn.Sequential(\n            torch.nn.Linear(in_channels, hidden_channels),\n            torch.nn.GELU(),\n            torch.nn.Linear(hidden_channels, hidden_channels),\n            torch.nn.GELU(),\n            torch.nn.Linear(hidden_channels, hidden_channels),\n        )\n        self.dropout1 = torch.nn.Dropout(dropout)\n        self.dropout2 = torch.nn.Dropout(dropout)\n        self.norm1 = torch.nn.LayerNorm(hidden_channels)\n        self.norm2 = torch.nn.LayerNorm(hidden_channels)\n        self.scale = scale\n        self.dense = PositionWiseFeedForward(hidden_channels,\n                                             hidden_channels * 4)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        edge_index: torch.Tensor,\n        edge_attr: torch.Tensor,\n        x_label: torch.Tensor,\n        mask: torch.Tensor,\n    ) -> torch.Tensor:\n        # x: [N, d_v]\n        # edge_index: [2, E]\n        # edge_attr: [E, d_e]\n        h_message = self.propagate(x=x, x_label=x_label, edge_index=edge_index,\n                                   edge_attr=edge_attr, mask=mask)\n        dh = h_message / self.scale\n        x = self.norm1(x + self.dropout1(dh))\n        dh = self.dense(x)\n        x = self.norm2(x + self.dropout2(dh))\n        return x\n\n    def message(self, x_i: torch.Tensor, x_j: torch.Tensor,\n                x_label_j: torch.Tensor, edge_attr: torch.Tensor,\n                mask: torch.Tensor) -> torch.Tensor:\n        h_1 = torch.cat([x_j, edge_attr, x_label_j], dim=-1)\n        h_0 = torch.cat([x_j, edge_attr, torch.zeros_like(x_label_j)], dim=-1)\n        h = h_1 * mask + h_0 * (1 - mask)\n        h = torch.concat([x_i, h], dim=-1)\n        h = self.out_v(h)\n        return h\n\n\nclass ProteinMPNN(torch.nn.Module):\n    r\"\"\"The ProteinMPNN model from the `\"Robust deep learning--based\n    protein sequence design using ProteinMPNN\"\n    <https://www.biorxiv.org/content/10.1101/2022.06.03.494563v1>`_ paper.\n\n    Args:\n        hidden_dim (int): Hidden channels.\n            (default: :obj:`128`)\n        num_encoder_layers (int): Number of encode layers.\n            (default: :obj:`3`)\n        num_decoder_layers (int): Number of decode layers.\n            (default: :obj:`3`)\n        num_neighbors (int): Number of neighbors for each atom.\n            (default: :obj:`30`)\n        num_rbf (int): Number of radial basis functions.\n            (default: :obj:`16`)\n        dropout (float): Dropout rate.\n            (default: :obj:`0.1`)\n        augment_eps (float): Augmentation epsilon for input coordinates.\n            (default: :obj:`0.2`)\n        num_positional_embedding (int): Number of positional embeddings.\n            (default: :obj:`16`)\n        vocab_size (int): Number of vocabulary.\n            (default: :obj:`21`)\n\n    .. note::\n        For an example of using :class:`ProteinMPNN`, see\n        `examples/llm/protein_mpnn.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/llm/protein_mpnn.py>`_.\n    \"\"\"\n    def __init__(\n        self,\n        hidden_dim: int = 128,\n        num_encoder_layers: int = 3,\n        num_decoder_layers: int = 3,\n        num_neighbors: int = 30,\n        num_rbf: int = 16,\n        dropout: float = 0.1,\n        augment_eps: float = 0.2,\n        num_positional_embedding: int = 16,\n        vocab_size: int = 21,\n    ) -> None:\n        super().__init__()\n        self.augment_eps = augment_eps\n        self.hidden_dim = hidden_dim\n        self.num_neighbors = num_neighbors\n        self.num_rbf = num_rbf\n        self.embedding = PositionalEncoding(num_positional_embedding)\n        self.edge_mlp = torch.nn.Sequential(\n            torch.nn.Linear(num_positional_embedding + 400, hidden_dim),\n            torch.nn.LayerNorm(hidden_dim),\n            torch.nn.Linear(hidden_dim, hidden_dim),\n        )\n        self.label_embedding = torch.nn.Embedding(vocab_size, hidden_dim)\n        self.encoder_layers = torch.nn.ModuleList([\n            Encoder(hidden_dim * 3, hidden_dim, dropout)\n            for _ in range(num_encoder_layers)\n        ])\n\n        self.decoder_layers = torch.nn.ModuleList([\n            Decoder(hidden_dim * 4, hidden_dim, dropout)\n            for _ in range(num_decoder_layers)\n        ])\n        self.output = torch.nn.Linear(hidden_dim, vocab_size)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for p in self.parameters():\n            if p.dim() > 1:\n                torch.nn.init.xavier_uniform_(p)\n\n    def _featurize(\n        self,\n        x: torch.Tensor,\n        mask: torch.Tensor,\n        batch: torch.Tensor,\n    ) -> Tuple[torch.Tensor, torch.Tensor]:\n        N, Ca, C, O = (x[:, i, :] for i in range(4))  # noqa: E741\n        b = Ca - N\n        c = C - Ca\n        a = torch.cross(b, c, dim=-1)\n        Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca\n\n        valid_mask = mask.bool()\n        valid_Ca = Ca[valid_mask]\n        valid_batch = batch[valid_mask]\n\n        edge_index = knn_graph(valid_Ca, k=self.num_neighbors,\n                               batch=valid_batch, loop=True)\n\n        row, col = edge_index\n        original_indices = torch.arange(Ca.size(0),\n                                        device=x.device)[valid_mask]\n        edge_index_original = torch.stack(\n            [original_indices[row], original_indices[col]], dim=0)\n        row, col = edge_index_original\n\n        rbf_all = []\n        for A, B in list(product([N, Ca, C, O, Cb], repeat=2)):\n            distances = torch.sqrt(torch.sum((A[row] - B[col])**2, 1) + 1e-6)\n            rbf = self._rbf(distances)\n            rbf_all.append(rbf)\n\n        return edge_index_original, torch.cat(rbf_all, dim=-1)\n\n    def _rbf(self, D: torch.Tensor) -> torch.Tensor:\n        D_min, D_max, D_count = 2., 22., self.num_rbf\n        D_mu = torch.linspace(D_min, D_max, D_count, device=D.device)\n        D_mu = D_mu.view([1, -1])\n        D_sigma = (D_max - D_min) / D_count\n        D_expand = torch.unsqueeze(D, -1)\n        RBF = torch.exp(-((D_expand - D_mu) / D_sigma)**2)\n        return RBF\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        chain_seq_label: torch.Tensor,\n        mask: torch.Tensor,\n        chain_mask_all: torch.Tensor,\n        residue_idx: torch.Tensor,\n        chain_encoding_all: torch.Tensor,\n        batch: torch.Tensor,\n    ) -> torch.Tensor:\n        device = x.device\n        if self.training and self.augment_eps > 0:\n            x = x + self.augment_eps * torch.randn_like(x)\n\n        edge_index, edge_attr = self._featurize(x, mask, batch)\n\n        row, col = edge_index\n        offset = residue_idx[row] - residue_idx[col]\n        # find self vs non-self interaction\n        e_chains = ((chain_encoding_all[row] -\n                     chain_encoding_all[col]) == 0).long()\n        e_pos = self.embedding(offset, e_chains)\n        h_e = self.edge_mlp(torch.cat([edge_attr, e_pos], dim=-1))\n        h_v = torch.zeros(x.size(0), self.hidden_dim, device=x.device)\n\n        # encoder\n        for encoder in self.encoder_layers:\n            h_v, h_e = encoder(h_v, edge_index, h_e)\n\n        # mask\n        h_label = self.label_embedding(chain_seq_label)\n        batch_chain_mask_all, _ = to_dense_batch(chain_mask_all * mask,\n                                                 batch)  # [B, N]\n        # 0 - visible - encoder, 1 - masked - decoder\n        decoding_order = torch.argsort(\n            (batch_chain_mask_all + 1e-4) * (torch.abs(\n                torch.randn(batch_chain_mask_all.shape, device=device))))\n        mask_size = batch_chain_mask_all.size(1)\n        permutation_matrix_reverse = F.one_hot(decoding_order,\n                                               num_classes=mask_size).float()\n        order_mask_backward = torch.einsum(\n            'ij, biq, bjp->bqp',\n            1 - torch.triu(torch.ones(mask_size, mask_size, device=device)),\n            permutation_matrix_reverse,\n            permutation_matrix_reverse,\n        )\n        adj = to_dense_adj(edge_index, batch)\n        mask_attend = order_mask_backward[adj.bool()].unsqueeze(-1)\n\n        # decoder\n        for decoder in self.decoder_layers:\n            h_v = decoder(\n                h_v,\n                edge_index,\n                h_e,\n                h_label,\n                mask_attend,\n            )\n\n        logits = self.output(h_v)\n        return F.log_softmax(logits, dim=-1)\n"
  },
  {
    "path": "torch_geometric/llm/models/sentence_transformer.py",
    "content": "from enum import Enum\nfrom typing import List, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom tqdm import tqdm\n\n\nclass PoolingStrategy(Enum):\n    MEAN = 'mean'\n    LAST = 'last'\n    CLS = 'cls'\n    LAST_HIDDEN_STATE = 'last_hidden_state'\n\n\nclass SentenceTransformer(torch.nn.Module):\n    r\"\"\"A wrapper around a Sentence-Transformer from HuggingFace.\n\n    Args:\n        model_name (str): The HuggingFace model name, *e.g.*, :obj:`\"BERT\"`.\n        pooling_strategy (str, optional): The pooling strategy to use\n            for generating node embeddings. (default: :obj:`\"mean\"`)\n    \"\"\"\n    def __init__(\n        self,\n        model_name: str,\n        pooling_strategy: Union[PoolingStrategy, str] = 'mean',\n    ) -> None:\n        super().__init__()\n\n        self.model_name = model_name\n        self.pooling_strategy = PoolingStrategy(pooling_strategy)\n\n        from transformers import AutoModel, AutoTokenizer\n\n        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n        self.model = AutoModel.from_pretrained(model_name)\n        if self.tokenizer.pad_token is None:\n            self.tokenizer.pad_token = self.tokenizer.eos_token\n\n        # Maximum sequence length from the model configuration (e.g. 8192 for\n        # models like ModernBERT)\n        self.max_seq_length = self.model.config.max_position_embeddings\n        \"\"\"\n        Some models define a max sequence length in their configuration. Others\n        only in the tokenizer. This is a hacky heuristic to find the max\n        sequence length that works for the model.\n        \"\"\"\n        probe_tokens = self.tokenizer(\"hacky heuristic\", padding='max_length',\n                                      return_tensors='pt')\n        self.max_seq_length = min(self.max_seq_length,\n                                  probe_tokens.input_ids.shape[1])\n\n    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:\n        out = self.model(input_ids=input_ids, attention_mask=attention_mask)\n\n        emb = out[0]  # First element contains all token embeddings.\n        if self.pooling_strategy == PoolingStrategy.MEAN:\n            emb = mean_pooling(emb, attention_mask)\n        elif self.pooling_strategy == PoolingStrategy.LAST:\n            emb = last_pooling(emb, attention_mask)\n        elif self.pooling_strategy == PoolingStrategy.LAST_HIDDEN_STATE:\n            emb = out.last_hidden_state\n        else:\n            assert self.pooling_strategy == PoolingStrategy.CLS\n            emb = emb[:, 0, :]\n\n        emb = F.normalize(emb, p=2, dim=1)\n        return emb\n\n    def get_input_ids(\n        self,\n        text: List[str],\n        batch_size: Optional[int] = None,\n        output_device: Optional[Union[torch.device, str]] = None,\n    ) -> Tensor:\n        is_empty = len(text) == 0\n        text = ['dummy'] if is_empty else text\n\n        batch_size = len(text) if batch_size is None else batch_size\n\n        input_ids: List[Tensor] = []\n        attention_masks: List[Tensor] = []\n        for start in range(0, len(text), batch_size):\n            token = self.tokenizer(\n                text[start:start + batch_size],\n                padding=True,\n                truncation=True,\n                return_tensors='pt',\n                max_length=self.max_seq_length,\n            )\n            input_ids.append(token.input_ids.to(self.device))\n            attention_masks.append(token.attention_mask.to(self.device))\n\n        def _out(x: List[Tensor]) -> Tensor:\n            out = torch.cat(x, dim=0) if len(x) > 1 else x[0]\n            out = out[:0] if is_empty else out\n            return out.to(output_device)\n\n        return _out(input_ids), _out(attention_masks)\n\n    @property\n    def device(self) -> torch.device:\n        return next(iter(self.model.parameters())).device\n\n    @torch.no_grad()\n    def encode(\n        self,\n        text: List[str],\n        batch_size: Optional[int] = None,\n        output_device: Optional[Union[torch.device, str]] = None,\n        verbose=False,\n    ) -> Tensor:\n        r\"\"\"Main function for users. Converts strings to embeddings.\n\n        Args:\n            text (List[str]): List of strings to embed.\n            batch_size (int, optional): How many strings to process.\n                Defaults to processing all at once, but this may lead to\n                OOM errors. (default: obj:`None`)\n            output_device (Union[torch.device, str], optional):\n                By default outputs cpu pytorch tensor, but can choose\n                to output to specific cuda devices. (default: obj:`None`)\n            verbose (bool, optional): Controls the verbosity of outputs.\n                (default: obj:`False`)\n        \"\"\"\n        is_empty = len(text) == 0\n        text = ['dummy'] if is_empty else text\n\n        batch_size = len(text) if batch_size is None else batch_size\n\n        embs: List[Tensor] = []\n        loader = range(0, len(text), batch_size)\n        if verbose:\n            loader = tqdm(\n                loader, desc=\"Encoding \" + str(len(text)) +\n                \" strings w/ SentenceTransformer\")\n        for start in loader:\n            token = self.tokenizer(\n                text[start:start + batch_size],\n                padding=True,\n                truncation=True,\n                return_tensors='pt',\n                max_length=self.max_seq_length,\n            )\n            try:\n                emb = self(\n                    input_ids=token.input_ids.to(self.device),\n                    attention_mask=token.attention_mask.to(self.device),\n                ).to(output_device)\n\n                embs.append(emb)\n            except:  # noqa\n                # fallback to using CPU for huge strings that cause OOMs\n                print(\"Sentence Transformer failed on cuda, trying w/ cpu...\")\n                previous_device = self.device\n                self.model = self.model.to(\"cpu\")\n                emb = self(\n                    input_ids=token.input_ids.to(self.device),\n                    attention_mask=token.attention_mask.to(self.device),\n                ).to(output_device)\n\n                embs.append(emb)\n                self.model = self.model.to(previous_device)\n\n        out = torch.cat(embs, dim=0) if len(embs) > 1 else embs[0]\n        out = out[:0] if is_empty else out\n        return out\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(model_name={self.model_name})'\n\n\ndef mean_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:\n    mask = attention_mask.unsqueeze(-1).expand(emb.size()).to(emb.dtype)\n    return (emb * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)\n\n\ndef last_pooling(emb: Tensor, attention_mask: Tensor) -> Tensor:\n    # Check whether language model uses left padding,\n    # which is always used for decoder LLMs\n    left_padding = attention_mask[:, -1].sum() == attention_mask.size(0)\n    if left_padding:\n        return emb[:, -1]\n\n    seq_indices = attention_mask.sum(dim=1) - 1\n    return emb[torch.arange(emb.size(0), device=emb.device), seq_indices]\n"
  },
  {
    "path": "torch_geometric/llm/models/txt2kg.py",
    "content": "import os\nimport time\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.multiprocessing as mp\n\nCLIENT_INITD = False\n\nCLIENT = None\nGLOBAL_NIM_KEY = \"\"\nSYSTEM_PROMPT = \"Please convert the above text into a list of knowledge triples with the form ('entity', 'relation', 'entity'). Separate each with a new line. Do not output anything else. Try to focus on key triples that form a connected graph.\"  # noqa\nMAX_OUTER_RETRIES = 5  # Maximum number of times the entire multiprocessing job is retried. # noqa\nRETRY_DELAY = 5  # Fixed sleep time (in seconds) between outer retries.\nMAX_NIM_RETRIES = 200  # Maximum number of attempts to call the NIM API inside one worker.  # noqa\nBASE_DELAY = 0.5  # Initial wait time before retrying a failed network call.\n\n\nclass TXT2KG():\n    \"\"\"A class to convert text data into a Knowledge Graph (KG) format.\n    Uses NVIDIA NIMs + Prompt engineering by default.\n    Default model `nvidia/llama-3.1-nemotron-70b-instruct`\n    is on par or better than GPT4o in benchmarks.\n    We need a high quality model to ensure high quality KG.\n    Otherwise we have garbage in garbage out for the rest of the\n    GNN+LLM RAG pipeline.\n\n    Use local_lm flag for local debugging/dev. You still need to be able to\n    inference a 14B param LLM, 'VAGOsolutions/SauerkrautLM-v2-14b-DPO'.\n    Smaller LLMs did not work at all in testing.\n    Note this 14B model requires a considerable amount of GPU memory.\n    See examples/llm/txt2kg_rag.py for an example.\n\n    Args:\n        NVIDIA_NIM_MODEL : str, optional\n            The name of the NVIDIA NIM model to use.\n            (default: \"nvidia/llama-3.1-nemotron-70b-instruct\").\n        NVIDIA_API_KEY : str, optional\n            The API key for accessing NVIDIA's NIM models (default: \"\").\n        ENDPOINT_URL : str, optional\n            The URL hosting your model, in case you are not using\n            the public NIM.\n            (default: \"https://integrate.api.nvidia.com/v1\").\n        local_LM : bool, optional\n            A flag indicating whether a local Language Model (LM)\n            should be used. This uses HuggingFace and will be slower\n            than deploying your own private NIM endpoint. This flag\n            is mainly recommended for dev/debug.\n            (default: False).\n        chunk_size : int, optional\n            The size of the chunks in which the text data is processed\n            (default: 512).\n    \"\"\"\n    def __init__(\n        self,\n        NVIDIA_NIM_MODEL: Optional[\n            str] = \"nvidia/llama-3.1-nemotron-70b-instruct\",\n        NVIDIA_API_KEY: Optional[str] = \"\",\n        ENDPOINT_URL: Optional[str] = \"https://integrate.api.nvidia.com/v1\",\n        local_LM: bool = False,\n        chunk_size: int = 512,\n    ) -> None:\n        self.local_LM = local_LM\n        # Initialize the local LM flag and the NIM model info accordingly\n        if self.local_LM:\n            # If using a local LM, set the initd_LM flag to False\n            self.initd_LM = False\n        else:\n            # If not using a local LM, store the provided NIM model info\n            self.NVIDIA_API_KEY = NVIDIA_API_KEY\n            self.NIM_MODEL = NVIDIA_NIM_MODEL\n            self.ENDPOINT_URL = ENDPOINT_URL\n\n        # Set the chunk size for processing text data\n        self.chunk_size = chunk_size\n\n        # Initialize counters and storage for parsing results\n        self.doc_id_counter = 0\n        self.relevant_triples = {}\n        self.total_chars_parsed = 0\n        self.time_to_parse = 0.0\n\n    def save_kg(self, path: str) -> None:\n        \"\"\"Saves the relevant triples in the knowledge graph (KG) to a file.\n\n        Args:\n            path (str): The file path where the KG will be saved.\n\n        Returns:\n            None\n        \"\"\"\n        torch.save(self.relevant_triples, path)\n\n    def _chunk_to_triples_str_local(self, txt: str) -> str:\n        # call LLM on text\n        chunk_start_time = time.time()\n        if not self.initd_LM:\n            from torch_geometric.llm.models import LLM\n            LM_name = \"VAGOsolutions/SauerkrautLM-v2-14b-DPO\"\n            self.model = LLM(LM_name).eval()\n            self.initd_LM = True\n        out_str = self.model.inference(question=[txt + '\\n' + SYSTEM_PROMPT],\n                                       max_tokens=self.chunk_size)[0]\n        # for debug\n        self.total_chars_parsed += len(txt)\n        self.time_to_parse += round(time.time() - chunk_start_time, 2)\n        self.avg_chars_parsed_per_sec = self.total_chars_parsed / (\n            self.time_to_parse + 1e-6)  # noqa\n        return out_str\n\n    def add_doc_2_KG(\n        self,\n        txt: str,\n        QA_pair: Optional[Tuple[str, str]] = None,\n    ) -> None:\n        \"\"\"Add a document to the Knowledge Graph (KG).\n\n        Args:\n            txt (str): The text to extract triples from.\n            QA_pair (Tuple[str, str]], optional):\n                A QA pair to associate with the extracted triples.\n                Useful for downstream evaluation.\n\n        Returns:\n        - None\n        \"\"\"\n        if not self.local_LM:\n            # Ensure NVIDIA_API_KEY is set before proceeding\n            assert self.NVIDIA_API_KEY != '', \\\n                \"Please init TXT2KG w/ NVIDIA_API_KEY or set local_lm=True\"\n        if QA_pair:\n            # QA_pairs should be unique keys, check if already exists in KG\n            if QA_pair in self.relevant_triples.keys():\n                print(\"Warning: QA_Pair was already added to the set\")\n                print(\"Q=\", QA_pair[0])\n                print(\"A=\", QA_pair[1])\n                print(\"Previously parsed triples=\",\n                      self.relevant_triples[QA_pair])\n                print(\"Skipping...\")\n            key = QA_pair\n        else:\n            # If no QA_pair, use the current doc_id_counter as the key\n            key = self.doc_id_counter\n\n        self.relevant_triples[key] = self._extract_relevant_triples(txt)\n\n        # Increment the doc_id_counter for the next document\n        self.doc_id_counter += 1\n\n    def _extract_relevant_triples(\n        self,\n        txt: str,\n        max_retries: int = MAX_OUTER_RETRIES,\n        retry_delay: float = RETRY_DELAY,\n    ) -> List[Tuple[str, str, str]]:\n        # Handle empty text (context-less QA pairs)\n        if txt == \"\":\n            return []\n\n        # Chunk the text into smaller pieces for processing\n        chunks = _chunk_text(txt, chunk_size=self.chunk_size)\n\n        if self.local_LM:\n            # For debugging purposes...\n            # process chunks sequentially on the local LM\n            return _llm_then_python_parse(chunks, _parse_n_check_triples,\n                                          self._chunk_to_triples_str_local)\n\n        # Create deterministic chunk assignment\n        import math\n        num_procs = min(len(chunks), _get_num_procs())\n        chunk_size = math.ceil(len(chunks) / num_procs)\n        in_chunks_per_proc = [\n            chunks[j * chunk_size:min((j + 1) * chunk_size, len(chunks))]\n            for j in range(num_procs)\n        ]\n\n        # Run workers via starmap for deterministic ordering\n        worker_args = [(\n            rank,\n            in_chunks_per_proc[rank],\n            _parse_n_check_triples,\n            _chunk_to_triples_str_cloud,\n            self.NVIDIA_API_KEY,\n            self.NIM_MODEL,\n            self.ENDPOINT_URL,\n        ) for rank in range(num_procs)]\n\n        for attempt in range(max_retries):\n            try:\n                with mp.get_context(\"spawn\").Pool(num_procs) as pool:\n                    results = pool.starmap(_multiproc_helper, worker_args)\n                break  # success\n\n            except Exception as e:\n                if attempt == max_retries - 1:\n                    raise  # re-raise on final failure\n\n                print(f\"[Retry {attempt+1}/{max_retries}] \"\n                      f\"Multiprocessing failed: {e}\")\n            time.sleep(retry_delay)\n\n        return _merge_triples_deterministically(results)\n\n\nknown_reasoners = [\n    \"llama-3.1-nemotron-ultra-253b-v1\",\n    \"kimi-k2-instruct\",\n    \"nemotron-super-49b-v1_5\",\n    \"gpt-oss\",\n]\n\n\ndef _chunk_to_triples_str_cloud(\n        txt: str, GLOBAL_NIM_KEY='',\n        NIM_MODEL=\"nvidia/llama-3.1-nemotron-ultra-253b-v1\",\n        ENDPOINT_URL=\"https://integrate.api.nvidia.com/v1\",\n        post_text=SYSTEM_PROMPT) -> str:\n    global CLIENT_INITD\n    if not CLIENT_INITD:\n        # We use NIMs since most PyG users may not be able to run a 70B+ model\n        try:\n            from openai import OpenAI\n        except ImportError:\n            quit(\n                \"Failed to import `openai` package, please install it and rerun the script\"  # noqa\n            )\n        global CLIENT\n        CLIENT = OpenAI(base_url=ENDPOINT_URL, api_key=GLOBAL_NIM_KEY)\n        CLIENT_INITD = True\n    txt_input = txt\n    if post_text != \"\":\n        txt_input += '\\n' + post_text\n    messages = []\n    if any([model_name_str in NIM_MODEL\n            for model_name_str in known_reasoners]):\n        messages.append({\"role\": \"system\", \"content\": \"detailed thinking on\"})\n    messages.append({\"role\": \"user\", \"content\": txt_input})\n    completion = CLIENT.chat.completions.create(model=NIM_MODEL,\n                                                messages=messages,\n                                                temperature=0, top_p=1,\n                                                max_tokens=1024, stream=True)\n    out_str = \"\"\n    for chunk in completion:\n        if chunk.choices[0].delta.content is not None:\n            out_str += chunk.choices[0].delta.content\n    return out_str\n\n\ndef _parse_n_check_triples(triples_str: str) -> List[Tuple[str, str, str]]:\n    # use pythonic checks for triples\n    processed = []\n    split_by_newline = triples_str.split(\"\\n\")\n    # sometimes LLM fails to obey the prompt\n    if len(split_by_newline) > 1:\n        split_triples = split_by_newline\n        llm_obeyed = True\n    else:\n        # handles form \"(e, r, e) (e, r, e) ... (e, r, e)\"\"\n        split_triples = triples_str[1:-1].split(\") (\")\n        llm_obeyed = False\n    for triple_str in split_triples:\n        try:\n            if llm_obeyed:\n                # remove parenthesis and single quotes for parsing\n                triple_str = triple_str.replace(\"(\", \"\").replace(\")\",\n                                                                 \"\").replace(\n                                                                     \"'\", \"\")\n            split_trip = triple_str.split(',')\n            # remove blank space at beginning or end\n            split_trip = [(i[1:] if i[0] == \" \" else i) for i in split_trip]\n            split_trip = [(i[:-1].lower() if i[-1] == \" \" else i)\n                          for i in split_trip]\n            potential_trip = tuple(split_trip)\n        except:  # noqa\n            continue\n        if 'tuple' in str(type(potential_trip)) and len(\n                potential_trip\n        ) == 3 and \"note:\" not in potential_trip[0].lower():\n            # additional check for empty node/edge attrs\n            if potential_trip[0] != '' and potential_trip[\n                    1] != '' and potential_trip[2] != '':\n                processed.append(potential_trip)\n    return processed\n\n\ndef _llm_then_python_parse(chunks, py_fn, llm_fn, **kwargs):\n    relevant_triples = []\n    for chunk in chunks:\n        relevant_triples += py_fn(llm_fn(chunk, **kwargs))\n    return relevant_triples\n\n\ndef _multiproc_helper(\n    rank,\n    chunks_for_rank,\n    py_fn,\n    llm_fn,\n    NIM_KEY,\n    NIM_MODEL,\n    ENDPOINT_URL,\n    max_retries=MAX_NIM_RETRIES,\n    base_delay=BASE_DELAY,\n):\n\n    for attempt in range(max_retries):\n        try:\n            return _llm_then_python_parse(\n                chunks_for_rank,\n                py_fn,\n                llm_fn,\n                GLOBAL_NIM_KEY=NIM_KEY,\n                NIM_MODEL=NIM_MODEL,\n                ENDPOINT_URL=ENDPOINT_URL,\n            )\n\n        except Exception:\n            # Optional: restrict to network-related exceptions only\n            if attempt == max_retries - 1:\n                raise\n\n            # exponential backoff with jitter\n            from random import uniform\n            sleep_time = base_delay * (2**min(attempt, 6))\n            sleep_time += uniform(0, 0.1)\n            time.sleep(sleep_time)\n\n\ndef _get_num_procs():\n    num_proc = None\n    if hasattr(os, \"sched_getaffinity\"):\n        try:\n            num_proc = len(os.sched_getaffinity(0)) / (2)\n        except Exception:\n            pass\n\n    if num_proc is None:\n        num_proc = os.cpu_count() / (2)\n\n    return int(num_proc)\n\n\ndef _chunk_text(text: str, chunk_size: int = 512) -> list[str]:\n    \"\"\"Function to chunk text into sentence-based segments.\n    Co-authored with Claude AI.\n    \"\"\"\n    # If the input text is empty or None, return an empty list\n    if not text:\n        return []\n\n    # List of punctuation marks that typically end sentences\n    sentence_endings = '.!?'\n\n    # List to store the resulting chunks\n    chunks = []\n\n    # Continue processing the entire text\n    while text:\n        # If the remaining text is shorter than chunk_size, add it and break\n        if len(text) <= chunk_size:\n            chunks.append(text.strip())\n            break\n\n        # Start with the maximum possible chunk\n        chunk = text[:chunk_size]\n\n        # Try to find the last sentence ending within the chunk\n        best_split = chunk_size\n        for ending in sentence_endings:\n            # Find the last occurrence of the ending punctuation\n            last_ending = chunk.rfind(ending)\n            if last_ending != -1:\n                # Ensure we include the punctuation and any following space\n                best_split = min(\n                    best_split, last_ending + 1 +\n                    (1 if last_ending + 1 < len(chunk)\n                     and chunk[last_ending + 1].isspace() else 0))\n\n        # Adjust to ensure we don't break words\n        # If the next character is a letter, find the last space\n        if best_split < len(text) and text[best_split].isalpha():\n            # Find the last space before the current split point\n            space_split = text[:best_split].rfind(' ')\n            if space_split != -1:\n                best_split = space_split\n\n        # Append the chunk, ensuring it's stripped\n        chunks.append(text[:best_split].strip())\n\n        # Remove the processed part from the text\n        text = text[best_split:].lstrip()\n\n    return chunks\n\n\nTriple = Union[List[str], Tuple[str, ...]]\n\n\ndef _merge_triples_deterministically(\n        triples: List[List[Triple]]) -> List[Tuple[str, ...]]:\n    \"\"\"Flatten a list of lists of triples and return a deterministic,\n    reproducible sorted list of tuples.\n\n    Args:\n        triples (List[List[Triple]]): A list of lists of triples, where each\n            triple is a list or tuple of strings or other comparable values.\n            Typically, each inner list comes from a worker.\n\n    Returns:\n        List[Tuple[str, ...]]: A flattened list of triples as tuples, sorted\n            deterministically. Sorting is Unicode-safe and reproducible across\n            Python versions using `str.casefold()`. Tuples are immutable to\n            ensure hashability and stability in dicts/sets.\n    \"\"\"\n    # Flatten all sublists and convert inner lists to tuples\n    flat_triples = [tuple(t) for sublist in triples for t in sublist]\n\n    # Deterministic sort (Unicode-safe, casefold for strings)\n    flat_triples.sort(key=lambda triple: tuple(\n        s.casefold() if isinstance(s, str) else s for s in triple))\n\n    return flat_triples\n"
  },
  {
    "path": "torch_geometric/llm/models/vision_transformer.py",
    "content": "from typing import Optional, Union\n\nimport torch\nfrom torch import Tensor\n\n\nclass VisionTransformer(torch.nn.Module):\n    r\"\"\"A wrapper around a Vision-Transformer from HuggingFace.\n\n    Args:\n        model_name (str): The HuggingFace model name, *e.g.*, :obj:`\"ViT\"`.\n    \"\"\"\n    def __init__(\n        self,\n        model_name: str,\n    ) -> None:\n        super().__init__()\n        self.model_name = model_name\n\n        from transformers import SwinConfig, SwinModel\n\n        self.config = SwinConfig.from_pretrained(model_name)\n        self.model = SwinModel(self.config)\n\n    @torch.no_grad()\n    def forward(\n        self,\n        images: Tensor,\n        output_device: Optional[Union[torch.device, str]] = None,\n    ) -> Tensor:\n        return self.model(images).last_hidden_state.to(output_device)\n\n    @property\n    def device(self) -> torch.device:\n        return next(iter(self.model.parameters())).device\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(model_name={self.model_name})'\n"
  },
  {
    "path": "torch_geometric/llm/rag_loader.py",
    "content": "from abc import abstractmethod\nfrom typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union\n\nfrom torch_geometric.data import Data, FeatureStore, HeteroData\nfrom torch_geometric.llm.utils.vectorrag import VectorRetriever\nfrom torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput\nfrom torch_geometric.typing import InputEdges, InputNodes\n\n\nclass RAGFeatureStore(Protocol):\n    \"\"\"Feature store template for remote GNN RAG backend.\"\"\"\n    @abstractmethod\n    def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:\n        \"\"\"Makes a comparison between the query and all the nodes to get all\n        the closest nodes. Return the indices of the nodes that are to be seeds\n        for the RAG Sampler.\n        \"\"\"\n        ...\n\n    @property\n    @abstractmethod\n    def config(self) -> Dict[str, Any]:\n        \"\"\"Get the config for the RAGFeatureStore.\"\"\"\n        ...\n\n    @config.setter\n    @abstractmethod\n    def config(self, config: Dict[str, Any]):\n        \"\"\"Set the config for the RAGFeatureStore.\"\"\"\n        ...\n\n    @abstractmethod\n    def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:\n        \"\"\"Makes a comparison between the query and all the edges to get all\n        the closest nodes. Returns the edge indices that are to be the seeds\n        for the RAG Sampler.\n        \"\"\"\n        ...\n\n    @abstractmethod\n    def load_subgraph(\n        self, sample: Union[SamplerOutput, HeteroSamplerOutput]\n    ) -> Union[Data, HeteroData]:\n        \"\"\"Combines sampled subgraph output with features in a Data object.\"\"\"\n        ...\n\n\nclass RAGGraphStore(Protocol):\n    \"\"\"Graph store template for remote GNN RAG backend.\"\"\"\n    @abstractmethod\n    def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,\n                        **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:\n        \"\"\"Sample a subgraph using the seeded nodes and edges.\"\"\"\n        ...\n\n    @property\n    @abstractmethod\n    def config(self) -> Dict[str, Any]:\n        \"\"\"Get the config for the RAGGraphStore.\"\"\"\n        ...\n\n    @config.setter\n    @abstractmethod\n    def config(self, config: Dict[str, Any]):\n        \"\"\"Set the config for the RAGGraphStore.\"\"\"\n        ...\n\n    @abstractmethod\n    def register_feature_store(self, feature_store: FeatureStore):\n        \"\"\"Register a feature store to be used with the sampler. Samplers need\n        info from the feature store in order to work properly on HeteroGraphs.\n        \"\"\"\n        ...\n\n\n# TODO: Make compatible with Heterographs\n\n\nclass RAGQueryLoader:\n    \"\"\"Loader meant for making RAG queries from a remote backend.\"\"\"\n    def __init__(self, graph_data: Tuple[RAGFeatureStore, RAGGraphStore],\n                 subgraph_filter: Optional[Callable[[Data, Any], Data]] = None,\n                 augment_query: bool = False,\n                 vector_retriever: Optional[VectorRetriever] = None,\n                 config: Optional[Dict[str, Any]] = None):\n        \"\"\"Loader meant for making queries from a remote backend.\n\n        Args:\n            graph_data (Tuple[RAGFeatureStore, RAGGraphStore]):\n                Remote FeatureStore and GraphStore to load from.\n                Assumed to conform to the protocols listed above.\n            subgraph_filter (Optional[Callable[[Data, Any], Data]], optional):\n                Optional local transform to apply to data after retrieval.\n                Defaults to None.\n            augment_query (bool, optional): Whether to augment the query with\n                retrieved documents. Defaults to False.\n            vector_retriever (Optional[VectorRetriever], optional):\n                VectorRetriever to use for retrieving documents.\n                Defaults to None.\n            config (Optional[Dict[str, Any]], optional): Config to pass into\n                the RAGQueryLoader. Defaults to None.\n        \"\"\"\n        fstore, gstore = graph_data\n        self.vector_retriever = vector_retriever\n        self.augment_query = augment_query\n        self.feature_store = fstore\n        self.graph_store = gstore\n        self.graph_store.edge_index = self.graph_store.edge_index.contiguous()\n        self.graph_store.register_feature_store(self.feature_store)\n        self.subgraph_filter = subgraph_filter\n        self.config = config\n\n    def _propagate_config(self, config: Dict[str, Any]):\n        \"\"\"Propagate the config the relevant components.\"\"\"\n        self.feature_store.config = config\n        self.graph_store.config = config\n\n    @property\n    def config(self):\n        \"\"\"Get the config for the RAGQueryLoader.\"\"\"\n        return self._config\n\n    @config.setter\n    def config(self, config: Dict[str, Any]):\n        \"\"\"Set the config for the RAGQueryLoader.\n\n        Args:\n            config (Dict[str, Any]): The config to set.\n        \"\"\"\n        self._propagate_config(config)\n        self._config = config\n\n    def query(self, query: Any) -> Data:\n        \"\"\"Retrieve a subgraph associated with the query with all its feature\n        attributes.\n        \"\"\"\n        if self.vector_retriever:\n            retrieved_docs = self.vector_retriever.query(query)\n\n        if self.augment_query:\n            query = [query] + retrieved_docs\n\n        seed_nodes, query_enc = self.feature_store.retrieve_seed_nodes(query)\n\n        subgraph_sample = self.graph_store.sample_subgraph(seed_nodes)\n\n        data = self.feature_store.load_subgraph(sample=subgraph_sample)\n\n        # apply local filter\n        if self.subgraph_filter:\n            data = self.subgraph_filter(data, query)\n        if self.vector_retriever:\n            data.text_context = retrieved_docs\n        return data\n"
  },
  {
    "path": "torch_geometric/llm/utils/__init__.py",
    "content": "from .backend_utils import *  # noqa\nfrom .feature_store import KNNRAGFeatureStore\nfrom .graph_store import NeighborSamplingRAGGraphStore\nfrom .vectorrag import DocumentRetriever\n\n__all__ = classes = [\n    'KNNRAGFeatureStore',\n    'NeighborSamplingRAGGraphStore',\n    'DocumentRetriever',\n]\n"
  },
  {
    "path": "torch_geometric/llm/utils/backend_utils.py",
    "content": "import os\nfrom dataclasses import dataclass\nfrom enum import Enum, auto\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Iterable,\n    Iterator,\n    List,\n    Optional,\n    Protocol,\n    Tuple,\n    Type,\n    Union,\n    no_type_check,\n    runtime_checkable,\n)\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Module\n\nfrom torch_geometric.data import Data, FeatureStore, GraphStore\nfrom torch_geometric.distributed import (\n    LocalFeatureStore,\n    LocalGraphStore,\n    Partitioner,\n)\nfrom torch_geometric.llm.large_graph_indexer import (\n    EDGE_RELATION,\n    LargeGraphIndexer,\n    TripletLike,\n)\nfrom torch_geometric.llm.models import SentenceTransformer\nfrom torch_geometric.typing import EdgeType, NodeType\n\ntry:\n    from pandas import DataFrame\nexcept ImportError:\n    DataFrame = None\nRemoteGraphBackend = Tuple[FeatureStore, GraphStore]\n\n# TODO: Make everything compatible with Hetero graphs aswell\n\n\ndef preprocess_triplet(triplet: TripletLike) -> TripletLike:\n    h, r, t = triplet\n    return str(h).lower(), str(r).lower(), str(t).lower()\n\n\n@no_type_check\ndef retrieval_via_pcst(\n    data: Data,\n    q_emb: Tensor,\n    textual_nodes: Any,\n    textual_edges: Any,\n    topk: int = 3,\n    topk_e: int = 5,\n    cost_e: float = 0.5,\n    num_clusters: int = 1,\n) -> Tuple[Data, str]:\n\n    # skip PCST for bad graphs\n    booly = data.edge_attr is None or data.edge_attr.numel() == 0\n    booly = booly or data.x is None or data.x.numel() == 0\n    booly = booly or data.edge_index is None or data.edge_index.numel() == 0\n    if not booly:\n        c = 0.01\n\n        from pcst_fast import pcst_fast\n\n        root = -1\n        pruning = 'gw'\n        verbosity_level = 0\n        if topk > 0:\n            n_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.x)\n            topk = min(topk, data.num_nodes)\n            _, topk_n_indices = torch.topk(n_prizes, topk, largest=True)\n\n            n_prizes = torch.zeros_like(n_prizes)\n            n_prizes[topk_n_indices] = torch.arange(topk, 0, -1,\n                                                    device=n_prizes.device,\n                                                    dtype=n_prizes.dtype)\n        else:\n            n_prizes = torch.zeros(data.num_nodes)\n\n        if topk_e > 0:\n            e_prizes = torch.nn.CosineSimilarity(dim=-1)(q_emb, data.edge_attr)\n            topk_e = min(topk_e, e_prizes.unique().size(0))\n\n            topk_e_values, _ = torch.topk(e_prizes.unique(), topk_e,\n                                          largest=True)\n            e_prizes[e_prizes < topk_e_values[-1]] = 0.0\n            last_topk_e_value = topk_e\n            for k in range(topk_e):\n                indices = e_prizes == topk_e_values[k]\n                value = min((topk_e - k) / sum(indices), last_topk_e_value - c)\n                e_prizes[indices] = value\n                last_topk_e_value = value * (1 - c)\n            # reduce the cost of the edges so that at least one edge is chosen\n            cost_e = min(cost_e, e_prizes.max().item() * (1 - c / 2))\n        else:\n            e_prizes = torch.zeros(data.num_edges)\n\n        costs = []\n        edges = []\n        virtual_n_prizes = []\n        virtual_edges = []\n        virtual_costs = []\n        mapping_n = {}\n        mapping_e = {}\n        for i, (src, dst) in enumerate(data.edge_index.t().numpy()):\n            prize_e = e_prizes[i]\n            if prize_e <= cost_e:\n                mapping_e[len(edges)] = i\n                edges.append((src, dst))\n                costs.append(cost_e - prize_e)\n            else:\n                virtual_node_id = data.num_nodes + len(virtual_n_prizes)\n                mapping_n[virtual_node_id] = i\n                virtual_edges.append((src, virtual_node_id))\n                virtual_edges.append((virtual_node_id, dst))\n                virtual_costs.append(0)\n                virtual_costs.append(0)\n                virtual_n_prizes.append(prize_e - cost_e)\n\n        prizes = np.concatenate([n_prizes, np.array(virtual_n_prizes)])\n        num_edges = len(edges)\n        if len(virtual_costs) > 0:\n            costs = np.array(costs + virtual_costs)\n            edges = np.array(edges + virtual_edges)\n\n        vertices, edges = pcst_fast(edges, prizes, costs, root, num_clusters,\n                                    pruning, verbosity_level)\n\n        selected_nodes = vertices[vertices < data.num_nodes]\n        selected_edges = [mapping_e[e] for e in edges if e < num_edges]\n        virtual_vertices = vertices[vertices >= data.num_nodes]\n        if len(virtual_vertices) > 0:\n            virtual_vertices = vertices[vertices >= data.num_nodes]\n            virtual_edges = [mapping_n[i] for i in virtual_vertices]\n            selected_edges = np.array(selected_edges + virtual_edges)\n\n        edge_index = data.edge_index[:, selected_edges]\n        selected_nodes = np.unique(\n            np.concatenate(\n                [selected_nodes, edge_index[0].numpy(),\n                 edge_index[1].numpy()]))\n\n        n = textual_nodes.iloc[selected_nodes]\n        e = textual_edges.iloc[selected_edges]\n    else:\n        n = textual_nodes\n        e = textual_edges\n\n    desc = n.to_csv(index=False) + '\\n' + e.to_csv(\n        index=False, columns=['src', 'edge_attr', 'dst'])\n\n    if booly:\n        return data, desc\n\n    mapping = {n: i for i, n in enumerate(selected_nodes.tolist())}\n    src = [mapping[i] for i in edge_index[0].tolist()]\n    dst = [mapping[i] for i in edge_index[1].tolist()]\n\n    # HACK Added so that the subset of nodes and edges selected can be tracked\n    node_idx = np.array(data.node_idx)[selected_nodes]\n    edge_idx = np.array(data.edge_idx)[selected_edges]\n\n    data = Data(\n        x=data.x[selected_nodes],\n        edge_index=torch.tensor([src, dst]).to(torch.long),\n        edge_attr=data.edge_attr[selected_edges],\n        # HACK: track subset of selected nodes/edges\n        node_idx=node_idx,\n        edge_idx=edge_idx,\n    )\n\n    return data, desc\n\n\ndef batch_knn(query_enc: Tensor, embeds: Tensor,\n              k: int) -> Iterator[Tuple[Tensor, Tensor]]:\n    from torchmetrics.functional import pairwise_cosine_similarity\n    prizes = pairwise_cosine_similarity(query_enc, embeds.to(query_enc.device))\n    topk = min(k, len(embeds))\n    for i, q in enumerate(prizes):\n        _, indices = torch.topk(q, topk, largest=True)\n        yield indices, query_enc[i].unsqueeze(0)\n\n\n# Adapted from LocalGraphStore\n@runtime_checkable\nclass ConvertableGraphStore(Protocol):\n    @classmethod\n    def from_data(\n        cls,\n        edge_id: Tensor,\n        edge_index: Tensor,\n        num_nodes: int,\n        is_sorted: bool = False,\n    ) -> GraphStore:\n        ...\n\n    @classmethod\n    def from_hetero_data(\n        cls,\n        edge_id_dict: Dict[EdgeType, Tensor],\n        edge_index_dict: Dict[EdgeType, Tensor],\n        num_nodes_dict: Dict[NodeType, int],\n        is_sorted: bool = False,\n    ) -> GraphStore:\n        ...\n\n    @classmethod\n    def from_partition(cls, root: str, pid: int) -> GraphStore:\n        ...\n\n\n# Adapted from LocalFeatureStore\n@runtime_checkable\nclass ConvertableFeatureStore(Protocol):\n    @classmethod\n    def from_data(\n        cls,\n        node_id: Tensor,\n        x: Optional[Tensor] = None,\n        y: Optional[Tensor] = None,\n        edge_id: Optional[Tensor] = None,\n        edge_attr: Optional[Tensor] = None,\n    ) -> FeatureStore:\n        ...\n\n    @classmethod\n    def from_hetero_data(\n        cls,\n        node_id_dict: Dict[NodeType, Tensor],\n        x_dict: Optional[Dict[NodeType, Tensor]] = None,\n        y_dict: Optional[Dict[NodeType, Tensor]] = None,\n        edge_id_dict: Optional[Dict[EdgeType, Tensor]] = None,\n        edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,\n    ) -> FeatureStore:\n        ...\n\n    @classmethod\n    def from_partition(cls, root: str, pid: int) -> FeatureStore:\n        ...\n\n\nclass RemoteDataType(Enum):\n    DATA = auto()\n    PARTITION = auto()\n\n\n@dataclass\nclass RemoteGraphBackendLoader:\n    \"\"\"Utility class to load triplets into a RAG Backend.\"\"\"\n    path: str\n    datatype: RemoteDataType\n    graph_store_type: Type[ConvertableGraphStore]\n    feature_store_type: Type[ConvertableFeatureStore]\n\n    def load(self, pid: Optional[int] = None) -> RemoteGraphBackend:\n        if self.datatype == RemoteDataType.DATA:\n            data_obj = torch.load(self.path, weights_only=False)\n            # is_sorted=true since assume nodes come sorted from indexer\n            graph_store = self.graph_store_type.from_data(\n                edge_id=data_obj['edge_id'], edge_index=data_obj.edge_index,\n                num_nodes=data_obj.num_nodes, is_sorted=True)\n            feature_store = self.feature_store_type.from_data(\n                node_id=data_obj['node_id'], x=data_obj.x,\n                edge_id=data_obj['edge_id'], edge_attr=data_obj.edge_attr)\n        elif self.datatype == RemoteDataType.PARTITION:\n            if pid is None:\n                assert pid is not None, \\\n                    \"Partition ID must be defined for loading from a \" \\\n                    + \"partitioned store.\"\n            graph_store = self.graph_store_type.from_partition(self.path, pid)\n            feature_store = self.feature_store_type.from_partition(\n                self.path, pid)\n        else:\n            raise NotImplementedError\n        return (feature_store, graph_store)\n\n    def __del__(self) -> None:\n        if os.path.exists(self.path):\n            os.remove(self.path)\n\n\ndef create_graph_from_triples(\n    triples: Iterable[TripletLike],\n    embedding_model: Union[Module, Callable],\n    embedding_method_kwargs: Optional[Dict[str, Any]] = None,\n    pre_transform: Optional[Callable[[TripletLike], TripletLike]] = None,\n) -> Data:\n    \"\"\"Utility function that can be used to create a graph from triples.\"\"\"\n    # Resolve callable methods\n    embedding_method_kwargs = embedding_method_kwargs \\\n        if embedding_method_kwargs is not None else dict()\n\n    indexer = LargeGraphIndexer.from_triplets(triples,\n                                              pre_transform=pre_transform)\n    node_feats = embedding_model(indexer.get_unique_node_features(),\n                                 **embedding_method_kwargs)\n    indexer.add_node_feature('x', node_feats)\n\n    edge_feats = embedding_model(\n        indexer.get_unique_edge_features(feature_name=EDGE_RELATION),\n        **embedding_method_kwargs)\n    indexer.add_edge_feature(new_feature_name=\"edge_attr\",\n                             new_feature_vals=edge_feats,\n                             map_from_feature=EDGE_RELATION)\n\n    data = indexer.to_data(node_feature_name='x',\n                           edge_feature_name='edge_attr')\n    data = data.to(\"cpu\")\n    return data\n\n\ndef create_remote_backend_from_graph_data(\n    graph_data: Data,\n    graph_db: Type[ConvertableGraphStore] = LocalGraphStore,\n    feature_db: Type[ConvertableFeatureStore] = LocalFeatureStore,\n    path: str = '',\n    n_parts: int = 1,\n) -> RemoteGraphBackendLoader:\n    \"\"\"Utility function that can be used to create a RAG Backend from triples.\n\n    Args:\n        graph_data (Data): Graph data to load into the RAG Backend.\n        graph_db (Type[ConvertableGraphStore], optional): GraphStore class to\n            use. Defaults to LocalGraphStore.\n        feature_db (Type[ConvertableFeatureStore], optional): FeatureStore\n            class to use. Defaults to LocalFeatureStore.\n        path (str, optional): path to save resulting stores. Defaults to ''.\n        n_parts (int, optional): Number of partitons to store in.\n            Defaults to 1.\n\n    Returns:\n        RemoteGraphBackendLoader: Loader to load RAG backend from disk or\n            memory.\n    \"\"\"\n    # Will return attribute errors for missing attributes\n    if not issubclass(graph_db, ConvertableGraphStore):\n        _ = graph_db.from_data\n        _ = graph_db.from_hetero_data\n        _ = graph_db.from_partition\n    elif not issubclass(feature_db, ConvertableFeatureStore):\n        _ = feature_db.from_data\n        _ = feature_db.from_hetero_data\n        _ = feature_db.from_partition\n\n    if n_parts == 1:\n        torch.save(graph_data, path)\n        return RemoteGraphBackendLoader(path, RemoteDataType.DATA, graph_db,\n                                        feature_db)\n    else:\n        partitioner = Partitioner(data=graph_data, num_parts=n_parts,\n                                  root=path)\n        partitioner.generate_partition()\n        return RemoteGraphBackendLoader(path, RemoteDataType.PARTITION,\n                                        graph_db, feature_db)\n\n\ndef make_pcst_filter(triples: List[Tuple[str, str,\n                                         str]], model: SentenceTransformer,\n                     topk: int = 5, topk_e: int = 5, cost_e: float = 0.5,\n                     num_clusters: int = 1) -> Callable[[Data, str], Data]:\n    \"\"\"Creates a PCST (Prize Collecting Tree) filter.\n\n    :param triples: List of triples (head, relation, tail) representing KG data\n    :param model: SentenceTransformer model for embedding text\n    :param topk: Number of top-K results to return (default: 5)\n    :param topk_e: Number of top-K entity results to return (default: 5)\n    :param cost_e: Cost of edges (default: 0.5)\n    :param num_clusters: Number of connected components in the PCST output.\n    :return: PCST Filter function\n    \"\"\"\n    if DataFrame is None:\n        raise Exception(\"PCST requires `pip install pandas`\"\n                        )  # Check if pandas is installed\n\n    # Remove duplicate triples to ensure unique set\n    triples = list(dict.fromkeys(triples))\n\n    # Initialize empty list to store nodes (entities) from triples\n    nodes = []\n\n    # Iterate over triples to extract unique nodes (entities)\n    for h, _, t in triples:\n        for node in (h, t):  # Extract head and tail entities from each triple\n            nodes.append(node)\n\n    # Remove duplicates and create final list of unique nodes\n    nodes = list(dict.fromkeys(nodes))\n\n    # Create full list of textual nodes (entities) for filtering\n    full_textual_nodes = nodes\n\n    def apply_retrieval_via_pcst(\n            graph: Data,  # Input graph data\n            query: str,  # Search query\n    ) -> Data:\n        \"\"\"Applies PCST filtering for retrieval.\n\n        :param graph: Input graph data\n        :param query: Search query\n        :return: Retrieved graph/query data\n        \"\"\"\n        # PCST relies on numpy and pcst_fast pypi libs, hence to(\"cpu\")\n        with torch.no_grad():\n            q_emb = model.encode([query]).to(\"cpu\")\n        textual_nodes = [(int(i), full_textual_nodes[i])\n                         for i in graph[\"node_idx\"]]\n        textual_nodes = DataFrame(textual_nodes,\n                                  columns=[\"node_id\", \"node_attr\"])\n        textual_edges = [triples[i] for i in graph[\"edge_idx\"]]\n        textual_edges = DataFrame(textual_edges,\n                                  columns=[\"src\", \"edge_attr\", \"dst\"])\n        out_graph, desc = retrieval_via_pcst(graph.to(q_emb.device), q_emb,\n                                             textual_nodes, textual_edges,\n                                             topk=topk, topk_e=topk_e,\n                                             cost_e=cost_e,\n                                             num_clusters=num_clusters)\n        out_graph[\"desc\"] = desc\n        where_trips_start = desc.find(\"src,edge_attr,dst\")\n        parsed_trips = []\n        for trip in desc[where_trips_start + 18:-1].split(\"\\n\"):\n            parsed_trips.append(tuple(trip.split(\",\")))\n\n        # Handle case where PCST returns an isolated node\n        \"\"\"\n        TODO find a better solution since these failed subgraphs\n        severely hurt accuracy.\n        \"\"\"\n        if str(parsed_trips) == \"[('',)]\" or out_graph.edge_index.numel() == 0:\n            out_graph[\"triples\"] = []\n        else:\n            out_graph[\"triples\"] = parsed_trips\n        out_graph[\"question\"] = query\n        return out_graph\n\n    return apply_retrieval_via_pcst\n"
  },
  {
    "path": "torch_geometric/llm/utils/feature_store.py",
    "content": "import gc\nfrom collections.abc import Iterable, Iterator\nfrom typing import Any, Dict, List, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.distributed.local_feature_store import LocalFeatureStore\nfrom torch_geometric.llm.utils.backend_utils import batch_knn\nfrom torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput\nfrom torch_geometric.typing import InputNodes\n\n\n# NOTE: Only compatible with Homogeneous graphs for now\nclass KNNRAGFeatureStore(LocalFeatureStore):\n    \"\"\"A feature store that uses a KNN-based retrieval.\"\"\"\n    def __init__(self) -> None:\n        \"\"\"Initializes the feature store.\"\"\"\n        # to be set by the config\n        self.encoder_model = None\n        self.k_nodes = None\n        self._config: Dict[str, Any] = {}\n        super().__init__()\n\n    @property\n    def config(self) -> Dict[str, Any]:\n        \"\"\"Get the config for the feature store.\"\"\"\n        return self._config\n\n    def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:\n        \"\"\"Set an attribute from the config.\n\n        Args:\n            config (Dict[str, Any]): Config dictionary\n            attr_name (str): Name of attribute to set\n\n        Raises:\n            ValueError: If required attribute not found in config\n        \"\"\"\n        if attr_name not in config:\n            raise ValueError(\n                f\"Required config parameter '{attr_name}' not found\")\n        setattr(self, attr_name, config[attr_name])\n\n    @config.setter  # type: ignore\n    def config(self, config: Dict[str, Any]) -> None:\n        \"\"\"Set the config for the feature store.\n\n        Args:\n            config (Dict[str, Any]):\n                Config dictionary containing required parameters\n\n        Raises:\n            ValueError: If required parameters missing from config\n        \"\"\"\n        self._set_from_config(config, \"k_nodes\")\n        self._set_from_config(config, \"encoder_model\")\n        assert self.encoder_model is not None, \\\n            \"Need to define encoder model from config\"\n        self.encoder_model.eval()\n\n        self._config = config\n\n    @property\n    def x(self) -> Tensor:\n        \"\"\"Returns the node features.\"\"\"\n        return Tensor(self.get_tensor(group_name=None, attr_name='x'))\n\n    @property\n    def edge_attr(self) -> Tensor:\n        \"\"\"Returns the edge attributes.\"\"\"\n        return Tensor(\n            self.get_tensor(group_name=(None, None), attr_name='edge_attr'))\n\n    def retrieve_seed_nodes(  # noqa: D417\n            self, query: Union[str, List[str],\n                               Tuple[str]]) -> Tuple[InputNodes, Tensor]:\n        \"\"\"Retrieves the k_nodes most similar nodes to the given query.\n\n        Args:\n            query (Union[str, List[str], Tuple[str]]): The query\n                or list of queries to search for.\n\n        Returns:\n            The indices of the most similar nodes and the encoded query\n        \"\"\"\n        if not isinstance(query, (list, tuple)):\n            query = [query]\n        assert self.k_nodes is not None, \"please set k_nodes via config\"\n        if len(query) == 1:\n            result, query_enc = next(\n                self._retrieve_seed_nodes_batch(query, self.k_nodes))\n            gc.collect()\n            torch.cuda.empty_cache()\n            return result, query_enc\n        else:\n            out_dict = {}\n            for i, out in enumerate(\n                    self._retrieve_seed_nodes_batch(query, self.k_nodes)):\n                out_dict[query[i]] = out\n            gc.collect()\n            torch.cuda.empty_cache()\n            return out_dict\n\n    def _retrieve_seed_nodes_batch(  # noqa: D417\n            self, query: Iterable[Any],\n            k_nodes: int) -> Iterator[Tuple[InputNodes, Tensor]]:\n        \"\"\"Retrieves the k_nodes most similar nodes to each query in the batch.\n\n        Args:\n        - query (Iterable[Any]: The batch of queries to search for.\n        - k_nodes (int): The number of nodes to retrieve.\n\n        Yields:\n        - The indices of the most similar nodes for each query.\n        \"\"\"\n        if isinstance(self.meta, dict) and self.meta.get(\"is_hetero\", False):\n            raise NotImplementedError\n        assert self.encoder_model is not None, \\\n            \"Need to define encoder model from config\"\n        query_enc = self.encoder_model.encode(query)\n        return batch_knn(query_enc, self.x, k_nodes)\n\n    def load_subgraph(  # noqa\n        self,\n        sample: Union[SamplerOutput, HeteroSamplerOutput],\n        induced: bool = True,\n    ) -> Union[Data, HeteroData]:\n        \"\"\"Loads a subgraph from the given sample.\n\n        Args:\n            sample: The sample to load the subgraph from.\n            induced: Whether to return the induced subgraph.\n                Resets node and edge ids.\n\n        Returns:\n            The loaded subgraph.\n        \"\"\"\n        if isinstance(sample, HeteroSamplerOutput):\n            raise NotImplementedError\n        \"\"\"\n        NOTE: torch_geometric.loader.utils.filter_custom_store\n        can be used here if it supported edge features.\n        \"\"\"\n        edge_id = sample.edge\n        x = self.x[sample.node]\n        edge_attr = self.edge_attr[edge_id]\n\n        edge_idx = torch.stack(\n            [sample.row, sample.col], dim=0) if induced else torch.stack(\n                [sample.global_row, sample.global_col], dim=0)\n        result = Data(x=x, edge_attr=edge_attr, edge_index=edge_idx)\n\n        # useful for tracking what subset of the graph was sampled\n        result.node_idx = sample.node\n        result.edge_idx = edge_id\n\n        return result\n\n\n\"\"\"\nTODO: make class CuVSKNNRAGFeatureStore(KNNRAGFeatureStore)\ninclude a approximate knn flag for the CuVS.\nConnect this with a CuGraphGraphStore\nfor enabling a accelerated boolean flag for RAGQueryLoader.\nOn by default if CuGraph+CuVS avail.\nIf not raise note mentioning its speedup.\n\"\"\"\n"
  },
  {
    "path": "torch_geometric/llm/utils/graph_store.py",
    "content": "from typing import Any, Dict, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import FeatureStore\nfrom torch_geometric.distributed.local_graph_store import LocalGraphStore\nfrom torch_geometric.sampler import (\n    BidirectionalNeighborSampler,\n    NodeSamplerInput,\n    SamplerOutput,\n)\nfrom torch_geometric.utils import index_sort\n\n# A representation of an edge index, following the possible formats:\n#    * default: Tensor, size = [2, num_edges]\n#    *     Tensor[0, :] == row, Tensor[1, :] == col\n#    * COO: (row, col)\n#    * CSC: (row, colptr)\n#    * CSR: (rowptr, col)\n_EdgeTensorType = Union[Tensor, Tuple[Tensor, Tensor]]\n\n\nclass NeighborSamplingRAGGraphStore(LocalGraphStore):\n    \"\"\"Neighbor sampling based graph-store to store & retrieve graph data.\"\"\"\n    def __init__(  # type: ignore[no-untyped-def]\n        self,\n        feature_store: Optional[FeatureStore] = None,\n        **kwargs,\n    ):\n        \"\"\"Initializes the graph store.\n        Optional feature store and neighbor sampling settings.\n\n        Args:\n        feature_store (optional): The feature store to use.\n            None if not yet registered.\n        **kwargs (optional):\n            Additional keyword arguments for neighbor sampling.\n        \"\"\"\n        self.feature_store = feature_store\n        self.sample_kwargs = kwargs\n        self._sampler_is_initialized = False\n        self._config: Dict[str, Any] = {}\n\n        # to be set by the config\n        self.num_neighbors = None\n        super().__init__()\n\n    @property\n    def config(self) -> Dict[str, Any]:\n        \"\"\"Get the config for the feature store.\"\"\"\n        return self._config\n\n    def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:\n        \"\"\"Set an attribute from the config.\n\n        Args:\n            config (Dict[str, Any]): Config dictionary\n            attr_name (str): Name of attribute to set\n\n        Raises:\n            ValueError: If required attribute not found in config\n        \"\"\"\n        if attr_name not in config:\n            raise ValueError(\n                f\"Required config parameter '{attr_name}' not found\")\n        setattr(self, attr_name, config[attr_name])\n\n    @config.setter  # type: ignore\n    def config(self, config: Dict[str, Any]) -> None:\n        \"\"\"Set the config for the feature store.\n\n        Args:\n            config (Dict[str, Any]):\n                Config dictionary containing required parameters\n\n        Raises:\n            ValueError: If required parameters missing from config\n        \"\"\"\n        self._set_from_config(config, \"num_neighbors\")\n        if hasattr(self, 'sampler'):\n            self.sampler.num_neighbors = (  # type: ignore[has-type]\n                self.num_neighbors)\n\n        self._config = config\n\n    def _init_sampler(self) -> None:\n        \"\"\"Initializes neighbor sampler with the registered feature store.\"\"\"\n        if self.feature_store is None:\n            raise AttributeError(\"Feature store not registered yet.\")\n        assert self.num_neighbors is not None, \\\n            \"Please set num_neighbors through config\"\n        self.sampler = BidirectionalNeighborSampler(\n            data=(self.feature_store, self), num_neighbors=self.num_neighbors,\n            **self.sample_kwargs)\n        self._sampler_is_initialized = True\n\n    def register_feature_store(self, feature_store: FeatureStore) -> None:\n        \"\"\"Registers a feature store with the graph store.\n\n        :param feature_store: The feature store to register.\n        \"\"\"\n        self.feature_store = feature_store\n        self._sampler_is_initialized = False\n\n    def put_edge_id(  # type: ignore[no-untyped-def]\n            self, edge_id: Tensor, *args, **kwargs) -> bool:\n        \"\"\"Stores an edge ID in the graph store.\n\n        :param edge_id: The edge ID to store.\n        :return: Whether the operation was successful.\n        \"\"\"\n        ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs)\n        self._sampler_is_initialized = False\n        return ret\n\n    @property\n    def edge_index(self) -> _EdgeTensorType:\n        \"\"\"Gets the edge index of the graph.\n\n        :return: The edge index as a tensor.\n        \"\"\"\n        return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs)\n\n    def put_edge_index(  # type: ignore[no-untyped-def]\n            self, edge_index: _EdgeTensorType, *args, **kwargs) -> bool:\n        \"\"\"Stores an edge index in the graph store.\n\n        :param edge_index: The edge index to store.\n        :return: Whether the operation was successful.\n        \"\"\"\n        ret = super().put_edge_index(edge_index, *args, **kwargs)\n        # HACK\n        self.edge_idx_args = args\n        self.edge_idx_kwargs = kwargs\n        self._sampler_is_initialized = False\n        return ret\n\n    # HACKY\n    @edge_index.setter  # type: ignore\n    def edge_index(self, edge_index: _EdgeTensorType) -> None:\n        \"\"\"Sets the edge index of the graph.\n\n        :param edge_index: The edge index to set.\n        \"\"\"\n        # correct since we make node list from triples\n        if isinstance(edge_index, Tensor):\n            num_nodes = int(edge_index.max()) + 1\n        else:\n            assert isinstance(edge_index, tuple) \\\n                and isinstance(edge_index[0], Tensor) \\\n                and isinstance(edge_index[1], Tensor), \\\n                \"edge_index must be a Tensor of [2, num_edges] \\\n                or a tuple of Tensors, (row, col).\"\n\n            num_nodes = int(edge_index[0].max()) + 1\n        attr = dict(\n            edge_type=None,\n            layout='coo',\n            size=(num_nodes, num_nodes),\n            is_sorted=False,\n        )\n        # edge index needs to be sorted here and the perm saved for later\n        col_sorted, self.perm = index_sort(edge_index[1], num_nodes,\n                                           stable=True)\n        row_sorted = edge_index[0][self.perm]\n        edge_index_sorted = torch.stack([row_sorted, col_sorted], dim=0)\n        self.put_edge_index(edge_index_sorted, **attr)\n\n    def sample_subgraph(\n        self,\n        seed_nodes: Tensor,\n    ) -> SamplerOutput:\n        \"\"\"Sample the graph starting from the given nodes using the\n        in-built NeighborSampler.\n\n        Args:\n            seed_nodes (InputNodes): Seed nodes to start sampling from.\n            num_neighbors (Optional[NumNeighborsType], optional): Parameters\n                to determine how many hops and number of neighbors per hop.\n                Defaults to None.\n\n        Returns:\n            Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput\n                for the input.\n        \"\"\"\n        # TODO add support for Hetero\n        if not self._sampler_is_initialized:\n            self._init_sampler()\n\n        seed_nodes = seed_nodes.unique().contiguous()\n        node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)\n        out = self.sampler.sample_from_nodes(  # type: ignore[has-type]\n            node_sample_input)\n\n        # edge ids need to be remapped to the original indices\n        out.edge = self.perm[out.edge]\n\n        return out\n"
  },
  {
    "path": "torch_geometric/llm/utils/vectorrag.py",
    "content": "# mypy: ignore-errors\nimport os\nfrom abc import abstractmethod\nfrom typing import Any, Callable, Dict, List, Optional, Protocol, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.llm.models import SentenceTransformer\nfrom torch_geometric.llm.utils.backend_utils import batch_knn\n\n\nclass VectorRetriever(Protocol):\n    \"\"\"Protocol for VectorRAG.\"\"\"\n    @abstractmethod\n    def query(self, query: Any, **kwargs: Optional[Dict[str, Any]]) -> Data:\n        \"\"\"Retrieve a context for a given query.\"\"\"\n        ...\n\n\nclass DocumentRetriever(VectorRetriever):\n    \"\"\"Retrieve documents from a vector database.\"\"\"\n    def __init__(self, raw_docs: List[str],\n                 embedded_docs: Optional[Tensor] = None, k_for_docs: int = 2,\n                 model: Optional[Union[SentenceTransformer, torch.nn.Module,\n                                       Callable]] = None,\n                 model_kwargs: Optional[Dict[str, Any]] = None):\n        \"\"\"Retrieve documents from a vector database.\n\n        Args:\n            raw_docs: List[str]: List of raw documents.\n            embedded_docs: Optional[Tensor]: Embedded documents.\n            k_for_docs: int: Number of documents to retrieve.\n            model: Optional[Union[SentenceTransformer, torch.nn.Module]]:\n                Model to use for encoding.\n            model_kwargs: Optional[Dict[str, Any]]:\n                Keyword arguments to pass to the model.\n        \"\"\"\n        self.raw_docs = raw_docs\n        self.embedded_docs = embedded_docs\n        self.k_for_docs = k_for_docs\n        self.model = model\n\n        if self.model is not None:\n            self.encoder = self.model\n            self.model_kwargs = model_kwargs\n\n        if self.embedded_docs is None:\n            assert self.model is not None, \\\n                \"Model must be provided if embedded_docs is not provided\"\n            self.model_kwargs = model_kwargs or {}\n            self.embedded_docs = self.encoder(self.raw_docs,\n                                              **self.model_kwargs)\n            # we don't want to print the verbose output in `query`\n            self.model_kwargs.pop(\"verbose\", None)\n\n    def query(self, query: Union[str, Tensor]) -> List[str]:\n        \"\"\"Retrieve documents from the vector database.\n\n        Args:\n            query: Union[str, Tensor]: Query to retrieve documents for.\n\n        Returns:\n            List[str]: Documents retrieved from the vector database.\n        \"\"\"\n        if isinstance(query, str):\n            with torch.no_grad():\n                query_enc = self.encoder(query, **self.model_kwargs)\n        else:\n            query_enc = query\n\n        selected_doc_idxs, _ = next(\n            batch_knn(query_enc, self.embedded_docs, self.k_for_docs))\n        return [self.raw_docs[i] for i in selected_doc_idxs]\n\n    def save(self, path: str) -> None:\n        \"\"\"Save the DocumentRetriever instance to disk.\n\n        Args:\n            path: str: Path where to save the retriever.\n        \"\"\"\n        os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)\n\n        # Prepare data to save\n        save_dict = {\n            'raw_docs': self.raw_docs,\n            'embedded_docs': self.embedded_docs,\n            'k_for_docs': self.k_for_docs,\n        }\n\n        # We do not serialize the model\n        torch.save(save_dict, path)\n\n    @classmethod\n    def load(cls, path: str, model: Union[SentenceTransformer, torch.nn.Module,\n                                          Callable],\n             model_kwargs: Optional[Dict[str, Any]] = None) -> VectorRetriever:\n        \"\"\"Load a DocumentRetriever instance from disk.\n\n        Args:\n            path: str: Path to the saved retriever.\n            model: Union[SentenceTransformer, torch.nn.Module, Callable]:\n                Model to use for encoding.\n                If None, the saved model will be used if available.\n            model_kwargs: Optional[Dict[str, Any]]\n                Key word args to be passed to model\n\n        Returns:\n            DocumentRetriever: The loaded retriever.\n        \"\"\"\n        if not os.path.exists(path):\n            raise FileNotFoundError(\n                f\"No saved document retriever found at {path}\")\n\n        save_dict = torch.load(path, weights_only=False)\n        if save_dict['embedded_docs'] is not None \\\n                and isinstance(save_dict['embedded_docs'], Tensor)\\\n                and model_kwargs is not None:\n            model_kwargs.pop(\"verbose\", None)\n        # Create a new DocumentRetriever with the loaded data\n        return cls(raw_docs=save_dict['raw_docs'],\n                   embedded_docs=save_dict['embedded_docs'],\n                   k_for_docs=save_dict['k_for_docs'], model=model,\n                   model_kwargs=model_kwargs)\n"
  },
  {
    "path": "torch_geometric/loader/__init__.py",
    "content": "from torch_geometric.deprecation import deprecated\n\nfrom .dataloader import DataLoader\nfrom .node_loader import NodeLoader\nfrom .link_loader import LinkLoader\nfrom .neighbor_loader import NeighborLoader\nfrom .link_neighbor_loader import LinkNeighborLoader\nfrom .hgt_loader import HGTLoader\nfrom .cluster import ClusterData, ClusterLoader\nfrom .graph_saint import (GraphSAINTSampler, GraphSAINTNodeSampler,\n                          GraphSAINTEdgeSampler, GraphSAINTRandomWalkSampler)\nfrom .shadow import ShaDowKHopSampler\nfrom .random_node_loader import RandomNodeLoader\n# from .ibmb_loader import IBMBBatchLoader, IBMBNodeLoader\nfrom .zip_loader import ZipLoader\nfrom .data_list_loader import DataListLoader\nfrom .dense_data_loader import DenseDataLoader\nfrom .temporal_dataloader import TemporalDataLoader\nfrom .neighbor_sampler import NeighborSampler\nfrom .imbalanced_sampler import ImbalancedSampler\nfrom .dynamic_batch_sampler import DynamicBatchSampler\nfrom .prefetch import PrefetchLoader\nfrom .cache import CachedLoader\nfrom .mixin import AffinityMixin\n\n__all__ = classes = [\n    'DataLoader',\n    'NodeLoader',\n    'LinkLoader',\n    'NeighborLoader',\n    'LinkNeighborLoader',\n    'HGTLoader',\n    'ClusterData',\n    'ClusterLoader',\n    'GraphSAINTSampler',\n    'GraphSAINTNodeSampler',\n    'GraphSAINTEdgeSampler',\n    'GraphSAINTRandomWalkSampler',\n    'ShaDowKHopSampler',\n    'RandomNodeLoader',\n    # 'IBMBBatchLoader',\n    # 'IBMBNodeLoader',\n    'ZipLoader',\n    'DataListLoader',\n    'DenseDataLoader',\n    'TemporalDataLoader',\n    'NeighborSampler',\n    'ImbalancedSampler',\n    'DynamicBatchSampler',\n    'PrefetchLoader',\n    'CachedLoader',\n    'AffinityMixin',\n]\n\nRandomNodeSampler = deprecated(\n    details=\"use 'loader.RandomNodeLoader' instead\",\n    func_name='loader.RandomNodeSampler',\n)(RandomNodeLoader)\n"
  },
  {
    "path": "torch_geometric/loader/base.py",
    "content": "from typing import Any, Callable\n\nfrom torch.utils.data.dataloader import (\n    _BaseDataLoaderIter,\n    _MultiProcessingDataLoaderIter,\n)\n\n\nclass DataLoaderIterator:\n    r\"\"\"A data loader iterator extended by a simple post transformation\n    function :meth:`transform_fn`. While the iterator may request items from\n    different sub-processes, :meth:`transform_fn` will always be executed in\n    the main process.\n\n    This iterator is used in PyG's sampler classes, and is responsible for\n    feature fetching and filtering data objects after sampling has taken place\n    in a sub-process. This has the following advantages:\n\n    * We do not need to share feature matrices across processes which may\n      prevent any errors due to too many open file handles.\n    * We can execute any expensive post-processing commands on the main thread\n      with full parallelization power (which usually executes faster).\n    * It lets us naturally support data already being present on the GPU.\n    \"\"\"\n    def __init__(self, iterator: _BaseDataLoaderIter, transform_fn: Callable):\n        self.iterator = iterator\n        self.transform_fn = transform_fn\n\n    def __iter__(self) -> 'DataLoaderIterator':\n        return self\n\n    def _reset(self, loader: Any, first_iter: bool = False):\n        self.iterator._reset(loader, first_iter)\n\n    def __len__(self) -> int:\n        return len(self.iterator)\n\n    def __next__(self) -> Any:\n        return self.transform_fn(next(self.iterator))\n\n    def __del__(self) -> Any:\n        if isinstance(self.iterator, _MultiProcessingDataLoaderIter):\n            self.iterator.__del__()\n"
  },
  {
    "path": "torch_geometric/loader/cache.py",
    "content": "from collections.abc import Mapping\nfrom typing import Any, Callable, List, Optional, Sequence\n\nimport torch\nfrom torch.utils.data import DataLoader\n\n\ndef to_device(inputs: Any, device: Optional[torch.device] = None) -> Any:\n    if hasattr(inputs, 'to'):\n        return inputs.to(device)\n    elif isinstance(inputs, Mapping):\n        return {key: to_device(value, device) for key, value in inputs.items()}\n    elif isinstance(inputs, tuple) and hasattr(inputs, '_fields'):\n        return type(inputs)(*(to_device(s, device) for s in zip(*inputs)))\n    elif isinstance(inputs, Sequence) and not isinstance(inputs, str):\n        return [to_device(s, device) for s in zip(*inputs)]\n\n    return inputs\n\n\nclass CachedLoader:\n    r\"\"\"A loader to cache mini-batch outputs, e.g., obtained during\n    :class:`NeighborLoader` iterations.\n\n    Args:\n        loader (torch.utils.data.DataLoader): The data loader.\n        device (torch.device, optional): The device to load the data to.\n            (default: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in\n            a sampled mini-batch and returns a transformed version.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        loader: DataLoader,\n        device: Optional[torch.device] = None,\n        transform: Optional[Callable] = None,\n    ):\n        self.loader = loader\n        self.device = device\n        self.transform = transform\n\n        self._cache: List[Any] = []\n\n    def clear(self):\n        r\"\"\"Clears the cache.\"\"\"\n        self._cache = []\n\n    def __iter__(self) -> Any:\n        if len(self._cache):\n            for batch in self._cache:\n                yield batch\n            return\n\n        for batch in self.loader:\n\n            if self.transform is not None:\n                batch = self.transform(batch)\n\n            batch = to_device(batch, self.device)\n\n            self._cache.append(batch)\n\n            yield batch\n\n    def __len__(self) -> int:\n        return len(self.loader)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.loader})'\n"
  },
  {
    "path": "torch_geometric/loader/cluster.py",
    "content": "import copy\nimport os\nimport os.path as osp\nimport sys\nfrom dataclasses import dataclass\nfrom typing import List, Literal, Optional\n\nimport torch\nimport torch.utils.data\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.data import Data\nfrom torch_geometric.index import index2ptr, ptr2index\nfrom torch_geometric.io import fs\nfrom torch_geometric.typing import pyg_lib\nfrom torch_geometric.utils import index_sort, narrow, select, sort_edge_index\nfrom torch_geometric.utils.map import map_index\n\n\n@dataclass\nclass Partition:\n    indptr: Tensor\n    index: Tensor\n    partptr: Tensor\n    node_perm: Tensor\n    edge_perm: Tensor\n    sparse_format: Literal['csr', 'csc']\n\n\nclass ClusterData(torch.utils.data.Dataset):\n    r\"\"\"Clusters/partitions a graph data object into multiple subgraphs, as\n    motivated by the `\"Cluster-GCN: An Efficient Algorithm for Training Deep\n    and Large Graph Convolutional Networks\"\n    <https://arxiv.org/abs/1905.07953>`_ paper.\n\n    .. note::\n        The underlying METIS algorithm requires undirected graphs as input.\n\n    Args:\n        data (torch_geometric.data.Data): The graph data object.\n        num_parts (int): The number of partitions.\n        recursive (bool, optional): If set to :obj:`True`, will use multilevel\n            recursive bisection instead of multilevel k-way partitioning.\n            (default: :obj:`False`)\n        save_dir (str, optional): If set, will save the partitioned data to the\n            :obj:`save_dir` directory for faster re-use. (default: :obj:`None`)\n        filename (str, optional): Name of the stored partitioned file.\n            (default: :obj:`None`)\n        log (bool, optional): If set to :obj:`False`, will not log any\n            progress. (default: :obj:`True`)\n        keep_inter_cluster_edges (bool, optional): If set to :obj:`True`,\n            will keep inter-cluster edge connections. (default: :obj:`False`)\n        sparse_format (str, optional): The sparse format to use for computing\n            partitions. (default: :obj:`\"csr\"`)\n    \"\"\"\n    def __init__(\n        self,\n        data,\n        num_parts: int,\n        recursive: bool = False,\n        save_dir: Optional[str] = None,\n        filename: Optional[str] = None,\n        log: bool = True,\n        keep_inter_cluster_edges: bool = False,\n        sparse_format: Literal['csr', 'csc'] = 'csr',\n    ):\n        assert data.edge_index is not None\n        assert sparse_format in ['csr', 'csc']\n\n        self.num_parts = num_parts\n        self.recursive = recursive\n        self.keep_inter_cluster_edges = keep_inter_cluster_edges\n        self.sparse_format = sparse_format\n\n        recursive_str = '_recursive' if recursive else ''\n        root_dir = osp.join(save_dir or '', f'part_{num_parts}{recursive_str}')\n        path = osp.join(root_dir, filename or 'metis.pt')\n\n        if save_dir is not None and osp.exists(path):\n            self.partition = fs.torch_load(path)\n        else:\n            if log:  # pragma: no cover\n                print('Computing METIS partitioning...', file=sys.stderr)\n\n            cluster = self._metis(data.edge_index, data.num_nodes)\n            self.partition = self._partition(data.edge_index, cluster)\n\n            if save_dir is not None:\n                os.makedirs(root_dir, exist_ok=True)\n                torch.save(self.partition, path)\n\n            if log:  # pragma: no cover\n                print('Done!', file=sys.stderr)\n\n        self.data = self._permute_data(data, self.partition)\n\n    def _metis(self, edge_index: Tensor, num_nodes: int) -> Tensor:\n        # Computes a node-level partition assignment vector via METIS.\n        if self.sparse_format == 'csr':  # Calculate CSR representation:\n            row, index = sort_edge_index(edge_index, num_nodes=num_nodes)\n            indptr = index2ptr(row, size=num_nodes)\n        else:  # Calculate CSC representation:\n            index, col = sort_edge_index(edge_index, num_nodes=num_nodes,\n                                         sort_by_row=False)\n            indptr = index2ptr(col, size=num_nodes)\n\n        # Compute METIS partitioning:\n        cluster: Optional[Tensor] = None\n\n        if torch_geometric.typing.WITH_TORCH_SPARSE:\n            try:\n                cluster = torch.ops.torch_sparse.partition(\n                    indptr.cpu(),\n                    index.cpu(),\n                    None,\n                    self.num_parts,\n                    self.recursive,\n                ).to(edge_index.device)\n            except (AttributeError, RuntimeError):\n                pass\n\n        if cluster is None and torch_geometric.typing.WITH_METIS:\n            cluster = pyg_lib.partition.metis(\n                indptr.cpu(),\n                index.cpu(),\n                self.num_parts,\n                recursive=self.recursive,\n            ).to(edge_index.device)\n\n        if cluster is None:\n            raise ImportError(f\"'{self.__class__.__name__}' requires either \"\n                              f\"'pyg-lib' or 'torch-sparse'\")\n\n        return cluster\n\n    def _partition(self, edge_index: Tensor, cluster: Tensor) -> Partition:\n        # Computes node-level and edge-level permutations and permutes the edge\n        # connectivity accordingly:\n\n        # Sort `cluster` and compute boundaries `partptr`:\n        cluster, node_perm = index_sort(cluster, max_value=self.num_parts)\n        partptr = index2ptr(cluster, size=self.num_parts)\n\n        # Permute `edge_index` based on node permutation:\n        edge_perm = torch.arange(edge_index.size(1), device=edge_index.device)\n        arange = torch.empty_like(node_perm)\n        arange[node_perm] = torch.arange(cluster.numel(),\n                                         device=cluster.device)\n        edge_index = arange[edge_index]\n\n        # Compute final CSR representation:\n        (row, col), edge_perm = sort_edge_index(\n            edge_index,\n            edge_attr=edge_perm,\n            num_nodes=cluster.numel(),\n            sort_by_row=self.sparse_format == 'csr',\n        )\n        if self.sparse_format == 'csr':\n            indptr, index = index2ptr(row, size=cluster.numel()), col\n        else:\n            indptr, index = index2ptr(col, size=cluster.numel()), row\n\n        return Partition(indptr, index, partptr, node_perm, edge_perm,\n                         self.sparse_format)\n\n    def _permute_data(self, data: Data, partition: Partition) -> Data:\n        # Permute node-level and edge-level attributes according to the\n        # calculated permutations in `Partition`:\n        out = copy.copy(data)\n        for key, value in data.items():\n            if key == 'edge_index':\n                continue\n            elif data.is_node_attr(key):\n                cat_dim = data.__cat_dim__(key, value)\n                out[key] = select(value, partition.node_perm, dim=cat_dim)\n            elif data.is_edge_attr(key):\n                cat_dim = data.__cat_dim__(key, value)\n                out[key] = select(value, partition.edge_perm, dim=cat_dim)\n        out.edge_index = None\n\n        return out\n\n    def __len__(self) -> int:\n        return self.partition.partptr.numel() - 1\n\n    def __getitem__(self, idx: int) -> Data:\n        node_start = int(self.partition.partptr[idx])\n        node_end = int(self.partition.partptr[idx + 1])\n        node_length = node_end - node_start\n\n        indptr = self.partition.indptr[node_start:node_end + 1]\n        edge_start = int(indptr[0])\n        edge_end = int(indptr[-1])\n        edge_length = edge_end - edge_start\n        indptr = indptr - edge_start\n\n        if self.sparse_format == 'csr':\n            row = ptr2index(indptr)\n            col = self.partition.index[edge_start:edge_end]\n            if not self.keep_inter_cluster_edges:\n                edge_mask = (col >= node_start) & (col < node_end)\n                row = row[edge_mask]\n                col = col[edge_mask] - node_start\n        else:\n            col = ptr2index(indptr)\n            row = self.partition.index[edge_start:edge_end]\n            if not self.keep_inter_cluster_edges:\n                edge_mask = (row >= node_start) & (row < node_end)\n                col = col[edge_mask]\n                row = row[edge_mask] - node_start\n\n        out = copy.copy(self.data)\n\n        for key, value in self.data.items():\n            if key == 'num_nodes':\n                out.num_nodes = node_length\n            elif self.data.is_node_attr(key):\n                cat_dim = self.data.__cat_dim__(key, value)\n                out[key] = narrow(value, cat_dim, node_start, node_length)\n            elif self.data.is_edge_attr(key):\n                cat_dim = self.data.__cat_dim__(key, value)\n                out[key] = narrow(value, cat_dim, edge_start, edge_length)\n                if not self.keep_inter_cluster_edges:\n                    out[key] = out[key][edge_mask]\n\n        out.edge_index = torch.stack([row, col], dim=0)\n\n        return out\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.num_parts})'\n\n\nclass ClusterLoader(torch.utils.data.DataLoader):\n    r\"\"\"The data loader scheme from the `\"Cluster-GCN: An Efficient Algorithm\n    for Training Deep and Large Graph Convolutional Networks\"\n    <https://arxiv.org/abs/1905.07953>`_ paper which merges partitioned\n    subgraphs and their between-cluster links from a large-scale graph data\n    object to form a mini-batch.\n\n    .. note::\n\n        Use :class:`~torch_geometric.loader.ClusterData` and\n        :class:`~torch_geometric.loader.ClusterLoader` in conjunction to\n        form mini-batches of clusters.\n        For an example of using Cluster-GCN, see\n        `examples/cluster_gcn_reddit.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/cluster_gcn_reddit.py>`_ or\n        `examples/cluster_gcn_ppi.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/cluster_gcn_ppi.py>`_.\n\n    Args:\n        cluster_data (torch_geometric.loader.ClusterData): The already\n            partitioned data object.\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n    \"\"\"\n    def __init__(self, cluster_data, **kwargs):\n        self.cluster_data = cluster_data\n        iterator = range(len(cluster_data))\n        super().__init__(iterator, collate_fn=self._collate, **kwargs)\n\n    def _collate(self, batch: List[int]) -> Data:\n        if not isinstance(batch, torch.Tensor):\n            batch = torch.tensor(batch)\n\n        global_indptr = self.cluster_data.partition.indptr\n        global_index = self.cluster_data.partition.index\n\n        # Get all node-level and edge-level start and end indices for the\n        # current mini-batch:\n        node_start = self.cluster_data.partition.partptr[batch]\n        node_end = self.cluster_data.partition.partptr[batch + 1]\n        edge_start = global_indptr[node_start]\n        edge_end = global_indptr[node_end]\n\n        # Iterate over each partition in the batch and calculate new edge\n        # connectivity. This is done by slicing the corresponding source and\n        # destination indices for each partition and adjusting their indices to\n        # start from zero:\n        rows, cols, nodes, cumsum = [], [], [], 0\n        for i in range(batch.numel()):\n            nodes.append(torch.arange(node_start[i], node_end[i]))\n            indptr = global_indptr[node_start[i]:node_end[i] + 1]\n            indptr = indptr - edge_start[i]\n            if self.cluster_data.partition.sparse_format == 'csr':\n                row = ptr2index(indptr) + cumsum\n                col = global_index[edge_start[i]:edge_end[i]]\n\n            else:\n                col = ptr2index(indptr) + cumsum\n                row = global_index[edge_start[i]:edge_end[i]]\n\n            rows.append(row)\n            cols.append(col)\n            cumsum += indptr.numel() - 1\n\n        node = torch.cat(nodes, dim=0)\n        row = torch.cat(rows, dim=0)\n        col = torch.cat(cols, dim=0)\n\n        # Map `col` vector to valid entries and remove any entries that do not\n        # connect two nodes within the same mini-batch:\n        if self.cluster_data.partition.sparse_format == 'csr':\n            col, edge_mask = map_index(col, node)\n            row = row[edge_mask]\n        else:\n            row, edge_mask = map_index(row, node)\n            col = col[edge_mask]\n        out = copy.copy(self.cluster_data.data)\n\n        # Slice node-level and edge-level attributes according to its offsets:\n        for key, value in self.cluster_data.data.items():\n            if key == 'num_nodes':\n                out.num_nodes = cumsum\n            elif self.cluster_data.data.is_node_attr(key):\n                cat_dim = self.cluster_data.data.__cat_dim__(key, value)\n                out[key] = torch.cat([\n                    narrow(out[key], cat_dim, s, e - s)\n                    for s, e in zip(node_start, node_end)\n                ], dim=cat_dim)\n            elif self.cluster_data.data.is_edge_attr(key):\n                cat_dim = self.cluster_data.data.__cat_dim__(key, value)\n                value = torch.cat([\n                    narrow(out[key], cat_dim, s, e - s)\n                    for s, e in zip(edge_start, edge_end)\n                ], dim=cat_dim)\n                out[key] = select(value, edge_mask, dim=cat_dim)\n\n        out.edge_index = torch.stack([row, col], dim=0)\n\n        return out\n"
  },
  {
    "path": "torch_geometric/loader/data_list_loader.py",
    "content": "from typing import List, Union\n\nimport torch\n\nfrom torch_geometric.data import Dataset\nfrom torch_geometric.data.data import BaseData\n\n\ndef collate_fn(data_list):\n    return data_list\n\n\nclass DataListLoader(torch.utils.data.DataLoader):\n    r\"\"\"A data loader which batches data objects from a\n    :class:`torch_geometric.data.dataset` to a :python:`Python` list.\n    Data objects can be either of type :class:`~torch_geometric.data.Data` or\n    :class:`~torch_geometric.data.HeteroData`.\n\n    .. note::\n\n        This data loader should be used for multi-GPU support via\n        :class:`torch_geometric.nn.DataParallel`.\n\n    Args:\n        dataset (Dataset): The dataset from which to load the data.\n        batch_size (int, optional): How many samples per batch to load.\n            (default: :obj:`1`)\n        shuffle (bool, optional): If set to :obj:`True`, the data will be\n            reshuffled at every epoch. (default: :obj:`False`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`drop_last` or\n            :obj:`num_workers`.\n    \"\"\"\n    def __init__(self, dataset: Union[Dataset, List[BaseData]],\n                 batch_size: int = 1, shuffle: bool = False, **kwargs):\n        # Remove for PyTorch Lightning:\n        kwargs.pop('collate_fn', None)\n\n        super().__init__(dataset, batch_size=batch_size, shuffle=shuffle,\n                         collate_fn=collate_fn, **kwargs)\n"
  },
  {
    "path": "torch_geometric/loader/dataloader.py",
    "content": "from collections.abc import Mapping\nfrom typing import Any, List, Optional, Sequence, Union\n\nimport torch.utils.data\nfrom torch.utils.data.dataloader import default_collate\n\nfrom torch_geometric.data import Batch, Dataset\nfrom torch_geometric.data.data import BaseData\nfrom torch_geometric.data.datapipes import DatasetAdapter\nfrom torch_geometric.typing import TensorFrame, torch_frame\n\n\nclass Collater:\n    def __init__(\n        self,\n        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],\n        follow_batch: Optional[List[str]] = None,\n        exclude_keys: Optional[List[str]] = None,\n    ):\n        self.dataset = dataset\n        self.follow_batch = follow_batch\n        self.exclude_keys = exclude_keys\n\n    def __call__(self, batch: List[Any]) -> Any:\n        elem = batch[0]\n        if isinstance(elem, BaseData):\n            return Batch.from_data_list(\n                batch,\n                follow_batch=self.follow_batch,\n                exclude_keys=self.exclude_keys,\n            )\n        elif isinstance(elem, torch.Tensor):\n            return default_collate(batch)\n        elif isinstance(elem, TensorFrame):\n            return torch_frame.cat(batch, dim=0)\n        elif isinstance(elem, float):\n            return torch.tensor(batch, dtype=torch.float)\n        elif isinstance(elem, int):\n            return torch.tensor(batch)\n        elif isinstance(elem, str):\n            return batch\n        elif isinstance(elem, Mapping):\n            return {key: self([data[key] for data in batch]) for key in elem}\n        elif isinstance(elem, tuple) and hasattr(elem, '_fields'):\n            return type(elem)(*(self(s) for s in zip(*batch)))\n        elif isinstance(elem, Sequence) and not isinstance(elem, str):\n            return [self(s) for s in zip(*batch)]\n\n        raise TypeError(f\"DataLoader found invalid type: '{type(elem)}'\")\n\n\nclass DataLoader(torch.utils.data.DataLoader):\n    r\"\"\"A data loader which merges data objects from a\n    :class:`torch_geometric.data.Dataset` to a mini-batch.\n    Data objects can be either of type :class:`~torch_geometric.data.Data` or\n    :class:`~torch_geometric.data.HeteroData`.\n\n    Args:\n        dataset (Dataset): The dataset from which to load the data.\n        batch_size (int, optional): How many samples per batch to load.\n            (default: :obj:`1`)\n        shuffle (bool, optional): If set to :obj:`True`, the data will be\n            reshuffled at every epoch. (default: :obj:`False`)\n        follow_batch (List[str], optional): Creates assignment batch\n            vectors for each key in the list. (default: :obj:`None`)\n        exclude_keys (List[str], optional): Will exclude each key in the\n            list. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`.\n    \"\"\"\n    def __init__(\n        self,\n        dataset: Union[Dataset, Sequence[BaseData], DatasetAdapter],\n        batch_size: int = 1,\n        shuffle: bool = False,\n        follow_batch: Optional[List[str]] = None,\n        exclude_keys: Optional[List[str]] = None,\n        **kwargs,\n    ):\n        # Remove for PyTorch Lightning:\n        kwargs.pop('collate_fn', None)\n\n        # Save for PyTorch Lightning < 1.6:\n        self.follow_batch = follow_batch\n        self.exclude_keys = exclude_keys\n\n        super().__init__(\n            dataset,\n            batch_size,\n            shuffle,\n            collate_fn=Collater(dataset, follow_batch, exclude_keys),\n            **kwargs,\n        )\n"
  },
  {
    "path": "torch_geometric/loader/dense_data_loader.py",
    "content": "from typing import List, Union\n\nimport torch\nfrom torch.utils.data.dataloader import default_collate\n\nfrom torch_geometric.data import Batch, Data, Dataset\n\n\ndef collate_fn(data_list: List[Data]) -> Batch:\n    batch = Batch()\n    for key in data_list[0].keys():\n        batch[key] = default_collate([data[key] for data in data_list])\n    return batch\n\n\nclass DenseDataLoader(torch.utils.data.DataLoader):\n    r\"\"\"A data loader which batches data objects from a\n    :class:`torch_geometric.data.dataset` to a\n    :class:`torch_geometric.data.Batch` object by stacking all attributes in a\n    new dimension.\n\n    .. note::\n\n        To make use of this data loader, all graph attributes in the dataset\n        need to have the same shape.\n        In particular, this data loader should only be used when working with\n        *dense* adjacency matrices.\n\n    Args:\n        dataset (Dataset): The dataset from which to load the data.\n        batch_size (int, optional): How many samples per batch to load.\n            (default: :obj:`1`)\n        shuffle (bool, optional): If set to :obj:`True`, the data will be\n            reshuffled at every epoch. (default: :obj:`False`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`drop_last` or\n            :obj:`num_workers`.\n    \"\"\"\n    def __init__(self, dataset: Union[Dataset, List[Data]],\n                 batch_size: int = 1, shuffle: bool = False, **kwargs):\n        # Remove for PyTorch Lightning:\n        kwargs.pop('collate_fn', None)\n\n        super().__init__(dataset, batch_size=batch_size, shuffle=shuffle,\n                         collate_fn=collate_fn, **kwargs)\n"
  },
  {
    "path": "torch_geometric/loader/dynamic_batch_sampler.py",
    "content": "from typing import Iterator, List, Optional\n\nimport torch\n\nfrom torch_geometric.data import Dataset\n\n\nclass DynamicBatchSampler(torch.utils.data.sampler.Sampler):\n    r\"\"\"Dynamically adds samples to a mini-batch up to a maximum size (either\n    based on number of nodes or number of edges). When data samples have a\n    wide range in sizes, specifying a mini-batch size in terms of number of\n    samples is not ideal and can cause CUDA OOM errors.\n\n    Within the :class:`DynamicBatchSampler`, the number of steps per epoch is\n    ambiguous, depending on the order of the samples. By default the\n    :meth:`__len__` will be undefined. This is fine for most cases but\n    progress bars will be infinite. Alternatively, :obj:`num_steps` can be\n    supplied to cap the number of mini-batches produced by the sampler.\n\n    .. code-block:: python\n\n        from torch_geometric.loader import DataLoader, DynamicBatchSampler\n\n        sampler = DynamicBatchSampler(dataset, max_num=10000, mode=\"node\")\n        loader = DataLoader(dataset, batch_sampler=sampler, ...)\n\n    Args:\n        dataset (Dataset): Dataset to sample from.\n        max_num (int): Size of mini-batch to aim for in number of nodes or\n            edges.\n        mode (str, optional): :obj:`\"node\"` or :obj:`\"edge\"` to measure\n            batch size. (default: :obj:`\"node\"`)\n        shuffle (bool, optional): If set to :obj:`True`, will have the data\n            reshuffled at every epoch. (default: :obj:`False`)\n        skip_too_big (bool, optional): If set to :obj:`True`, skip samples\n            which cannot fit in a batch by itself. (default: :obj:`False`)\n        num_steps (int, optional): The number of mini-batches to draw for a\n            single epoch. If set to :obj:`None`, will iterate through all the\n            underlying examples, but :meth:`__len__` will be :obj:`None` since\n            it is ambiguous. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        dataset: Dataset,\n        max_num: int,\n        mode: str = 'node',\n        shuffle: bool = False,\n        skip_too_big: bool = False,\n        num_steps: Optional[int] = None,\n    ):\n        if max_num <= 0:\n            raise ValueError(f\"`max_num` should be a positive integer value \"\n                             f\"(got {max_num})\")\n        if mode not in ['node', 'edge']:\n            raise ValueError(f\"`mode` choice should be either \"\n                             f\"'node' or 'edge' (got '{mode}')\")\n\n        self.dataset = dataset\n        self.max_num = max_num\n        self.mode = mode\n        self.shuffle = shuffle\n        self.skip_too_big = skip_too_big\n        self.num_steps = num_steps\n        self.max_steps = num_steps or len(dataset)\n\n    def __iter__(self) -> Iterator[List[int]]:\n        if self.shuffle:\n            indices = torch.randperm(len(self.dataset)).tolist()\n        else:\n            indices = range(len(self.dataset))\n\n        samples: List[int] = []\n        current_num: int = 0\n        num_steps: int = 0\n        num_processed: int = 0\n\n        while (num_processed < len(self.dataset)\n               and num_steps < self.max_steps):\n\n            for i in indices[num_processed:]:\n                data = self.dataset[i]\n                num = data.num_nodes if self.mode == 'node' else data.num_edges\n\n                if current_num + num > self.max_num:\n                    if current_num == 0:\n                        if self.skip_too_big:\n                            continue\n                    else:  # Mini-batch filled:\n                        break\n\n                samples.append(i)\n                num_processed += 1\n                current_num += num\n\n            yield samples\n            samples: List[int] = []\n            current_num = 0\n            num_steps += 1\n\n    def __len__(self) -> int:\n        if self.num_steps is None:\n            raise ValueError(f\"The length of '{self.__class__.__name__}' is \"\n                             f\"undefined since the number of steps per epoch \"\n                             f\"is ambiguous. Either specify `num_steps` or \"\n                             f\"use a static batch sampler.\")\n\n        return self.num_steps\n"
  },
  {
    "path": "torch_geometric/loader/graph_saint.py",
    "content": "import os.path as osp\nfrom typing import Optional\n\nimport torch\nfrom tqdm import tqdm\n\nfrom torch_geometric.io import fs\nfrom torch_geometric.typing import SparseTensor\n\n\nclass GraphSAINTSampler(torch.utils.data.DataLoader):\n    r\"\"\"The GraphSAINT sampler base class from the `\"GraphSAINT: Graph\n    Sampling Based Inductive Learning Method\"\n    <https://arxiv.org/abs/1907.04931>`_ paper.\n    Given a graph in a :obj:`data` object, this class samples nodes and\n    constructs subgraphs that can be processed in a mini-batch fashion.\n    Normalization coefficients for each mini-batch are given via\n    :obj:`node_norm` and :obj:`edge_norm` data attributes.\n\n    .. note::\n\n        See :class:`~torch_geometric.loader.GraphSAINTNodeSampler`,\n        :class:`~torch_geometric.loader.GraphSAINTEdgeSampler` and\n        :class:`~torch_geometric.loader.GraphSAINTRandomWalkSampler` for\n        currently supported samplers.\n        For an example of using GraphSAINT sampling, see\n        `examples/graph_saint.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/graph_saint.py>`_.\n\n    Args:\n        data (torch_geometric.data.Data): The graph data object.\n        batch_size (int): The approximate number of samples per batch.\n        num_steps (int, optional): The number of iterations per epoch.\n            (default: :obj:`1`)\n        sample_coverage (int): How many samples per node should be used to\n            compute normalization statistics. (default: :obj:`0`)\n        save_dir (str, optional): If set, will save normalization statistics to\n            the :obj:`save_dir` directory for faster re-use.\n            (default: :obj:`None`)\n        log (bool, optional): If set to :obj:`False`, will not log any\n            pre-processing progress. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or\n            :obj:`num_workers`.\n    \"\"\"\n    def __init__(self, data, batch_size: int, num_steps: int = 1,\n                 sample_coverage: int = 0, save_dir: Optional[str] = None,\n                 log: bool = True, **kwargs):\n\n        # Remove for PyTorch Lightning:\n        kwargs.pop('dataset', None)\n        kwargs.pop('collate_fn', None)\n\n        assert data.edge_index is not None\n        assert 'node_norm' not in data\n        assert 'edge_norm' not in data\n        assert not data.edge_index.is_cuda\n\n        self.num_steps = num_steps\n        self._batch_size = batch_size\n        self.sample_coverage = sample_coverage\n        self.save_dir = save_dir\n        self.log = log\n\n        self.N = N = data.num_nodes\n        self.E = data.num_edges\n\n        self.adj = SparseTensor(\n            row=data.edge_index[0], col=data.edge_index[1],\n            value=torch.arange(self.E, device=data.edge_index.device),\n            sparse_sizes=(N, N))\n\n        self.data = data\n\n        super().__init__(self, batch_size=1, collate_fn=self._collate,\n                         **kwargs)\n\n        if self.sample_coverage > 0:\n            path = osp.join(save_dir or '', self._filename)\n            if save_dir is not None and osp.exists(path):  # pragma: no cover\n                self.node_norm, self.edge_norm = fs.torch_load(path)\n            else:\n                self.node_norm, self.edge_norm = self._compute_norm()\n                if save_dir is not None:  # pragma: no cover\n                    torch.save((self.node_norm, self.edge_norm), path)\n\n    @property\n    def _filename(self):\n        return f'{self.__class__.__name__.lower()}_{self.sample_coverage}.pt'\n\n    def __len__(self):\n        return self.num_steps\n\n    def _sample_nodes(self, batch_size):\n        raise NotImplementedError\n\n    def __getitem__(self, idx):\n        node_idx = self._sample_nodes(self._batch_size).unique()\n        adj, _ = self.adj.saint_subgraph(node_idx)\n        return node_idx, adj\n\n    def _collate(self, data_list):\n        assert len(data_list) == 1\n        node_idx, adj = data_list[0]\n\n        data = self.data.__class__()\n        data.num_nodes = node_idx.size(0)\n        row, col, edge_idx = adj.coo()\n        data.edge_index = torch.stack([row, col], dim=0)\n\n        for key, item in self.data:\n            if key in ['edge_index', 'num_nodes']:\n                continue\n            if isinstance(item, torch.Tensor) and item.size(0) == self.N:\n                data[key] = item[node_idx]\n            elif isinstance(item, torch.Tensor) and item.size(0) == self.E:\n                data[key] = item[edge_idx]\n            else:\n                data[key] = item\n\n        if self.sample_coverage > 0:\n            data.node_norm = self.node_norm[node_idx]\n            data.edge_norm = self.edge_norm[edge_idx]\n\n        return data\n\n    def _compute_norm(self):\n        node_count = torch.zeros(self.N, dtype=torch.float)\n        edge_count = torch.zeros(self.E, dtype=torch.float)\n\n        loader = torch.utils.data.DataLoader(self, batch_size=200,\n                                             collate_fn=lambda x: x,\n                                             num_workers=self.num_workers)\n\n        if self.log:  # pragma: no cover\n            pbar = tqdm(total=self.N * self.sample_coverage)\n            pbar.set_description('Compute GraphSAINT normalization')\n\n        num_samples = total_sampled_nodes = 0\n        while total_sampled_nodes < self.N * self.sample_coverage:\n            for data in loader:\n                for node_idx, adj in data:\n                    edge_idx = adj.storage.value()\n                    node_count[node_idx] += 1\n                    edge_count[edge_idx] += 1\n                    total_sampled_nodes += node_idx.size(0)\n\n                    if self.log:  # pragma: no cover\n                        pbar.update(node_idx.size(0))\n            num_samples += self.num_steps\n\n        if self.log:  # pragma: no cover\n            pbar.close()\n\n        row, _, edge_idx = self.adj.coo()\n        t = torch.empty_like(edge_count).scatter_(0, edge_idx, node_count[row])\n        edge_norm = (t / edge_count).clamp_(0, 1e4)\n        edge_norm[torch.isnan(edge_norm)] = 0.1\n\n        node_count[node_count == 0] = 0.1\n        node_norm = num_samples / node_count / self.N\n\n        return node_norm, edge_norm\n\n\nclass GraphSAINTNodeSampler(GraphSAINTSampler):\n    r\"\"\"The GraphSAINT node sampler class (see\n    :class:`~torch_geometric.loader.GraphSAINTSampler`).\n    \"\"\"\n    def _sample_nodes(self, batch_size):\n        edge_sample = torch.randint(0, self.E, (batch_size, self.batch_size),\n                                    dtype=torch.long)\n\n        return self.adj.storage.row()[edge_sample]\n\n\nclass GraphSAINTEdgeSampler(GraphSAINTSampler):\n    r\"\"\"The GraphSAINT edge sampler class (see\n    :class:`~torch_geometric.loader.GraphSAINTSampler`).\n    \"\"\"\n    def _sample_nodes(self, batch_size):\n        row, col, _ = self.adj.coo()\n\n        deg_in = 1. / self.adj.storage.colcount()\n        deg_out = 1. / self.adj.storage.rowcount()\n        prob = (1. / deg_in[row]) + (1. / deg_out[col])\n\n        # Parallel multinomial sampling (without replacement)\n        # https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503\n        rand = torch.rand(batch_size, self.E).log() / (prob + 1e-10)\n        edge_sample = rand.topk(self.batch_size, dim=-1).indices\n\n        source_node_sample = col[edge_sample]\n        target_node_sample = row[edge_sample]\n\n        return torch.cat([source_node_sample, target_node_sample], -1)\n\n\nclass GraphSAINTRandomWalkSampler(GraphSAINTSampler):\n    r\"\"\"The GraphSAINT random walk sampler class (see\n    :class:`~torch_geometric.loader.GraphSAINTSampler`).\n\n    Args:\n        walk_length (int): The length of each random walk.\n    \"\"\"\n    def __init__(self, data, batch_size: int, walk_length: int,\n                 num_steps: int = 1, sample_coverage: int = 0,\n                 save_dir: Optional[str] = None, log: bool = True, **kwargs):\n        self.walk_length = walk_length\n        super().__init__(data, batch_size, num_steps, sample_coverage,\n                         save_dir, log, **kwargs)\n\n    @property\n    def _filename(self):\n        return (f'{self.__class__.__name__.lower()}_{self.walk_length}_'\n                f'{self.sample_coverage}.pt')\n\n    def _sample_nodes(self, batch_size):\n        start = torch.randint(0, self.N, (batch_size, ), dtype=torch.long)\n        node_idx = self.adj.random_walk(start.flatten(), self.walk_length)\n        return node_idx.view(-1)\n"
  },
  {
    "path": "torch_geometric/loader/hgt_loader.py",
    "content": "from typing import Callable, Dict, List, Optional, Tuple, Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.data import FeatureStore, GraphStore, HeteroData\nfrom torch_geometric.loader import NodeLoader\nfrom torch_geometric.sampler import HGTSampler\nfrom torch_geometric.typing import NodeType\n\n\nclass HGTLoader(NodeLoader):\n    r\"\"\"The Heterogeneous Graph Sampler from the `\"Heterogeneous Graph\n    Transformer\" <https://arxiv.org/abs/2003.01332>`_ paper.\n    This loader allows for mini-batch training of GNNs on large-scale graphs\n    where full-batch training is not feasible.\n\n    :class:`~torch_geometric.data.HGTLoader` tries to (1) keep a similar\n    number of nodes and edges for each type and (2) keep the sampled sub-graph\n    dense to minimize the information loss and reduce the sample variance.\n\n    Methodically, :class:`~torch_geometric.data.HGTLoader` keeps track of a\n    node budget for each node type, which is then used to determine the\n    sampling probability of a node.\n    In particular, the probability of sampling a node is determined by the\n    number of connections to already sampled nodes and their node degrees.\n    With this, :class:`~torch_geometric.data.HGTLoader` will sample a fixed\n    amount of neighbors for each node type in each iteration, as given by the\n    :obj:`num_samples` argument.\n\n    Sampled nodes are sorted based on the order in which they were sampled.\n    In particular, the first :obj:`batch_size` nodes represent the set of\n    original mini-batch nodes.\n\n    .. note::\n\n        For an example of using :class:`~torch_geometric.data.HGTLoader`, see\n        `examples/hetero/to_hetero_mag.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        hetero/to_hetero_mag.py>`_.\n\n    .. code-block:: python\n\n        from torch_geometric.loader import HGTLoader\n        from torch_geometric.datasets import OGB_MAG\n\n        hetero_data = OGB_MAG(path)[0]\n\n        loader = HGTLoader(\n            hetero_data,\n            # Sample 512 nodes per type and per iteration for 4 iterations\n            num_samples={key: [512] * 4 for key in hetero_data.node_types},\n            # Use a batch size of 128 for sampling training nodes of type paper\n            batch_size=128,\n            input_nodes=('paper', hetero_data['paper'].train_mask),\n        )\n\n        sampled_hetero_data = next(iter(loader))\n        print(sampled_data.batch_size)\n        >>> 128\n\n    Args:\n        data (Any): A :class:`~torch_geometric.data.Data`,\n            :class:`~torch_geometric.data.HeteroData`, or\n            (:class:`~torch_geometric.data.FeatureStore`,\n            :class:`~torch_geometric.data.GraphStore`) data object.\n        num_samples (List[int] or Dict[str, List[int]]): The number of nodes to\n            sample in each iteration and for each node type.\n            If given as a list, will sample the same amount of nodes for each\n            node type.\n        input_nodes (str or Tuple[str, torch.Tensor]): The indices of nodes for\n            which neighbors are sampled to create mini-batches.\n            Needs to be passed as a tuple that holds the node type and\n            corresponding node indices.\n            Node indices need to be either given as a :obj:`torch.LongTensor`\n            or :obj:`torch.BoolTensor`.\n            If node indices are set to :obj:`None`, all nodes of this specific\n            type will be considered.\n        transform (callable, optional): A function/transform that takes in\n            an a sampled mini-batch and returns a transformed version.\n            (default: :obj:`None`)\n        transform_sampler_output (callable, optional): A function/transform\n            that takes in a :class:`torch_geometric.sampler.SamplerOutput` and\n            returns a transformed version. (default: :obj:`None`)\n        is_sorted (bool, optional): If set to :obj:`True`, assumes that\n            :obj:`edge_index` is sorted by column. This avoids internal\n            re-sorting of the data and can improve runtime and memory\n            efficiency. (default: :obj:`False`)\n        filter_per_worker (bool, optional): If set to :obj:`True`, will filter\n            the returned data in each worker's subprocess.\n            If set to :obj:`False`, will filter the returned data in the main\n            process.\n            If set to :obj:`None`, will automatically infer the decision based\n            on whether data partially lives on the GPU\n            (:obj:`filter_per_worker=True`) or entirely on the CPU\n            (:obj:`filter_per_worker=False`).\n            There exists different trade-offs for setting this option.\n            Specifically, setting this option to :obj:`True` for in-memory\n            datasets will move all features to shared memory, which may result\n            in too many open file handles. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Union[HeteroData, Tuple[FeatureStore, GraphStore]],\n        num_samples: Union[List[int], Dict[NodeType, List[int]]],\n        input_nodes: Union[NodeType, Tuple[NodeType, Optional[Tensor]]],\n        is_sorted: bool = False,\n        transform: Optional[Callable] = None,\n        transform_sampler_output: Optional[Callable] = None,\n        filter_per_worker: Optional[bool] = None,\n        **kwargs,\n    ):\n        hgt_sampler = HGTSampler(\n            data,\n            num_samples=num_samples,\n            is_sorted=is_sorted,\n            share_memory=kwargs.get('num_workers', 0) > 0,\n        )\n\n        super().__init__(\n            data=data,\n            node_sampler=hgt_sampler,\n            input_nodes=input_nodes,\n            transform=transform,\n            transform_sampler_output=transform_sampler_output,\n            filter_per_worker=filter_per_worker,\n            **kwargs,\n        )\n"
  },
  {
    "path": "torch_geometric/loader/ibmb_loader.py",
    "content": "import logging\nimport math\nfrom typing import (\n    Any,\n    Callable,\n    Iterator,\n    List,\n    NamedTuple,\n    Optional,\n    Tuple,\n    Union,\n)\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import get_ppr, is_undirected, subgraph\n\ntry:\n    import numba\n    WITH_NUMBA = True\nexcept ImportError:  # pragma: no cover\n    WITH_NUMBA = False\n\n\nclass OutputNodes(NamedTuple):\n    seed_id: Tensor\n    auxiliary_id: Tensor\n\n\nclass _IBMBBaseLoader(torch.utils.data.DataLoader):\n    def __init__(self, data: Data, **kwargs):\n        kwargs.pop('collate_fn', None)\n        batch_size = kwargs.get('batch_size', 1)\n\n        output_nodes = self.get_output_nodes(self)\n\n        if batch_size == 1:  # Pre-process subgraphs:\n            data_list = ...\n            super().__init__(data_list, collate_fn=self._cache_fn, **kwargs)\n        else:\n            self.data = data\n            super().__init__(output_nodes, collate_fn=self._collate_fn,\n                             **kwargs)\n\n    def get_output_nodes(self) -> List[OutputNodes]:\n        raise NotImplementedError\n\n    def _cache_fn(self, data_list: List[Data]) -> Data:\n        assert len(data_list) == 1\n        return data_list[0]\n\n    def _collate_fn(self, output_nodes: List[OutputNodes]) -> Data:\n        raise NotImplementedError\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n\n\n###############################################################################\n\n\ndef get_partitions(\n    edge_index: Union[Tensor, SparseTensor],\n    num_partitions: int,\n    indices: Tensor,\n    num_nodes: int,\n    output_weight: Optional[float] = None,\n) -> List[Tensor]:\n    assert isinstance(\n        edge_index,\n        (torch.LongTensor,\n         SparseTensor)), f'Unsupported edge_index type {type(edge_index)}'\n    if isinstance(edge_index, torch.LongTensor):\n        edge_index = SparseTensor.from_edge_index(\n            edge_index, sparse_sizes=(num_nodes, num_nodes))\n\n    if output_weight is not None and output_weight != 1:\n        node_weight = torch.ones(num_nodes)\n        node_weight[indices] = output_weight\n    else:\n        node_weight = None\n\n    _, partptr, perm = edge_index.partition(num_parts=num_partitions,\n                                            recursive=False, weighted=False,\n                                            node_weight=node_weight)\n\n    partitions = []\n    for i in range(len(partptr) - 1):\n        partitions.append(perm[partptr[i]:partptr[i + 1]])\n\n    return partitions\n\n\ndef get_pair_wise_distance(\n    ys: List,\n    num_classes: int,\n    dist_type: str = 'kl',\n) -> np.ndarray:\n    num_batches = len(ys)\n\n    counts = np.zeros((num_batches, num_classes), dtype=np.int32)\n    for i in range(num_batches):\n        unique, count = np.unique(ys[i], return_counts=True)\n        counts[i, unique] = count\n\n    counts += 1\n    counts = counts / counts.sum(1).reshape(-1, 1)\n    pairwise_dist = np.zeros((num_batches, num_batches), dtype=np.float64)\n\n    for i in range(0, num_batches - 1):\n        for j in range(i + 1, num_batches):\n            if dist_type == 'l1':\n                pairwise_dist[i, j] = np.sum(np.abs(counts[i] - counts[j]))\n            elif dist_type == 'kl':\n\n                def kl_divergence(p: np.ndarray, q: np.ndarray):\n                    return (p * np.log(p / q)).sum()\n\n                pairwise_dist[i, j] = kl_divergence(counts[i],\n                                                    counts[j]) + kl_divergence(\n                                                        counts[j], counts[i])\n            else:\n                raise ValueError\n\n    pairwise_dist += pairwise_dist.T\n    pairwise_dist += 1e-5  # for numerical stability\n    np.fill_diagonal(pairwise_dist, 0.)\n\n    return pairwise_dist\n\n\ndef indices_complete_check(\n    loader: List[Tuple[Union[Tensor, np.ndarray], Union[Tensor, np.ndarray]]],\n    output_indices: Union[Tensor, np.ndarray],\n):\n    if isinstance(output_indices, Tensor):\n        output_indices = output_indices.cpu().numpy()\n\n    outs = []\n    for out, aux in loader:\n        if isinstance(out, Tensor):\n            out = out.cpu().numpy()\n        if isinstance(aux, Tensor):\n            aux = aux.cpu().numpy()\n\n        assert np.all(np.isin(out,\n                              aux)), \"Not all output nodes are in aux nodes!\"\n        outs.append(out)\n\n    outs = np.sort(np.concatenate(outs))\n    assert np.all(\n        outs == np.sort(output_indices)), \"Output nodes missing or duplicate!\"\n\n\ndef get_subgraph(\n    out_indices: Tensor,\n    graph: Data,\n    return_edge_index_type: str,\n    adj: SparseTensor,\n    **kwargs,\n):\n    if return_edge_index_type == 'adj':\n        assert adj is not None\n\n    if return_edge_index_type == 'adj':\n        subg = Data(x=graph.x[out_indices], y=graph.y[out_indices],\n                    edge_index=adj[out_indices, :][:, out_indices])\n    elif return_edge_index_type == 'edge_index':\n        edge_index, edge_attr = subgraph(out_indices, graph.edge_index,\n                                         graph.edge_attr, relabel_nodes=True,\n                                         num_nodes=graph.num_nodes,\n                                         return_edge_mask=False)\n        subg = Data(x=graph.x[out_indices], y=graph.y[out_indices],\n                    edge_index=edge_index, edge_attr=edge_attr)\n    else:\n        raise NotImplementedError\n\n    for k, v in kwargs.items():\n        subg[k] = v\n\n    return subg\n\n\ndef define_sampler(\n    batch_order: str,\n    ys: List[Union[Tensor, np.ndarray, List]],\n    num_classes: int,\n    dist_type: str = 'kl',\n):\n    if batch_order == 'rand':\n        logging.info(\"Running with random order\")\n        sampler = torch.utils.data.RandomSampler(ys)\n    elif batch_order in ['order', 'sample']:\n        kl_div = get_pair_wise_distance(ys, num_classes, dist_type=dist_type)\n        if batch_order == 'order':\n            from python_tsp.heuristics import solve_tsp_simulated_annealing\n            best_perm, _ = solve_tsp_simulated_annealing(kl_div)\n            logging.info(f\"Running with given order: {best_perm}\")\n            sampler = IBMBOrderedSampler(best_perm)\n        else:\n            logging.info(\"Running with weighted sampling\")\n            sampler = IBMBWeightedSampler(kl_div)\n    else:\n        raise ValueError\n\n    return sampler\n\n\ndef create_batchwise_out_aux_pairs(\n    adj: SparseTensor,\n    partitions: List[Union[torch.LongTensor, np.ndarray]],\n    prime_indices: Union[torch.LongTensor, np.ndarray],\n    topk: int,\n    num_outnodeset_per_batch: int = 50,\n    alpha: float = 0.2,\n    ppr_iterations: int = 50,\n) -> List[Tuple[np.ndarray, np.ndarray]]:\n    def ppr_power_method(\n        adj: SparseTensor,\n        batch: List[Union[np.ndarray, torch.LongTensor]],\n        topk: int,\n        num_iter: int,\n        alpha: float,\n    ) -> List[np.ndarray]:\n\n        topk_neighbors = []\n        logits = torch.zeros(\n            adj.size(0), len(batch),\n            device=adj.device())  # each column contains a set of output nodes\n        for i, tele_set in enumerate(batch):\n            logits[tele_set, i] = 1. / len(tele_set)\n\n        new_logits = logits.clone()\n        for _ in range(num_iter):\n            new_logits = adj @ new_logits * (1 - alpha) + alpha * logits\n\n        inds = new_logits.argsort(0)\n        nonzeros = (new_logits > 0).sum(0)\n        nonzeros = torch.minimum(\n            nonzeros,\n            torch.tensor([topk], dtype=torch.int64, device=adj.device()))\n        for i in range(new_logits.shape[1]):\n            topk_neighbors.append(inds[-nonzeros[i]:, i].cpu().numpy())\n\n        return topk_neighbors\n\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    if isinstance(prime_indices, Tensor):\n        prime_indices = prime_indices.cpu().numpy()\n\n    adj = adj.to(device)\n\n    cur_output_nodes = []\n    loader = []\n\n    pbar = tqdm(range(len(partitions)))\n    pbar.set_description(\"Processing topic-sensitive PPR batches\")\n    for n in pbar:\n        part = partitions[n]\n        if isinstance(part, Tensor):\n            part = part.cpu().numpy()\n\n        primes_in_part, *_ = np.intersect1d(part, prime_indices,\n                                            assume_unique=True,\n                                            return_indices=True)\n        if len(primes_in_part):  # no output nodes in this partition\n            cur_output_nodes.append(primes_in_part)\n\n        # accumulate enough output nodes to make good use of GPU memory\n        if len(cur_output_nodes\n               ) >= num_outnodeset_per_batch or n == len(partitions) - 1:\n            topk_neighbors = ppr_power_method(adj, cur_output_nodes, topk,\n                                              ppr_iterations, alpha)\n            for i in range(len(cur_output_nodes)):\n                # force output nodes to be aux nodes\n                auxiliary_nodes = np.union1d(cur_output_nodes[i],\n                                             topk_neighbors[i])\n                loader.append((cur_output_nodes[i], auxiliary_nodes))\n            cur_output_nodes = []\n\n        if torch.cuda.is_available():\n            torch.cuda.empty_cache()\n\n    return loader\n\n\ndef get_pairs(ppr_mat: Any) -> np.ndarray:\n    ppr_mat = ppr_mat + ppr_mat.transpose()\n\n    ppr_mat = ppr_mat.tocoo()\n    row, col, data = ppr_mat.row, ppr_mat.col, ppr_mat.data\n    mask = (row > col)  # lu\n\n    row, col, data = row[mask], col[mask], data[mask]\n    sort_arg = np.argsort(data)[::-1]\n    # sort_arg = parallel_sort.parallel_argsort(data)[::-1]\n\n    # map prime_nodes to arange\n    ppr_pairs = np.vstack((row[sort_arg], col[sort_arg])).T\n    return ppr_pairs\n\n\n_prime_orient_merge_numba: Optional[Callable] = None\n\n\ndef prime_orient_merge(\n    ppr_pairs: np.ndarray,\n    primes_per_batch: int,\n    num_nodes: int,\n):\n    if not WITH_NUMBA:  # pragma: no cover\n        raise ImportError(\"'prime_orient_merge' requires the 'numba' package\")\n\n    global _prime_orient_merge_numba\n    if _prime_orient_merge_numba is None:\n        _prime_orient_merge_numba = numba.njit(cache=True)(_prime_orient_merge)\n\n    return _prime_orient_merge_numba(ppr_pairs, primes_per_batch, num_nodes)\n\n\ndef _prime_orient_merge(\n    ppr_pairs: np.ndarray,\n    primes_per_batch: int,\n    num_nodes: int,\n):\n    id_primes_list = list(np.arange(num_nodes, dtype=np.int32).reshape(-1, 1))\n    node_id_list = np.arange(num_nodes, dtype=np.int32)\n    placeholder = np.zeros(0, dtype=np.int32)\n\n    for i, j in ppr_pairs:\n        id1, id2 = node_id_list[i], node_id_list[j]\n        if id1 > id2:\n            id1, id2 = id2, id1\n\n        if id1 != id2 and len(id_primes_list[id1]) + len(\n                id_primes_list[id2]) <= primes_per_batch:\n            id_primes_list[id1] = np.concatenate(\n                (id_primes_list[id1], id_primes_list[id2]))\n            node_id_list[id_primes_list[id2]] = id1\n            id_primes_list[id2] = placeholder\n\n    prime_lst = list()\n    ids = np.unique(node_id_list)\n\n    for _id in ids:\n        prime_lst.append(list(id_primes_list[_id]))\n\n    return list(prime_lst)\n\n\ndef prime_post_process(loader, merge_max_size):\n    from heapq import heapify, heappop, heappush\n\n    h = [(\n        len(p),\n        p,\n    ) for p in loader]\n    heapify(h)\n\n    while len(h) > 1:\n        len1, p1 = heappop(h)\n        len2, p2 = heappop(h)\n        if len1 + len2 <= merge_max_size:\n            heappush(h, (len1 + len2, p1 + p2))\n        else:\n            heappush(h, (\n                len1,\n                p1,\n            ))\n            heappush(h, (\n                len2,\n                p2,\n            ))\n            break\n\n    new_batch = []\n\n    while len(h):\n        _, p = heappop(h)\n        new_batch.append(p)\n\n    return new_batch\n\n\ndef topk_ppr_matrix(\n    edge_index: Tensor,\n    num_nodes: int,\n    alpha: float,\n    eps: float,\n    output_node_indices: Union[np.ndarray, torch.LongTensor],\n    topk: int,\n    normalization='row',\n) -> Tuple[Any, List[np.ndarray]]:\n    neighbors, weights = get_ppr(edge_index, alpha, eps, output_node_indices,\n                                 num_nodes)\n\n    _, neighbor_counts = neighbors[0].unique(return_counts=True)\n\n    ppr_matrix = SparseTensor(\n        row=torch.arange(\n            len(output_node_indices)).repeat_interleave(neighbor_counts),\n        col=neighbors[1], value=weights,\n        sparse_sizes=(len(output_node_indices),\n                      num_nodes)).to_scipy(layout='csr')\n\n    neighbors = [\n        n.cpu().numpy()\n        for n in torch.split(neighbors[1],\n                             neighbor_counts.cpu().tolist(), dim=0)\n    ]\n    weights = [\n        n.cpu().numpy()\n        for n in torch.split(weights,\n                             neighbor_counts.cpu().tolist(), dim=0)\n    ]\n\n    def sparsify(neighbors: List[np.ndarray], weights: List[np.ndarray],\n                 topk: int):\n        new_neighbors = []\n        for n, w in zip(neighbors, weights):\n            idx_topk = np.argsort(w)[-topk:]\n            new_neighbor = n[idx_topk]\n            new_neighbors.append(new_neighbor)\n\n        return new_neighbors\n\n    neighbors = sparsify(neighbors, weights, topk)\n    neighbors = [\n        np.union1d(nei, pr) for nei, pr in zip(neighbors, output_node_indices)\n    ]\n\n    _, out_degree = torch.unique(edge_index[0], sorted=True,\n                                 return_counts=True)\n    if normalization == 'sym':\n        # Assume undirected (symmetric) adjacency matrix\n        deg_sqrt = np.sqrt(np.maximum(out_degree, 1e-12))\n        deg_inv_sqrt = 1. / deg_sqrt\n\n        row, col = ppr_matrix.nonzero()\n        ppr_matrix.data = deg_sqrt[output_node_indices[row]] * \\\n            ppr_matrix.data * \\\n            deg_inv_sqrt[col]\n    elif normalization == 'col':\n        # Assume undirected (symmetric) adjacency matrix\n        deg_inv = 1. / np.maximum(out_degree, 1e-12)\n\n        row, col = ppr_matrix.nonzero()\n        ppr_matrix.data = out_degree[output_node_indices[row]] * \\\n            ppr_matrix.data * \\\n            deg_inv[col]\n    elif normalization == 'row':\n        pass\n    else:\n        raise ValueError(f\"Unknown PPR normalization: {normalization}\")\n\n    return ppr_matrix, neighbors\n\n\nclass IBMBBaseLoader(torch.utils.data.DataLoader):\n    def __init__(\n        self,\n        data_list: Union[List[Data], List[Tuple]],\n        graph: Data,\n        adj: SparseTensor,\n        return_edge_index_type: str,\n        **kwargs,\n    ):\n        self.graph = graph\n        self.adj = adj\n        self.return_edge_index_type = return_edge_index_type\n        if 'collate_fn' in kwargs:\n            del kwargs['collate_fn']\n        super().__init__(data_list, collate_fn=self.collate_fn, **kwargs)\n\n    def create_loader(self, *args, **kwargs):\n        raise NotImplementedError\n\n    @classmethod\n    def prepare_cache(\n        cls,\n        graph: Data,\n        batch_wise_out_aux_pairs: List[Tuple[np.ndarray, np.ndarray]],\n        adj: Optional[SparseTensor],\n        return_edge_index_type: str,\n    ):\n        subgraphs = []\n\n        pbar = tqdm(batch_wise_out_aux_pairs)\n        pbar.set_description(\n            f\"Caching data with type {return_edge_index_type}\")\n\n        if return_edge_index_type == 'adj':\n            assert adj is not None\n\n        for out, aux in pbar:\n            mask = torch.from_numpy(np.isin(aux, out))\n            if isinstance(aux, np.ndarray):\n                aux = torch.from_numpy(aux)\n            subg = get_subgraph(aux, graph, return_edge_index_type, adj,\n                                output_node_mask=mask)\n            subgraphs.append(subg)\n\n        return subgraphs\n\n    @classmethod\n    def create_adj_from_edge_index(\n        cls,\n        edge_index: Tensor,\n        num_nodes: int,\n        normalization: str,\n    ):\n        assert normalization in ['sym', 'rw']\n        adj = SparseTensor.from_edge_index(\n            edge_index,\n            sparse_sizes=(num_nodes, num_nodes),\n        )\n        adj = adj.fill_value(1.)\n        degree = adj.sum(0)\n\n        degree[degree == 0.] = 1e-12\n        deg_inv = 1 / degree\n\n        if normalization == 'sym':\n            deg_inv_sqrt = deg_inv**0.5\n            adj = adj * deg_inv_sqrt.reshape(1, -1)\n            adj = adj * deg_inv_sqrt.reshape(-1, 1)\n        elif normalization == 'rw':\n            adj = adj * deg_inv.reshape(-1, 1)\n\n        return adj\n\n    def collate_fn(self, data_list: List[Union[Data, Tuple]]):\n        if len(data_list) == 1 and isinstance(data_list[0], Data):\n            return data_list[0]\n\n        out, aux = zip(*data_list)\n        out = np.concatenate(out)\n        aux = np.unique(np.concatenate(aux))\n        mask = torch.from_numpy(np.isin(aux, out))\n        aux = torch.from_numpy(aux)\n\n        subg = get_subgraph(aux, self.graph, self.return_edge_index_type,\n                            self.adj, output_node_mask=mask)\n        return subg\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n\n\nclass IBMBBatchLoader(IBMBBaseLoader):\n    r\"\"\"The batch-wise influence-based data loader from the\n    `\"Influence-Based Mini-Batching for Graph Neural Networks\"\n    <https://arxiv.org/abs/2212.09083>`__ paper.\n\n    First, the METIS graph partitioning algorithm separates the graph into\n    :obj:`num_partitions` many partitions.\n    Afterwards, input/seed nodes and their auxiliary nodes (found via\n    topic-sensitive PageRank) are used to form a mini-batch.\n\n    If :obj:`batch_size` is set to :obj:`1`, mini-batches are pre-calculated\n    and cached in memory.\n    Otherwise, only input nodes and their auxiliary nodes are pre-computed, and\n    mini-batches are collated on-the-fly.\n\n    Args:\n        data (torch_geometric.data.Data): A\n            :class:`~torch_geometric.data.Data` object.\n        batch_order (str): A string indicating the batch order type (one of\n            :obj:`\"order\"`, :obj:`\"sample\"` or :obj:`\"rand\"`).\n            If :obj:`\"order\"`, calculates the pair-wise KL divergence between\n            every two batches to organize an optimal order.\n            If :obj:`\"sample\"`, samples the next batch w.r.t. the last one in\n            which a batch with higher KL divergence score is more likely to be\n            sampled.\n            If :obj:`\"rand\"`, batches are generated randomly.\n        num_partitions (int): The number of partitions.\n        input_nodes (torch.Tensor): A vector containing the set of seed\n            nodes.\n        batch_expand_ratio (float, optional): The ratio between the returned\n            batch size and the original partition size. For example, set it to\n            :obj:`2.0` in case you would like the batch to have double the\n            number of nodes as the size of its partition.\n            (default: :obj:`1.0`)\n        metis_input_node_weight (float, optional): The weights on the input\n            nodes for METIS graph partitioning. (default: :obj:`None`)\n        alpha (float, optional): The teleport probability of the PageRank\n            calculation. (default: :obj:`0.2`)\n        approximate_ppr_iterations (int, optional): The number of power\n            iterations for PageRank calculation. (default: :obj:`50`)\n        return_edge_index_type (str, optional): A string indicating the output\n            type of edge indices (one of :obj:`\"edge_index\"` or :obj:`\"adj\"`).\n            If set to :obj:`\"adj\"`, the :obj:`edge_index` of the batch will\n            be a :class:`torch_sparse.SparseTensor`, otherwise a\n            :class:`torch.Tensor`. (default: :obj:`\"edge_index\"`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Data,\n        batch_order: str,\n        num_partitions: int,\n        input_nodes: Tensor,\n        batch_expand_ratio: Optional[float] = 1.0,\n        metis_input_node_weight: Optional[float] = None,\n        alpha: Optional[float] = 0.2,\n        approximate_ppr_iterations: Optional[int] = 50,\n        return_edge_index_type: str = 'edge_index',\n        **kwargs,\n    ):\n        self.subgraphs = []\n        self.batch_wise_out_aux_pairs = []\n\n        assert is_undirected(\n            data.edge_index,\n            num_nodes=data.num_nodes), \"Assume the graph to be undirected\"\n        assert batch_order in ['rand', 'sample', 'order'\n                               ], f\"Unsupported batch order: {batch_order}\"\n\n        adj = self.create_adj_from_edge_index(\n            data.edge_index,\n            data.num_nodes,\n            normalization='rw',\n        )\n\n        self.cache_data = kwargs['batch_size'] == 1\n        self.num_partitions = num_partitions\n        self.output_indices = input_nodes\n        assert return_edge_index_type in ['adj', 'edge_index']\n        self.return_edge_index_type = return_edge_index_type\n        self.batch_expand_ratio = batch_expand_ratio\n        self.metis_output_weight = metis_input_node_weight\n        self.num_outnodeset_per_batch = 50\n        self.alpha = alpha\n        self.approximate_ppr_iterations = approximate_ppr_iterations\n\n        self.create_loader(data, adj)\n\n        if len(self.batch_wise_out_aux_pairs) > 2:  # <= 2 order makes no sense\n            ys = [\n                data.y[out].numpy() for out, _ in self.batch_wise_out_aux_pairs\n            ]\n            sampler = define_sampler(batch_order, ys, data.y.max().item() + 1)\n        else:\n            sampler = None\n\n        if not self.cache_data:\n            cached_data = data  # need to cache the original graph\n            if return_edge_index_type == 'adj':\n                cached_adj = adj\n            else:\n                cached_adj = None\n        else:\n            cached_data = None\n            cached_adj = None\n\n        super().__init__(\n            self.subgraphs\n            if self.cache_data else self.batch_wise_out_aux_pairs,\n            cached_data,\n            cached_adj,\n            return_edge_index_type,\n            sampler=sampler,\n            **kwargs,\n        )\n\n    def create_loader(self, graph: Data, adj: SparseTensor):\n        partitions = get_partitions(\n            adj,\n            self.num_partitions,\n            self.output_indices,\n            graph.num_nodes,\n            self.metis_output_weight,\n        )\n\n        # get output - auxiliary node pairs\n        topk = math.ceil(self.batch_expand_ratio * graph.num_nodes /\n                         self.num_partitions)\n        batch_wise_out_aux_pairs = create_batchwise_out_aux_pairs(\n            adj, partitions, self.output_indices, topk,\n            self.num_outnodeset_per_batch, self.alpha,\n            self.approximate_ppr_iterations)\n\n        indices_complete_check(batch_wise_out_aux_pairs, self.output_indices)\n        self.batch_wise_out_aux_pairs = batch_wise_out_aux_pairs\n\n        if self.cache_data:\n            self.subgraphs = self.prepare_cache(\n                graph,\n                batch_wise_out_aux_pairs,\n                adj,\n                self.return_edge_index_type,\n            )\n\n\nclass IBMBNodeLoader(IBMBBaseLoader):\n    r\"\"\"The node-wise influence-based data loader from the\n    `\"Influence-Based Mini-Batching for Graph Neural Networks\"\n    <https://arxiv.org/abs/2212.09083>`__ paper.\n\n    First, the Personalized PageRank (PPR) score for each input node is\n    computed, for which the :obj:`k` nodes with the highest scores are taken\n    auxiliary nodes.\n    Afterwards, input nodes are merged according to their pair-wise PPR scores.\n\n    Similar to :class:`~torch_geometric.loader.IBMBBatchLoader`, subgraphs are\n    cached in memory for :obj:`batch_size = 1`, and collated on-the-fly\n    otherwise.\n\n    Args:\n        data (torch_geometric.data.Data): A\n            :class:`~torch_geometric.data.Data` object.\n        batch_order (str): A string indicating the batch order type (one of\n            :obj:`\"order\"`, :obj:`\"sample\"` or :obj:`\"rand\"`).\n            If :obj:`\"order\"`, calculates the pair-wise KL divergence between\n            every two batches to organize an optimal order.\n            If :obj:`\"sample\"`, samples the next batch w.r.t. the last one in\n            which a batch with higher KL divergence score is more likely to be\n            sampled.\n            If :obj:`\"rand\"`, batches are generated randomly.\n        input_nodes (torch.Tensor): A vector containing the set of seed\n            nodes.\n        num_auxiliary_nodes (int): The number of auxiliary nodes per input\n            node.\n        num_nodes_per_batch (int): The number of seed nodes per batch.\n        alpha (float, optional): The teleport probability of the PageRank\n            calculation. (default: :obj:`0.2`)\n        eps (float, optional): The threshold for stopping the PPR calculation\n            The smaller :obj`eps` is, the more accurate are the results of\n            PPR calculation, but it also takes longer.\n            (default: :obj:`1e-5`)\n        return_edge_index_type (str, optional): A string indicating the output\n            type of edge indices (one of :obj:`\"edge_index\"` or :obj:`\"adj\"`).\n            If set to :obj:`\"adj\"`, the :obj:`edge_index` of the batch will\n            be a :class:`torch_sparse.SparseTensor`, otherwise a\n            :class:`torch.Tensor`. (default: :obj:`\"edge_index\"`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Data,\n        batch_order: str,\n        input_nodes: torch.Tensor,\n        num_auxiliary_nodes: int,\n        num_nodes_per_batch: int,\n        alpha: float = 0.2,\n        eps: float = 1e-5,\n        return_edge_index_type: str = 'edge_index',\n        **kwargs,\n    ):\n        self.subgraphs = []\n        self.node_wise_out_aux_pairs = []\n\n        assert is_undirected(\n            data.edge_index,\n            num_nodes=data.num_nodes), \"Assume the graph to be undirected\"\n        assert batch_order in ['rand', 'sample', 'order'\n                               ], f\"Unsupported batch order: {batch_order}\"\n\n        if return_edge_index_type == 'adj':\n            adj = self.create_adj_from_edge_index(data.edge_index,\n                                                  data.num_nodes,\n                                                  normalization='rw')\n        else:\n            adj = None\n\n        self.cache_data = kwargs['batch_size'] == 1\n        self._batchsize = kwargs['batch_size']\n        self.output_indices = input_nodes.numpy()\n        assert return_edge_index_type in ['adj', 'edge_index']\n        self.return_edge_index_type = return_edge_index_type\n        self.num_auxiliary_node_per_output = num_auxiliary_nodes\n        self.num_output_nodes_per_batch = num_nodes_per_batch\n        self.alpha = alpha\n        self.eps = eps\n\n        self.create_loader(data, adj)\n\n        if len(self.node_wise_out_aux_pairs) > 2:  # <= 2 order makes no sense\n            ys = [\n                data.y[out].numpy() for out, _ in self.node_wise_out_aux_pairs\n            ]\n            sampler = define_sampler(batch_order, ys, data.y.max().item() + 1)\n        else:\n            sampler = None\n\n        if not self.cache_data:\n            cached_graph = data  # need to cache the original graph\n            cached_adj = adj\n        else:\n            cached_graph = None\n            cached_adj = None\n\n        super().__init__(\n            self.subgraphs\n            if self.cache_data else self.node_wise_out_aux_pairs,\n            cached_graph,\n            cached_adj,\n            return_edge_index_type,\n            sampler=sampler,\n            **kwargs,\n        )\n\n    def create_loader(self, graph: Data, adj: SparseTensor):\n        logging.info(\"Start PPR calculation\")\n        ppr_matrix, neighbors = topk_ppr_matrix(\n            graph.edge_index, graph.num_nodes, self.alpha, self.eps,\n            torch.from_numpy(self.output_indices),\n            self.num_auxiliary_node_per_output)\n\n        ppr_matrix = ppr_matrix[:, self.output_indices]\n\n        logging.info(\"Getting PPR pairs\")\n        ppr_pairs = get_pairs(ppr_matrix)\n\n        output_list = prime_orient_merge(\n            ppr_pairs,\n            self.num_output_nodes_per_batch,\n            len(self.output_indices),\n        )\n        output_list = prime_post_process(\n            output_list,\n            self.num_output_nodes_per_batch,\n        )\n        node_wise_out_aux_pairs = []\n\n        if isinstance(neighbors, list):\n            neighbors = np.array(neighbors, dtype=object)\n\n        def _union(inputs):\n            return np.unique(np.concatenate(inputs))\n\n        for p in output_list:\n            node_wise_out_aux_pairs.append(\n                (self.output_indices[p],\n                 _union(neighbors[p]).astype(np.int64)))\n\n        indices_complete_check(node_wise_out_aux_pairs, self.output_indices)\n        self.node_wise_out_aux_pairs = node_wise_out_aux_pairs\n\n        if self.cache_data:\n            self.subgraphs = self.prepare_cache(\n                graph,\n                node_wise_out_aux_pairs,\n                adj,\n                self.return_edge_index_type,\n            )\n\n\nclass IBMBOrderedSampler(torch.utils.data.Sampler[int]):\n    r\"\"\"A sampler with given order, specially for IBMB loaders.\n\n    Args:\n        data_source (np.ndarray, torch.Tensor, List): A :obj:`np.ndarray`,\n            :obj:`torch.Tensor`, or :obj:`List` data object. Contains the\n            order of the batches.\n    \"\"\"\n    def __init__(self, data_source: Union[np.ndarray, torch.Tensor,\n                                          List]) -> None:\n        self.data_source = data_source\n        super().__init__(data_source)\n\n    def __iter__(self) -> Iterator[int]:\n        return iter(self.data_source)\n\n    def __len__(self) -> int:\n        return len(self.data_source)\n\n\nclass IBMBWeightedSampler(torch.utils.data.Sampler[int]):\n    r\"\"\"A weighted sampler wrt the pair wise KL divergence.\n    The very first batch after initialization is sampled randomly,\n    with the next ones being sampled according to the last batch,\n    including the first batch in the next round.\n\n    Args:\n        batch_kl_div (np.ndarray, torch.Tensor): A :obj:`np.ndarray` or\n            :obj:`torch.Tensor`, each element [i, j] contains the pair wise\n            KL divergence between batch i and j.\n    \"\"\"\n    def __init__(self, batch_kl_div: Union[np.ndarray, torch.Tensor]) -> None:\n        data_source = np.arange(batch_kl_div.shape[0])\n        self.data_source = data_source\n        self.batch_kl_div = batch_kl_div\n        self.last_train_batch_id = 0\n        super().__init__(data_source)\n\n    def __iter__(self) -> Iterator[int]:\n        probs = self.batch_kl_div.copy()\n\n        last = self.last_train_batch_id\n        num_batches = probs.shape[0]\n\n        fetch_idx = []\n\n        next_id = 0\n        while np.any(probs):\n            next_id = np.random.choice(num_batches, size=None, replace=False,\n                                       p=probs[last] / probs[last].sum())\n            last = next_id\n            fetch_idx.append(next_id)\n            probs[:, next_id] = 0.\n\n        self.last_train_batch_id = next_id\n\n        return iter(fetch_idx)\n\n    def __len__(self) -> int:\n        return len(self.data_source)\n"
  },
  {
    "path": "torch_geometric/loader/imbalanced_sampler.py",
    "content": "from typing import List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, Dataset, InMemoryDataset\n\n\nclass ImbalancedSampler(torch.utils.data.WeightedRandomSampler):\n    r\"\"\"A weighted random sampler that randomly samples elements according to\n    class distribution.\n    As such, it will either remove samples from the majority class\n    (under-sampling) or add more examples from the minority class\n    (over-sampling).\n\n    **Graph-level sampling:**\n\n    .. code-block:: python\n\n        from torch_geometric.loader import DataLoader, ImbalancedSampler\n\n        sampler = ImbalancedSampler(dataset)\n        loader = DataLoader(dataset, batch_size=64, sampler=sampler, ...)\n\n    **Node-level sampling:**\n\n    .. code-block:: python\n\n        from torch_geometric.loader import NeighborLoader, ImbalancedSampler\n\n        sampler = ImbalancedSampler(data, input_nodes=data.train_mask)\n        loader = NeighborLoader(data, input_nodes=data.train_mask,\n                                batch_size=64, num_neighbors=[-1, -1],\n                                sampler=sampler, ...)\n\n    You can also pass in the class labels directly as a :class:`torch.Tensor`:\n\n    .. code-block:: python\n\n        from torch_geometric.loader import NeighborLoader, ImbalancedSampler\n\n        sampler = ImbalancedSampler(data.y)\n        loader = NeighborLoader(data, input_nodes=data.train_mask,\n                                batch_size=64, num_neighbors=[-1, -1],\n                                sampler=sampler, ...)\n\n    Args:\n        dataset (Dataset or Data or Tensor): The dataset or class distribution\n            from which to sample the data, given either as a\n            :class:`~torch_geometric.data.Dataset`,\n            :class:`~torch_geometric.data.Data`, or :class:`torch.Tensor`\n            object.\n        input_nodes (Tensor, optional): The indices of nodes that are used by\n            the corresponding loader, *e.g.*, by\n            :class:`~torch_geometric.loader.NeighborLoader`.\n            If set to :obj:`None`, all nodes will be considered.\n            This argument should only be set for node-level loaders and does\n            not have any effect when operating on a set of graphs as given by\n            :class:`~torch_geometric.data.Dataset`. (default: :obj:`None`)\n        num_samples (int, optional): The number of samples to draw for a single\n            epoch. If set to :obj:`None`, will sample as much elements as there\n            exists in the underlying data. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        dataset: Union[Dataset, Data, List[Data], Tensor],\n        input_nodes: Optional[Tensor] = None,\n        num_samples: Optional[int] = None,\n    ):\n\n        if isinstance(dataset, Data):\n            y = dataset.y.view(-1)\n            assert dataset.num_nodes == y.numel()\n            y = y[input_nodes] if input_nodes is not None else y\n\n        elif isinstance(dataset, Tensor):\n            y = dataset.view(-1)\n            y = y[input_nodes] if input_nodes is not None else y\n\n        elif isinstance(dataset, InMemoryDataset):\n            y = dataset.y.view(-1)\n            assert len(dataset) == y.numel()\n\n        else:\n            ys = [data.y for data in dataset]\n            if isinstance(ys[0], Tensor):\n                y = torch.cat(ys, dim=0).view(-1)\n            else:\n                y = torch.tensor(ys).view(-1)\n            assert len(dataset) == y.numel()\n\n        assert y.dtype == torch.long  # Require classification.\n\n        num_samples = y.numel() if num_samples is None else num_samples\n\n        class_weight = 1. / y.bincount()\n        weight = class_weight[y]\n\n        return super().__init__(weight, num_samples, replacement=True)\n"
  },
  {
    "path": "torch_geometric/loader/link_loader.py",
    "content": "from typing import Any, Callable, Iterator, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData\nfrom torch_geometric.loader.base import DataLoaderIterator\nfrom torch_geometric.loader.mixin import (\n    AffinityMixin,\n    LogMemoryMixin,\n    MultithreadingMixin,\n)\nfrom torch_geometric.loader.utils import (\n    filter_custom_hetero_store,\n    filter_custom_store,\n    filter_data,\n    filter_hetero_data,\n    get_edge_label_index,\n    infer_filter_per_worker,\n)\nfrom torch_geometric.sampler import (\n    BaseSampler,\n    EdgeSamplerInput,\n    HeteroSamplerOutput,\n    NegativeSampling,\n    SamplerOutput,\n)\nfrom torch_geometric.typing import InputEdges, OptTensor\n\n\nclass LinkLoader(\n        torch.utils.data.DataLoader,\n        AffinityMixin,\n        MultithreadingMixin,\n        LogMemoryMixin,\n):\n    r\"\"\"A data loader that performs mini-batch sampling from link information,\n    using a generic :class:`~torch_geometric.sampler.BaseSampler`\n    implementation that defines a\n    :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges` function and\n    is supported on the provided input :obj:`data` object.\n\n    .. note::\n        Negative sampling is currently implemented in an approximate\n        way, *i.e.* negative edges may contain false negatives.\n\n    Args:\n        data (Any): A :class:`~torch_geometric.data.Data`,\n            :class:`~torch_geometric.data.HeteroData`, or\n            (:class:`~torch_geometric.data.FeatureStore`,\n            :class:`~torch_geometric.data.GraphStore`) data object.\n        link_sampler (torch_geometric.sampler.BaseSampler): The sampler\n            implementation to be used with this loader.\n            Needs to implement\n            :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.\n            The sampler implementation must be compatible with the input\n            :obj:`data` object.\n        edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]):\n            The edge indices, holding source and destination nodes to start\n            sampling from.\n            If set to :obj:`None`, all edges will be considered.\n            In heterogeneous graphs, needs to be passed as a tuple that holds\n            the edge type and corresponding edge indices.\n            (default: :obj:`None`)\n        edge_label (Tensor, optional): The labels of edge indices from which to\n            start sampling from. Must be the same length as\n            the :obj:`edge_label_index`. (default: :obj:`None`)\n        edge_label_time (Tensor, optional): The timestamps of edge indices from\n            which to start sampling from. Must be the same length as\n            :obj:`edge_label_index`. If set, temporal sampling will be\n            used such that neighbors are guaranteed to fulfill temporal\n            constraints, *i.e.*, neighbors have an earlier timestamp than\n            the output edge. The :obj:`time_attr` needs to be set for this\n            to work. (default: :obj:`None`)\n        neg_sampling (NegativeSampling, optional): The negative sampling\n            configuration.\n            For negative sampling mode :obj:`\"binary\"`, samples can be accessed\n            via the attributes :obj:`edge_label_index` and :obj:`edge_label` in\n            the respective edge type of the returned mini-batch.\n            In case :obj:`edge_label` does not exist, it will be automatically\n            created and represents a binary classification task (:obj:`0` =\n            negative edge, :obj:`1` = positive edge).\n            In case :obj:`edge_label` does exist, it has to be a categorical\n            label from :obj:`0` to :obj:`num_classes - 1`.\n            After negative sampling, label :obj:`0` represents negative edges,\n            and labels :obj:`1` to :obj:`num_classes` represent the labels of\n            positive edges.\n            Note that returned labels are of type :obj:`torch.float` for binary\n            classification (to facilitate the ease-of-use of\n            :meth:`F.binary_cross_entropy`) and of type\n            :obj:`torch.long` for multi-class classification (to facilitate the\n            ease-of-use of :meth:`F.cross_entropy`).\n            For negative sampling mode :obj:`\"triplet\"`, samples can be\n            accessed via the attributes :obj:`src_index`, :obj:`dst_pos_index`\n            and :obj:`dst_neg_index` in the respective node types of the\n            returned mini-batch.\n            :obj:`edge_label` needs to be :obj:`None` for :obj:`\"triplet\"`\n            negative sampling mode.\n            If set to :obj:`None`, no negative sampling strategy is applied.\n            (default: :obj:`None`)\n        neg_sampling_ratio (int or float, optional): The ratio of sampled\n            negative edges to the number of positive edges.\n            Deprecated in favor of the :obj:`neg_sampling` argument.\n            (default: :obj:`None`).\n        transform (callable, optional): A function/transform that takes in\n            a sampled mini-batch and returns a transformed version.\n            (default: :obj:`None`)\n        transform_sampler_output (callable, optional): A function/transform\n            that takes in a :class:`torch_geometric.sampler.SamplerOutput` and\n            returns a transformed version. (default: :obj:`None`)\n        filter_per_worker (bool, optional): If set to :obj:`True`, will filter\n            the returned data in each worker's subprocess.\n            If set to :obj:`False`, will filter the returned data in the main\n            process.\n            If set to :obj:`None`, will automatically infer the decision based\n            on whether data partially lives on the GPU\n            (:obj:`filter_per_worker=True`) or entirely on the CPU\n            (:obj:`filter_per_worker=False`).\n            There exists different trade-offs for setting this option.\n            Specifically, setting this option to :obj:`True` for in-memory\n            datasets will move all features to shared memory, which may result\n            in too many open file handles. (default: :obj:`None`)\n        custom_cls (HeteroData, optional): A custom\n            :class:`~torch_geometric.data.HeteroData` class to return for\n            mini-batches in case of remote backends. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],\n        link_sampler: BaseSampler,\n        edge_label_index: InputEdges = None,\n        edge_label: OptTensor = None,\n        edge_label_time: OptTensor = None,\n        neg_sampling: Optional[NegativeSampling] = None,\n        neg_sampling_ratio: Optional[Union[int, float]] = None,\n        transform: Optional[Callable] = None,\n        transform_sampler_output: Optional[Callable] = None,\n        filter_per_worker: Optional[bool] = None,\n        custom_cls: Optional[HeteroData] = None,\n        input_id: OptTensor = None,\n        **kwargs,\n    ):\n        if filter_per_worker is None:\n            filter_per_worker = infer_filter_per_worker(data)\n\n        # Remove for PyTorch Lightning:\n        kwargs.pop('dataset', None)\n        kwargs.pop('collate_fn', None)\n        # Save for PyTorch Lightning:\n        self.edge_label_index = edge_label_index\n\n        if neg_sampling_ratio is not None and neg_sampling_ratio != 0.0:\n            # TODO: Deprecation warning.\n            neg_sampling = NegativeSampling(\"binary\", neg_sampling_ratio)\n\n        # Get edge type (or `None` for homogeneous graphs):\n        input_type, edge_label_index = get_edge_label_index(\n            data, edge_label_index)\n\n        self.data = data\n        self.link_sampler = link_sampler\n        self.neg_sampling = NegativeSampling.cast(neg_sampling)\n        self.transform = transform\n        self.transform_sampler_output = transform_sampler_output\n        self.filter_per_worker = filter_per_worker\n        self.custom_cls = custom_cls\n\n        if (self.neg_sampling is not None and self.neg_sampling.is_binary()\n                and edge_label is not None and edge_label.min() == 0):\n            # Increment labels such that `zero` now denotes \"negative\".\n            edge_label = edge_label + 1\n\n        if (self.neg_sampling is not None and self.neg_sampling.is_triplet()\n                and edge_label is not None):\n            raise ValueError(\"'edge_label' needs to be undefined for \"\n                             \"'triplet'-based negative sampling. Please use \"\n                             \"`src_index`, `dst_pos_index` and \"\n                             \"`neg_pos_index` of the returned mini-batch \"\n                             \"instead to differentiate between positive and \"\n                             \"negative samples.\")\n\n        self.input_data = EdgeSamplerInput(\n            input_id=input_id,\n            row=edge_label_index[0],\n            col=edge_label_index[1],\n            label=edge_label,\n            time=edge_label_time,\n            input_type=input_type,\n        )\n\n        iterator = range(edge_label_index.size(1))\n        super().__init__(iterator, collate_fn=self.collate_fn, **kwargs)\n\n    def __call__(\n        self,\n        index: Union[Tensor, List[int]],\n    ) -> Union[Data, HeteroData]:\n        r\"\"\"Samples a subgraph from a batch of input edges.\"\"\"\n        out = self.collate_fn(index)\n        if not self.filter_per_worker:\n            out = self.filter_fn(out)\n        return out\n\n    def collate_fn(self, index: Union[Tensor, List[int]]) -> Any:\n        r\"\"\"Samples a subgraph from a batch of input edges.\"\"\"\n        input_data: EdgeSamplerInput = self.input_data[index]\n\n        out = self.link_sampler.sample_from_edges(\n            input_data, neg_sampling=self.neg_sampling)\n\n        if self.filter_per_worker:  # Execute `filter_fn` in the worker process\n            out = self.filter_fn(out)\n\n        return out\n\n    def filter_fn(\n        self,\n        out: Union[SamplerOutput, HeteroSamplerOutput],\n    ) -> Union[Data, HeteroData]:\n        r\"\"\"Joins the sampled nodes with their corresponding features,\n        returning the resulting :class:`~torch_geometric.data.Data` or\n        :class:`~torch_geometric.data.HeteroData` object to be used downstream.\n        \"\"\"\n        if self.transform_sampler_output:\n            out = self.transform_sampler_output(out)\n\n        if isinstance(out, SamplerOutput):\n            if isinstance(self.data, Data):\n                data = filter_data(  #\n                    self.data, out.node, out.row, out.col, out.edge,\n                    self.link_sampler.edge_permutation)\n\n            else:  # Tuple[FeatureStore, GraphStore]\n\n                # Hack to detect whether we are in a distributed setting.\n                if (self.link_sampler.__class__.__name__ ==\n                        'DistNeighborSampler'):\n                    edge_index = torch.stack([out.row, out.col])\n                    data = Data(edge_index=edge_index)\n                    # Metadata entries are populated in\n                    # `DistributedNeighborSampler._collate_fn()`\n                    data.x = out.metadata[-3]\n                    data.y = out.metadata[-2]\n                    data.edge_attr = out.metadata[-1]\n                else:\n                    data = filter_custom_store(  #\n                        *self.data, out.node, out.row, out.col, out.edge,\n                        self.custom_cls)\n\n            if 'n_id' not in data:\n                data.n_id = out.node\n            if out.edge is not None and 'e_id' not in data:\n                edge = out.edge.to(torch.long)\n                perm = self.link_sampler.edge_permutation\n                data.e_id = perm[out.edge] if perm is not None else out.edge\n\n            data.batch = out.batch\n            data.num_sampled_nodes = out.num_sampled_nodes\n            data.num_sampled_edges = out.num_sampled_edges\n\n            data.input_id = out.metadata[0]\n\n            if self.neg_sampling is None or self.neg_sampling.is_binary():\n                data.edge_label_index = out.metadata[1]\n                data.edge_label = out.metadata[2]\n                data.edge_label_time = out.metadata[3]\n            elif self.neg_sampling.is_triplet():\n                data.src_index = out.metadata[1]\n                data.dst_pos_index = out.metadata[2]\n                data.dst_neg_index = out.metadata[3]\n                data.seed_time = out.metadata[4]\n                # Sanity removals in case `edge_label_index` and\n                # `edge_label_time` are attributes of the base `data` object:\n                del data.edge_label_index  # Sanity removals.\n                del data.edge_label_time\n\n        elif isinstance(out, HeteroSamplerOutput):\n            if isinstance(self.data, HeteroData):\n                data = filter_hetero_data(  #\n                    self.data, out.node, out.row, out.col, out.edge,\n                    self.link_sampler.edge_permutation)\n\n            else:  # Tuple[FeatureStore, GraphStore]\n\n                # Hack to detect whether we are in a distributed setting.\n                if (self.link_sampler.__class__.__name__ ==\n                        'DistNeighborSampler'):\n                    import torch_geometric.distributed as dist\n                    data = dist.utils.filter_dist_store(\n                        *self.data, out.node, out.row, out.col, out.edge,\n                        self.custom_cls, out.metadata,\n                        self.input_data.input_type)\n                else:\n                    data = filter_custom_hetero_store(  #\n                        *self.data, out.node, out.row, out.col, out.edge,\n                        self.custom_cls)\n\n            for key, node in out.node.items():\n                if 'n_id' not in data[key]:\n                    data[key].n_id = node\n\n            for key, edge in (out.edge or {}).items():\n                if edge is not None and 'e_id' not in data[key]:\n                    edge = edge.to(torch.long)\n                    perm = self.link_sampler.edge_permutation\n                    if perm is not None and perm.get(key, None) is not None:\n                        edge = perm[key][edge]\n                    data[key].e_id = edge\n\n            data.set_value_dict('batch', out.batch)\n            data.set_value_dict('num_sampled_nodes', out.num_sampled_nodes)\n            data.set_value_dict('num_sampled_edges', out.num_sampled_edges)\n\n            input_type = self.input_data.input_type\n            data[input_type].input_id = out.metadata[0]\n\n            if self.neg_sampling is None or self.neg_sampling.is_binary():\n                data[input_type].edge_label_index = out.metadata[1]\n                data[input_type].edge_label = out.metadata[2]\n                data[input_type].edge_label_time = out.metadata[3]\n            elif self.neg_sampling.is_triplet():\n                data[input_type[0]].src_index = out.metadata[1]\n                data[input_type[-1]].dst_pos_index = out.metadata[2]\n                data[input_type[-1]].dst_neg_index = out.metadata[3]\n                data[input_type[0]].seed_time = out.metadata[4]\n                data[input_type[-1]].seed_time = out.metadata[4]\n                # Sanity removals in case `edge_label_index` and\n                # `edge_label_time` are attributes of the base `data` object:\n                if input_type in data.edge_types:\n                    del data[input_type].edge_label_index\n                    del data[input_type].edge_label_time\n\n        else:\n            raise TypeError(f\"'{self.__class__.__name__}'' found invalid \"\n                            f\"type: '{type(out)}'\")\n\n        return data if self.transform is None else self.transform(data)\n\n    def _get_iterator(self) -> Iterator:\n        if self.filter_per_worker:\n            return super()._get_iterator()\n\n        # Execute `filter_fn` in the main process:\n        return DataLoaderIterator(super()._get_iterator(), self.filter_fn)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/loader/link_neighbor_loader.py",
    "content": "from typing import Callable, Dict, List, Optional, Tuple, Union\n\nfrom torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData\nfrom torch_geometric.loader.link_loader import LinkLoader\nfrom torch_geometric.sampler import NegativeSampling, NeighborSampler\nfrom torch_geometric.sampler.base import SubgraphType\nfrom torch_geometric.typing import EdgeType, InputEdges, OptTensor\n\n\nclass LinkNeighborLoader(LinkLoader):\n    r\"\"\"A link-based data loader derived as an extension of the node-based\n    :class:`torch_geometric.loader.NeighborLoader`.\n    This loader allows for mini-batch training of GNNs on large-scale graphs\n    where full-batch training is not feasible.\n\n    More specifically, this loader first selects a sample of edges from the\n    set of input edges :obj:`edge_label_index` (which may or not be edges in\n    the original graph) and then constructs a subgraph from all the nodes\n    present in this list by sampling :obj:`num_neighbors` neighbors in each\n    iteration.\n\n    .. code-block:: python\n\n        from torch_geometric.datasets import Planetoid\n        from torch_geometric.loader import LinkNeighborLoader\n\n        data = Planetoid(path, name='Cora')[0]\n\n        loader = LinkNeighborLoader(\n            data,\n            # Sample 30 neighbors for each node for 2 iterations\n            num_neighbors=[30] * 2,\n            # Use a batch size of 128 for sampling training nodes\n            batch_size=128,\n            edge_label_index=data.edge_index,\n        )\n\n        sampled_data = next(iter(loader))\n        print(sampled_data)\n        >>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368],\n                 train_mask=[1368], val_mask=[1368], test_mask=[1368],\n                 edge_label_index=[2, 128])\n\n    It is additionally possible to provide edge labels for sampled edges, which\n    are then added to the batch:\n\n    .. code-block:: python\n\n        loader = LinkNeighborLoader(\n            data,\n            num_neighbors=[30] * 2,\n            batch_size=128,\n            edge_label_index=data.edge_index,\n            edge_label=torch.ones(data.edge_index.size(1))\n        )\n\n        sampled_data = next(iter(loader))\n        print(sampled_data)\n        >>> Data(x=[1368, 1433], edge_index=[2, 3103], y=[1368],\n                 train_mask=[1368], val_mask=[1368], test_mask=[1368],\n                 edge_label_index=[2, 128], edge_label=[128])\n\n    The rest of the functionality mirrors that of\n    :class:`~torch_geometric.loader.NeighborLoader`, including support for\n    heterogeneous graphs.\n    In particular, the data loader will add the following attributes to the\n    returned mini-batch:\n\n    * :obj:`n_id` The global node index for every sampled node\n    * :obj:`e_id` The global edge index for every sampled edge\n    * :obj:`input_id`: The global index of the :obj:`edge_label_index`\n    * :obj:`num_sampled_nodes`: The number of sampled nodes in each hop\n    * :obj:`num_sampled_edges`: The number of sampled edges in each hop\n\n    .. note::\n        Negative sampling is currently implemented in an approximate\n        way, *i.e.* negative edges may contain false negatives.\n\n    .. warning::\n        Note that the sampling scheme is independent from the edge we are\n        making a prediction for.\n        That is, by default supervision edges in :obj:`edge_label_index`\n        **will not** get masked out during sampling.\n        In case there exists an overlap between message passing edges in\n        :obj:`data.edge_index` and supervision edges in\n        :obj:`edge_label_index`, you might end up sampling an edge you are\n        making a prediction for.\n        You can generally avoid this behavior (if desired) by making\n        :obj:`data.edge_index` and :obj:`edge_label_index` two disjoint sets of\n        edges, *e.g.*, via the\n        :class:`~torch_geometric.transforms.RandomLinkSplit` transformation and\n        its :obj:`disjoint_train_ratio` argument.\n\n    Args:\n        data (Any): A :class:`~torch_geometric.data.Data`,\n            :class:`~torch_geometric.data.HeteroData`, or\n            (:class:`~torch_geometric.data.FeatureStore`,\n            :class:`~torch_geometric.data.GraphStore`) data object.\n        num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The\n            number of neighbors to sample for each node in each iteration.\n            If an entry is set to :obj:`-1`, all neighbors will be included.\n            In heterogeneous graphs, may also take in a dictionary denoting\n            the amount of neighbors to sample for each individual edge type.\n        edge_label_index (Tensor or EdgeType or Tuple[EdgeType, Tensor]):\n            The edge indices for which neighbors are sampled to create\n            mini-batches.\n            If set to :obj:`None`, all edges will be considered.\n            In heterogeneous graphs, needs to be passed as a tuple that holds\n            the edge type and corresponding edge indices.\n            (default: :obj:`None`)\n        edge_label (Tensor, optional): The labels of edge indices for\n            which neighbors are sampled. Must be the same length as\n            the :obj:`edge_label_index`. If set to :obj:`None` its set to\n            `torch.zeros(...)` internally. (default: :obj:`None`)\n        edge_label_time (Tensor, optional): The timestamps for edge indices\n            for which neighbors are sampled. Must be the same length as\n            :obj:`edge_label_index`. If set, temporal sampling will be\n            used such that neighbors are guaranteed to fulfill temporal\n            constraints, *i.e.*, neighbors have an earlier timestamp than\n            the output edge. The :obj:`time_attr` needs to be set for this\n            to work. (default: :obj:`None`)\n        replace (bool, optional): If set to :obj:`True`, will sample with\n            replacement. (default: :obj:`False`)\n        subgraph_type (SubgraphType or str, optional): The type of the returned\n            subgraph.\n            If set to :obj:`\"directional\"`, the returned subgraph only holds\n            the sampled (directed) edges which are necessary to compute\n            representations for the sampled seed nodes.\n            If set to :obj:`\"bidirectional\"`, sampled edges are converted to\n            bidirectional edges.\n            If set to :obj:`\"induced\"`, the returned subgraph contains the\n            induced subgraph of all sampled nodes.\n            (default: :obj:`\"directional\"`)\n        disjoint (bool, optional): If set to :obj: `True`, each seed node will\n            create its own disjoint subgraph.\n            If set to :obj:`True`, mini-batch outputs will have a :obj:`batch`\n            vector holding the mapping of nodes to their respective subgraph.\n            Will get automatically set to :obj:`True` in case of temporal\n            sampling. (default: :obj:`False`)\n        temporal_strategy (str, optional): The sampling strategy when using\n            temporal sampling (:obj:`\"uniform\"`, :obj:`\"last\"`).\n            If set to :obj:`\"uniform\"`, will sample uniformly across neighbors\n            that fulfill temporal constraints.\n            If set to :obj:`\"last\"`, will sample the last `num_neighbors` that\n            fulfill temporal constraints.\n            (default: :obj:`\"uniform\"`)\n        neg_sampling (NegativeSampling, optional): The negative sampling\n            configuration.\n            For negative sampling mode :obj:`\"binary\"`, samples can be accessed\n            via the attributes :obj:`edge_label_index` and :obj:`edge_label` in\n            the respective edge type of the returned mini-batch.\n            In case :obj:`edge_label` does not exist, it will be automatically\n            created and represents a binary classification task (:obj:`0` =\n            negative edge, :obj:`1` = positive edge).\n            In case :obj:`edge_label` does exist, it has to be a categorical\n            label from :obj:`0` to :obj:`num_classes - 1`.\n            After negative sampling, label :obj:`0` represents negative edges,\n            and labels :obj:`1` to :obj:`num_classes` represent the labels of\n            positive edges.\n            Note that returned labels are of type :obj:`torch.float` for binary\n            classification (to facilitate the ease-of-use of\n            :meth:`F.binary_cross_entropy`) and of type\n            :obj:`torch.long` for multi-class classification (to facilitate the\n            ease-of-use of :meth:`F.cross_entropy`).\n            For negative sampling mode :obj:`\"triplet\"`, samples can be\n            accessed via the attributes :obj:`src_index`, :obj:`dst_pos_index`\n            and :obj:`dst_neg_index` in the respective node types of the\n            returned mini-batch.\n            :obj:`edge_label` needs to be :obj:`None` for :obj:`\"triplet\"`\n            negative sampling mode.\n            If set to :obj:`None`, no negative sampling strategy is applied.\n            (default: :obj:`None`)\n            For example use obj:`neg_sampling=dict(mode= 'binary', amount=0.5)`\n        neg_sampling_ratio (int or float, optional): The ratio of sampled\n            negative edges to the number of positive edges.\n            Deprecated in favor of the :obj:`neg_sampling` argument.\n            (default: :obj:`None`)\n        time_attr (str, optional): The name of the attribute that denotes\n            timestamps for either the nodes or edges in the graph.\n            If set, temporal sampling will be used such that neighbors are\n            guaranteed to fulfill temporal constraints, *i.e.* neighbors have\n            an earlier or equal timestamp than the center node.\n            Only used if :obj:`edge_label_time` is set. (default: :obj:`None`)\n        weight_attr (str, optional): The name of the attribute that denotes\n            edge weights in the graph.\n            If set, weighted/biased sampling will be used such that neighbors\n            are more likely to get sampled the higher their edge weights are.\n            Edge weights do not need to sum to one, but must be non-negative,\n            finite and have a non-zero sum within local neighborhoods.\n            (default: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in\n            a sampled mini-batch and returns a transformed version.\n            (default: :obj:`None`)\n        transform_sampler_output (callable, optional): A function/transform\n            that takes in a :class:`torch_geometric.sampler.SamplerOutput` and\n            returns a transformed version. (default: :obj:`None`)\n        is_sorted (bool, optional): If set to :obj:`True`, assumes that\n            :obj:`edge_index` is sorted by column.\n            If :obj:`time_attr` is set, additionally requires that rows are\n            sorted according to time within individual neighborhoods.\n            This avoids internal re-sorting of the data and can improve\n            runtime and memory efficiency. (default: :obj:`False`)\n        filter_per_worker (bool, optional): If set to :obj:`True`, will filter\n            the returned data in each worker's subprocess.\n            If set to :obj:`False`, will filter the returned data in the main\n            process.\n            If set to :obj:`None`, will automatically infer the decision based\n            on whether data partially lives on the GPU\n            (:obj:`filter_per_worker=True`) or entirely on the CPU\n            (:obj:`filter_per_worker=False`).\n            There exists different trade-offs for setting this option.\n            Specifically, setting this option to :obj:`True` for in-memory\n            datasets will move all features to shared memory, which may result\n            in too many open file handles. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],\n        num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],\n        edge_label_index: InputEdges = None,\n        edge_label: OptTensor = None,\n        edge_label_time: OptTensor = None,\n        replace: bool = False,\n        subgraph_type: Union[SubgraphType, str] = 'directional',\n        disjoint: bool = False,\n        temporal_strategy: str = 'uniform',\n        neg_sampling: Optional[NegativeSampling] = None,\n        neg_sampling_ratio: Optional[Union[int, float]] = None,\n        time_attr: Optional[str] = None,\n        weight_attr: Optional[str] = None,\n        transform: Optional[Callable] = None,\n        transform_sampler_output: Optional[Callable] = None,\n        is_sorted: bool = False,\n        filter_per_worker: Optional[bool] = None,\n        neighbor_sampler: Optional[NeighborSampler] = None,\n        directed: bool = True,  # Deprecated.\n        **kwargs,\n    ):\n        if (edge_label_time is not None) != (time_attr is not None):\n            raise ValueError(\n                f\"Received conflicting 'edge_label_time' and 'time_attr' \"\n                f\"arguments: 'edge_label_time' is \"\n                f\"{'set' if edge_label_time is not None else 'not set'} \"\n                f\"while 'time_attr' is \"\n                f\"{'set' if time_attr is not None else 'not set'}. \"\n                f\"Both arguments must be provided for temporal sampling.\")\n\n        if neighbor_sampler is None:\n            neighbor_sampler = NeighborSampler(\n                data,\n                num_neighbors=num_neighbors,\n                replace=replace,\n                subgraph_type=subgraph_type,\n                disjoint=disjoint,\n                temporal_strategy=temporal_strategy,\n                time_attr=time_attr,\n                weight_attr=weight_attr,\n                is_sorted=is_sorted,\n                share_memory=kwargs.get('num_workers', 0) > 0,\n                directed=directed,\n            )\n\n        super().__init__(\n            data=data,\n            link_sampler=neighbor_sampler,\n            edge_label_index=edge_label_index,\n            edge_label=edge_label,\n            edge_label_time=edge_label_time,\n            neg_sampling=neg_sampling,\n            neg_sampling_ratio=neg_sampling_ratio,\n            transform=transform,\n            transform_sampler_output=transform_sampler_output,\n            filter_per_worker=filter_per_worker,\n            **kwargs,\n        )\n"
  },
  {
    "path": "torch_geometric/loader/mixin.py",
    "content": "import glob\nimport logging\nimport os\nimport os.path as osp\nimport warnings\nfrom contextlib import contextmanager\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport psutil\nimport torch\n\nfrom torch_geometric.data import HeteroData\n\n\ndef get_numa_nodes_cores() -> Dict[str, Any]:\n    \"\"\"Parses numa nodes information into a dictionary.\n\n    ..code-block::\n\n        {<node_id>: [(<core_id>, [<sibling_thread_id_0>, <sibling_thread_id_1>\n        ...]), ...], ...}\n\n        # For example:\n        {0: [(0, [0, 4]), (1, [1, 5])], 1: [(2, [2, 6]), (3, [3, 7])]}\n\n    If not available, returns an empty dictionary.\n    \"\"\"\n    numa_node_paths = glob.glob('/sys/devices/system/node/node[0-9]*')\n\n    if not numa_node_paths:\n        return {}\n\n    nodes = {}\n    try:\n        for node_path in numa_node_paths:\n            numa_node_id = int(osp.basename(node_path)[4:])\n\n            thread_siblings = {}\n            for cpu_dir in glob.glob(osp.join(node_path, 'cpu[0-9]*')):\n                cpu_id = int(osp.basename(cpu_dir)[3:])\n                if cpu_id > 0:\n                    with open(osp.join(cpu_dir, 'online')) as core_online_file:\n                        core_online = int(\n                            core_online_file.read().splitlines()[0])\n                else:\n                    core_online = 1  # cpu0 is always online (special case)\n                if core_online == 1:\n                    with open(osp.join(cpu_dir, 'topology',\n                                       'core_id')) as core_id_file:\n                        core_id = int(core_id_file.read().strip())\n                        if core_id in thread_siblings:\n                            thread_siblings[core_id].append(cpu_id)\n                        else:\n                            thread_siblings[core_id] = [cpu_id]\n\n            nodes[numa_node_id] = sorted([(k, sorted(v))\n                                          for k, v in thread_siblings.items()])\n\n    except (OSError, ValueError, IndexError):\n        Warning('Failed to read NUMA info')\n        return {}\n\n    return nodes\n\n\nclass WorkerInitWrapper:\n    r\"\"\"Wraps the :attr:`worker_init_fn` argument for\n    :class:`torch.utils.data.DataLoader` workers.\n    \"\"\"\n    def __init__(self, func: Callable) -> None:\n        self.func = func\n\n    def __call__(self, worker_id: int) -> None:\n        if self.func is not None:\n            self.func(worker_id)\n\n\nclass LogMemoryMixin:\n    r\"\"\"A context manager to enable logging of memory consumption in\n    :class:`~torch.utils.data.DataLoader` workers.\n    \"\"\"\n    def _mem_init_fn(self, worker_id: int) -> None:\n        proc = psutil.Process(os.getpid())\n        memory = proc.memory_info().rss / (1024 * 1024)\n        logging.debug(f\"Worker {worker_id} @ PID {proc.pid}: {memory:.2f} MB\")\n\n        # Chain worker init functions:\n        self._old_worker_init_fn(worker_id)\n\n    @contextmanager\n    def enable_memory_log(self) -> None:\n        self._old_worker_init_fn = WorkerInitWrapper(self.worker_init_fn)\n        try:\n            self.worker_init_fn = self._mem_init_fn\n            yield\n        finally:\n            self.worker_init_fn = self._old_worker_init_fn\n\n\nclass MultithreadingMixin:\n    r\"\"\"A context manager to enable multi-threading in\n    :class:`~torch.utils.data.DataLoader` workers.\n    It changes the default value of threads used in the loader from :obj:`1`\n    to :obj:`worker_threads`.\n    \"\"\"\n    def _mt_init_fn(self, worker_id: int) -> None:\n        try:\n            torch.set_num_threads(int(self._worker_threads))\n        except IndexError as e:\n            raise ValueError(f\"Cannot set {self.worker_threads} threads \"\n                             f\"in worker {worker_id}\") from e\n\n        # Chain worker init functions:\n        self._old_worker_init_fn(worker_id)\n\n    @contextmanager\n    def enable_multithreading(\n        self,\n        worker_threads: Optional[int] = None,\n    ) -> None:\n        r\"\"\"Enables multithreading in worker subprocesses.\n        This option requires to change the start method from :obj:`\"fork\"` to\n        :obj:`\"spawn\"`.\n\n        .. code-block:: python\n\n            def run():\n                loader = NeigborLoader(data, num_workers=3)\n                with loader.enable_multithreading(10):\n                    for batch in loader:\n                        pass\n\n            if __name__ == '__main__':\n                torch.set_start_method('spawn')\n                run()\n\n        Args:\n            worker_threads (int, optional): The number of threads to use in\n                each worker process.\n                By default, it uses half of all available CPU cores.\n                (default: :obj:`torch.get_num_threads() // num_workers`)\n        \"\"\"\n        if worker_threads is None:\n            worker_threads = torch.get_num_threads() // self.num_workers\n\n        self._worker_threads = worker_threads\n\n        if not self.num_workers > 0:\n            raise ValueError(f\"'enable_multithreading' needs to be performed \"\n                             f\"with at least one worker \"\n                             f\"(got {self.num_workers})\")\n\n        if worker_threads > torch.get_num_threads():\n            raise ValueError(f\"'worker_threads' should be smaller than the \"\n                             f\"total available number of threads \"\n                             f\"{torch.get_num_threads()} \"\n                             f\"(got {worker_threads})\")\n\n        context = torch.multiprocessing.get_context()._name\n        if context != 'spawn':\n            raise ValueError(f\"'enable_multithreading' can only be used with \"\n                             f\"the 'spawn' multiprocessing context \"\n                             f\"(got {context})\")\n\n        self._old_worker_init_fn = WorkerInitWrapper(self.worker_init_fn)\n        try:\n            logging.debug(f\"Using {worker_threads} threads in each worker\")\n            self.worker_init_fn = self._mt_init_fn\n            yield\n        finally:\n            self.worker_init_fn = self._old_worker_init_fn\n\n\nclass AffinityMixin:\n    r\"\"\"A context manager to enable CPU affinity for data loader workers\n    (only used when running on CPU devices).\n\n    Affinitization places data loader workers threads on specific CPU cores.\n    In effect, it allows for more efficient local memory allocation and reduces\n    remote memory calls.\n    Every time a process or thread moves from one core to another, registers\n    and caches need to be flushed and reloaded.\n    This can become very costly if it happens often, and our threads may also\n    no longer be close to their data, or be able to share data in a cache.\n\n    See `here <https://pytorch-geometric.readthedocs.io/en/latest/advanced/\n    cpu_affinity.html>`__ for the accompanying tutorial.\n\n    .. warning::\n\n        To correctly affinitize compute threads (*i.e.* with\n        :obj:`KMP_AFFINITY`), please make sure that you exclude\n        :obj:`loader_cores` from the list of cores available for the main\n        process.\n        This will cause core oversubsription and exacerbate performance.\n\n    .. code-block:: python\n\n        loader = NeigborLoader(data, num_workers=3)\n        with loader.enable_cpu_affinity(loader_cores=[0, 1, 2]):\n            for batch in loader:\n                pass\n\n    \"\"\"\n    def _aff_init_fn(self, worker_id: int) -> None:\n        try:\n            worker_cores = self.loader_cores[worker_id]\n            if not isinstance(worker_cores, List):\n                worker_cores = [worker_cores]\n\n            if torch.multiprocessing.get_context()._name == 'spawn':\n                torch.set_num_threads(len(worker_cores))\n\n            psutil.Process().cpu_affinity(worker_cores)\n\n        except IndexError as e:\n            raise ValueError(f\"Cannot use CPU affinity for worker ID \"\n                             f\"{worker_id} on CPU {self.loader_cores}\") from e\n\n        # Chain worker init functions:\n        self._old_worker_init_fn(worker_id)\n\n    @contextmanager\n    def enable_cpu_affinity(\n        self,\n        loader_cores: Optional[Union[List[List[int]], List[int]]] = None,\n    ) -> None:\n        r\"\"\"Enables CPU affinity.\n\n        Args:\n            loader_cores ([int], optional): List of CPU cores to which data\n                loader workers should affinitize to.\n                By default, it will affinitize to :obj:`numa0` cores.\n                If used with :obj:`\"spawn\"` multiprocessing context, it will\n                automatically enable multithreading and use multiple cores\n                per each worker.\n        \"\"\"\n        if not self.num_workers > 0:\n            raise ValueError(\n                f\"'enable_cpu_affinity' should be used with at least one \"\n                f\"worker (got {self.num_workers})\")\n        if loader_cores and len(loader_cores) != self.num_workers:\n            raise ValueError(\n                f\"The number of loader cores (got {len(loader_cores)}) \"\n                f\"in 'enable_cpu_affinity' should match with the number \"\n                f\"of workers (got {self.num_workers})\")\n        if isinstance(self.data, HeteroData):\n            warnings.warn(\n                \"Due to conflicting parallelization methods it is not advised \"\n                \"to use affinitization with 'HeteroData' datasets. \"\n                \"Use `enable_multithreading` for better performance.\",\n                stacklevel=2)\n\n        self.loader_cores = loader_cores[:] if loader_cores else None\n        if self.loader_cores is None:\n            numa_info = get_numa_nodes_cores()\n\n            if numa_info and len(numa_info[0]) > self.num_workers:\n                # Take one thread per each node 0 core:\n                node0_cores = [cpus[0] for core_id, cpus in numa_info[0]]\n                node0_cores.sort()\n            else:\n                node0_cores = list(range(psutil.cpu_count(logical=False)))\n\n            if len(node0_cores) < self.num_workers:\n                raise ValueError(\n                    f\"More workers (got {self.num_workers}) than available \"\n                    f\"cores (got {len(node0_cores)})\")\n\n            # Set default loader core IDs:\n            if torch.multiprocessing.get_context()._name == 'spawn':\n                work_thread_pool = int(len(node0_cores) / self.num_workers)\n                self.loader_cores = [\n                    list(\n                        range(\n                            work_thread_pool * i,\n                            work_thread_pool * (i + 1),\n                        )) for i in range(self.num_workers)\n                ]\n            else:\n                self.loader_cores = node0_cores[:self.num_workers]\n\n        self._old_worker_init_fn = WorkerInitWrapper(self.worker_init_fn)\n        try:\n            self.worker_init_fn = self._aff_init_fn\n            logging.debug(f\"{self.num_workers} data loader workers are \"\n                          f\"assigned to CPUs {self.loader_cores}\")\n            yield\n        finally:\n            self.worker_init_fn = self._old_worker_init_fn\n"
  },
  {
    "path": "torch_geometric/loader/neighbor_loader.py",
    "content": "from typing import Callable, Dict, List, Optional, Tuple, Union\n\nfrom torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData\nfrom torch_geometric.loader.node_loader import NodeLoader\nfrom torch_geometric.sampler import NeighborSampler\nfrom torch_geometric.sampler.base import SubgraphType\nfrom torch_geometric.typing import EdgeType, InputNodes, OptTensor\n\n\nclass NeighborLoader(NodeLoader):\n    r\"\"\"A data loader that performs neighbor sampling as introduced in the\n    `\"Inductive Representation Learning on Large Graphs\"\n    <https://arxiv.org/abs/1706.02216>`_ paper.\n    This loader allows for mini-batch training of GNNs on large-scale graphs\n    where full-batch training is not feasible.\n\n    More specifically, :obj:`num_neighbors` denotes how many neighbors are\n    sampled for each node in each iteration.\n    :class:`~torch_geometric.loader.NeighborLoader` takes in this list of\n    :obj:`num_neighbors` and iteratively samples :obj:`num_neighbors[i]` for\n    each node involved in iteration :obj:`i - 1`.\n\n    Sampled nodes are sorted based on the order in which they were sampled.\n    In particular, the first :obj:`batch_size` nodes represent the set of\n    original mini-batch nodes.\n\n    .. code-block:: python\n\n        from torch_geometric.datasets import Planetoid\n        from torch_geometric.loader import NeighborLoader\n\n        data = Planetoid(path, name='Cora')[0]\n\n        loader = NeighborLoader(\n            data,\n            # Sample 30 neighbors for each node for 2 iterations\n            num_neighbors=[30] * 2,\n            # Use a batch size of 128 for sampling training nodes\n            batch_size=128,\n            input_nodes=data.train_mask,\n        )\n\n        sampled_data = next(iter(loader))\n        print(sampled_data.batch_size)\n        >>> 128\n\n    By default, the data loader will only include the edges that were\n    originally sampled (:obj:`directed = True`).\n    This option should only be used in case the number of hops is equivalent to\n    the number of GNN layers.\n    In case the number of GNN layers is greater than the number of hops,\n    consider setting :obj:`directed = False`, which will include all edges\n    between all sampled nodes (but is slightly slower as a result).\n\n    Furthermore, :class:`~torch_geometric.loader.NeighborLoader` works for both\n    **homogeneous** graphs stored via :class:`~torch_geometric.data.Data` as\n    well as **heterogeneous** graphs stored via\n    :class:`~torch_geometric.data.HeteroData`.\n    When operating in heterogeneous graphs, up to :obj:`num_neighbors`\n    neighbors will be sampled for each :obj:`edge_type`.\n    However, more fine-grained control over\n    the amount of sampled neighbors of individual edge types is possible:\n\n    .. code-block:: python\n\n        from torch_geometric.datasets import OGB_MAG\n        from torch_geometric.loader import NeighborLoader\n\n        hetero_data = OGB_MAG(path)[0]\n\n        loader = NeighborLoader(\n            hetero_data,\n            # Sample 30 neighbors for each node and edge type for 2 iterations\n            num_neighbors={key: [30] * 2 for key in hetero_data.edge_types},\n            # Use a batch size of 128 for sampling training nodes of type paper\n            batch_size=128,\n            input_nodes=('paper', hetero_data['paper'].train_mask),\n        )\n\n        sampled_hetero_data = next(iter(loader))\n        print(sampled_hetero_data['paper'].batch_size)\n        >>> 128\n\n    .. note::\n\n        For an example of using\n        :class:`~torch_geometric.loader.NeighborLoader`, see\n        `examples/hetero/to_hetero_mag.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/hetero/to_hetero_mag.py>`_.\n\n    The :class:`~torch_geometric.loader.NeighborLoader` will return subgraphs\n    where global node indices are mapped to local indices corresponding to this\n    specific subgraph. However, often times it is desired to map the nodes of\n    the current subgraph back to the global node indices. The\n    :class:`~torch_geometric.loader.NeighborLoader` will include this mapping\n    as part of the :obj:`data` object:\n\n    .. code-block:: python\n\n        loader = NeighborLoader(data, ...)\n        sampled_data = next(iter(loader))\n        print(sampled_data.n_id)  # Global node index of each node in batch.\n\n    In particular, the data loader will add the following attributes to the\n    returned mini-batch:\n\n    * :obj:`batch_size` The number of seed nodes (first nodes in the batch)\n    * :obj:`n_id` The global node index for every sampled node\n    * :obj:`e_id` The global edge index for every sampled edge\n    * :obj:`input_id`: The global index of the :obj:`input_nodes`\n    * :obj:`num_sampled_nodes`: The number of sampled nodes in each hop\n    * :obj:`num_sampled_edges`: The number of sampled edges in each hop\n\n    Args:\n        data (Any): A :class:`~torch_geometric.data.Data`,\n            :class:`~torch_geometric.data.HeteroData`, or\n            (:class:`~torch_geometric.data.FeatureStore`,\n            :class:`~torch_geometric.data.GraphStore`) data object.\n        num_neighbors (List[int] or Dict[Tuple[str, str, str], List[int]]): The\n            number of neighbors to sample for each node in each iteration.\n            If an entry is set to :obj:`-1`, all neighbors will be included.\n            In heterogeneous graphs, may also take in a dictionary denoting\n            the amount of neighbors to sample for each individual edge type.\n        input_nodes (torch.Tensor or str or Tuple[str, torch.Tensor]): The\n            indices of nodes for which neighbors are sampled to create\n            mini-batches.\n            Needs to be either given as a :obj:`torch.LongTensor` or\n            :obj:`torch.BoolTensor`.\n            If set to :obj:`None`, all nodes will be considered.\n            In heterogeneous graphs, needs to be passed as a tuple that holds\n            the node type and node indices. (default: :obj:`None`)\n        input_time (torch.Tensor, optional): Optional values to override the\n            timestamp for the input nodes given in :obj:`input_nodes`. If not\n            set, will use the timestamps in :obj:`time_attr` as default (if\n            present). The :obj:`time_attr` needs to be set for this to work.\n            (default: :obj:`None`)\n        replace (bool, optional): If set to :obj:`True`, will sample with\n            replacement. (default: :obj:`False`)\n        subgraph_type (SubgraphType or str, optional): The type of the returned\n            subgraph.\n            If set to :obj:`\"directional\"`, the returned subgraph only holds\n            the sampled (directed) edges which are necessary to compute\n            representations for the sampled seed nodes.\n            If set to :obj:`\"bidirectional\"`, sampled edges are converted to\n            bidirectional edges.\n            If set to :obj:`\"induced\"`, the returned subgraph contains the\n            induced subgraph of all sampled nodes.\n            (default: :obj:`\"directional\"`)\n        disjoint (bool, optional): If set to :obj: `True`, each seed node will\n            create its own disjoint subgraph.\n            If set to :obj:`True`, mini-batch outputs will have a :obj:`batch`\n            vector holding the mapping of nodes to their respective subgraph.\n            Will get automatically set to :obj:`True` in case of temporal\n            sampling. (default: :obj:`False`)\n        temporal_strategy (str, optional): The sampling strategy when using\n            temporal sampling (:obj:`\"uniform\"`, :obj:`\"last\"`).\n            If set to :obj:`\"uniform\"`, will sample uniformly across neighbors\n            that fulfill temporal constraints.\n            If set to :obj:`\"last\"`, will sample the last `num_neighbors` that\n            fulfill temporal constraints.\n            (default: :obj:`\"uniform\"`)\n        time_attr (str, optional): The name of the attribute that denotes\n            timestamps for either the nodes or edges in the graph.\n            If set, temporal sampling will be used such that neighbors are\n            guaranteed to fulfill temporal constraints, *i.e.* neighbors have\n            an earlier or equal timestamp than the center node.\n            (default: :obj:`None`)\n        weight_attr (str, optional): The name of the attribute that denotes\n            edge weights in the graph.\n            If set, weighted/biased sampling will be used such that neighbors\n            are more likely to get sampled the higher their edge weights are.\n            Edge weights do not need to sum to one, but must be non-negative,\n            finite and have a non-zero sum within local neighborhoods.\n            (default: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in\n            a sampled mini-batch and returns a transformed version.\n            (default: :obj:`None`)\n        transform_sampler_output (callable, optional): A function/transform\n            that takes in a :class:`torch_geometric.sampler.SamplerOutput` and\n            returns a transformed version. (default: :obj:`None`)\n        is_sorted (bool, optional): If set to :obj:`True`, assumes that\n            :obj:`edge_index` is sorted by column.\n            If :obj:`time_attr` is set, additionally requires that rows are\n            sorted according to time within individual neighborhoods.\n            This avoids internal re-sorting of the data and can improve\n            runtime and memory efficiency. (default: :obj:`False`)\n        filter_per_worker (bool, optional): If set to :obj:`True`, will filter\n            the returned data in each worker's subprocess.\n            If set to :obj:`False`, will filter the returned data in the main\n            process.\n            If set to :obj:`None`, will automatically infer the decision based\n            on whether data partially lives on the GPU\n            (:obj:`filter_per_worker=True`) or entirely on the CPU\n            (:obj:`filter_per_worker=False`).\n            There exists different trade-offs for setting this option.\n            Specifically, setting this option to :obj:`True` for in-memory\n            datasets will move all features to shared memory, which may result\n            in too many open file handles. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],\n        num_neighbors: Union[List[int], Dict[EdgeType, List[int]]],\n        input_nodes: InputNodes = None,\n        input_time: OptTensor = None,\n        replace: bool = False,\n        subgraph_type: Union[SubgraphType, str] = 'directional',\n        disjoint: bool = False,\n        temporal_strategy: str = 'uniform',\n        time_attr: Optional[str] = None,\n        weight_attr: Optional[str] = None,\n        transform: Optional[Callable] = None,\n        transform_sampler_output: Optional[Callable] = None,\n        is_sorted: bool = False,\n        filter_per_worker: Optional[bool] = None,\n        neighbor_sampler: Optional[NeighborSampler] = None,\n        directed: bool = True,  # Deprecated.\n        **kwargs,\n    ):\n        if input_time is not None and time_attr is None:\n            raise ValueError(\"Received conflicting 'input_time' and \"\n                             \"'time_attr' arguments: 'input_time' is set \"\n                             \"while 'time_attr' is not set.\")\n\n        if neighbor_sampler is None:\n            neighbor_sampler = NeighborSampler(\n                data,\n                num_neighbors=num_neighbors,\n                replace=replace,\n                subgraph_type=subgraph_type,\n                disjoint=disjoint,\n                temporal_strategy=temporal_strategy,\n                time_attr=time_attr,\n                weight_attr=weight_attr,\n                is_sorted=is_sorted,\n                share_memory=kwargs.get('num_workers', 0) > 0,\n                directed=directed,\n            )\n\n        super().__init__(\n            data=data,\n            node_sampler=neighbor_sampler,\n            input_nodes=input_nodes,\n            input_time=input_time,\n            transform=transform,\n            transform_sampler_output=transform_sampler_output,\n            filter_per_worker=filter_per_worker,\n            **kwargs,\n        )\n"
  },
  {
    "path": "torch_geometric/loader/neighbor_sampler.py",
    "content": "from typing import Callable, List, NamedTuple, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import SparseTensor\n\n\nclass EdgeIndex(NamedTuple):\n    edge_index: Tensor\n    e_id: Optional[Tensor]\n    size: Tuple[int, int]\n\n    def to(self, *args, **kwargs):\n        edge_index = self.edge_index.to(*args, **kwargs)\n        e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None\n        return EdgeIndex(edge_index, e_id, self.size)\n\n\nclass Adj(NamedTuple):\n    adj_t: SparseTensor\n    e_id: Optional[Tensor]\n    size: Tuple[int, int]\n\n    def to(self, *args, **kwargs):\n        adj_t = self.adj_t.to(*args, **kwargs)\n        e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None\n        return Adj(adj_t, e_id, self.size)\n\n\nclass NeighborSampler(torch.utils.data.DataLoader):\n    r\"\"\"The neighbor sampler from the `\"Inductive Representation Learning on\n    Large Graphs\" <https://arxiv.org/abs/1706.02216>`_ paper, which allows\n    for mini-batch training of GNNs on large-scale graphs where full-batch\n    training is not feasible.\n\n    Given a GNN with :math:`L` layers and a specific mini-batch of nodes\n    :obj:`node_idx` for which we want to compute embeddings, this module\n    iteratively samples neighbors and constructs bipartite graphs that simulate\n    the actual computation flow of GNNs.\n\n    More specifically, :obj:`sizes` denotes how much neighbors we want to\n    sample for each node in each layer.\n    This module then takes in these :obj:`sizes` and iteratively samples\n    :obj:`sizes[l]` for each node involved in layer :obj:`l`.\n    In the next layer, sampling is repeated for the union of nodes that were\n    already encountered.\n    The actual computation graphs are then returned in reverse-mode, meaning\n    that we pass messages from a larger set of nodes to a smaller one, until we\n    reach the nodes for which we originally wanted to compute embeddings.\n\n    Hence, an item returned by :class:`NeighborSampler` holds the current\n    :obj:`batch_size`, the IDs :obj:`n_id` of all nodes involved in the\n    computation, and a list of bipartite graph objects via the tuple\n    :obj:`(edge_index, e_id, size)`, where :obj:`edge_index` represents the\n    bipartite edges between source and target nodes, :obj:`e_id` denotes the\n    IDs of original edges in the full graph, and :obj:`size` holds the shape\n    of the bipartite graph.\n    For each bipartite graph, target nodes are also included at the beginning\n    of the list of source nodes so that one can easily apply skip-connections\n    or add self-loops.\n\n    .. warning::\n\n        :class:`~torch_geometric.loader.NeighborSampler` is deprecated and will\n        be removed in a future release.\n        Use :class:`torch_geometric.loader.NeighborLoader` instead.\n\n    .. note::\n\n        For an example of using :obj:`NeighborSampler`, see\n        `examples/reddit.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        reddit.py>`_ or\n        `examples/ogbn_train.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        ogbn_train.py>`_.\n\n    Args:\n        edge_index (Tensor or SparseTensor): A :obj:`torch.LongTensor` or a\n            :class:`torch_sparse.SparseTensor` that defines the underlying\n            graph connectivity/message passing flow.\n            :obj:`edge_index` holds the indices of a (sparse) symmetric\n            adjacency matrix.\n            If :obj:`edge_index` is of type :obj:`torch.LongTensor`, its shape\n            must be defined as :obj:`[2, num_edges]`, where messages from nodes\n            :obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]`\n            (in case :obj:`flow=\"source_to_target\"`).\n            If :obj:`edge_index` is of type :class:`torch_sparse.SparseTensor`,\n            its sparse indices :obj:`(row, col)` should relate to\n            :obj:`row = edge_index[1]` and :obj:`col = edge_index[0]`.\n            The major difference between both formats is that we need to input\n            the *transposed* sparse adjacency matrix.\n        sizes ([int]): The number of neighbors to sample for each node in each\n            layer. If set to :obj:`sizes[l] = -1`, all neighbors are included\n            in layer :obj:`l`.\n        node_idx (LongTensor, optional): The nodes that should be considered\n            for creating mini-batches. If set to :obj:`None`, all nodes will be\n            considered.\n        num_nodes (int, optional): The number of nodes in the graph.\n            (default: :obj:`None`)\n        return_e_id (bool, optional): If set to :obj:`False`, will not return\n            original edge indices of sampled edges. This is only useful in case\n            when operating on graphs without edge features to save memory.\n            (default: :obj:`True`)\n        transform (callable, optional): A function/transform that takes in\n            a sampled mini-batch and returns a transformed version.\n            (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n    \"\"\"\n    def __init__(self, edge_index: Union[Tensor, SparseTensor],\n                 sizes: List[int], node_idx: Optional[Tensor] = None,\n                 num_nodes: Optional[int] = None, return_e_id: bool = True,\n                 transform: Callable = None, **kwargs):\n\n        edge_index = edge_index.to('cpu')\n\n        # Remove for PyTorch Lightning:\n        kwargs.pop('dataset', None)\n        kwargs.pop('collate_fn', None)\n\n        # Save for Pytorch Lightning < 1.6:\n        self.edge_index = edge_index\n        self.node_idx = node_idx\n        self.num_nodes = num_nodes\n\n        self.sizes = sizes\n        self.return_e_id = return_e_id\n        self.transform = transform\n        self.is_sparse_tensor = isinstance(edge_index, SparseTensor)\n        self.__val__ = None\n\n        # Obtain a *transposed* `SparseTensor` instance.\n        if not self.is_sparse_tensor:\n            if (num_nodes is None and node_idx is not None\n                    and node_idx.dtype == torch.bool):\n                num_nodes = node_idx.size(0)\n            if (num_nodes is None and node_idx is not None\n                    and node_idx.dtype == torch.long):\n                num_nodes = max(int(edge_index.max()), int(node_idx.max())) + 1\n            if num_nodes is None:\n                num_nodes = int(edge_index.max()) + 1\n\n            value = torch.arange(edge_index.size(1)) if return_e_id else None\n            self.adj_t = SparseTensor(row=edge_index[0], col=edge_index[1],\n                                      value=value,\n                                      sparse_sizes=(num_nodes, num_nodes)).t()\n        else:\n            adj_t = edge_index\n            if return_e_id:\n                self.__val__ = adj_t.storage.value()\n                value = torch.arange(adj_t.nnz())\n                adj_t = adj_t.set_value(value, layout='coo')\n            self.adj_t = adj_t\n\n        self.adj_t.storage.rowptr()\n\n        if node_idx is None:\n            node_idx = torch.arange(self.adj_t.sparse_size(0))\n        elif node_idx.dtype == torch.bool:\n            node_idx = node_idx.nonzero(as_tuple=False).view(-1)\n\n        super().__init__(\n            node_idx.view(-1).tolist(), collate_fn=self.sample, **kwargs)\n\n    def sample(self, batch):\n        if not isinstance(batch, Tensor):\n            batch = torch.tensor(batch)\n\n        batch_size: int = len(batch)\n\n        adjs = []\n        n_id = batch\n        for size in self.sizes:\n            adj_t, n_id = self.adj_t.sample_adj(n_id, size, replace=False)\n            e_id = adj_t.storage.value()\n            size = adj_t.sparse_sizes()[::-1]\n            if self.__val__ is not None:\n                adj_t.set_value_(self.__val__[e_id], layout='coo')\n\n            if self.is_sparse_tensor:\n                adjs.append(Adj(adj_t, e_id, size))\n            else:\n                row, col, _ = adj_t.coo()\n                edge_index = torch.stack([col, row], dim=0)\n                adjs.append(EdgeIndex(edge_index, e_id, size))\n\n        adjs = adjs[0] if len(adjs) == 1 else adjs[::-1]\n        out = (batch_size, n_id, adjs)\n        out = self.transform(*out) if self.transform is not None else out\n        return out\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(sizes={self.sizes})'\n"
  },
  {
    "path": "torch_geometric/loader/node_loader.py",
    "content": "from typing import Any, Callable, Iterator, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData\nfrom torch_geometric.loader.base import DataLoaderIterator\nfrom torch_geometric.loader.mixin import (\n    AffinityMixin,\n    LogMemoryMixin,\n    MultithreadingMixin,\n)\nfrom torch_geometric.loader.utils import (\n    filter_custom_hetero_store,\n    filter_custom_store,\n    filter_data,\n    filter_hetero_data,\n    get_input_nodes,\n    infer_filter_per_worker,\n)\nfrom torch_geometric.sampler import (\n    BaseSampler,\n    HeteroSamplerOutput,\n    NodeSamplerInput,\n    SamplerOutput,\n)\nfrom torch_geometric.typing import InputNodes, OptTensor\n\n\nclass NodeLoader(\n        torch.utils.data.DataLoader,\n        AffinityMixin,\n        MultithreadingMixin,\n        LogMemoryMixin,\n):\n    r\"\"\"A data loader that performs mini-batch sampling from node information,\n    using a generic :class:`~torch_geometric.sampler.BaseSampler`\n    implementation that defines a\n    :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes` function and\n    is supported on the provided input :obj:`data` object.\n\n    Args:\n        data (Any): A :class:`~torch_geometric.data.Data`,\n            :class:`~torch_geometric.data.HeteroData`, or\n            (:class:`~torch_geometric.data.FeatureStore`,\n            :class:`~torch_geometric.data.GraphStore`) data object.\n        node_sampler (torch_geometric.sampler.BaseSampler): The sampler\n            implementation to be used with this loader.\n            Needs to implement\n            :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes`.\n            The sampler implementation must be compatible with the input\n            :obj:`data` object.\n        input_nodes (torch.Tensor or str or Tuple[str, torch.Tensor]): The\n            indices of seed nodes to start sampling from.\n            Needs to be either given as a :obj:`torch.LongTensor` or\n            :obj:`torch.BoolTensor`.\n            If set to :obj:`None`, all nodes will be considered.\n            In heterogeneous graphs, needs to be passed as a tuple that holds\n            the node type and node indices. (default: :obj:`None`)\n        input_time (torch.Tensor, optional): Optional values to override the\n            timestamp for the input nodes given in :obj:`input_nodes`. If not\n            set, will use the timestamps in :obj:`time_attr` as default (if\n            present). The :obj:`time_attr` needs to be set for this to work.\n            (default: :obj:`None`)\n        transform (callable, optional): A function/transform that takes in\n            a sampled mini-batch and returns a transformed version.\n            (default: :obj:`None`)\n        transform_sampler_output (callable, optional): A function/transform\n            that takes in a :class:`torch_geometric.sampler.SamplerOutput` and\n            returns a transformed version. (default: :obj:`None`)\n        filter_per_worker (bool, optional): If set to :obj:`True`, will filter\n            the returned data in each worker's subprocess.\n            If set to :obj:`False`, will filter the returned data in the main\n            process.\n            If set to :obj:`None`, will automatically infer the decision based\n            on whether data partially lives on the GPU\n            (:obj:`filter_per_worker=True`) or entirely on the CPU\n            (:obj:`filter_per_worker=False`).\n            There exists different trade-offs for setting this option.\n            Specifically, setting this option to :obj:`True` for in-memory\n            datasets will move all features to shared memory, which may result\n            in too many open file handles. (default: :obj:`None`)\n        custom_cls (HeteroData, optional): A custom\n            :class:`~torch_geometric.data.HeteroData` class to return for\n            mini-batches in case of remote backends. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],\n        node_sampler: BaseSampler,\n        input_nodes: InputNodes = None,\n        input_time: OptTensor = None,\n        transform: Optional[Callable] = None,\n        transform_sampler_output: Optional[Callable] = None,\n        filter_per_worker: Optional[bool] = None,\n        custom_cls: Optional[HeteroData] = None,\n        input_id: OptTensor = None,\n        **kwargs,\n    ):\n        if filter_per_worker is None:\n            filter_per_worker = infer_filter_per_worker(data)\n\n        self.data = data\n        self.node_sampler = node_sampler\n        self.input_nodes = input_nodes\n        self.input_time = input_time\n        self.transform = transform\n        self.transform_sampler_output = transform_sampler_output\n        self.filter_per_worker = filter_per_worker\n        self.custom_cls = custom_cls\n        self.input_id = input_id\n\n        kwargs.pop('dataset', None)\n        kwargs.pop('collate_fn', None)\n\n        # Get node type (or `None` for homogeneous graphs):\n        input_type, input_nodes, input_id = get_input_nodes(\n            data, input_nodes, input_id)\n\n        self.input_data = NodeSamplerInput(\n            input_id=input_id,\n            node=input_nodes,\n            time=input_time,\n            input_type=input_type,\n        )\n\n        iterator = range(input_nodes.size(0))\n        super().__init__(iterator, collate_fn=self.collate_fn, **kwargs)\n\n    def __call__(\n        self,\n        index: Union[Tensor, List[int]],\n    ) -> Union[Data, HeteroData]:\n        r\"\"\"Samples a subgraph from a batch of input nodes.\"\"\"\n        out = self.collate_fn(index)\n        if not self.filter_per_worker:\n            out = self.filter_fn(out)\n        return out\n\n    def collate_fn(self, index: Union[Tensor, List[int]]) -> Any:\n        r\"\"\"Samples a subgraph from a batch of input nodes.\"\"\"\n        input_data: NodeSamplerInput = self.input_data[index]\n\n        out = self.node_sampler.sample_from_nodes(input_data)\n\n        if self.filter_per_worker:  # Execute `filter_fn` in the worker process\n            out = self.filter_fn(out)\n\n        return out\n\n    def filter_fn(\n        self,\n        out: Union[SamplerOutput, HeteroSamplerOutput],\n    ) -> Union[Data, HeteroData]:\n        r\"\"\"Joins the sampled nodes with their corresponding features,\n        returning the resulting :class:`~torch_geometric.data.Data` or\n        :class:`~torch_geometric.data.HeteroData` object to be used downstream.\n        \"\"\"\n        if self.transform_sampler_output:\n            out = self.transform_sampler_output(out)\n\n        if isinstance(out, SamplerOutput):\n            if isinstance(self.data, Data):\n                data = filter_data(  #\n                    self.data, out.node, out.row, out.col, out.edge,\n                    self.node_sampler.edge_permutation)\n\n            else:  # Tuple[FeatureStore, GraphStore]\n\n                # Hack to detect whether we are in a distributed setting.\n                if (self.node_sampler.__class__.__name__ ==\n                        'DistNeighborSampler'):\n                    edge_index = torch.stack([out.row, out.col])\n                    data = Data(edge_index=edge_index)\n                    # Metadata entries are populated in\n                    # `DistributedNeighborSampler._collate_fn()`\n                    data.x = out.metadata[-3]\n                    data.y = out.metadata[-2]\n                    data.edge_attr = out.metadata[-1]\n                else:\n                    data = filter_custom_store(  #\n                        *self.data, out.node, out.row, out.col, out.edge,\n                        self.custom_cls)\n\n            if 'n_id' not in data:\n                data.n_id = out.node\n            if out.edge is not None and 'e_id' not in data:\n                edge = out.edge.to(torch.long)\n                perm = self.node_sampler.edge_permutation\n                data.e_id = perm[edge] if perm is not None else edge\n\n            data.batch = out.batch\n            data.num_sampled_nodes = out.num_sampled_nodes\n            data.num_sampled_edges = out.num_sampled_edges\n\n            if out.orig_row is not None and out.orig_col is not None:\n                data._orig_edge_index = torch.stack([\n                    out.orig_row,\n                    out.orig_col,\n                ], dim=0)\n\n            data.input_id = out.metadata[0]\n            data.seed_time = out.metadata[1]\n            data.batch_size = out.metadata[0].size(0)\n\n        elif isinstance(out, HeteroSamplerOutput):\n            if isinstance(self.data, HeteroData):\n                data = filter_hetero_data(  #\n                    self.data, out.node, out.row, out.col, out.edge,\n                    self.node_sampler.edge_permutation)\n\n            else:  # Tuple[FeatureStore, GraphStore]\n\n                # Hack to detect whether we are in a distributed setting.\n                if (self.node_sampler.__class__.__name__ ==\n                        'DistNeighborSampler'):\n                    import torch_geometric.distributed as dist\n\n                    data = dist.utils.filter_dist_store(\n                        *self.data, out.node, out.row, out.col, out.edge,\n                        self.custom_cls, out.metadata,\n                        self.input_data.input_type)\n                else:\n                    data = filter_custom_hetero_store(  #\n                        *self.data, out.node, out.row, out.col, out.edge,\n                        self.custom_cls)\n\n            for key, node in out.node.items():\n                if 'n_id' not in data[key]:\n                    data[key].n_id = node\n\n            for key, edge in (out.edge or {}).items():\n                if edge is not None and 'e_id' not in data[key]:\n                    edge = edge.to(torch.long)\n                    perm = self.node_sampler.edge_permutation\n                    if perm is not None and perm.get(key, None) is not None:\n                        edge = perm[key][edge]\n                    data[key].e_id = edge\n\n            data.set_value_dict('batch', out.batch)\n            data.set_value_dict('num_sampled_nodes', out.num_sampled_nodes)\n            data.set_value_dict('num_sampled_edges', out.num_sampled_edges)\n\n            if out.orig_row is not None and out.orig_col is not None:\n                for key in out.orig_row.keys():\n                    data[key]._orig_edge_index = torch.stack([\n                        out.orig_row[key],\n                        out.orig_col[key],\n                    ], dim=0)\n\n            input_type = self.input_data.input_type\n            data[input_type].input_id = out.metadata[0]\n            data[input_type].seed_time = out.metadata[1]\n            data[input_type].batch_size = out.metadata[0].size(0)\n\n        else:\n            raise TypeError(f\"'{self.__class__.__name__}'' found invalid \"\n                            f\"type: '{type(out)}'\")\n\n        return data if self.transform is None else self.transform(data)\n\n    def _get_iterator(self) -> Iterator:\n        if self.filter_per_worker:\n            return super()._get_iterator()\n\n        # if not self.is_cuda_available and not self.cpu_affinity_enabled:\n        # TODO: Add manual page for best CPU practices\n        # link = ...\n        # Warning('Dataloader CPU affinity opt is not enabled, consider '\n        #          'switching it on with enable_cpu_affinity() or see CPU '\n        #          f'best practices for PyG [{link}])')\n\n        # Execute `filter_fn` in the main process:\n        return DataLoaderIterator(super()._get_iterator(), self.filter_fn)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/loader/prefetch.py",
    "content": "import warnings\nfrom contextlib import nullcontext\nfrom functools import partial\nfrom typing import Any, Optional\n\nimport torch\nfrom torch.utils.data import DataLoader\n\nfrom torch_geometric.typing import WITH_IPEX\n\n\nclass DeviceHelper:\n    def __init__(self, device: Optional[torch.device] = None):\n        with_cuda = torch.cuda.is_available()\n        with_xpu = torch.xpu.is_available() if WITH_IPEX else False\n\n        if device is None:\n            if with_cuda:\n                device = 'cuda'\n            elif with_xpu:\n                device = 'xpu'\n            else:\n                device = 'cpu'\n\n        self.device = torch.device(device)\n        self.is_gpu = self.device.type in ['cuda', 'xpu']\n\n        if ((self.device.type == 'cuda' and not with_cuda)\n                or (self.device.type == 'xpu' and not with_xpu)):\n            warnings.warn(\n                f\"Requested device '{self.device.type}' is not \"\n                f\"available, falling back to CPU\", stacklevel=2)\n            self.device = torch.device('cpu')\n\n        self.stream = None\n        self.stream_context = nullcontext\n        self.module = getattr(torch, self.device.type) if self.is_gpu else None\n\n    def maybe_init_stream(self) -> None:\n        if self.is_gpu:\n            self.stream = self.module.Stream()\n            self.stream_context = partial(\n                self.module.stream,\n                stream=self.stream,\n            )\n\n    def maybe_wait_stream(self) -> None:\n        if self.stream is not None:\n            self.module.current_stream().wait_stream(self.stream)\n\n\nclass PrefetchLoader:\n    r\"\"\"A GPU prefetcher class for asynchronously transferring data of a\n    :class:`torch.utils.data.DataLoader` from host memory to device memory.\n\n    Args:\n        loader (torch.utils.data.DataLoader): The data loader.\n        device (torch.device, optional): The device to load the data to.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        loader: DataLoader,\n        device: Optional[torch.device] = None,\n    ):\n        self.loader = loader\n        self.device_helper = DeviceHelper(device)\n\n    def non_blocking_transfer(self, batch: Any) -> Any:\n        if not self.device_helper.is_gpu:\n            return batch\n        if isinstance(batch, (list, tuple)):\n            return [self.non_blocking_transfer(v) for v in batch]\n        if isinstance(batch, dict):\n            return {k: self.non_blocking_transfer(v) for k, v in batch.items()}\n\n        batch = batch.pin_memory()\n        return batch.to(self.device_helper.device, non_blocking=True)\n\n    def __iter__(self) -> Any:\n        first = True\n        self.device_helper.maybe_init_stream()\n\n        batch = None\n        for next_batch in self.loader:\n\n            with self.device_helper.stream_context():\n                next_batch = self.non_blocking_transfer(next_batch)\n\n            if not first:\n                yield batch\n            else:\n                first = False\n\n            self.device_helper.maybe_wait_stream()\n\n            batch = next_batch\n\n        yield batch\n\n    def __len__(self) -> int:\n        return len(self.loader)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.loader})'\n"
  },
  {
    "path": "torch_geometric/loader/random_node_loader.py",
    "content": "import math\nfrom typing import Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.hetero_data import to_homogeneous_edge_index\n\n\nclass RandomNodeLoader(torch.utils.data.DataLoader):\n    r\"\"\"A data loader that randomly samples nodes within a graph and returns\n    their induced subgraph.\n\n    .. note::\n\n        For an example of using\n        :class:`~torch_geometric.loader.RandomNodeLoader`, see\n        `examples/ogbn_proteins_deepgcn.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        ogbn_proteins_deepgcn.py>`_.\n\n    Args:\n        data (torch_geometric.data.Data or torch_geometric.data.HeteroData):\n            The :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` graph object.\n        num_parts (int): The number of partitions.\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`num_workers`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Union[Data, HeteroData],\n        num_parts: int,\n        **kwargs,\n    ):\n        self.data = data\n        self.num_parts = num_parts\n\n        if isinstance(data, HeteroData):\n            edge_index, node_dict, edge_dict = to_homogeneous_edge_index(data)\n            self.node_dict, self.edge_dict = node_dict, edge_dict\n        else:\n            edge_index = data.edge_index\n\n        self.edge_index = edge_index\n        self.num_nodes = data.num_nodes\n\n        super().__init__(\n            range(self.num_nodes),\n            batch_size=math.ceil(self.num_nodes / num_parts),\n            collate_fn=self.collate_fn,\n            **kwargs,\n        )\n\n    def collate_fn(self, index):\n        if not isinstance(index, Tensor):\n            index = torch.tensor(index)\n\n        if isinstance(self.data, Data):\n            return self.data.subgraph(index)\n\n        elif isinstance(self.data, HeteroData):\n            node_dict = {\n                key: index[(index >= start) & (index < end)] - start\n                for key, (start, end) in self.node_dict.items()\n            }\n            return self.data.subgraph(node_dict)\n"
  },
  {
    "path": "torch_geometric/loader/shadow.py",
    "content": "import copy\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Batch, Data\nfrom torch_geometric.typing import WITH_TORCH_SPARSE, SparseTensor\n\n\nclass ShaDowKHopSampler(torch.utils.data.DataLoader):\n    r\"\"\"The ShaDow :math:`k`-hop sampler from the `\"Decoupling the Depth and\n    Scope of Graph Neural Networks\" <https://arxiv.org/abs/2201.07858>`_ paper.\n    Given a graph in a :obj:`data` object, the sampler will create shallow,\n    localized subgraphs.\n    A deep GNN on this local graph then smooths the informative local signals.\n\n    .. note::\n\n        For an example of using :class:`ShaDowKHopSampler`, see\n        `examples/shadow.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/shadow.py>`_.\n\n    Args:\n        data (torch_geometric.data.Data): The graph data object.\n        depth (int): The depth/number of hops of the localized subgraph.\n        num_neighbors (int): The number of neighbors to sample for each node in\n            each hop.\n        node_idx (LongTensor or BoolTensor, optional): The nodes that should be\n            considered for creating mini-batches.\n            If set to :obj:`None`, all nodes will be\n            considered.\n        replace (bool, optional): If set to :obj:`True`, will sample neighbors\n            with replacement. (default: :obj:`False`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size` or\n            :obj:`num_workers`.\n    \"\"\"\n    def __init__(self, data: Data, depth: int, num_neighbors: int,\n                 node_idx: Optional[Tensor] = None, replace: bool = False,\n                 **kwargs):\n\n        if not WITH_TORCH_SPARSE:\n            raise ImportError(\n                f\"'{self.__class__.__name__}' requires 'torch-sparse'\")\n\n        self.data = copy.copy(data)\n        self.depth = depth\n        self.num_neighbors = num_neighbors\n        self.replace = replace\n\n        if data.edge_index is not None:\n            self.is_sparse_tensor = False\n            row, col = data.edge_index.cpu()\n            self.adj_t = SparseTensor(\n                row=row, col=col, value=torch.arange(col.size(0)),\n                sparse_sizes=(data.num_nodes, data.num_nodes)).t()\n        else:\n            self.is_sparse_tensor = True\n            self.adj_t = data.adj_t.cpu()\n\n        if node_idx is None:\n            node_idx = torch.arange(self.adj_t.sparse_size(0))\n        elif node_idx.dtype == torch.bool:\n            node_idx = node_idx.nonzero(as_tuple=False).view(-1)\n        self.node_idx = node_idx\n\n        super().__init__(node_idx.tolist(), collate_fn=self.__collate__,\n                         **kwargs)\n\n    def __collate__(self, n_id):\n        n_id = torch.tensor(n_id)\n\n        rowptr, col, value = self.adj_t.csr()\n        out = torch.ops.torch_sparse.ego_k_hop_sample_adj(\n            rowptr, col, n_id, self.depth, self.num_neighbors, self.replace)\n        rowptr, col, n_id, e_id, ptr, root_n_id = out\n\n        adj_t = SparseTensor(rowptr=rowptr, col=col,\n                             value=value[e_id] if value is not None else None,\n                             sparse_sizes=(n_id.numel(), n_id.numel()),\n                             is_sorted=True, trust_data=True)\n\n        batch = Batch(batch=torch.ops.torch_sparse.ptr2ind(ptr, n_id.numel()),\n                      ptr=ptr)\n        batch.root_n_id = root_n_id\n\n        if self.is_sparse_tensor:\n            batch.adj_t = adj_t\n        else:\n            row, col, e_id = adj_t.t().coo()\n            batch.edge_index = torch.stack([row, col], dim=0)\n\n        for k, v in self.data:\n            if k in ['edge_index', 'adj_t', 'num_nodes', 'batch', 'ptr']:\n                continue\n            if k == 'y' and v.size(0) == self.data.num_nodes:\n                batch[k] = v[n_id][root_n_id]\n            elif isinstance(v, Tensor) and v.size(0) == self.data.num_nodes:\n                batch[k] = v[n_id]\n            elif isinstance(v, Tensor) and v.size(0) == self.data.num_edges:\n                batch[k] = v[e_id]\n            else:\n                batch[k] = v\n\n        return batch\n"
  },
  {
    "path": "torch_geometric/loader/temporal_dataloader.py",
    "content": "from typing import List\n\nimport torch\n\nfrom torch_geometric.data import TemporalData\n\n\nclass TemporalDataLoader(torch.utils.data.DataLoader):\n    r\"\"\"A data loader which merges successive events of a\n    :class:`torch_geometric.data.TemporalData` to a mini-batch.\n\n    Args:\n        data (TemporalData): The :obj:`~torch_geometric.data.TemporalData`\n            from which to load the data.\n        batch_size (int, optional): How many samples per batch to load.\n            (default: :obj:`1`)\n        neg_sampling_ratio (float, optional): The ratio of sampled negative\n            destination nodes to the number of positive destination nodes.\n            (default: :obj:`0.0`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`.\n    \"\"\"\n    def __init__(\n        self,\n        data: TemporalData,\n        batch_size: int = 1,\n        neg_sampling_ratio: float = 0.0,\n        **kwargs,\n    ):\n        # Remove for PyTorch Lightning:\n        kwargs.pop('dataset', None)\n        kwargs.pop('collate_fn', None)\n        kwargs.pop('shuffle', None)\n\n        self.data = data\n        self.events_per_batch = batch_size\n        self.neg_sampling_ratio = neg_sampling_ratio\n\n        if neg_sampling_ratio > 0:\n            self.min_dst = int(data.dst.min())\n            self.max_dst = int(data.dst.max())\n\n        if kwargs.get('drop_last', False) and len(data) % batch_size != 0:\n            arange = range(0, len(data) - batch_size, batch_size)\n        else:\n            arange = range(0, len(data), batch_size)\n\n        super().__init__(arange, 1, shuffle=False, collate_fn=self, **kwargs)\n\n    def __call__(self, arange: List[int]) -> TemporalData:\n        batch = self.data[arange[0]:arange[0] + self.events_per_batch]\n\n        n_ids = [batch.src, batch.dst]\n\n        if self.neg_sampling_ratio > 0:\n            batch.neg_dst = torch.randint(\n                low=self.min_dst,\n                high=self.max_dst + 1,\n                size=(round(self.neg_sampling_ratio * batch.dst.size(0)), ),\n                dtype=batch.dst.dtype,\n                device=batch.dst.device,\n            )\n            n_ids += [batch.neg_dst]\n\n        batch.n_id = torch.cat(n_ids, dim=0).unique()\n\n        return batch\n"
  },
  {
    "path": "torch_geometric/loader/utils.py",
    "content": "import copy\nimport logging\nimport math\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.data import (\n    Data,\n    FeatureStore,\n    GraphStore,\n    HeteroData,\n    TensorAttr,\n    remote_backend_utils,\n)\nfrom torch_geometric.data.storage import EdgeStorage, NodeStorage\nfrom torch_geometric.typing import (\n    EdgeType,\n    FeatureTensorType,\n    InputEdges,\n    InputNodes,\n    NodeType,\n    OptTensor,\n    SparseTensor,\n    TensorFrame,\n)\n\n\ndef index_select(\n    value: FeatureTensorType,\n    index: Tensor,\n    dim: int = 0,\n) -> Tensor:\n    r\"\"\"Indexes the :obj:`value` tensor along dimension :obj:`dim` using the\n    entries in :obj:`index`.\n\n    Args:\n        value (torch.Tensor or np.ndarray): The input tensor.\n        index (torch.Tensor): The 1-D tensor containing the indices to index.\n        dim (int, optional): The dimension in which to index.\n            (default: :obj:`0`)\n\n    .. warning::\n\n        :obj:`index` is casted to a :obj:`torch.int64` tensor internally, as\n        `PyTorch currently only supports indexing\n        <https://github.com/pytorch/pytorch/issues/61819>`_ via\n        :obj:`torch.int64`.\n    \"\"\"\n    # PyTorch currently only supports indexing via `torch.int64`:\n    # https://github.com/pytorch/pytorch/issues/61819\n    index = index.to(torch.int64)\n\n    if isinstance(value, Tensor):\n        out: Optional[Tensor] = None\n        if torch.utils.data.get_worker_info() is not None:\n            # If we are in a background process, we write directly into a\n            # shared memory tensor to avoid an extra copy:\n            size = list(value.shape)\n            size[dim] = index.numel()\n            numel = math.prod(size)\n            if torch_geometric.typing.WITH_PT20:\n                storage = value.untyped_storage()._new_shared(\n                    numel * value.element_size())\n            else:\n                storage = value.storage()._new_shared(numel)\n            out = value.new(storage).view(size)\n\n        return torch.index_select(value, dim, index, out=out)\n\n    if isinstance(value, TensorFrame):\n        assert dim == 0\n        return value[index]\n\n    elif isinstance(value, np.ndarray):\n        return torch.from_numpy(np.take(value, index, axis=dim))\n\n    raise ValueError(f\"Encountered invalid feature tensor type \"\n                     f\"(got '{type(value)}')\")\n\n\ndef filter_node_store_(store: NodeStorage, out_store: NodeStorage,\n                       index: Tensor):\n    # Filters a node storage object to only hold the nodes in `index`:\n    for key, value in store.items():\n        if key == 'num_nodes':\n            out_store.num_nodes = index.numel()\n\n        elif store.is_node_attr(key):\n            if isinstance(value, (Tensor, TensorFrame)):\n                index = index.to(value.device)\n            elif isinstance(value, np.ndarray):\n                index = index.cpu()\n            dim = store._parent().__cat_dim__(key, value, store)\n            out_store[key] = index_select(value, index, dim=dim)\n\n\ndef filter_edge_store_(store: EdgeStorage, out_store: EdgeStorage, row: Tensor,\n                       col: Tensor, index: OptTensor, perm: OptTensor = None):\n    # Filters a edge storage object to only hold the edges in `index`,\n    # which represents the new graph as denoted by `(row, col)`:\n    for key, value in store.items():\n        if key == 'edge_index':\n            edge_index = torch.stack([row, col], dim=0).to(value.device)\n            # TODO Integrate `EdgeIndex` into `custom_store`.\n            # edge_index = EdgeIndex(\n            #     torch.stack([row, col], dim=0).to(value.device),\n            #     sparse_size=out_store.size(),\n            #     sort_order='col',\n            #     # TODO Support `is_undirected`.\n            # )\n            out_store.edge_index = edge_index\n\n        elif key == 'adj_t':\n            # NOTE: We expect `(row, col)` to be sorted by `col` (CSC layout).\n            row = row.to(value.device())\n            col = col.to(value.device())\n            edge_attr = value.storage.value()\n            if edge_attr is not None:\n                if index is not None:\n                    index = index.to(edge_attr.device)\n                    edge_attr = index_select(edge_attr, index, dim=0)\n                else:\n                    edge_attr = None\n            sparse_sizes = out_store.size()[::-1]\n            # TODO Currently, we set `is_sorted=False`, see:\n            # https://github.com/pyg-team/pytorch_geometric/issues/4346\n            out_store.adj_t = SparseTensor(row=col, col=row, value=edge_attr,\n                                           sparse_sizes=sparse_sizes,\n                                           is_sorted=False, trust_data=True)\n\n        elif store.is_edge_attr(key):\n            if index is None:\n                out_store[key] = None\n                continue\n\n            dim = store._parent().__cat_dim__(key, value, store)\n            if isinstance(value, (Tensor, TensorFrame)):\n                index = index.to(value.device)\n            elif isinstance(value, np.ndarray):\n                index = index.cpu()\n            if perm is None:\n                out_store[key] = index_select(value, index, dim=dim)\n            else:\n                if isinstance(value, (Tensor, TensorFrame)):\n                    perm = perm.to(value.device)\n                elif isinstance(value, np.ndarray):\n                    perm = perm.cpu()\n                out_store[key] = index_select(\n                    value,\n                    perm[index.to(torch.int64)],\n                    dim=dim,\n                )\n\n\ndef filter_data(data: Data, node: Tensor, row: Tensor, col: Tensor,\n                edge: OptTensor, perm: OptTensor = None) -> Data:\n    # Filters a data object to only hold nodes in `node` and edges in `edge`:\n    out = copy.copy(data)\n    filter_node_store_(data._store, out._store, node)\n    filter_edge_store_(data._store, out._store, row, col, edge, perm)\n    return out\n\n\ndef filter_hetero_data(\n    data: HeteroData,\n    node_dict: Dict[NodeType, Tensor],\n    row_dict: Dict[EdgeType, Tensor],\n    col_dict: Dict[EdgeType, Tensor],\n    edge_dict: Dict[EdgeType, OptTensor],\n    perm_dict: Optional[Dict[EdgeType, OptTensor]] = None,\n) -> HeteroData:\n    # Filters a heterogeneous data object to only hold nodes in `node` and\n    # edges in `edge` for each node and edge type, respectively:\n    out = copy.copy(data)\n\n    for node_type in out.node_types:\n        # Handle the case of disconnected graph sampling:\n        if node_type not in node_dict:\n            node_dict[node_type] = torch.empty(0, dtype=torch.long)\n\n        filter_node_store_(data[node_type], out[node_type],\n                           node_dict[node_type])\n\n    for edge_type in out.edge_types:\n        # Handle the case of disconnected graph sampling:\n        if edge_type not in row_dict:\n            row_dict[edge_type] = torch.empty(0, dtype=torch.long)\n        if edge_type not in col_dict:\n            col_dict[edge_type] = torch.empty(0, dtype=torch.long)\n        if edge_type not in edge_dict:\n            edge_dict[edge_type] = torch.empty(0, dtype=torch.long)\n\n        filter_edge_store_(\n            data[edge_type],\n            out[edge_type],\n            row_dict[edge_type],\n            col_dict[edge_type],\n            edge_dict[edge_type],\n            perm_dict.get(edge_type, None) if perm_dict else None,\n        )\n\n    return out\n\n\ndef filter_custom_store(\n    feature_store: FeatureStore,\n    graph_store: GraphStore,\n    node: Tensor,\n    row: Tensor,\n    col: Tensor,\n    edge: OptTensor,\n    custom_cls: Optional[Data] = None,\n) -> Data:\n    r\"\"\"Constructs a :class:`~torch_geometric.data.Data` object from a feature\n    store and graph store instance.\n    \"\"\"\n    # Construct a new `Data` object:\n    data = custom_cls() if custom_cls is not None else Data()\n\n    data.edge_index = torch.stack([row, col], dim=0)\n\n    # Filter node storage:\n    required_attrs = []\n    for attr in feature_store.get_all_tensor_attrs():\n        attr.index = node  # TODO Support edge features.\n        required_attrs.append(attr)\n        data.num_nodes = attr.index.size(0)\n\n    # NOTE Here, we utilize `feature_store.multi_get` to give the feature store\n    # full control over optimizing how it returns features (since the call is\n    # synchronous, this amounts to giving the feature store control over all\n    # iteration).\n    tensors = feature_store.multi_get_tensor(required_attrs)\n    for i, attr in enumerate(required_attrs):\n        data[attr.attr_name] = tensors[i]\n\n    return data\n\n\ndef filter_custom_hetero_store(\n    feature_store: FeatureStore,\n    graph_store: GraphStore,\n    node_dict: Dict[str, Tensor],\n    row_dict: Dict[str, Tensor],\n    col_dict: Dict[str, Tensor],\n    edge_dict: Dict[str, OptTensor],\n    custom_cls: Optional[HeteroData] = None,\n) -> HeteroData:\n    r\"\"\"Constructs a :class:`~torch_geometric.data.HeteroData` object from a\n    feature store and graph store instance.\n    \"\"\"\n    # Construct a new `HeteroData` object:\n    data = custom_cls() if custom_cls is not None else HeteroData()\n\n    # Filter node storage:\n    required_attrs = []\n    for attr in feature_store.get_all_tensor_attrs():\n        if attr.group_name in node_dict:\n            attr.index = node_dict[attr.group_name]\n            required_attrs.append(attr)\n            data[attr.group_name].num_nodes = attr.index.size(0)\n\n    # NOTE Here, we utilize `feature_store.multi_get` to give the feature store\n    # full control over optimizing how it returns features (since the call is\n    # synchronous, this amounts to giving the feature store control over all\n    # iteration).\n    tensors = feature_store.multi_get_tensor(required_attrs)\n    for i, attr in enumerate(required_attrs):\n        data[attr.group_name][attr.attr_name] = tensors[i]\n\n    # Filter edge storage:\n    # TODO support edge attributes\n    for attr in graph_store.get_all_edge_attrs():\n        key = attr.edge_type\n        if key in row_dict and key in col_dict:\n            edge_index = torch.stack([row_dict[key], col_dict[key]], dim=0)\n            data[attr.edge_type].edge_index = edge_index\n\n    return data\n\n\n# Input Utilities #############################################################\n\n\ndef get_input_nodes(\n    data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],\n    input_nodes: Union[InputNodes, TensorAttr],\n    input_id: Optional[Tensor] = None,\n) -> Tuple[Optional[str], Tensor, Optional[Tensor]]:\n    def to_index(nodes, input_id) -> Tuple[Tensor, Optional[Tensor]]:\n        if isinstance(nodes, Tensor) and nodes.dtype == torch.bool:\n            nodes = nodes.nonzero(as_tuple=False).view(-1)\n            if input_id is not None:\n                assert input_id.numel() == nodes.numel()\n            else:\n                input_id = nodes\n            return nodes, input_id\n\n        if not isinstance(nodes, Tensor):\n            nodes = torch.tensor(nodes, dtype=torch.long)\n\n        if input_id is not None:\n            assert input_id.numel() == nodes.numel()\n\n        return nodes, input_id\n\n    if isinstance(data, Data):\n        if input_nodes is None:\n            return None, torch.arange(data.num_nodes), None\n        return None, *to_index(input_nodes, input_id)\n\n    elif isinstance(data, HeteroData):\n        assert input_nodes is not None\n\n        if isinstance(input_nodes, str):\n            return input_nodes, torch.arange(data[input_nodes].num_nodes), None\n\n        assert isinstance(input_nodes, (list, tuple))\n        assert len(input_nodes) == 2\n        assert isinstance(input_nodes[0], str)\n\n        node_type, input_nodes = input_nodes\n        if input_nodes is None:\n            return node_type, torch.arange(data[node_type].num_nodes), None\n        return node_type, *to_index(input_nodes, input_id)\n\n    else:  # Tuple[FeatureStore, GraphStore]\n        feature_store, graph_store = data\n        assert input_nodes is not None\n\n        if isinstance(input_nodes, Tensor):\n            return None, *to_index(input_nodes, input_id)\n\n        if isinstance(input_nodes, str):\n            num_nodes = remote_backend_utils.num_nodes(  #\n                feature_store, graph_store, input_nodes)\n            return input_nodes, torch.arange(num_nodes), None\n\n        if isinstance(input_nodes, (list, tuple)):\n            assert len(input_nodes) == 2\n            assert isinstance(input_nodes[0], str)\n\n            node_type, input_nodes = input_nodes\n            if input_nodes is None:\n                num_nodes = remote_backend_utils.num_nodes(  #\n                    feature_store, graph_store, input_nodes)\n                return node_type, torch.arange(num_nodes), None\n\n            return node_type, *to_index(input_nodes, input_id)\n\n\ndef get_edge_label_index(\n    data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],\n    edge_label_index: InputEdges,\n) -> Tuple[Optional[str], Tensor]:\n    edge_type = None\n    if isinstance(data, Data):\n        if edge_label_index is None:\n            return None, data.edge_index\n        return None, edge_label_index\n\n    assert edge_label_index is not None\n    assert isinstance(edge_label_index, (list, tuple))\n\n    if isinstance(data, HeteroData):\n        if isinstance(edge_label_index[0], str):\n            edge_type = edge_label_index\n            edge_type = data._to_canonical(*edge_type)\n            assert edge_type in data.edge_types\n            return edge_type, data[edge_type].edge_index\n\n        assert len(edge_label_index) == 2\n\n        edge_type, edge_label_index = edge_label_index\n        edge_type = data._to_canonical(*edge_type)\n\n        if edge_label_index is None:\n            return edge_type, data[edge_type].edge_index\n\n        return edge_type, edge_label_index\n\n    else:  # Tuple[FeatureStore, GraphStore]\n        _, graph_store = data\n\n        # Need the edge index in COO for LinkNeighborLoader:\n        def _get_edge_index(edge_type):\n            row_dict, col_dict, _ = graph_store.coo([edge_type])\n            row = list(row_dict.values())[0]\n            col = list(col_dict.values())[0]\n            return torch.stack((row, col), dim=0)\n\n        if isinstance(edge_label_index[0], str):\n            edge_type = edge_label_index\n            return edge_type, _get_edge_index(edge_type)\n\n        assert len(edge_label_index) == 2\n        edge_type, edge_label_index = edge_label_index\n\n        if edge_label_index is None:\n            return edge_type, _get_edge_index(edge_type)\n\n        return edge_type, edge_label_index\n\n\ndef infer_filter_per_worker(data: Any) -> bool:\n    out = True\n    if isinstance(data, (Data, HeteroData)) and data.is_cuda:\n        out = False\n    logging.debug(f\"Inferred 'filter_per_worker={out}' option for feature \"\n                  f\"fetching routines of the data loader\")\n    return out\n"
  },
  {
    "path": "torch_geometric/loader/zip_loader.py",
    "content": "from typing import Any, Iterator, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.loader import LinkLoader, NodeLoader\nfrom torch_geometric.loader.base import DataLoaderIterator\nfrom torch_geometric.loader.utils import infer_filter_per_worker\n\n\nclass ZipLoader(torch.utils.data.DataLoader):\n    r\"\"\"A loader that returns a tuple of data objects by sampling from multiple\n    :class:`NodeLoader` or :class:`LinkLoader` instances.\n\n    Args:\n        loaders (List[NodeLoader] or List[LinkLoader]): The loader instances.\n        filter_per_worker (bool, optional): If set to :obj:`True`, will filter\n            the returned data in each worker's subprocess.\n            If set to :obj:`False`, will filter the returned data in the main\n            process.\n            If set to :obj:`None`, will automatically infer the decision based\n            on whether data partially lives on the GPU\n            (:obj:`filter_per_worker=True`) or entirely on the CPU\n            (:obj:`filter_per_worker=False`).\n            There exists different trade-offs for setting this option.\n            Specifically, setting this option to :obj:`True` for in-memory\n            datasets will move all features to shared memory, which may result\n            in too many open file handles. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch.utils.data.DataLoader`, such as :obj:`batch_size`,\n            :obj:`shuffle`, :obj:`drop_last` or :obj:`num_workers`.\n    \"\"\"\n    def __init__(\n        self,\n        loaders: Union[List[NodeLoader], List[LinkLoader]],\n        filter_per_worker: Optional[bool] = None,\n        **kwargs,\n    ):\n        if filter_per_worker is None:\n            filter_per_worker = infer_filter_per_worker(loaders[0].data)\n\n        # Remove for PyTorch Lightning:\n        kwargs.pop('dataset', None)\n        kwargs.pop('collate_fn', None)\n\n        for loader in loaders:\n            if not callable(getattr(loader, 'collate_fn', None)):\n                raise ValueError(\"'{loader.__class__.__name__}' does not have \"\n                                 \"a 'collate_fn' method\")\n            if not callable(getattr(loader, 'filter_fn', None)):\n                raise ValueError(\"'{loader.__class__.__name__}' does not have \"\n                                 \"a 'filter_fn' method\")\n            loader.filter_per_worker = filter_per_worker\n\n        iterator = range(min([len(loader.dataset) for loader in loaders]))\n        super().__init__(iterator, collate_fn=self.collate_fn, **kwargs)\n\n        self.loaders = loaders\n        self.filter_per_worker = filter_per_worker\n\n    def __call__(\n        self,\n        index: Union[Tensor, List[int]],\n    ) -> Union[Tuple[Data, ...], Tuple[HeteroData, ...]]:\n        r\"\"\"Samples subgraphs from a batch of input IDs.\"\"\"\n        out = self.collate_fn(index)\n        if not self.filter_per_worker:\n            out = self.filter_fn(out)\n        return out\n\n    def collate_fn(self, index: List[int]) -> Tuple[Any, ...]:\n        if not isinstance(index, Tensor):\n            index = torch.tensor(index, dtype=torch.long)\n\n        return tuple(loader.collate_fn(index) for loader in self.loaders)\n\n    def filter_fn(\n        self,\n        outs: Tuple[Any, ...],\n    ) -> Tuple[Union[Data, HeteroData], ...]:\n        loaders = self.loaders\n        return tuple(loader.filter_fn(v) for loader, v in zip(loaders, outs))\n\n    def _get_iterator(self) -> Iterator:\n        if self.filter_per_worker:\n            return super()._get_iterator()\n\n        # Execute `filter_fn` in the main process:\n        return DataLoaderIterator(super()._get_iterator(), self.filter_fn)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(loaders={self.loaders})'\n"
  },
  {
    "path": "torch_geometric/logging.py",
    "content": "import sys\nfrom typing import Any\n\n_wandb_initialized: bool = False\n\n\ndef init_wandb(name: str, **kwargs: Any) -> None:\n    if '--wandb' not in sys.argv:\n        return\n\n    from datetime import datetime\n\n    import wandb\n\n    wandb.init(\n        project=name,\n        entity='pytorch-geometric',\n        name=datetime.now().strftime('%Y-%m-%d_%H:%M'),\n        config=kwargs,\n    )\n\n    global _wandb_initialized\n    _wandb_initialized = True\n\n\ndef log(**kwargs: Any) -> None:\n    def _map(value: Any) -> str:\n        if isinstance(value, int) and not isinstance(value, bool):\n            return f'{value:03d}'\n        if isinstance(value, float):\n            return f'{value:.4f}'\n        return value\n\n    print(', '.join(f'{key}: {_map(value)}' for key, value in kwargs.items()))\n\n    if _wandb_initialized:\n        import wandb\n        wandb.log(kwargs)\n"
  },
  {
    "path": "torch_geometric/metrics/__init__.py",
    "content": "# flake8: noqa\n\nfrom .link_pred import (\n    LinkPredMetric,\n    LinkPredMetricCollection,\n    LinkPredPrecision,\n    LinkPredRecall,\n    LinkPredF1,\n    LinkPredMAP,\n    LinkPredNDCG,\n    LinkPredMRR,\n    LinkPredHitRatio,\n    LinkPredCoverage,\n    LinkPredDiversity,\n    LinkPredPersonalization,\n    LinkPredAveragePopularity,\n)\n\nlink_pred_metrics = [\n    'LinkPredMetric',\n    'LinkPredMetricCollection',\n    'LinkPredPrecision',\n    'LinkPredRecall',\n    'LinkPredF1',\n    'LinkPredMAP',\n    'LinkPredNDCG',\n    'LinkPredMRR',\n    'LinkPredHitRatio',\n    'LinkPredCoverage',\n    'LinkPredDiversity',\n    'LinkPredPersonalization',\n    'LinkPredAveragePopularity',\n]\n\n__all__ = link_pred_metrics\n"
  },
  {
    "path": "torch_geometric/metrics/link_pred.py",
    "content": "from dataclasses import dataclass\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import cumsum, scatter\n\ntry:\n    import torchmetrics  # noqa\n    WITH_TORCHMETRICS = True\n    BaseMetric = torchmetrics.Metric\nexcept Exception:\n    WITH_TORCHMETRICS = False\n    BaseMetric = torch.nn.Module  # type: ignore\n\n\n@dataclass(repr=False)\nclass LinkPredMetricData:\n    pred_index_mat: Tensor\n    edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]]\n    edge_label_weight: Optional[Tensor] = None\n\n    def __post_init__(self) -> None:\n        # Filter all negative weights - they should not be used as ground-truth\n        if self.edge_label_weight is not None:\n            pos_mask = self.edge_label_weight > 0\n            self.edge_label_weight = self.edge_label_weight[pos_mask]\n            if isinstance(self.edge_label_index, Tensor):\n                self.edge_label_index = self.edge_label_index[:, pos_mask]\n            else:\n                self.edge_label_index = (\n                    self.edge_label_index[0][pos_mask],\n                    self.edge_label_index[1][pos_mask],\n                )\n\n    @property\n    def pred_rel_mat(self) -> Tensor:\n        r\"\"\"Returns a matrix indicating the relevance of the `k`-th prediction.\n        If :obj:`edge_label_weight` is not given, relevance will be denoted as\n        binary.\n        \"\"\"\n        if hasattr(self, '_pred_rel_mat'):\n            return self._pred_rel_mat  # type: ignore\n\n        if self.edge_label_index[1].numel() == 0:\n            self._pred_rel_mat = torch.zeros_like(\n                self.pred_index_mat,\n                dtype=torch.bool if self.edge_label_weight is None else\n                torch.get_default_dtype(),\n            )\n            return self._pred_rel_mat\n\n        # Flatten both prediction and ground-truth indices, and determine\n        # overlaps afterwards via `torch.searchsorted`.\n        max_index = max(\n            self.pred_index_mat.max()\n            if self.pred_index_mat.numel() > 0 else 0,\n            self.edge_label_index[1].max()\n            if self.edge_label_index[1].numel() > 0 else 0,\n        ) + 1\n        arange = torch.arange(\n            start=0,\n            end=max_index * self.pred_index_mat.size(0),  # type: ignore\n            step=max_index,  # type: ignore\n            device=self.pred_index_mat.device,\n        ).view(-1, 1)\n        flat_pred_index = (self.pred_index_mat + arange).view(-1)\n        flat_label_index = max_index * self.edge_label_index[0]\n        flat_label_index = flat_label_index + self.edge_label_index[1]\n        flat_label_index, perm = flat_label_index.sort()\n        edge_label_weight = self.edge_label_weight\n        if edge_label_weight is not None:\n            assert edge_label_weight.size() == self.edge_label_index[0].size()\n            edge_label_weight = edge_label_weight[perm]\n\n        pos = torch.searchsorted(flat_label_index, flat_pred_index)\n        pos = pos.clamp(max=flat_label_index.size(0) - 1)  # Out-of-bounds.\n\n        pred_rel_mat = flat_label_index[pos] == flat_pred_index  # Find matches\n        if edge_label_weight is not None:\n            pred_rel_mat = edge_label_weight[pos].where(\n                pred_rel_mat,\n                pred_rel_mat.new_zeros(1),\n            )\n        pred_rel_mat = pred_rel_mat.view(self.pred_index_mat.size())\n\n        self._pred_rel_mat = pred_rel_mat\n        return pred_rel_mat\n\n    @property\n    def label_count(self) -> Tensor:\n        r\"\"\"The number of ground-truth labels for every example.\"\"\"\n        if hasattr(self, '_label_count'):\n            return self._label_count  # type: ignore\n\n        label_count = scatter(\n            torch.ones_like(self.edge_label_index[0]),\n            self.edge_label_index[0],\n            dim=0,\n            dim_size=self.pred_index_mat.size(0),\n            reduce='sum',\n        )\n\n        self._label_count = label_count\n        return label_count\n\n    @property\n    def label_weight_sum(self) -> Tensor:\n        r\"\"\"The sum of edge label weights for every example.\"\"\"\n        if self.edge_label_weight is None:\n            return self.label_count\n\n        if hasattr(self, '_label_weight_sum'):\n            return self._label_weight_sum  # type: ignore\n\n        label_weight_sum = scatter(\n            self.edge_label_weight,\n            self.edge_label_index[0],\n            dim=0,\n            dim_size=self.pred_index_mat.size(0),\n            reduce='sum',\n        )\n\n        self._label_weight_sum = label_weight_sum\n        return label_weight_sum\n\n    @property\n    def edge_label_weight_pos(self) -> Optional[Tensor]:\n        r\"\"\"Returns the position of edge label weights in descending order\n        within example-wise buckets.\n        \"\"\"\n        if self.edge_label_weight is None:\n            return None\n\n        if hasattr(self, '_edge_label_weight_pos'):\n            return self._edge_label_weight_pos  # type: ignore\n\n        # Get the permutation via two sorts: One globally on the weights,\n        # followed by a (stable) sort on the example indices.\n        perm1 = self.edge_label_weight.argsort(descending=True)\n        perm2 = self.edge_label_index[0][perm1].argsort(stable=True)\n        perm = perm1[perm2]\n        # Invert the permutation to get the final position:\n        pos = torch.empty_like(perm)\n        pos[perm] = torch.arange(perm.size(0), device=perm.device)\n        # Normalize position to zero within all buckets:\n        pos = pos - cumsum(self.label_count)[self.edge_label_index[0]]\n\n        self._edge_label_weight_pos = pos\n        return pos\n\n\nclass _LinkPredMetric(BaseMetric):\n    r\"\"\"An abstract class for computing link prediction retrieval metrics.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n    \"\"\"\n    is_differentiable: bool = False\n    full_state_update: bool = False\n    higher_is_better: Optional[bool] = None\n\n    def __init__(self, k: int) -> None:\n        super().__init__()\n\n        if k <= 0:\n            raise ValueError(f\"'k' needs to be a positive integer in \"\n                             f\"'{self.__class__.__name__}' (got {k})\")\n\n        self.k = k\n\n    def update(\n        self,\n        pred_index_mat: Tensor,\n        edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],\n        edge_label_weight: Optional[Tensor] = None,\n    ) -> None:\n        r\"\"\"Updates the state variables based on the current mini-batch\n        prediction.\n\n        :meth:`update` can be repeated multiple times to accumulate the results\n        of successive predictions, *e.g.*, inside a mini-batch training or\n        evaluation loop.\n\n        Args:\n            pred_index_mat (torch.Tensor): The top-:math:`k` predictions of\n                every example in the mini-batch with shape\n                :obj:`[batch_size, k]`.\n            edge_label_index (torch.Tensor): The ground-truth indices for every\n                example in the mini-batch, given in COO format of shape\n                :obj:`[2, num_ground_truth_indices]`.\n            edge_label_weight (torch.Tensor, optional): The weight of the\n                ground-truth indices for every example in the mini-batch of\n                shape :obj:`[num_ground_truth_indices]`. If given, needs to be\n                a vector of positive values. Required for weighted metrics,\n                ignored otherwise. (default: :obj:`None`)\n        \"\"\"\n        raise NotImplementedError\n\n    def compute(self) -> Tensor:\n        r\"\"\"Computes the final metric value.\"\"\"\n        raise NotImplementedError\n\n    def reset(self) -> None:\n        r\"\"\"Resets metric state variables to their default value.\"\"\"\n        if WITH_TORCHMETRICS:\n            super().reset()\n        else:\n            self._reset()\n\n    def _reset(self) -> None:\n        raise NotImplementedError\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(k={self.k})'\n\n\nclass LinkPredMetric(_LinkPredMetric):\n    r\"\"\"An abstract class for computing link prediction retrieval metrics.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n    \"\"\"\n    weighted: bool\n\n    def __init__(self, k: int) -> None:\n        super().__init__(k)\n\n        self.accum: Tensor\n        self.total: Tensor\n\n        if WITH_TORCHMETRICS:\n            self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')\n            self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')\n        else:\n            self.register_buffer('accum', torch.tensor(0.), persistent=False)\n            self.register_buffer('total', torch.tensor(0), persistent=False)\n\n    def update(\n        self,\n        pred_index_mat: Tensor,\n        edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],\n        edge_label_weight: Optional[Tensor] = None,\n    ) -> None:\n        if self.weighted and edge_label_weight is None:\n            raise ValueError(f\"'edge_label_weight' is a required argument for \"\n                             f\"weighted '{self.__class__.__name__}' metrics\")\n        if not self.weighted:\n            edge_label_weight = None\n\n        data = LinkPredMetricData(\n            pred_index_mat=pred_index_mat,\n            edge_label_index=edge_label_index,\n            edge_label_weight=edge_label_weight,\n        )\n        self._update(data)\n\n    def _update(self, data: LinkPredMetricData) -> None:\n        metric = self._compute(data)\n\n        self.accum += metric.sum()\n        self.total += (data.label_count > 0).sum()\n\n    def compute(self) -> Tensor:\n        if self.total == 0:\n            return torch.zeros_like(self.accum)\n        return self.accum / self.total\n\n    def _compute(self, data: LinkPredMetricData) -> Tensor:\n        r\"\"\"Computes the specific metric.\n        To be implemented separately for each metric class.\n\n        Args:\n            data (LinkPredMetricData): The mini-batch data for computing a link\n                prediction metric per example.\n        \"\"\"\n        raise NotImplementedError\n\n    def _reset(self) -> None:\n        self.accum.zero_()\n        self.total.zero_()\n\n    def __repr__(self) -> str:\n        weighted_repr = ', weighted=True' if self.weighted else ''\n        return f'{self.__class__.__name__}(k={self.k}{weighted_repr})'\n\n\nclass LinkPredMetricCollection(torch.nn.ModuleDict):\n    r\"\"\"A collection of metrics to reduce and speed-up computation of link\n    prediction metrics.\n\n    .. code-block:: python\n\n        from torch_geometric.metrics import (\n            LinkPredMAP,\n            LinkPredMetricCollection,\n            LinkPredPrecision,\n            LinkPredRecall,\n        )\n\n        metrics = LinkPredMetricCollection([\n            LinkPredMAP(k=10),\n            LinkPredPrecision(k=100),\n            LinkPredRecall(k=50),\n        ])\n\n        metrics.update(pred_index_mat, edge_label_index)\n        out = metrics.compute()\n        metrics.reset()\n\n        print(out)\n        >>> {'LinkPredMAP@10': tensor(0.375),\n        ...  'LinkPredPrecision@100': tensor(0.127),\n        ...  'LinkPredRecall@50': tensor(0.483)}\n\n    Args:\n        metrics: The link prediction metrics.\n    \"\"\"\n    def __init__(\n        self,\n        metrics: Union[\n            List[LinkPredMetric],\n            Dict[str, LinkPredMetric],\n        ],\n    ) -> None:\n        super().__init__()\n\n        if isinstance(metrics, (list, tuple)):\n            metrics = {\n                (f'{\"Weighted\" if getattr(metric, \"weighted\", False) else \"\"}'\n                 f'{metric.__class__.__name__}@{metric.k}'):\n                metric\n                for metric in metrics\n            }\n        assert len(metrics) > 0\n        assert isinstance(metrics, dict)\n\n        for name, metric in metrics.items():\n            assert isinstance(metric, _LinkPredMetric)\n            self[name] = metric\n\n    @property\n    def max_k(self) -> int:\n        r\"\"\"The maximum number of top-:math:`k` predictions to evaluate\n        against.\n        \"\"\"\n        return max([\n            metric.k  # type: ignore[return-value]\n            for metric in self.values()\n        ])  # type: ignore[type-var]\n\n    @property\n    def weighted(self) -> bool:\n        r\"\"\"Returns :obj:`True` in case the collection holds at least one\n        weighted link prediction metric.\n        \"\"\"\n        return any(\n            [getattr(metric, 'weighted', False) for metric in self.values()])\n\n    def update(  # type: ignore\n        self,\n        pred_index_mat: Tensor,\n        edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],\n        edge_label_weight: Optional[Tensor] = None,\n    ) -> None:\n        r\"\"\"Updates the state variables based on the current mini-batch\n        prediction.\n\n        :meth:`update` can be repeated multiple times to accumulate the results\n        of successive predictions, *e.g.*, inside a mini-batch training or\n        evaluation loop.\n\n        Args:\n            pred_index_mat (torch.Tensor): The top-:math:`k` predictions of\n                every example in the mini-batch with shape\n                :obj:`[batch_size, k]`.\n            edge_label_index (torch.Tensor): The ground-truth indices for every\n                example in the mini-batch, given in COO format of shape\n                :obj:`[2, num_ground_truth_indices]`.\n            edge_label_weight (torch.Tensor, optional): The weight of the\n                ground-truth indices for every example in the mini-batch of\n                shape :obj:`[num_ground_truth_indices]`. If given, needs to be\n                a vector of positive values. Required for weighted metrics,\n                ignored otherwise. (default: :obj:`None`)\n        \"\"\"\n        if self.weighted and edge_label_weight is None:\n            raise ValueError(f\"'edge_label_weight' is a required argument for \"\n                             f\"weighted '{self.__class__.__name__}' metrics\")\n\n        data = LinkPredMetricData(  # Share metric data across metrics.\n            pred_index_mat=pred_index_mat,\n            edge_label_index=edge_label_index,\n            edge_label_weight=edge_label_weight,\n        )\n\n        for metric in self.values():\n            if isinstance(metric, LinkPredMetric) and metric.weighted:\n                metric._update(data)\n                if WITH_TORCHMETRICS:\n                    metric._update_count += 1\n\n        data.edge_label_weight = None\n        if hasattr(data, '_pred_rel_mat'):\n            data._pred_rel_mat = data._pred_rel_mat != 0.0\n        if hasattr(data, '_label_weight_sum'):\n            del data._label_weight_sum\n        if hasattr(data, '_edge_label_weight_pos'):\n            del data._edge_label_weight_pos\n\n        for metric in self.values():\n            if isinstance(metric, LinkPredMetric) and not metric.weighted:\n                metric._update(data)\n                if WITH_TORCHMETRICS:\n                    metric._update_count += 1\n\n        for metric in self.values():\n            if not isinstance(metric, LinkPredMetric):\n                metric.update(  # type: ignore[operator]\n                    pred_index_mat,\n                    edge_label_index,\n                    edge_label_weight,\n                )\n\n    def compute(self) -> Dict[str, Tensor]:\n        r\"\"\"Computes the final metric values.\"\"\"\n        return {\n            name: metric.compute()  # type: ignore[operator]\n            for name, metric in self.items()\n        }\n\n    def reset(self) -> None:\n        r\"\"\"Reset metric state variables to their default value.\"\"\"\n        for metric in self.values():\n            metric.reset()  # type: ignore[operator]\n\n    def __repr__(self) -> str:\n        names = [f'  {name}: {metric},\\n' for name, metric in self.items()]\n        return f'{self.__class__.__name__}([\\n{\"\".join(names)}])'\n\n\nclass LinkPredPrecision(LinkPredMetric):\n    r\"\"\"A link prediction metric to compute Precision @ :math:`k`, *i.e.* the\n    proportion of recommendations within the top-:math:`k` that are actually\n    relevant.\n\n    A higher precision indicates the model's ability to surface relevant items\n    early in the ranking.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n    \"\"\"\n    higher_is_better: bool = True\n    weighted: bool = False\n\n    def _compute(self, data: LinkPredMetricData) -> Tensor:\n        pred_rel_mat = data.pred_rel_mat[:, :self.k]\n        return pred_rel_mat.sum(dim=-1) / self.k\n\n\nclass LinkPredRecall(LinkPredMetric):\n    r\"\"\"A link prediction metric to compute Recall @ :math:`k`, *i.e.* the\n    proportion of relevant items that appear within the top-:math:`k`.\n\n    A higher recall indicates the model's ability to retrieve a larger\n    proportion of relevant items.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n    \"\"\"\n    higher_is_better: bool = True\n\n    def __init__(self, k: int, weighted: bool = False):\n        super().__init__(k=k)\n        self.weighted = weighted\n\n    def _compute(self, data: LinkPredMetricData) -> Tensor:\n        pred_rel_mat = data.pred_rel_mat[:, :self.k]\n        return pred_rel_mat.sum(dim=-1) / data.label_weight_sum.clamp(min=1e-7)\n\n\nclass LinkPredF1(LinkPredMetric):\n    r\"\"\"A link prediction metric to compute F1 @ :math:`k`.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n    \"\"\"\n    higher_is_better: bool = True\n    weighted: bool = False\n\n    def _compute(self, data: LinkPredMetricData) -> Tensor:\n        pred_rel_mat = data.pred_rel_mat[:, :self.k]\n        isin_count = pred_rel_mat.sum(dim=-1)\n        precision = isin_count / self.k\n        recall = isin_count / data.label_count.clamp(min=1e-7)\n        return 2 * precision * recall / (precision + recall).clamp(min=1e-7)\n\n\nclass LinkPredMAP(LinkPredMetric):\n    r\"\"\"A link prediction metric to compute MAP @ :math:`k` (Mean Average\n    Precision), considering the order of relevant items within the\n    top-:math:`k`.\n\n    MAP @ :math:`k` can provide a more comprehensive view of ranking quality\n    than precision alone.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n    \"\"\"\n    higher_is_better: bool = True\n    weighted: bool = False\n\n    def _compute(self, data: LinkPredMetricData) -> Tensor:\n        pred_rel_mat = data.pred_rel_mat[:, :self.k]\n        device = pred_rel_mat.device\n        arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)\n        cum_precision = pred_rel_mat.cumsum(dim=1) / arange\n        return ((cum_precision * pred_rel_mat).sum(dim=-1) /\n                data.label_count.clamp(min=1e-7, max=self.k))\n\n\nclass LinkPredNDCG(LinkPredMetric):\n    r\"\"\"A link prediction metric to compute the NDCG @ :math:`k` (Normalized\n    Discounted Cumulative Gain).\n\n    In particular, can account for the position of relevant items by\n    considering relevance scores, giving higher weight to more relevant items\n    appearing at the top.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n        weighted (bool, optional): If set to :obj:`True`, assumes sorted lists\n            of ground-truth items according to a relevance score as given by\n            :obj:`edge_label_weight`. (default: :obj:`False`)\n    \"\"\"\n    higher_is_better: bool = True\n\n    def __init__(self, k: int, weighted: bool = False):\n        super().__init__(k=k)\n        self.weighted = weighted\n\n        dtype = torch.get_default_dtype()\n        discount = torch.arange(2, k + 2, dtype=dtype).log2()\n\n        self.discount: Tensor\n        self.register_buffer('discount', discount, persistent=False)\n\n        if not weighted:\n            self.register_buffer('idcg', cumsum(1.0 / discount),\n                                 persistent=False)\n        else:\n            self.idcg = None\n\n    def _compute(self, data: LinkPredMetricData) -> Tensor:\n        pred_rel_mat = data.pred_rel_mat[:, :self.k]\n        discount = self.discount[:pred_rel_mat.size(1)].view(1, -1)\n        dcg = (pred_rel_mat / discount).sum(dim=-1)\n\n        if not self.weighted:\n            assert self.idcg is not None\n            idcg = self.idcg[data.label_count.clamp(max=self.k)]\n        else:\n            assert data.edge_label_weight is not None\n            pos = data.edge_label_weight_pos\n            assert pos is not None\n\n            discount = torch.cat([\n                self.discount,\n                self.discount.new_full((1, ), fill_value=float('inf')),\n            ])\n            discount = discount[pos.clamp(max=self.k)]\n\n            idcg = scatter(  # Apply discount and aggregate:\n                data.edge_label_weight / discount,\n                data.edge_label_index[0],\n                dim_size=data.pred_index_mat.size(0),\n                reduce='sum',\n            )\n\n        out = dcg / idcg\n        out[out.isnan() | out.isinf()] = 0.0\n        return out\n\n\nclass LinkPredMRR(LinkPredMetric):\n    r\"\"\"A link prediction metric to compute the MRR @ :math:`k` (Mean\n    Reciprocal Rank), *i.e.* the mean reciprocal rank of the first correct\n    prediction (or zero otherwise).\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n    \"\"\"\n    higher_is_better: bool = True\n    weighted: bool = False\n\n    def _compute(self, data: LinkPredMetricData) -> Tensor:\n        pred_rel_mat = data.pred_rel_mat[:, :self.k]\n        device = pred_rel_mat.device\n        arange = torch.arange(1, pred_rel_mat.size(1) + 1, device=device)\n        return (pred_rel_mat / arange).max(dim=-1)[0]\n\n\nclass LinkPredHitRatio(LinkPredMetric):\n    r\"\"\"A link prediction metric to compute the hit ratio @ :math:`k`, *i.e.*\n    the percentage of users for whom at least one relevant item is present\n    within the top-:math:`k` recommendations.\n\n    A high ratio signifies the model's effectiveness in satisfying a broad\n    range of user preferences.\n    \"\"\"\n    higher_is_better: bool = True\n    weighted: bool = False\n\n    def _compute(self, data: LinkPredMetricData) -> Tensor:\n        pred_rel_mat = data.pred_rel_mat[:, :self.k]\n        return pred_rel_mat.max(dim=-1)[0].to(torch.get_default_dtype())\n\n\nclass LinkPredCoverage(_LinkPredMetric):\n    r\"\"\"A link prediction metric to compute the Coverage @ :math:`k` of\n    predictions, *i.e.* the percentage of unique items recommended across all\n    users within the top-:math:`k`.\n\n    Higher coverage indicates a wider exploration of the item catalog.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n        num_dst_nodes (int): The total number of destination nodes.\n    \"\"\"\n    higher_is_better: bool = True\n\n    def __init__(self, k: int, num_dst_nodes: int) -> None:\n        super().__init__(k)\n        self.num_dst_nodes = num_dst_nodes\n\n        self.mask: Tensor\n        mask = torch.zeros(num_dst_nodes, dtype=torch.bool)\n        if WITH_TORCHMETRICS:\n            self.add_state('mask', mask, dist_reduce_fx='max')\n        else:\n            self.register_buffer('mask', mask, persistent=False)\n\n    def update(\n        self,\n        pred_index_mat: Tensor,\n        edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],\n        edge_label_weight: Optional[Tensor] = None,\n    ) -> None:\n        self.mask[pred_index_mat[:, :self.k].flatten()] = True\n\n    def compute(self) -> Tensor:\n        return self.mask.to(torch.get_default_dtype()).mean()\n\n    def _reset(self) -> None:\n        self.mask.zero_()\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(k={self.k}, '\n                f'num_dst_nodes={self.num_dst_nodes})')\n\n\nclass LinkPredDiversity(_LinkPredMetric):\n    r\"\"\"A link prediction metric to compute the Diversity @ :math:`k` of\n    predictions according to item categories.\n\n    Diversity is computed as\n\n    .. math::\n        div_{u@k} = 1 - \\left( \\frac{1}{k \\cdot (k-1)} \\right) \\sum_{i \\neq j}\n        sim(i, j)\n\n    where\n\n    .. math::\n        sim(i,j) = \\begin{cases}\n            1 & \\quad \\text{if } i,j \\text{ share category,}\\\\\n            0 & \\quad \\text{otherwise.}\n        \\end{cases}\n\n    which measures the pair-wise inequality of recommendations according to\n    item categories.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n        category (torch.Tensor): A vector that assigns each destination node to\n            a specific category.\n    \"\"\"\n    higher_is_better: bool = True\n\n    def __init__(self, k: int, category: Tensor) -> None:\n        super().__init__(k)\n\n        self.accum: Tensor\n        self.total: Tensor\n\n        if WITH_TORCHMETRICS:\n            self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')\n            self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')\n        else:\n            self.register_buffer('accum', torch.tensor(0.), persistent=False)\n            self.register_buffer('total', torch.tensor(0), persistent=False)\n\n        self.category: Tensor\n        self.register_buffer('category', category, persistent=False)\n\n    def update(\n        self,\n        pred_index_mat: Tensor,\n        edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],\n        edge_label_weight: Optional[Tensor] = None,\n    ) -> None:\n        category = self.category[pred_index_mat[:, :self.k]]\n\n        sim = (category.unsqueeze(-2) == category.unsqueeze(-1)).sum(dim=-1)\n        div = 1 - 1 / (self.k * (self.k - 1)) * (sim - 1).sum(dim=-1)\n\n        self.accum += div.sum()\n        self.total += pred_index_mat.size(0)\n\n    def compute(self) -> Tensor:\n        if self.total == 0:\n            return torch.zeros_like(self.accum)\n        return self.accum / self.total\n\n    def _reset(self) -> None:\n        self.accum.zero_()\n        self.total.zero_()\n\n\nclass LinkPredPersonalization(_LinkPredMetric):\n    r\"\"\"A link prediction metric to compute the Personalization @ :math:`k`,\n    *i.e.* the dissimilarity of recommendations across different users.\n\n    Higher personalization suggests that the model tailors recommendations to\n    individual user preferences rather than providing generic results.\n\n    Dissimilarity is defined by the average inverse cosine similarity between\n    users' lists of recommendations.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n        max_src_nodes (int, optional): The maximum source nodes to consider to\n            compute pair-wise dissimilarity. If specified,\n            Personalization @ :math:`k` is approximated to avoid computation\n            blowup due to quadratic complexity. (default: :obj:`2**12`)\n        batch_size (int, optional): The batch size to determine how many pairs\n            of user recommendations should be processed at once.\n            (default: :obj:`2**16`)\n    \"\"\"\n    higher_is_better: bool = True\n\n    def __init__(\n        self,\n        k: int,\n        max_src_nodes: Optional[int] = 2**12,\n        batch_size: int = 2**16,\n    ) -> None:\n        super().__init__(k)\n        self.max_src_nodes = max_src_nodes\n        self.batch_size = batch_size\n\n        self.preds: List[Tensor]\n        self.total: Tensor\n\n        if WITH_TORCHMETRICS:\n            self.add_state('preds', default=[], dist_reduce_fx='cat')\n            self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')\n        else:\n            self.preds = []\n            self.register_buffer('total', torch.tensor(0), persistent=False)\n\n    def update(\n        self,\n        pred_index_mat: Tensor,\n        edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],\n        edge_label_weight: Optional[Tensor] = None,\n    ) -> None:\n\n        # NOTE Move to CPU to avoid memory blowup.\n        pred_index_mat = pred_index_mat[:, :self.k].cpu()\n\n        if self.max_src_nodes is None:\n            self.preds.append(pred_index_mat)\n            self.total += pred_index_mat.size(0)\n        elif self.total < self.max_src_nodes:\n            remaining = int(self.max_src_nodes - self.total)\n            pred_index_mat = pred_index_mat[:remaining]\n            self.preds.append(pred_index_mat)\n            self.total += pred_index_mat.size(0)\n\n    def compute(self) -> Tensor:\n        device = self.total.device\n        score = torch.tensor(0.0, device=device)\n        total = torch.tensor(0, device=device)\n\n        if len(self.preds) == 0:\n            return score\n\n        pred = torch.cat(self.preds, dim=0)\n\n        if pred.size(0) == 0:\n            return score\n\n        # Calculate all pairs of nodes (e.g., triu_indices with offset=1).\n        # NOTE We do this in chunks to avoid memory blow-up, which leads to a\n        # more efficient but trickier implementation.\n        num_pairs = (pred.size(0) * (pred.size(0) - 1)) // 2\n        offset = torch.arange(pred.size(0) - 1, 0, -1, device=device)\n        rowptr = cumsum(offset)\n        for start in range(0, num_pairs, self.batch_size):\n            end = min(start + self.batch_size, num_pairs)\n            idx = torch.arange(start, end, device=device)\n\n            # Find the corresponding row:\n            row = torch.searchsorted(rowptr, idx, right=True) - 1\n            # Find the corresponding column:\n            col = idx - rowptr[row] + (pred.size(0) - offset[row])\n\n            left = pred[row.cpu()].to(device)\n            right = pred[col.cpu()].to(device)\n\n            # Use offset to work around applying `isin` along a specific dim:\n            i = max(int(left.max()), int(right.max())) + 1\n            idx = torch.arange(0, i * row.size(0), i, device=device)\n            idx = idx.view(-1, 1)\n            isin = torch.isin(left + idx, right + idx)\n\n            # Compute personalization via average inverse cosine similarity:\n            cos = isin.sum(dim=-1) / pred.size(1)\n            score += (1 - cos).sum()\n            total += cos.numel()\n\n        return score / total\n\n    def _reset(self) -> None:\n        self.preds = []\n        self.total.zero_()\n\n\nclass LinkPredAveragePopularity(_LinkPredMetric):\n    r\"\"\"A link prediction metric to compute the Average Recommendation\n    Popularity (ARP) @ :math:`k`, which provides insights into the model's\n    tendency to recommend popular items by averaging the popularity scores of\n    items within the top-:math:`k` recommendations.\n\n    Args:\n        k (int): The number of top-:math:`k` predictions to evaluate against.\n        popularity (torch.Tensor): The popularity of every item in the training\n            set, *e.g.*, the number of times an item has been rated.\n    \"\"\"\n    higher_is_better: bool = False\n\n    def __init__(self, k: int, popularity: Tensor) -> None:\n        super().__init__(k)\n\n        self.accum: Tensor\n        self.total: Tensor\n\n        if WITH_TORCHMETRICS:\n            self.add_state('accum', torch.tensor(0.), dist_reduce_fx='sum')\n            self.add_state('total', torch.tensor(0), dist_reduce_fx='sum')\n        else:\n            self.register_buffer('accum', torch.tensor(0.), persistent=False)\n            self.register_buffer('total', torch.tensor(0), persistent=False)\n\n        self.popularity: Tensor\n        self.register_buffer('popularity', popularity, persistent=False)\n\n    def update(\n        self,\n        pred_index_mat: Tensor,\n        edge_label_index: Union[Tensor, Tuple[Tensor, Tensor]],\n        edge_label_weight: Optional[Tensor] = None,\n    ) -> None:\n        pred_index_mat = pred_index_mat[:, :self.k]\n        popularity = self.popularity[pred_index_mat]\n        popularity = popularity.to(self.accum.dtype).mean(dim=-1)\n        self.accum += popularity.sum()\n        self.total += popularity.numel()\n\n    def compute(self) -> Tensor:\n        if self.total == 0:\n            return torch.zeros_like(self.accum)\n        return self.accum / self.total\n\n    def _reset(self) -> None:\n        self.accum.zero_()\n        self.total.zero_()\n"
  },
  {
    "path": "torch_geometric/nn/__init__.py",
    "content": "from .reshape import Reshape\nfrom .sequential import Sequential\nfrom .data_parallel import DataParallel\nfrom .to_hetero_transformer import to_hetero\nfrom .to_hetero_with_bases_transformer import to_hetero_with_bases\nfrom .to_fixed_size_transformer import to_fixed_size\nfrom .encoding import PositionalEncoding, TemporalEncoding\nfrom .summary import summary\n\nfrom .aggr import *  # noqa\nfrom .attention import *  # noqa\nfrom .conv import *  # noqa\nfrom .pool import *  # noqa\nfrom .glob import *  # noqa\nfrom .norm import *  # noqa\nfrom .unpool import *  # noqa\nfrom .dense import *  # noqa\nfrom .kge import *  # noqa\nfrom .models import *  # noqa\nfrom .functional import *  # noqa\n\n__all__ = [\n    'Reshape',\n    'Sequential',\n    'DataParallel',\n    'to_hetero',\n    'to_hetero_with_bases',\n    'to_fixed_size',\n    'PositionalEncoding',\n    'TemporalEncoding',\n    'summary',\n]\n"
  },
  {
    "path": "torch_geometric/nn/aggr/__init__.py",
    "content": "from .base import Aggregation\nfrom .multi import MultiAggregation\nfrom .basic import (\n    MeanAggregation,\n    SumAggregation,\n    MaxAggregation,\n    MinAggregation,\n    MulAggregation,\n    VarAggregation,\n    StdAggregation,\n    SoftmaxAggregation,\n    PowerMeanAggregation,\n)\nfrom .quantile import MedianAggregation, QuantileAggregation\nfrom .lstm import LSTMAggregation\nfrom .gru import GRUAggregation\nfrom .set2set import Set2Set\nfrom .scaler import DegreeScalerAggregation\nfrom .equilibrium import EquilibriumAggregation\nfrom .sort import SortAggregation\nfrom .gmt import GraphMultisetTransformer\nfrom .attention import AttentionalAggregation\nfrom .mlp import MLPAggregation\nfrom .deep_sets import DeepSetsAggregation\nfrom .set_transformer import SetTransformerAggregation\nfrom .lcm import LCMAggregation\nfrom .variance_preserving import VariancePreservingAggregation\nfrom .patch_transformer import PatchTransformerAggregation\n\n__all__ = classes = [\n    'Aggregation',\n    'MultiAggregation',\n    'SumAggregation',\n    'MeanAggregation',\n    'MaxAggregation',\n    'MinAggregation',\n    'MulAggregation',\n    'VarAggregation',\n    'StdAggregation',\n    'SoftmaxAggregation',\n    'PowerMeanAggregation',\n    'MedianAggregation',\n    'QuantileAggregation',\n    'LSTMAggregation',\n    'GRUAggregation',\n    'Set2Set',\n    'DegreeScalerAggregation',\n    'SortAggregation',\n    'GraphMultisetTransformer',\n    'AttentionalAggregation',\n    'EquilibriumAggregation',\n    'MLPAggregation',\n    'DeepSetsAggregation',\n    'SetTransformerAggregation',\n    'LCMAggregation',\n    'VariancePreservingAggregation',\n    'PatchTransformerAggregation',\n]\n"
  },
  {
    "path": "torch_geometric/nn/aggr/attention.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.utils import softmax\n\n\nclass AttentionalAggregation(Aggregation):\n    r\"\"\"The soft attention aggregation layer from the `\"Graph Matching Networks\n    for Learning the Similarity of Graph Structured Objects\"\n    <https://arxiv.org/abs/1904.12787>`_ paper.\n\n    .. math::\n        \\mathbf{r}_i = \\sum_{n=1}^{N_i} \\mathrm{softmax} \\left(\n        h_{\\mathrm{gate}} ( \\mathbf{x}_n ) \\right) \\cdot\n        h_{\\mathbf{\\Theta}} ( \\mathbf{x}_n ),\n\n    where :math:`h_{\\mathrm{gate}} \\colon \\mathbb{R}^F \\to\n    \\mathbb{R}` and :math:`h_{\\mathbf{\\Theta}}` denote neural networks, *i.e.*\n    MLPs.\n\n    Args:\n        gate_nn (torch.nn.Module): A neural network :math:`h_{\\mathrm{gate}}`\n            that computes attention scores by mapping node features :obj:`x` of\n            shape :obj:`[-1, in_channels]` to shape :obj:`[-1, 1]` (for\n            node-level gating) or :obj:`[1, out_channels]` (for feature-level\n            gating), *e.g.*, defined by :class:`torch.nn.Sequential`.\n        nn (torch.nn.Module, optional): A neural network\n            :math:`h_{\\mathbf{\\Theta}}` that maps node features :obj:`x` of\n            shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]`\n            before combining them with the attention scores, *e.g.*, defined by\n            :class:`torch.nn.Sequential`. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        gate_nn: torch.nn.Module,\n        nn: Optional[torch.nn.Module] = None,\n    ):\n        super().__init__()\n\n        from torch_geometric.nn import MLP\n\n        self.gate_nn = self.gate_mlp = None\n        if isinstance(gate_nn, MLP):\n            self.gate_mlp = gate_nn\n        else:\n            self.gate_nn = gate_nn\n\n        self.nn = self.mlp = None\n        if isinstance(nn, MLP):\n            self.mlp = nn\n        else:\n            self.nn = nn\n\n    def reset_parameters(self):\n        reset(self.gate_nn)\n        reset(self.gate_mlp)\n        reset(self.nn)\n        reset(self.mlp)\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        if self.gate_mlp is not None:\n            gate = self.gate_mlp(x, batch=index, batch_size=dim_size)\n        else:\n            gate = self.gate_nn(x)\n\n        if self.mlp is not None:\n            x = self.mlp(x, batch=index, batch_size=dim_size)\n        elif self.nn is not None:\n            x = self.nn(x)\n\n        gate = softmax(gate, index, ptr, dim_size, dim)\n        return self.reduce(gate * x, index, ptr, dim_size, dim)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'gate_nn={self.gate_mlp or self.gate_nn}, '\n                f'nn={self.mlp or self.nn})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/base.py",
    "content": "from typing import Final, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.experimental import disable_dynamic_shapes\nfrom torch_geometric.utils import scatter, segment, to_dense_batch\n\n\nclass Aggregation(torch.nn.Module):\n    r\"\"\"An abstract base class for implementing custom aggregations.\n\n    Aggregation can be either performed via an :obj:`index` vector, which\n    defines the mapping from input elements to their location in the output:\n\n    |\n\n    .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/\n            master/docs/source/_figures/add.svg?sanitize=true\n        :align: center\n        :width: 400px\n\n    |\n\n    Notably, :obj:`index` does not have to be sorted (for most aggregation\n    operators):\n\n    .. code-block:: python\n\n       # Feature matrix holding 10 elements with 64 features each:\n       x = torch.randn(10, 64)\n\n       # Assign each element to one of three sets:\n       index = torch.tensor([0, 0, 1, 0, 2, 0, 2, 1, 0, 2])\n\n       output = aggr(x, index)  #  Output shape: [3, 64]\n\n    Alternatively, aggregation can be achieved via a \"compressed\" index vector\n    called :obj:`ptr`. Here, elements within the same set need to be grouped\n    together in the input, and :obj:`ptr` defines their boundaries:\n\n    .. code-block:: python\n\n       # Feature matrix holding 10 elements with 64 features each:\n       x = torch.randn(10, 64)\n\n       # Define the boundary indices for three sets:\n       ptr = torch.tensor([0, 4, 7, 10])\n\n       output = aggr(x, ptr=ptr)  #  Output shape: [3, 64]\n\n    Note that at least one of :obj:`index` or :obj:`ptr` must be defined.\n\n    Shapes:\n        - **input:**\n          node features :math:`(*, |\\mathcal{V}|, F_{in})` or edge features\n          :math:`(*, |\\mathcal{E}|, F_{in})`,\n          index vector :math:`(|\\mathcal{V}|)` or :math:`(|\\mathcal{E}|)`,\n        - **output:** graph features :math:`(*, |\\mathcal{G}|, F_{out})` or\n          node features :math:`(*, |\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(self) -> None:\n        super().__init__()\n\n        self._deterministic: Final[bool] = (\n            torch.are_deterministic_algorithms_enabled()\n            or torch.is_deterministic_algorithms_warn_only_enabled())\n\n    def forward(\n        self,\n        x: Tensor,\n        index: Optional[Tensor] = None,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n        dim: int = -2,\n        max_num_elements: Optional[int] = None,\n    ) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n            index (torch.Tensor, optional): The indices of elements for\n                applying the aggregation.\n                One of :obj:`index` or :obj:`ptr` must be defined.\n                (default: :obj:`None`)\n            ptr (torch.Tensor, optional): If given, computes the aggregation\n                based on sorted inputs in CSR representation.\n                One of :obj:`index` or :obj:`ptr` must be defined.\n                (default: :obj:`None`)\n            dim_size (int, optional): The size of the output tensor at\n                dimension :obj:`dim` after aggregation. (default: :obj:`None`)\n            dim (int, optional): The dimension in which to aggregate.\n                (default: :obj:`-2`)\n            max_num_elements: (int, optional): The maximum number of elements\n                within a single aggregation group. (default: :obj:`None`)\n        \"\"\"\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n\n    @disable_dynamic_shapes(required_args=['dim_size'])\n    def __call__(\n        self,\n        x: Tensor,\n        index: Optional[Tensor] = None,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n        dim: int = -2,\n        **kwargs,\n    ) -> Tensor:\n\n        if dim >= x.dim() or dim < -x.dim():\n            raise ValueError(f\"Encountered invalid dimension '{dim}' of \"\n                             f\"source tensor with {x.dim()} dimensions\")\n\n        if index is None and ptr is None:\n            index = x.new_zeros(x.size(dim), dtype=torch.long)\n\n        if ptr is not None:\n            if dim_size is None:\n                dim_size = ptr.numel() - 1\n            elif dim_size != ptr.numel() - 1:\n                raise ValueError(f\"Encountered invalid 'dim_size' (got \"\n                                 f\"'{dim_size}' but expected \"\n                                 f\"'{ptr.numel() - 1}')\")\n\n        if index is not None and dim_size is None:\n            dim_size = int(index.max()) + 1 if index.numel() > 0 else 0\n\n        try:\n            return super().__call__(x, index=index, ptr=ptr, dim_size=dim_size,\n                                    dim=dim, **kwargs)\n        except (IndexError, RuntimeError) as e:\n            if index is not None:\n                if index.numel() > 0 and dim_size <= int(index.max()):\n                    raise ValueError(f\"Encountered invalid 'dim_size' (got \"\n                                     f\"'{dim_size}' but expected \"\n                                     f\">= '{int(index.max()) + 1}')\") from e\n            raise e\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n\n    # Assertions ##############################################################\n\n    def assert_index_present(self, index: Optional[Tensor]):\n        # TODO Currently, not all aggregators support `ptr`. This assert helps\n        # to ensure that we require `index` to be passed to the computation:\n        if index is None:\n            raise NotImplementedError(\n                \"Aggregation requires 'index' to be specified\")\n\n    def assert_sorted_index(self, index: Optional[Tensor]):\n        if index is not None and not torch.all(index[:-1] <= index[1:]):\n            raise ValueError(\"Can not perform aggregation since the 'index' \"\n                             \"tensor is not sorted. Specifically, if you use \"\n                             \"this aggregation as part of 'MessagePassing`, \"\n                             \"ensure that 'edge_index' is sorted by \"\n                             \"destination nodes, e.g., by calling \"\n                             \"`data.sort(sort_by_row=False)`\")\n\n    def assert_two_dimensional_input(self, x: Tensor, dim: int):\n        if x.dim() != 2:\n            raise ValueError(f\"Aggregation requires two-dimensional inputs \"\n                             f\"(got '{x.dim()}')\")\n\n        if dim not in [-2, 0]:\n            raise ValueError(f\"Aggregation needs to perform aggregation in \"\n                             f\"first dimension (got '{dim}')\")\n\n    # Helper methods ##########################################################\n\n    def reduce(self, x: Tensor, index: Optional[Tensor] = None,\n               ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n               dim: int = -2, reduce: str = 'sum') -> Tensor:\n\n        if ptr is not None:\n            if index is None or self._deterministic:\n                ptr = expand_left(ptr, dim, dims=x.dim())\n                return segment(x, ptr, reduce=reduce)\n\n        if index is None:\n            raise RuntimeError(\"Aggregation requires 'index' to be specified\")\n\n        return scatter(x, index, dim, dim_size, reduce)\n\n    def to_dense_batch(\n        self,\n        x: Tensor,\n        index: Optional[Tensor] = None,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n        dim: int = -2,\n        fill_value: float = 0.0,\n        max_num_elements: Optional[int] = None,\n    ) -> Tuple[Tensor, Tensor]:\n\n        # TODO Currently, `to_dense_batch` can only operate on `index`:\n        self.assert_index_present(index)\n        self.assert_sorted_index(index)\n        self.assert_two_dimensional_input(x, dim)\n\n        return to_dense_batch(\n            x,\n            index,\n            batch_size=dim_size,\n            fill_value=fill_value,\n            max_num_nodes=max_num_elements,\n        )\n\n\n###############################################################################\n\n\ndef expand_left(ptr: Tensor, dim: int, dims: int) -> Tensor:\n    for _ in range(dims + dim if dim < 0 else dim):\n        ptr = ptr.unsqueeze(0)\n    return ptr\n"
  },
  {
    "path": "torch_geometric/nn/aggr/basic.py",
    "content": "import math\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.utils import softmax\n\n\nclass SumAggregation(Aggregation):\n    r\"\"\"An aggregation operator that sums up features across a set of elements.\n\n    .. math::\n        \\mathrm{sum}(\\mathcal{X}) = \\sum_{\\mathbf{x}_i \\in \\mathcal{X}}\n        \\mathbf{x}_i.\n    \"\"\"\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n        return self.reduce(x, index, ptr, dim_size, dim, reduce='sum')\n\n\nclass MeanAggregation(Aggregation):\n    r\"\"\"An aggregation operator that averages features across a set of\n    elements.\n\n    .. math::\n        \\mathrm{mean}(\\mathcal{X}) = \\frac{1}{|\\mathcal{X}|}\n        \\sum_{\\mathbf{x}_i \\in \\mathcal{X}} \\mathbf{x}_i.\n    \"\"\"\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n        return self.reduce(x, index, ptr, dim_size, dim, reduce='mean')\n\n\nclass MaxAggregation(Aggregation):\n    r\"\"\"An aggregation operator that takes the feature-wise maximum across a\n    set of elements.\n\n    .. math::\n        \\mathrm{max}(\\mathcal{X}) = \\max_{\\mathbf{x}_i \\in \\mathcal{X}}\n        \\mathbf{x}_i.\n    \"\"\"\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n        return self.reduce(x, index, ptr, dim_size, dim, reduce='max')\n\n\nclass MinAggregation(Aggregation):\n    r\"\"\"An aggregation operator that takes the feature-wise minimum across a\n    set of elements.\n\n    .. math::\n        \\mathrm{min}(\\mathcal{X}) = \\min_{\\mathbf{x}_i \\in \\mathcal{X}}\n        \\mathbf{x}_i.\n    \"\"\"\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n        return self.reduce(x, index, ptr, dim_size, dim, reduce='min')\n\n\nclass MulAggregation(Aggregation):\n    r\"\"\"An aggregation operator that multiples features across a set of\n    elements.\n\n    .. math::\n        \\mathrm{mul}(\\mathcal{X}) = \\prod_{\\mathbf{x}_i \\in \\mathcal{X}}\n        \\mathbf{x}_i.\n    \"\"\"\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n        # TODO Currently, `mul` reduction can only operate on `index`:\n        self.assert_index_present(index)\n        return self.reduce(x, index, None, dim_size, dim, reduce='mul')\n\n\nclass VarAggregation(Aggregation):\n    r\"\"\"An aggregation operator that takes the feature-wise variance across a\n    set of elements.\n\n    .. math::\n        \\mathrm{var}(\\mathcal{X}) = \\mathrm{mean}(\\{ \\mathbf{x}_i^2 : x \\in\n        \\mathcal{X} \\}) - \\mathrm{mean}(\\mathcal{X})^2.\n\n    Args:\n        semi_grad (bool, optional): If set to :obj:`True`, will turn off\n            gradient calculation during :math:`E[X^2]` computation. Therefore,\n            only semi-gradients are used during backpropagation. Useful for\n            saving memory and accelerating backward computation.\n            (default: :obj:`False`)\n    \"\"\"\n    def __init__(self, semi_grad: bool = False):\n        super().__init__()\n        self.semi_grad = semi_grad\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n        mean = self.reduce(x, index, ptr, dim_size, dim, reduce='mean')\n        if self.semi_grad:\n            with torch.no_grad():\n                mean2 = self.reduce(x * x, index, ptr, dim_size, dim, 'mean')\n        else:\n            mean2 = self.reduce(x * x, index, ptr, dim_size, dim, 'mean')\n        return mean2 - mean * mean\n\n\nclass StdAggregation(Aggregation):\n    r\"\"\"An aggregation operator that takes the feature-wise standard deviation\n    across a set of elements.\n\n    .. math::\n        \\mathrm{std}(\\mathcal{X}) = \\sqrt{\\mathrm{var}(\\mathcal{X})}.\n\n    Args:\n        semi_grad (bool, optional): If set to :obj:`True`, will turn off\n            gradient calculation during :math:`E[X^2]` computation. Therefore,\n            only semi-gradients are used during backpropagation. Useful for\n            saving memory and accelerating backward computation.\n            (default: :obj:`False`)\n    \"\"\"\n    def __init__(self, semi_grad: bool = False):\n        super().__init__()\n        self.var_aggr = VarAggregation(semi_grad)\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n        var = self.var_aggr(x, index, ptr, dim_size, dim)\n        # Allow \"undefined\" gradient at `sqrt(0.0)`:\n        out = var.clamp(min=1e-5).sqrt()\n        out = out.masked_fill(out <= math.sqrt(1e-5), 0.0)\n        return out\n\n\nclass SoftmaxAggregation(Aggregation):\n    r\"\"\"The softmax aggregation operator based on a temperature term, as\n    described in the `\"DeeperGCN: All You Need to Train Deeper GCNs\"\n    <https://arxiv.org/abs/2006.07739>`_ paper.\n\n    .. math::\n        \\mathrm{softmax}(\\mathcal{X}|t) = \\sum_{\\mathbf{x}_i\\in\\mathcal{X}}\n        \\frac{\\exp(t\\cdot\\mathbf{x}_i)}{\\sum_{\\mathbf{x}_j\\in\\mathcal{X}}\n        \\exp(t\\cdot\\mathbf{x}_j)}\\cdot\\mathbf{x}_{i},\n\n    where :math:`t` controls the softness of the softmax when aggregating over\n    a set of features :math:`\\mathcal{X}`.\n\n    Args:\n        t (float, optional): Initial inverse temperature for softmax\n            aggregation. (default: :obj:`1.0`)\n        learn (bool, optional): If set to :obj:`True`, will learn the value\n            :obj:`t` for softmax aggregation dynamically.\n            (default: :obj:`False`)\n        semi_grad (bool, optional): If set to :obj:`True`, will turn off\n            gradient calculation during softmax computation. Therefore, only\n            semi-gradients are used during backpropagation. Useful for saving\n            memory and accelerating backward computation when :obj:`t` is not\n            learnable. (default: :obj:`False`)\n        channels (int, optional): Number of channels to learn from :math:`t`.\n            If set to a value greater than :obj:`1`, :math:`t` will be learned\n            per input feature channel. This requires compatible shapes for the\n            input to the forward calculation. (default: :obj:`1`)\n    \"\"\"\n    def __init__(self, t: float = 1.0, learn: bool = False,\n                 semi_grad: bool = False, channels: int = 1):\n        super().__init__()\n\n        if learn and semi_grad:\n            raise ValueError(\n                f\"Cannot enable 'semi_grad' in '{self.__class__.__name__}' in \"\n                f\"case the temperature term 't' is learnable\")\n\n        if not learn and channels != 1:\n            raise ValueError(f\"Cannot set 'channels' greater than '1' in case \"\n                             f\"'{self.__class__.__name__}' is not trainable\")\n\n        self._init_t = t\n        self.learn = learn\n        self.semi_grad = semi_grad\n        self.channels = channels\n\n        self.t = Parameter(torch.empty(channels)) if learn else t\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        if isinstance(self.t, Tensor):\n            self.t.data.fill_(self._init_t)\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        t = self.t\n        if self.channels != 1:\n            self.assert_two_dimensional_input(x, dim)\n            assert isinstance(t, Tensor)\n            t = t.view(-1, self.channels)\n\n        alpha = x\n        if not isinstance(t, (int, float)) or t != 1:\n            alpha = x * t\n\n        if not self.learn and self.semi_grad:\n            with torch.no_grad():\n                alpha = softmax(alpha, index, ptr, dim_size, dim)\n        else:\n            alpha = softmax(alpha, index, ptr, dim_size, dim)\n        return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce='sum')\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(learn={self.learn})')\n\n\nclass PowerMeanAggregation(Aggregation):\n    r\"\"\"The powermean aggregation operator based on a power term, as\n    described in the `\"DeeperGCN: All You Need to Train Deeper GCNs\"\n    <https://arxiv.org/abs/2006.07739>`_ paper.\n\n    .. math::\n        \\mathrm{powermean}(\\mathcal{X}|p) = \\left(\\frac{1}{|\\mathcal{X}|}\n        \\sum_{\\mathbf{x}_i\\in\\mathcal{X}}\\mathbf{x}_i^{p}\\right)^{1/p},\n\n    where :math:`p` controls the power of the powermean when aggregating over\n    a set of features :math:`\\mathcal{X}`.\n\n    Args:\n        p (float, optional): Initial power for powermean aggregation.\n            (default: :obj:`1.0`)\n        learn (bool, optional): If set to :obj:`True`, will learn the value\n            :obj:`p` for powermean aggregation dynamically.\n            (default: :obj:`False`)\n        channels (int, optional): Number of channels to learn from :math:`p`.\n            If set to a value greater than :obj:`1`, :math:`p` will be learned\n            per input feature channel. This requires compatible shapes for the\n            input to the forward calculation. (default: :obj:`1`)\n        clamp_min (float, optional): Lower-bound of the range to be clamped\n            to. There is no lower bound if set to :obj:`None`.\n        clamp_max (float, optional): Upper-bound of the range to be clamped\n            to. There is no upper bound if set to :obj:`None`.\n    \"\"\"\n    def __init__(\n        self,\n        p: float = 1.0,\n        learn: bool = False,\n        channels: int = 1,\n        clamp_min: Optional[float] = 1e-4,\n        clamp_max: Optional[float] = 100.,\n    ) -> None:\n        super().__init__()\n\n        if not learn and channels != 1:\n            raise ValueError(f\"Cannot set 'channels' greater than '1' in case \"\n                             f\"'{self.__class__.__name__}' is not trainable\")\n\n        self._init_p = p\n        self.learn = learn\n        self.channels = channels\n\n        self.p = Parameter(torch.empty(channels)) if learn else p\n        self.reset_parameters()\n        self.min_value = clamp_min\n        self.max_value = clamp_max\n\n    def reset_parameters(self):\n        if isinstance(self.p, Tensor):\n            self.p.data.fill_(self._init_p)\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        p = self.p\n        if self.channels != 1:\n            assert isinstance(p, Tensor)\n            self.assert_two_dimensional_input(x, dim)\n            p = p.view(-1, self.channels)\n\n        if not isinstance(p, (int, float)) or p != 1:\n            x = x.clamp(min=self.min_value, max=self.max_value).pow(p)\n\n        out = self.reduce(x, index, ptr, dim_size, dim, reduce='mean')\n\n        if not isinstance(p, (int, float)) or p != 1:\n            out = out.clamp(min=self.min_value, max=self.max_value).pow(1. / p)\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(learn={self.learn})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/deep_sets.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.nn.inits import reset\n\n\nclass DeepSetsAggregation(Aggregation):\n    r\"\"\"Performs Deep Sets aggregation in which the elements to aggregate are\n    first transformed by a Multi-Layer Perceptron (MLP)\n    :math:`\\phi_{\\mathbf{\\Theta}}`, summed, and then transformed by another MLP\n    :math:`\\rho_{\\mathbf{\\Theta}}`, as suggested in the `\"Graph Neural Networks\n    with Adaptive Readouts\" <https://arxiv.org/abs/2211.04952>`_ paper.\n\n    Args:\n        local_nn (torch.nn.Module, optional): The neural network\n            :math:`\\phi_{\\mathbf{\\Theta}}`, *e.g.*, defined by\n            :class:`torch.nn.Sequential` or\n            :class:`torch_geometric.nn.models.MLP`. (default: :obj:`None`)\n        global_nn (torch.nn.Module, optional): The neural network\n            :math:`\\rho_{\\mathbf{\\Theta}}`, *e.g.*, defined by\n            :class:`torch.nn.Sequential` or\n            :class:`torch_geometric.nn.models.MLP`. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        local_nn: Optional[torch.nn.Module] = None,\n        global_nn: Optional[torch.nn.Module] = None,\n    ):\n        super().__init__()\n\n        from torch_geometric.nn import MLP\n\n        self.local_nn = self.local_mlp = None\n        if isinstance(local_nn, MLP):\n            self.local_mlp = local_nn\n        else:\n            self.local_nn = local_nn\n\n        self.global_nn = self.global_mlp = None\n        if isinstance(global_nn, MLP):\n            self.global_mlp = global_nn\n        else:\n            self.global_nn = global_nn\n\n    def reset_parameters(self):\n        reset(self.local_nn)\n        reset(self.local_mlp)\n        reset(self.global_nn)\n        reset(self.global_mlp)\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        if self.local_mlp is not None:\n            x = self.local_mlp(x, batch=index, batch_size=dim_size)\n        if self.local_nn is not None:\n            x = self.local_nn(x)\n\n        x = self.reduce(x, index, ptr, dim_size, dim, reduce='sum')\n\n        if self.global_mlp is not None:\n            x = self.global_mlp(x, batch=index, batch_size=dim_size)\n        elif self.global_nn is not None:\n            x = self.global_nn(x)\n\n        return x\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'local_nn={self.local_mlp or self.local_nn}, '\n                f'global_nn={self.global_mlp or self.global_nn})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/equilibrium.py",
    "content": "from typing import Callable, List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.utils import scatter\n\n\nclass ResNetPotential(torch.nn.Module):\n    def __init__(self, in_channels: int, out_channels: int,\n                 num_layers: List[int]):\n\n        super().__init__()\n        sizes = [in_channels] + num_layers + [out_channels]\n        self.layers = torch.nn.ModuleList([\n            torch.nn.Sequential(torch.nn.Linear(in_size, out_size),\n                                torch.nn.LayerNorm(out_size), torch.nn.Tanh())\n            for in_size, out_size in zip(sizes[:-2], sizes[1:-1])\n        ])\n        self.layers.append(torch.nn.Linear(sizes[-2], sizes[-1]))\n\n        self.res_trans = torch.nn.ModuleList([\n            torch.nn.Linear(in_channels, layer_size)\n            for layer_size in num_layers + [out_channels]\n        ])\n\n    def forward(self, x: Tensor, y: Tensor, index: Optional[Tensor],\n                dim_size: Optional[int] = None) -> Tensor:\n        if index is None:\n            inp = torch.cat([x, y.expand(x.size(0), -1)], dim=1)\n        else:\n            inp = torch.cat([x, y[index]], dim=1)\n\n        h = inp\n        for layer, res in zip(self.layers, self.res_trans):\n            h = layer(h)\n            h = res(inp) + h\n\n        if index is None:\n            return h.mean()\n\n        if dim_size is None:\n            dim_size = int(index.max().item() + 1)\n\n        return scatter(h, index, 0, dim_size, reduce='mean').sum()\n\n\nclass MomentumOptimizer(torch.nn.Module):\n    r\"\"\"Provides an inner loop optimizer for the implicitly defined output\n    layer. It is based on an unrolled Nesterov momentum algorithm.\n\n    Args:\n        learning_rate (float): learning rate for optimizer.\n        momentum (float): momentum for optimizer.\n        learnable (bool): If :obj:`True` then the :obj:`learning_rate` and\n            :obj:`momentum` will be learnable parameters. If False they\n            are fixed. (default: :obj:`True`)\n    \"\"\"\n    def __init__(self, learning_rate: float = 0.1, momentum: float = 0.9,\n                 learnable: bool = True):\n        super().__init__()\n\n        self._initial_lr = learning_rate\n        self._initial_mom = momentum\n        self._lr = torch.nn.Parameter(Tensor([learning_rate]),\n                                      requires_grad=learnable)\n        self._mom = torch.nn.Parameter(Tensor([momentum]),\n                                       requires_grad=learnable)\n        self.softplus = torch.nn.Softplus()\n        self.sigmoid = torch.nn.Sigmoid()\n\n    def reset_parameters(self):\n        self._lr.data.fill_(self._initial_lr)\n        self._mom.data.fill_(self._initial_mom)\n\n    @property\n    def learning_rate(self):\n        return self.softplus(self._lr)\n\n    @property\n    def momentum(self):\n        return self.sigmoid(self._mom)\n\n    def forward(\n        self,\n        x: Tensor,\n        y: Tensor,\n        index: Optional[Tensor],\n        dim_size: Optional[int],\n        func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor],\n        iterations: int = 5,\n    ) -> Tuple[Tensor, float]:\n\n        momentum_buffer = torch.zeros_like(y)\n        for _ in range(iterations):\n            val = func(x, y, index, dim_size)\n            grad = torch.autograd.grad(val, y, create_graph=True,\n                                       retain_graph=True)[0]\n            delta = self.learning_rate * grad\n            momentum_buffer = self.momentum * momentum_buffer - delta\n            y = y + momentum_buffer\n        return y\n\n\nclass EquilibriumAggregation(Aggregation):\n    r\"\"\"The equilibrium aggregation layer from the `\"Equilibrium Aggregation:\n    Encoding Sets via Optimization\" <https://arxiv.org/abs/2202.12795>`_ paper.\n\n    The output of this layer :math:`\\mathbf{y}` is defined implicitly via a\n    potential function :math:`F(\\mathbf{x}, \\mathbf{y})`, a regularization term\n    :math:`R(\\mathbf{y})`, and the condition\n\n    .. math::\n        \\mathbf{y} = \\min_\\mathbf{y} R(\\mathbf{y}) + \\sum_{i}\n        F(\\mathbf{x}_i, \\mathbf{y}).\n\n    The given implementation uses a ResNet-like model for the potential\n    function and a simple :math:`L_2` norm :math:`R(\\mathbf{y}) =\n    \\textrm{softplus}(\\lambda) \\cdot {\\| \\mathbf{y} \\|}^2_2` for the\n    regularizer with learnable weight :math:`\\lambda`.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        out_channels (int): Size of each output sample.\n        num_layers (List[int): List of hidden channels in the potential\n            function.\n        grad_iter (int): The number of steps to take in the internal gradient\n            descent. (default: :obj:`5`)\n        lamb (float): The initial regularization constant.\n            (default: :obj:`0.1`)\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int,\n                 num_layers: List[int], grad_iter: int = 5, lamb: float = 0.1):\n        super().__init__()\n\n        self.potential = ResNetPotential(in_channels + out_channels, 1,\n                                         num_layers)\n        self.optimizer = MomentumOptimizer()\n        self.initial_lamb = lamb\n        self.lamb = torch.nn.Parameter(Tensor(1), requires_grad=True)\n        self.softplus = torch.nn.Softplus()\n        self.grad_iter = grad_iter\n        self.output_dim = out_channels\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.lamb.data.fill_(self.initial_lamb)\n        reset(self.optimizer)\n        reset(self.potential)\n\n    def init_output(self, dim_size: int) -> Tensor:\n        return torch.zeros(dim_size, self.output_dim, requires_grad=True,\n                           device=self.lamb.device).float()\n\n    def reg(self, y: Tensor) -> Tensor:\n        return self.softplus(self.lamb) * y.square().sum(dim=-1).mean()\n\n    def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor],\n               dim_size: Optional[int] = None):\n        return self.potential(x, y, index, dim_size) + self.reg(y)\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        self.assert_index_present(index)\n\n        dim_size = int(index.max()) + 1 if dim_size is None else dim_size\n\n        with torch.enable_grad():\n            y = self.optimizer(x, self.init_output(dim_size), index, dim_size,\n                               self.energy, iterations=self.grad_iter)\n\n        return y\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}()')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/fused.py",
    "content": "import math\nfrom typing import Dict, List, Optional, Tuple, Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr.base import Aggregation\nfrom torch_geometric.nn.aggr.basic import (\n    MaxAggregation,\n    MeanAggregation,\n    MinAggregation,\n    MulAggregation,\n    StdAggregation,\n    SumAggregation,\n    VarAggregation,\n)\nfrom torch_geometric.nn.resolver import aggregation_resolver\nfrom torch_geometric.utils import scatter\n\n\nclass FusedAggregation(Aggregation):\n    r\"\"\"Helper class to fuse computation of multiple aggregations together.\n\n    Used internally in :class:`~torch_geometric.nn.aggr.MultiAggregation` to\n    speed-up computation.\n    Currently, the following optimizations are performed:\n\n    * :class:`MeanAggregation` will share the output with\n      :class:`SumAggregation` in case it is present as well.\n\n    * :class:`VarAggregation` will share the output with either\n      :class:`MeanAggregation` or :class:`SumAggregation` in case one of them\n      is present as well.\n\n    * :class:`StdAggregation` will share the output with either\n      :class:`VarAggregation`, :class:`MeanAggregation` or\n      :class:`SumAggregation` in case one of them is present as well.\n\n    In addition, temporary values such as the count per group index are shared\n    as well.\n\n    Benchmarking results on PyTorch 1.12 (summed over 1000 runs):\n\n    +------------------------------+---------+---------+\n    | Aggregators                  | Vanilla | Fusion  |\n    +==============================+=========+=========+\n    | :obj:`[sum, mean]`           | 0.3325s | 0.1996s |\n    +------------------------------+---------+---------+\n    | :obj:`[sum, mean, min, max]` | 0.7139s | 0.5037s |\n    +------------------------------+---------+---------+\n    | :obj:`[sum, mean, var]`      | 0.6849s | 0.3871s |\n    +------------------------------+---------+---------+\n    | :obj:`[sum, mean, var, std]` | 1.0955s | 0.3973s |\n    +------------------------------+---------+---------+\n\n    Args:\n        aggrs (list): The list of aggregation schemes to use.\n    \"\"\"\n    # We can fuse all aggregations together that rely on `scatter` directives.\n    FUSABLE_AGGRS = {\n        SumAggregation,\n        MeanAggregation,\n        MinAggregation,\n        MaxAggregation,\n        MulAggregation,\n        VarAggregation,\n        StdAggregation,\n    }\n\n    # All aggregations that rely on computing the degree of indices.\n    DEGREE_BASED_AGGRS = {\n        MeanAggregation,\n        VarAggregation,\n        StdAggregation,\n    }\n\n    # Map aggregations to `reduce` options in `scatter` directives.\n    REDUCE = {\n        'SumAggregation': 'sum',\n        'MeanAggregation': 'sum',\n        'MinAggregation': 'min',\n        'MaxAggregation': 'max',\n        'MulAggregation': 'mul',\n        'VarAggregation': 'pow_sum',\n        'StdAggregation': 'pow_sum',\n    }\n\n    def __init__(self, aggrs: List[Union[Aggregation, str]]):\n        super().__init__()\n\n        if not isinstance(aggrs, (list, tuple)):\n            raise ValueError(f\"'aggrs' of '{self.__class__.__name__}' should \"\n                             f\"be a list or tuple (got '{type(aggrs)}').\")\n\n        if len(aggrs) == 0:\n            raise ValueError(f\"'aggrs' of '{self.__class__.__name__}' should \"\n                             f\"not be empty.\")\n\n        aggrs = [aggregation_resolver(aggr) for aggr in aggrs]\n        aggr_classes = [aggr.__class__ for aggr in aggrs]\n        self.aggr_names = [cls.__name__ for cls in aggr_classes]\n        self.aggr_index: Dict[str, int] = {\n            name: i\n            for i, name in enumerate(self.aggr_names)\n        }\n\n        for cls in aggr_classes:\n            if cls not in self.FUSABLE_AGGRS:\n                raise ValueError(f\"Received aggregation '{cls.__name__}' in \"\n                                 f\"'{self.__class__.__name__}' which is not \"\n                                 f\"fusable\")\n\n        self.semi_grad = False\n        for aggr in aggrs:\n            if hasattr(aggr, 'semi_grad'):\n                self.semi_grad = self.semi_grad or aggr.semi_grad\n\n        # Check whether we need to compute degree information:\n        self.need_degree = False\n        for cls in aggr_classes:\n            if cls in self.DEGREE_BASED_AGGRS:\n                self.need_degree = True\n\n        # Determine which reduction to use for each aggregator:\n        # An entry of `None` means that this operator re-uses intermediate\n        # outputs from other aggregators.\n        reduce_ops: List[Optional[str]] = []\n        # Determine which `(Aggregator, index)` to use as intermediate output:\n        lookup_ops: List[Optional[Tuple[str, int]]] = []\n\n        for name in self.aggr_names:\n            if name == 'MeanAggregation':\n                # Directly use output of `SumAggregation`:\n                if 'SumAggregation' in self.aggr_index:\n                    reduce_ops.append(None)\n                    lookup_ops.append((\n                        'SumAggregation',\n                        self.aggr_index['SumAggregation'],\n                    ))\n                else:\n                    reduce_ops.append(self.REDUCE[name])\n                    lookup_ops.append(None)\n\n            elif name == 'VarAggregation':\n                if 'MeanAggregation' in self.aggr_index:\n                    reduce_ops.append(self.REDUCE[name])\n                    lookup_ops.append((\n                        'MeanAggregation',\n                        self.aggr_index['MeanAggregation'],\n                    ))\n                elif 'SumAggregation' in self.aggr_index:\n                    reduce_ops.append(self.REDUCE[name])\n                    lookup_ops.append((\n                        'SumAggregation',\n                        self.aggr_index['SumAggregation'],\n                    ))\n                else:\n                    reduce_ops.append(self.REDUCE[name])\n                    lookup_ops.append(None)\n\n            elif name == 'StdAggregation':\n                # Directly use output of `VarAggregation`:\n                if 'VarAggregation' in self.aggr_index:\n                    reduce_ops.append(None)\n                    lookup_ops.append((\n                        'VarAggregation',\n                        self.aggr_index['VarAggregation'],\n                    ))\n                elif 'MeanAggregation' in self.aggr_index:\n                    reduce_ops.append(self.REDUCE[name])\n                    lookup_ops.append((\n                        'MeanAggregation',\n                        self.aggr_index['MeanAggregation'],\n                    ))\n                elif 'SumAggregation' in self.aggr_index:\n                    reduce_ops.append(self.REDUCE[name])\n                    lookup_ops.append((\n                        'SumAggregation',\n                        self.aggr_index['SumAggregation'],\n                    ))\n                else:\n                    reduce_ops.append(self.REDUCE[name])\n                    lookup_ops.append(None)\n\n            else:\n                reduce_ops.append(self.REDUCE[name])\n                lookup_ops.append(None)\n\n        self.reduce_ops: List[Optional[str]] = reduce_ops\n        self.lookup_ops: List[Optional[Tuple[str, int]]] = lookup_ops\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> List[Tensor]:\n\n        # Assert two-dimensional input for now to simplify computation:\n        # TODO refactor this to support any dimension.\n        self.assert_index_present(index)\n        self.assert_two_dimensional_input(x, dim)\n\n        assert index is not None\n\n        if dim_size is None:\n            if ptr is not None:\n                dim_size = ptr.numel() - 1\n            else:\n                dim_size = int(index.max()) + 1 if index.numel() > 0 else 0\n\n        count: Optional[Tensor] = None\n        if self.need_degree:\n            count = x.new_zeros(dim_size)\n            count.scatter_add_(0, index, x.new_ones(x.size(0)))\n            count = count.clamp_(min=1).view(-1, 1)\n\n        #######################################################################\n\n        outs: List[Optional[Tensor]] = []\n\n        # Iterate over all reduction ops to compute first results:\n        for reduce in self.reduce_ops:\n            if reduce is None:\n                outs.append(None)\n                continue\n            assert isinstance(reduce, str)\n\n            if reduce == 'pow_sum':\n                if self.semi_grad:\n                    out = scatter(x.detach() * x.detach(), index, 0, dim_size,\n                                  reduce='sum')\n                else:\n                    out = scatter(x * x, index, 0, dim_size, reduce='sum')\n            else:\n                out = scatter(x, index, 0, dim_size, reduce=reduce)\n\n            outs.append(out)\n\n        #######################################################################\n\n        # Compute `MeanAggregation` first to be able to re-use it:\n        i = self.aggr_index.get('MeanAggregation')\n        if i is not None:\n            assert count is not None\n\n            if self.lookup_ops[i] is None:\n                sum_ = outs[i]\n            else:\n                lookup_op = self.lookup_ops[i]\n                assert lookup_op is not None\n                tmp_aggr, j = lookup_op\n                assert tmp_aggr == 'SumAggregation'\n\n                sum_ = outs[j]\n\n            assert sum_ is not None\n            outs[i] = sum_ / count\n\n        # Compute `VarAggregation` second to be able to re-use it:\n        if 'VarAggregation' in self.aggr_index:\n            i = self.aggr_index['VarAggregation']\n\n            assert count is not None\n\n            if self.lookup_ops[i] is None:\n                sum_ = scatter(x, index, 0, dim_size, reduce='sum')\n                mean = sum_ / count\n            else:\n                lookup_op = self.lookup_ops[i]\n                assert lookup_op is not None\n                tmp_aggr, j = lookup_op\n\n                if tmp_aggr == 'SumAggregation':\n                    sum_ = outs[j]\n                    assert sum_ is not None\n                    mean = sum_ / count\n                elif tmp_aggr == 'MeanAggregation':\n                    mean = outs[j]\n                else:\n                    raise NotImplementedError\n\n            pow_sum = outs[i]\n\n            assert pow_sum is not None\n            assert mean is not None\n            outs[i] = (pow_sum / count) - (mean * mean)\n\n        # Compute `StdAggregation` last:\n        if 'StdAggregation' in self.aggr_index:\n            i = self.aggr_index['StdAggregation']\n\n            var: Optional[Tensor] = None\n            pow_sum: Optional[Tensor] = None\n            mean: Optional[Tensor] = None\n\n            if self.lookup_ops[i] is None:\n                pow_sum = outs[i]\n                sum_ = scatter(x, index, 0, dim_size, reduce='sum')\n                assert count is not None\n                mean = sum_ / count\n            else:\n                lookup_op = self.lookup_ops[i]\n                assert lookup_op is not None\n                tmp_aggr, j = lookup_op\n\n                if tmp_aggr == 'VarAggregation':\n                    var = outs[j]\n                elif tmp_aggr == 'SumAggregation':\n                    pow_sum = outs[i]\n                    sum_ = outs[j]\n                    assert sum_ is not None\n                    assert count is not None\n                    mean = sum_ / count\n                elif tmp_aggr == 'MeanAggregation':\n                    pow_sum = outs[i]\n                    mean = outs[j]\n                else:\n                    raise NotImplementedError\n\n            if var is None:\n                assert pow_sum is not None\n                assert count is not None\n                assert mean is not None\n                var = (pow_sum / count) - (mean * mean)\n\n            # Allow \"undefined\" gradient at `sqrt(0.0)`:\n            out = var.clamp(min=1e-5).sqrt()\n            out = out.masked_fill(out <= math.sqrt(1e-5), 0.0)\n\n            outs[i] = out\n\n        #######################################################################\n\n        vals: List[Tensor] = []\n        for out in outs:\n            assert out is not None\n            vals.append(out)\n\n        return vals\n"
  },
  {
    "path": "torch_geometric/nn/aggr/gmt.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.experimental import disable_dynamic_shapes\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.nn.aggr.utils import (\n    PoolingByMultiheadAttention,\n    SetAttentionBlock,\n)\n\n\nclass GraphMultisetTransformer(Aggregation):\n    r\"\"\"The Graph Multiset Transformer pooling operator from the\n    `\"Accurate Learning of Graph Representations\n    with Graph Multiset Pooling\" <https://arxiv.org/abs/2102.11533>`_ paper.\n\n    The :class:`GraphMultisetTransformer` aggregates elements into\n    :math:`k` representative elements via attention-based pooling, computes the\n    interaction among them via :obj:`num_encoder_blocks` self-attention blocks,\n    and finally pools the representative elements via attention-based pooling\n    into a single cluster.\n\n    .. note::\n\n        :class:`GraphMultisetTransformer` requires sorted indices :obj:`index`\n        as input. Specifically, if you use this aggregation as part of\n        :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that\n        :obj:`edge_index` is sorted by destination nodes, either by manually\n        sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index`\n        or by calling :meth:`torch_geometric.data.Data.sort`.\n\n    Args:\n        channels (int): Size of each input sample.\n        k (int): Number of :math:`k` representative nodes after pooling.\n        num_encoder_blocks (int, optional): Number of Set Attention Blocks\n            (SABs) between the two pooling blocks. (default: :obj:`1`)\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        norm (str, optional): If set to :obj:`True`, will apply layer\n            normalization. (default: :obj:`False`)\n        dropout (float, optional): Dropout probability of attention weights.\n            (default: :obj:`0`)\n    \"\"\"\n    def __init__(\n        self,\n        channels: int,\n        k: int,\n        num_encoder_blocks: int = 1,\n        heads: int = 1,\n        layer_norm: bool = False,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n\n        self.channels = channels\n        self.k = k\n        self.heads = heads\n        self.layer_norm = layer_norm\n        self.dropout = dropout\n\n        self.pma1 = PoolingByMultiheadAttention(channels, k, heads, layer_norm,\n                                                dropout)\n        self.encoders = torch.nn.ModuleList([\n            SetAttentionBlock(channels, heads, layer_norm, dropout)\n            for _ in range(num_encoder_blocks)\n        ])\n        self.pma2 = PoolingByMultiheadAttention(channels, 1, heads, layer_norm,\n                                                dropout)\n\n    def reset_parameters(self):\n        self.pma1.reset_parameters()\n        for encoder in self.encoders:\n            encoder.reset_parameters()\n        self.pma2.reset_parameters()\n\n    @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])\n    def forward(\n        self,\n        x: Tensor,\n        index: Optional[Tensor] = None,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n        dim: int = -2,\n        max_num_elements: Optional[int] = None,\n    ) -> Tensor:\n\n        x, mask = self.to_dense_batch(x, index, ptr, dim_size, dim,\n                                      max_num_elements=max_num_elements)\n\n        x = self.pma1(x, mask)\n\n        for encoder in self.encoders:\n            x = encoder(x)\n\n        x = self.pma2(x)\n\n        return x.squeeze(1)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.channels}, '\n                f'k={self.k}, heads={self.heads}, '\n                f'layer_norm={self.layer_norm}, '\n                f'dropout={self.dropout})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/gru.py",
    "content": "from typing import Optional\n\nfrom torch import Tensor\nfrom torch.nn import GRU\n\nfrom torch_geometric.experimental import disable_dynamic_shapes\nfrom torch_geometric.nn.aggr import Aggregation\n\n\nclass GRUAggregation(Aggregation):\n    r\"\"\"Performs GRU aggregation in which the elements to aggregate are\n    interpreted as a sequence, as described in the `\"Graph Neural Networks\n    with Adaptive Readouts\" <https://arxiv.org/abs/2211.04952>`_ paper.\n\n    .. note::\n\n        :class:`GRUAggregation` requires sorted indices :obj:`index` as input.\n        Specifically, if you use this aggregation as part of\n        :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that\n        :obj:`edge_index` is sorted by destination nodes, either by manually\n        sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index`\n        or by calling :meth:`torch_geometric.data.Data.sort`.\n\n    .. warning::\n\n        :class:`GRUAggregation` is not a permutation-invariant operator.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        out_channels (int): Size of each output sample.\n        **kwargs (optional): Additional arguments of :class:`torch.nn.GRU`.\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int, **kwargs):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.gru = GRU(in_channels, out_channels, batch_first=True, **kwargs)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.gru.reset_parameters()\n\n    @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])\n    def forward(\n        self,\n        x: Tensor,\n        index: Optional[Tensor] = None,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n        dim: int = -2,\n        max_num_elements: Optional[int] = None,\n    ) -> Tensor:\n\n        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,\n                                   max_num_elements=max_num_elements)\n\n        return self.gru(x)[0][:, -1]\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/lcm.py",
    "content": "from math import ceil, log2\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import GRUCell, Linear\n\nfrom torch_geometric.experimental import disable_dynamic_shapes\nfrom torch_geometric.nn.aggr import Aggregation\n\n\nclass LCMAggregation(Aggregation):\n    r\"\"\"The Learnable Commutative Monoid aggregation from the\n    `\"Learnable Commutative Monoids for Graph Neural Networks\"\n    <https://arxiv.org/abs/2212.08541>`_ paper, in which the elements are\n    aggregated using a binary tree reduction with\n    :math:`\\mathcal{O}(\\log |\\mathcal{V}|)` depth.\n\n    .. note::\n\n        :class:`LCMAggregation` requires sorted indices :obj:`index` as input.\n        Specifically, if you use this aggregation as part of\n        :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that\n        :obj:`edge_index` is sorted by destination nodes, either by manually\n        sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index`\n        or by calling :meth:`torch_geometric.data.Data.sort`.\n\n    .. warning::\n\n        :class:`LCMAggregation` is not a permutation-invariant operator.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        out_channels (int): Size of each output sample.\n        project (bool, optional): If set to :obj:`True`, the layer will apply a\n            linear transformation followed by an activation function before\n            aggregation. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        project: bool = True,\n    ):\n        super().__init__()\n\n        if in_channels != out_channels and not project:\n            raise ValueError(f\"Inputs of '{self.__class__.__name__}' must be \"\n                             f\"projected if `in_channels != out_channels`\")\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.project = project\n\n        if self.project:\n            self.lin = Linear(in_channels, out_channels)\n        else:\n            self.lin = None\n\n        self.gru_cell = GRUCell(out_channels, out_channels)\n\n    def reset_parameters(self):\n        if self.project:\n            self.lin.reset_parameters()\n        self.gru_cell.reset_parameters()\n\n    @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])\n    def forward(\n        self,\n        x: Tensor,\n        index: Optional[Tensor] = None,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n        dim: int = -2,\n        max_num_elements: Optional[int] = None,\n    ) -> Tensor:\n\n        if self.project:\n            x = self.lin(x).relu()\n\n        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,\n                                   max_num_elements=max_num_elements)\n\n        x = x.permute(1, 0, 2)  # [num_neighbors, num_nodes, num_features]\n        _, num_nodes, num_features = x.size()\n\n        depth = ceil(log2(x.size(0)))\n        for _ in range(depth):\n            half_size = ceil(x.size(0) / 2)\n\n            if x.size(0) % 2 == 1:\n                # This level of the tree has an odd number of nodes, so the\n                # remaining unmatched node gets moved to the next level.\n                x, remainder = x[:-1], x[-1:]\n            else:\n                remainder = None\n\n            left_right = x.view(-1, 2, num_nodes, num_features)\n            right_left = left_right.flip(dims=[1])\n\n            left_right = left_right.reshape(-1, num_features)\n            right_left = right_left.reshape(-1, num_features)\n\n            # Execute the GRUCell for all (left, right) pairs in the current\n            # level of the tree in parallel:\n            out = self.gru_cell(left_right, right_left)\n            out = out.view(-1, 2, num_nodes, num_features)\n            out = out.mean(dim=1)\n            if remainder is not None:\n                out = torch.cat([out, remainder], dim=0)\n\n            x = out.view(half_size, num_nodes, num_features)\n\n        assert x.size(0) == 1\n        return x.squeeze(0)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, project={self.project})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/lstm.py",
    "content": "from typing import Optional\n\nfrom torch import Tensor\nfrom torch.nn import LSTM\n\nfrom torch_geometric.experimental import disable_dynamic_shapes\nfrom torch_geometric.nn.aggr import Aggregation\n\n\nclass LSTMAggregation(Aggregation):\n    r\"\"\"Performs LSTM-style aggregation in which the elements to aggregate are\n    interpreted as a sequence, as described in the `\"Inductive Representation\n    Learning on Large Graphs\" <https://arxiv.org/abs/1706.02216>`_ paper.\n\n    .. note::\n\n        :class:`LSTMAggregation` requires sorted indices :obj:`index` as input.\n        Specifically, if you use this aggregation as part of\n        :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that\n        :obj:`edge_index` is sorted by destination nodes, either by manually\n        sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index`\n        or by calling :meth:`torch_geometric.data.Data.sort`.\n\n    .. warning::\n\n        :class:`LSTMAggregation` is not a permutation-invariant operator.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        out_channels (int): Size of each output sample.\n        **kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`.\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int, **kwargs):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.lstm = LSTM(in_channels, out_channels, batch_first=True, **kwargs)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.lstm.reset_parameters()\n\n    @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])\n    def forward(\n        self,\n        x: Tensor,\n        index: Optional[Tensor] = None,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n        dim: int = -2,\n        max_num_elements: Optional[int] = None,\n    ) -> Tensor:\n\n        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,\n                                   max_num_elements=max_num_elements)\n\n        return self.lstm(x)[0][:, -1]\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/mlp.py",
    "content": "from typing import Optional\n\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation\n\n\nclass MLPAggregation(Aggregation):\n    r\"\"\"Performs MLP aggregation in which the elements to aggregate are\n    flattened into a single vectorial representation, and are then processed by\n    a Multi-Layer Perceptron (MLP), as described in the `\"Graph Neural Networks\n    with Adaptive Readouts\" <https://arxiv.org/abs/2211.04952>`_ paper.\n\n    .. note::\n\n        :class:`MLPAggregation` requires sorted indices :obj:`index` as input.\n        Specifically, if you use this aggregation as part of\n        :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that\n        :obj:`edge_index` is sorted by destination nodes, either by manually\n        sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index`\n        or by calling :meth:`torch_geometric.data.Data.sort`.\n\n    .. warning::\n\n        :class:`MLPAggregation` is not a permutation-invariant operator.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        out_channels (int): Size of each output sample.\n        max_num_elements (int): The maximum number of elements to aggregate per\n            group.\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.models.MLP`.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        max_num_elements: int,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.max_num_elements = max_num_elements\n\n        from torch_geometric.nn import MLP\n        self.mlp = MLP(\n            in_channels=in_channels * max_num_elements,\n            out_channels=out_channels,\n            **kwargs,\n        )\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.mlp.reset_parameters()\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,\n                                   max_num_elements=self.max_num_elements)\n\n        return self.mlp(x.view(-1, x.size(1) * x.size(2)), index, dim_size)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, '\n                f'max_num_elements={self.max_num_elements})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/multi.py",
    "content": "import copy\nfrom typing import Any, Dict, List, Optional, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Linear, MultiheadAttention\n\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.nn.aggr.fused import FusedAggregation\nfrom torch_geometric.nn.dense import HeteroDictLinear\nfrom torch_geometric.nn.resolver import aggregation_resolver\n\n\nclass MultiAggregation(Aggregation):\n    r\"\"\"Performs aggregations with one or more aggregators and combines\n    aggregated results, as described in the `\"Principal Neighbourhood\n    Aggregation for Graph Nets\" <https://arxiv.org/abs/2004.05718>`_ and\n    `\"Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions\"\n    <https://arxiv.org/abs/2104.01481>`_ papers.\n\n    Args:\n        aggrs (list): The list of aggregation schemes to use.\n        aggrs_kwargs (dict, optional): Arguments passed to the\n            respective aggregation function in case it gets automatically\n            resolved. (default: :obj:`None`)\n        mode (str, optional): The combine mode to use for combining\n            aggregated results from multiple aggregations (:obj:`\"cat\"`,\n            :obj:`\"proj\"`, :obj:`\"sum\"`, :obj:`\"mean\"`, :obj:`\"max\"`,\n            :obj:`\"min\"`, :obj:`\"logsumexp\"`, :obj:`\"std\"`, :obj:`\"var\"`,\n            :obj:`\"attn\"`). (default: :obj:`\"cat\"`)\n        mode_kwargs (dict, optional): Arguments passed for the combine\n            :obj:`mode`. When :obj:`\"proj\"` or :obj:`\"attn\"` is used as the\n            combine :obj:`mode`, :obj:`in_channels` (int or tuple) and\n            :obj:`out_channels` (int) are needed to be specified respectively\n            for the size of each input sample to combine from the respective\n            aggregation outputs and the size of each output sample after\n            combination. When :obj:`\"attn\"` mode is used, :obj:`num_heads`\n            (int) is needed to be specified for the number of parallel\n            attention heads. (default: :obj:`None`)\n    \"\"\"\n    fused_out_index: List[int]\n    is_fused_aggr: List[bool]\n\n    def __init__(\n        self,\n        aggrs: List[Union[Aggregation, str]],\n        aggrs_kwargs: Optional[List[Dict[str, Any]]] = None,\n        mode: Optional[str] = 'cat',\n        mode_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n\n        super().__init__()\n\n        if not isinstance(aggrs, (list, tuple)):\n            raise ValueError(f\"'aggrs' of '{self.__class__.__name__}' should \"\n                             f\"be a list or tuple (got '{type(aggrs)}').\")\n\n        if len(aggrs) == 0:\n            raise ValueError(f\"'aggrs' of '{self.__class__.__name__}' should \"\n                             f\"not be empty.\")\n\n        if aggrs_kwargs is None:\n            aggrs_kwargs = [{}] * len(aggrs)\n        elif len(aggrs) != len(aggrs_kwargs):\n            raise ValueError(f\"'aggrs_kwargs' with invalid length passed to \"\n                             f\"'{self.__class__.__name__}' \"\n                             f\"(got '{len(aggrs_kwargs)}', \"\n                             f\"expected '{len(aggrs)}'). Ensure that both \"\n                             f\"'aggrs' and 'aggrs_kwargs' are consistent.\")\n\n        self.aggrs = torch.nn.ModuleList([\n            aggregation_resolver(aggr, **aggr_kwargs)\n            for aggr, aggr_kwargs in zip(aggrs, aggrs_kwargs)\n        ])\n\n        # Divide the set into fusable and non-fusable aggregations:\n        fused_aggrs: List[Aggregation] = []\n        self.fused_out_index: List[int] = []\n        self.is_fused_aggr: List[bool] = []\n        for i, aggr in enumerate(self.aggrs):\n            if aggr.__class__ in FusedAggregation.FUSABLE_AGGRS:\n                fused_aggrs.append(aggr)\n                self.fused_out_index.append(i)\n                self.is_fused_aggr.append(True)\n            else:\n                self.is_fused_aggr.append(False)\n\n        if len(fused_aggrs) > 0:\n            self.fused_aggr = FusedAggregation(fused_aggrs)\n        else:\n            self.fused_aggr = None\n\n        self.mode = mode\n        mode_kwargs = copy.copy(mode_kwargs) or {}\n\n        self.in_channels = mode_kwargs.pop('in_channels', None)\n        self.out_channels = mode_kwargs.pop('out_channels', None)\n\n        if mode == 'proj' or mode == 'attn':\n            if len(aggrs) == 1:\n                raise ValueError(\"Multiple aggregations are required for \"\n                                 \"'proj' or 'attn' combine mode.\")\n\n            if (self.in_channels and self.out_channels) is None:\n                raise ValueError(\n                    f\"Combine mode '{mode}' must have `in_channels` \"\n                    f\"and `out_channels` specified.\")\n\n            if isinstance(self.in_channels, int):\n                self.in_channels = [self.in_channels] * len(aggrs)\n\n            if mode == 'proj':\n                self.lin = Linear(\n                    sum(self.in_channels),\n                    self.out_channels,\n                    **mode_kwargs,\n                )\n\n            elif mode == 'attn':\n                channels = {str(k): v for k, v, in enumerate(self.in_channels)}\n                self.lin_heads = HeteroDictLinear(channels, self.out_channels)\n                num_heads = mode_kwargs.pop('num_heads', 1)\n                self.multihead_attn = MultiheadAttention(\n                    self.out_channels,\n                    num_heads,\n                    **mode_kwargs,\n                )\n\n        dense_combine_modes = [\n            'sum', 'mean', 'max', 'min', 'logsumexp', 'std', 'var'\n        ]\n        if mode in dense_combine_modes:\n            self.dense_combine = getattr(torch, mode)\n\n    def reset_parameters(self):\n        for aggr in self.aggrs:\n            aggr.reset_parameters()\n        if self.mode == 'proj':\n            self.lin.reset_parameters()\n        if self.mode == 'attn':\n            self.lin_heads.reset_parameters()\n            self.multihead_attn._reset_parameters()\n\n    def get_out_channels(self, in_channels: int) -> int:\n        if self.out_channels is not None:\n            return self.out_channels\n        # TODO Support having customized `out_channels` in each aggregation.\n        if self.mode == 'cat':\n            return in_channels * len(self.aggrs)\n        return in_channels\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        # `FusedAggregation` is currently limited to two-dimensional inputs:\n        if index is None or x.dim() != 2 or self.fused_aggr is None:\n            outs = [aggr(x, index, ptr, dim_size, dim) for aggr in self.aggrs]\n            return self.combine(outs)\n\n        outs: List[Tensor] = [x] * len(self.aggrs)  # Fill with dummy tensors.\n\n        fused_outs = self.fused_aggr(x, index, ptr, dim_size, dim)\n        for i, out in zip(self.fused_out_index, fused_outs):\n            outs[i] = out\n\n        for i, aggr in enumerate(self.aggrs):\n            if not self.is_fused_aggr[i]:\n                outs[i] = aggr(x, index, ptr, dim_size, dim)\n\n        return self.combine(outs)\n\n    def combine(self, inputs: List[Tensor]) -> Tensor:\n        if len(inputs) == 1:\n            return inputs[0]\n\n        if self.mode == 'cat':\n            return torch.cat(inputs, dim=-1)\n\n        if hasattr(self, 'lin'):\n            return self.lin(torch.cat(inputs, dim=-1))\n\n        if hasattr(self, 'multihead_attn'):\n            x_dict = {str(k): v for k, v, in enumerate(inputs)}\n            x_dict = self.lin_heads(x_dict)\n            xs = [x_dict[str(key)] for key in range(len(inputs))]\n            x = torch.stack(xs, dim=0)\n            attn_out, _ = self.multihead_attn(x, x, x)\n            return torch.mean(attn_out, dim=0)\n\n        if hasattr(self, 'dense_combine'):\n            out = self.dense_combine(torch.stack(inputs, dim=0), dim=0)\n            return out if isinstance(out, Tensor) else out[0]\n\n        raise ValueError(f\"Combine mode '{self.mode}' is not supported.\")\n\n    def __repr__(self) -> str:\n        aggrs = ',\\n'.join([f'  {aggr}' for aggr in self.aggrs]) + ',\\n'\n        return f'{self.__class__.__name__}([\\n{aggrs}], mode={self.mode})'\n"
  },
  {
    "path": "torch_geometric/nn/aggr/patch_transformer.py",
    "content": "import math\nfrom typing import List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.experimental import disable_dynamic_shapes\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.nn.aggr.utils import MultiheadAttentionBlock\nfrom torch_geometric.nn.encoding import PositionalEncoding\nfrom torch_geometric.utils import scatter\n\n\nclass PatchTransformerAggregation(Aggregation):\n    r\"\"\"Performs patch transformer aggregation in which the elements to\n    aggregate are processed by multi-head attention blocks across patches, as\n    described in the `\"Simplifying Temporal Heterogeneous Network for\n    Continuous-Time Link Prediction\"\n    <https://dl.acm.org/doi/pdf/10.1145/3583780.3615059>`_ paper.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        out_channels (int): Size of each output sample.\n        patch_size (int): Number of elements in a patch.\n        hidden_channels (int): Intermediate size of each sample.\n        num_transformer_blocks (int, optional): Number of transformer blocks\n            (default: :obj:`1`).\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        dropout (float, optional): Dropout probability of attention weights.\n            (default: :obj:`0.0`)\n        aggr (str or list[str], optional): The aggregation module, *e.g.*,\n            :obj:`\"sum\"`, :obj:`\"mean\"`, :obj:`\"min\"`, :obj:`\"max\"`,\n            :obj:`\"var\"`, :obj:`\"std\"`. (default: :obj:`\"mean\"`)\n        device (torch.device, optional): The device of the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        patch_size: int,\n        hidden_channels: int,\n        num_transformer_blocks: int = 1,\n        heads: int = 1,\n        dropout: float = 0.0,\n        aggr: Union[str, List[str]] = 'mean',\n        device: Optional[torch.device] = None,\n    ) -> None:\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.patch_size = patch_size\n        self.aggrs = [aggr] if isinstance(aggr, str) else aggr\n\n        assert len(self.aggrs) > 0\n        for aggr in self.aggrs:\n            assert aggr in ['sum', 'mean', 'min', 'max', 'var', 'std']\n\n        self.lin = torch.nn.Linear(in_channels, hidden_channels, device=device)\n        self.pad_projector = torch.nn.Linear(\n            patch_size * hidden_channels,\n            hidden_channels,\n            device=device,\n        )\n        self.pe = PositionalEncoding(hidden_channels, device=device)\n\n        self.blocks = torch.nn.ModuleList([\n            MultiheadAttentionBlock(\n                channels=hidden_channels,\n                heads=heads,\n                layer_norm=True,\n                dropout=dropout,\n                device=device,\n            ) for _ in range(num_transformer_blocks)\n        ])\n\n        self.fc = torch.nn.Linear(\n            hidden_channels * len(self.aggrs),\n            out_channels,\n            device=device,\n        )\n\n    def reset_parameters(self) -> None:\n        self.lin.reset_parameters()\n        self.pad_projector.reset_parameters()\n        self.pe.reset_parameters()\n        for block in self.blocks:\n            block.reset_parameters()\n        self.fc.reset_parameters()\n\n    @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])\n    def forward(\n        self,\n        x: Tensor,\n        index: Tensor,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n        dim: int = -2,\n        max_num_elements: Optional[int] = None,\n    ) -> Tensor:\n\n        if max_num_elements is None:\n            if ptr is not None:\n                count = ptr.diff()\n            else:\n                count = scatter(torch.ones_like(index), index, dim=0,\n                                dim_size=dim_size, reduce='sum')\n            max_num_elements = int(count.max()) + 1\n\n        # Set `max_num_elements` to a multiple of `patch_size`:\n        max_num_elements = (math.floor(max_num_elements / self.patch_size) *\n                            self.patch_size)\n\n        x = self.lin(x)\n\n        # TODO If groups are heavily unbalanced, this will create a lot of\n        # \"empty\" patches. Try to figure out a way to fix this.\n        # [batch_size, num_patches * patch_size, hidden_channels]\n        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,\n                                   max_num_elements=max_num_elements)\n\n        # [batch_size, num_patches, patch_size * hidden_channels]\n        x = x.view(x.size(0), max_num_elements // self.patch_size,\n                   self.patch_size * x.size(-1))\n\n        # [batch_size, num_patches, hidden_channels]\n        x = self.pad_projector(x)\n\n        x = x + self.pe(torch.arange(x.size(1), device=x.device))\n\n        # [batch_size, num_patches, hidden_channels]\n        for block in self.blocks:\n            x = block(x, x)\n\n        # [batch_size, hidden_channels]\n        outs: List[Tensor] = []\n        for aggr in self.aggrs:\n            out = getattr(torch, aggr)(x, dim=1)\n            outs.append(out[0] if isinstance(out, tuple) else out)\n        out = torch.cat(outs, dim=1) if len(outs) > 1 else outs[0]\n\n        # [batch_size, out_channels]\n        return self.fc(out)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, patch_size={self.patch_size})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/quantile.py",
    "content": "from typing import List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.utils import cumsum\n\n\nclass QuantileAggregation(Aggregation):\n    r\"\"\"An aggregation operator that returns the feature-wise :math:`q`-th\n    quantile of a set :math:`\\mathcal{X}`.\n\n    That is, for every feature :math:`d`, it computes\n\n    .. math::\n        {\\mathrm{Q}_q(\\mathcal{X})}_d = \\begin{cases}\n            x_{\\pi_i,d} & i = q \\cdot n, \\\\\n            f(x_{\\pi_i,d}, x_{\\pi_{i+1},d}) & i < q \\cdot n < i + 1,\\\\\n        \\end{cases}\n\n    where :math:`x_{\\pi_1,d} \\le \\dots \\le x_{\\pi_i,d} \\le \\dots \\le\n    x_{\\pi_n,d}` and :math:`f(a, b)` is an interpolation\n    function defined by :obj:`interpolation`.\n\n    Args:\n        q (float or list): The quantile value(s) :math:`q`. Can be a scalar or\n            a list of scalars in the range :math:`[0, 1]`. If more than a\n            quantile is passed, the results are concatenated.\n        interpolation (str): Interpolation method applied if the quantile point\n            :math:`q\\cdot n` lies between two values\n            :math:`a \\le b`. Can be one of the following:\n\n            * :obj:`\"lower\"`: Returns the one with lowest value.\n\n            * :obj:`\"higher\"`: Returns the one with highest value.\n\n            * :obj:`\"midpoint\"`: Returns the average of the two values.\n\n            * :obj:`\"nearest\"`: Returns the one whose index is nearest to the\n              quantile point.\n\n            * :obj:`\"linear\"`: Returns a linear combination of the two\n              elements, defined as\n              :math:`f(a, b) = a + (b - a)\\cdot(q\\cdot n - i)`.\n\n            (default: :obj:`\"linear\"`)\n        fill_value (float, optional): The default value in the case no entry is\n            found for a given index (default: :obj:`0.0`).\n    \"\"\"\n    interpolations = {'linear', 'lower', 'higher', 'nearest', 'midpoint'}\n\n    def __init__(self, q: Union[float, List[float]],\n                 interpolation: str = 'linear', fill_value: float = 0.0):\n        super().__init__()\n\n        qs = [q] if not isinstance(q, (list, tuple)) else q\n        if len(qs) == 0:\n            raise ValueError(\"Provide at least one quantile value for `q`.\")\n        if not all(0. <= quantile <= 1. for quantile in qs):\n            raise ValueError(\"`q` must be in the range [0, 1].\")\n        if interpolation not in self.interpolations:\n            raise ValueError(f\"Invalid interpolation method \"\n                             f\"got ('{interpolation}')\")\n\n        self._q = q\n        self.register_buffer('q', torch.tensor(qs).view(-1, 1))\n        self.interpolation = interpolation\n        self.fill_value = fill_value\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        dim = x.dim() + dim if dim < 0 else dim\n\n        self.assert_index_present(index)\n        assert index is not None  # Required for TorchScript.\n\n        count = torch.bincount(index, minlength=dim_size or 0)\n        ptr = cumsum(count)[:-1]\n\n        # In case there exists dangling indices (`dim_size > index.max()`), we\n        # need to clamp them to prevent out-of-bound issues:\n        if dim_size is not None:\n            ptr = ptr.clamp(max=x.size(dim) - 1)\n\n        q_point = self.q * (count - 1) + ptr\n        q_point = q_point.t().reshape(-1)\n\n        shape = [1] * x.dim()\n        shape[dim] = -1\n        index = index.view(shape).expand_as(x)\n\n        # Two sorts: the first one on the value,\n        # the second (stable) on the indices:\n        x, x_perm = torch.sort(x, dim=dim)\n        index = index.take_along_dim(x_perm, dim=dim)\n        index, index_perm = torch.sort(index, dim=dim, stable=True)\n        x = x.take_along_dim(index_perm, dim=dim)\n\n        # Compute the quantile interpolations:\n        if self.interpolation == 'lower':\n            quantile = x.index_select(dim, q_point.floor().long())\n        elif self.interpolation == 'higher':\n            quantile = x.index_select(dim, q_point.ceil().long())\n        elif self.interpolation == 'nearest':\n            quantile = x.index_select(dim, q_point.round().long())\n        else:\n            l_quant = x.index_select(dim, q_point.floor().long())\n            r_quant = x.index_select(dim, q_point.ceil().long())\n\n            if self.interpolation == 'linear':\n                q_frac = q_point.frac().view(shape)\n                quantile = l_quant + (r_quant - l_quant) * q_frac\n            else:  # 'midpoint'\n                quantile = 0.5 * l_quant + 0.5 * r_quant\n\n        # If the number of elements is zero, fill with pre-defined value:\n        repeats = self.q.numel()\n        mask = (count == 0).repeat_interleave(\n            repeats, output_size=repeats * count.numel()).view(shape)\n        out = quantile.masked_fill(mask, self.fill_value)\n\n        if self.q.numel() > 1:\n            shape = list(out.shape)\n            shape = (shape[:dim] + [shape[dim] // self.q.numel(), -1] +\n                     shape[dim + 2:])\n            out = out.view(shape)\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(q={self._q})')\n\n\nclass MedianAggregation(QuantileAggregation):\n    r\"\"\"An aggregation operator that returns the feature-wise median of a set.\n\n    That is, for every feature :math:`d`, it computes\n\n    .. math::\n        {\\mathrm{median}(\\mathcal{X})}_d = x_{\\pi_i,d}\n\n    where :math:`x_{\\pi_1,d} \\le x_{\\pi_2,d} \\le \\dots \\le\n    x_{\\pi_n,d}` and :math:`i = \\lfloor \\frac{n}{2} \\rfloor`.\n\n    .. note::\n        If the median lies between two values, the lowest one is returned.\n        To compute the midpoint (or other kind of interpolation) of the two\n        values, use :class:`QuantileAggregation` instead.\n\n    Args:\n        fill_value (float, optional): The default value in the case no entry is\n            found for a given index (default: :obj:`0.0`).\n    \"\"\"\n    def __init__(self, fill_value: float = 0.0):\n        super().__init__(0.5, 'lower', fill_value)\n\n    def __repr__(self) -> str:\n        return f\"{self.__class__.__name__}()\"\n"
  },
  {
    "path": "torch_geometric/nn/aggr/scaler.py",
    "content": "from typing import Any, Dict, List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation, MultiAggregation\nfrom torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver\nfrom torch_geometric.utils import degree\n\n\nclass DegreeScalerAggregation(Aggregation):\n    r\"\"\"Combines one or more aggregators and transforms its output with one or\n    more scalers as introduced in the `\"Principal Neighbourhood Aggregation for\n    Graph Nets\" <https://arxiv.org/abs/2004.05718>`_ paper.\n    The scalers are normalised by the in-degree of the training set and so must\n    be provided at time of construction.\n    See :class:`torch_geometric.nn.conv.PNAConv` for more information.\n\n    Args:\n        aggr (str or [str] or Aggregation): The aggregation scheme to use.\n            See :class:`~torch_geometric.nn.conv.MessagePassing` for more\n            information.\n        scaler (str or list): Set of scaling function identifiers, namely one\n            or more of :obj:`\"identity\"`, :obj:`\"amplification\"`,\n            :obj:`\"attenuation\"`, :obj:`\"linear\"` and :obj:`\"inverse_linear\"`.\n        deg (Tensor): Histogram of in-degrees of nodes in the training set,\n            used by scalers to normalize.\n        train_norm (bool, optional): Whether normalization parameters\n            are trainable. (default: :obj:`False`)\n        aggr_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective aggregation function in case it gets automatically\n            resolved. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        aggr: Union[str, List[str], Aggregation],\n        scaler: Union[str, List[str]],\n        deg: Tensor,\n        train_norm: bool = False,\n        aggr_kwargs: Optional[List[Dict[str, Any]]] = None,\n    ):\n        super().__init__()\n\n        if isinstance(aggr, (str, Aggregation)):\n            self.aggr = aggr_resolver(aggr, **(aggr_kwargs or {}))\n        elif isinstance(aggr, (tuple, list)):\n            self.aggr = MultiAggregation(aggr, aggr_kwargs)\n        else:\n            raise ValueError(f\"Only strings, list, tuples and instances of\"\n                             f\"`torch_geometric.nn.aggr.Aggregation` are \"\n                             f\"valid aggregation schemes (got '{type(aggr)}')\")\n\n        self.scaler = [scaler] if isinstance(aggr, str) else scaler\n\n        deg = deg.to(torch.float)\n        N = int(deg.sum())\n        bin_degree = torch.arange(deg.numel(), device=deg.device)\n\n        self.init_avg_deg_lin = float((bin_degree * deg).sum()) / N\n        self.init_avg_deg_log = float(((bin_degree + 1).log() * deg).sum()) / N\n\n        if train_norm:\n            self.avg_deg_lin = torch.nn.Parameter(torch.empty(1))\n            self.avg_deg_log = torch.nn.Parameter(torch.empty(1))\n        else:\n            self.register_buffer('avg_deg_lin', torch.empty(1))\n            self.register_buffer('avg_deg_log', torch.empty(1))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.avg_deg_lin.data.fill_(self.init_avg_deg_lin)\n        self.avg_deg_log.data.fill_(self.init_avg_deg_log)\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        # TODO Currently, `degree` can only operate on `index`:\n        self.assert_index_present(index)\n\n        out = self.aggr(x, index, ptr, dim_size, dim)\n\n        assert index is not None\n        deg = degree(index, num_nodes=dim_size, dtype=out.dtype)\n        size = [1] * len(out.size())\n        size[dim] = -1\n        deg = deg.view(size)\n\n        outs = []\n        for scaler in self.scaler:\n            if scaler == 'identity':\n                out_scaler = out\n            elif scaler == 'amplification':\n                out_scaler = out * (torch.log(deg + 1) / self.avg_deg_log)\n            elif scaler == 'attenuation':\n                # Clamp minimum degree to one to avoid dividing by zero:\n                out_scaler = out * (self.avg_deg_log /\n                                    torch.log(deg.clamp(min=1) + 1))\n            elif scaler == 'linear':\n                out_scaler = out * (deg / self.avg_deg_lin)\n            elif scaler == 'inverse_linear':\n                # Clamp minimum degree to one to avoid dividing by zero:\n                out_scaler = out * (self.avg_deg_lin / deg.clamp(min=1))\n            else:\n                raise ValueError(f\"Unknown scaler '{scaler}'\")\n            outs.append(out_scaler)\n\n        return torch.cat(outs, dim=-1) if len(outs) > 1 else outs[0]\n"
  },
  {
    "path": "torch_geometric/nn/aggr/set2set.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.utils import softmax\n\n\nclass Set2Set(Aggregation):\n    r\"\"\"The Set2Set aggregation operator based on iterative content-based\n    attention, as described in the `\"Order Matters: Sequence to sequence for\n    Sets\" <https://arxiv.org/abs/1511.06391>`_ paper.\n\n    .. math::\n        \\mathbf{q}_t &= \\mathrm{LSTM}(\\mathbf{q}^{*}_{t-1})\n\n        \\alpha_{i,t} &= \\mathrm{softmax}(\\mathbf{x}_i \\cdot \\mathbf{q}_t)\n\n        \\mathbf{r}_t &= \\sum_{i=1}^N \\alpha_{i,t} \\mathbf{x}_i\n\n        \\mathbf{q}^{*}_t &= \\mathbf{q}_t \\, \\Vert \\, \\mathbf{r}_t,\n\n    where :math:`\\mathbf{q}^{*}_T` defines the output of the layer with twice\n    the dimensionality as the input.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        processing_steps (int): Number of iterations :math:`T`.\n        **kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`.\n    \"\"\"\n    def __init__(self, in_channels: int, processing_steps: int, **kwargs):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = 2 * in_channels\n        self.processing_steps = processing_steps\n        self.lstm = torch.nn.LSTM(self.out_channels, in_channels, **kwargs)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.lstm.reset_parameters()\n\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        self.assert_index_present(index)\n        self.assert_two_dimensional_input(x, dim)\n\n        h = (x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))),\n             x.new_zeros((self.lstm.num_layers, dim_size, x.size(-1))))\n        q_star = x.new_zeros(dim_size, self.out_channels)\n\n        for _ in range(self.processing_steps):\n            q, h = self.lstm(q_star.unsqueeze(0), h)\n            q = q.view(dim_size, self.in_channels)\n            e = (x * q[index]).sum(dim=-1, keepdim=True)\n            a = softmax(e, index, ptr, dim_size, dim)\n            r = self.reduce(a * x, index, ptr, dim_size, dim, reduce='sum')\n            q_star = torch.cat([q, r], dim=-1)\n\n        return q_star\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/set_transformer.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.experimental import disable_dynamic_shapes\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.nn.aggr.utils import (\n    PoolingByMultiheadAttention,\n    SetAttentionBlock,\n)\n\n\nclass SetTransformerAggregation(Aggregation):\n    r\"\"\"Performs \"Set Transformer\" aggregation in which the elements to\n    aggregate are processed by multi-head attention blocks, as described in\n    the `\"Graph Neural Networks with Adaptive Readouts\"\n    <https://arxiv.org/abs/2211.04952>`_ paper.\n\n    .. note::\n\n        :class:`SetTransformerAggregation` requires sorted indices :obj:`index`\n        as input. Specifically, if you use this aggregation as part of\n        :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that\n        :obj:`edge_index` is sorted by destination nodes, either by manually\n        sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index`\n        or by calling :meth:`torch_geometric.data.Data.sort`.\n\n    Args:\n        channels (int): Size of each input sample.\n        num_seed_points (int, optional): Number of seed points.\n            (default: :obj:`1`)\n        num_encoder_blocks (int, optional): Number of Set Attention Blocks\n            (SABs) in the encoder. (default: :obj:`1`).\n        num_decoder_blocks (int, optional): Number of Set Attention Blocks\n            (SABs) in the decoder. (default: :obj:`1`).\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        concat (bool, optional): If set to :obj:`False`, the seed embeddings\n            are averaged instead of concatenated. (default: :obj:`True`)\n        layer_norm (str, optional): If set to :obj:`True`, will apply layer\n            normalization. (default: :obj:`False`)\n        dropout (float, optional): Dropout probability of attention weights.\n            (default: :obj:`0`)\n    \"\"\"\n    def __init__(\n        self,\n        channels: int,\n        num_seed_points: int = 1,\n        num_encoder_blocks: int = 1,\n        num_decoder_blocks: int = 1,\n        heads: int = 1,\n        concat: bool = True,\n        layer_norm: bool = False,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n\n        self.channels = channels\n        self.num_seed_points = num_seed_points\n        self.heads = heads\n        self.concat = concat\n        self.layer_norm = layer_norm\n        self.dropout = dropout\n\n        self.encoders = torch.nn.ModuleList([\n            SetAttentionBlock(channels, heads, layer_norm, dropout)\n            for _ in range(num_encoder_blocks)\n        ])\n\n        self.pma = PoolingByMultiheadAttention(channels, num_seed_points,\n                                               heads, layer_norm, dropout)\n\n        self.decoders = torch.nn.ModuleList([\n            SetAttentionBlock(channels, heads, layer_norm, dropout)\n            for _ in range(num_decoder_blocks)\n        ])\n\n    def reset_parameters(self):\n        for encoder in self.encoders:\n            encoder.reset_parameters()\n        self.pma.reset_parameters()\n        for decoder in self.decoders:\n            decoder.reset_parameters()\n\n    @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])\n    def forward(\n        self,\n        x: Tensor,\n        index: Optional[Tensor] = None,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n        dim: int = -2,\n        max_num_elements: Optional[int] = None,\n    ) -> Tensor:\n\n        x, mask = self.to_dense_batch(x, index, ptr, dim_size, dim,\n                                      max_num_elements=max_num_elements)\n\n        for encoder in self.encoders:\n            x = encoder(x, mask)\n\n        x = self.pma(x, mask)\n\n        for decoder in self.decoders:\n            x = decoder(x)\n\n        x = x.nan_to_num()\n\n        return x.flatten(1, 2) if self.concat else x.mean(dim=1)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.channels}, '\n                f'num_seed_points={self.num_seed_points}, '\n                f'heads={self.heads}, '\n                f'layer_norm={self.layer_norm}, '\n                f'dropout={self.dropout})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/sort.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.experimental import disable_dynamic_shapes\nfrom torch_geometric.nn.aggr import Aggregation\n\n\nclass SortAggregation(Aggregation):\n    r\"\"\"The pooling operator from the `\"An End-to-End Deep Learning\n    Architecture for Graph Classification\"\n    <https://www.cse.wustl.edu/~muhan/papers/AAAI_2018_DGCNN.pdf>`_ paper,\n    where node features are sorted in descending order based on their last\n    feature channel. The first :math:`k` nodes form the output of the layer.\n\n    .. note::\n\n        :class:`SortAggregation` requires sorted indices :obj:`index` as input.\n        Specifically, if you use this aggregation as part of\n        :class:`~torch_geometric.nn.conv.MessagePassing`, ensure that\n        :obj:`edge_index` is sorted by destination nodes, either by manually\n        sorting edge indices via :meth:`~torch_geometric.utils.sort_edge_index`\n        or by calling :meth:`torch_geometric.data.Data.sort`.\n\n    Args:\n        k (int): The number of nodes to hold for each graph.\n    \"\"\"\n    def __init__(self, k: int):\n        super().__init__()\n        self.k = k\n\n    @disable_dynamic_shapes(required_args=['dim_size', 'max_num_elements'])\n    def forward(\n        self,\n        x: Tensor,\n        index: Optional[Tensor] = None,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n        dim: int = -2,\n        max_num_elements: Optional[int] = None,\n    ) -> Tensor:\n\n        fill_value = x.detach().min() - 1\n        batch_x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim,\n                                         fill_value=fill_value,\n                                         max_num_elements=max_num_elements)\n        B, N, D = batch_x.size()\n\n        _, perm = batch_x[:, :, -1].sort(dim=-1, descending=True)\n        arange = torch.arange(B, dtype=torch.long, device=perm.device) * N\n        perm = perm + arange.view(-1, 1)\n\n        batch_x = batch_x.view(B * N, D)\n        batch_x = batch_x[perm]\n        batch_x = batch_x.view(B, N, D)\n\n        if N >= self.k:\n            batch_x = batch_x[:, :self.k].contiguous()\n        else:\n            expand_batch_x = batch_x.new_full((B, self.k - N, D), fill_value)\n            batch_x = torch.cat([batch_x, expand_batch_x], dim=1)\n\n        batch_x[batch_x == fill_value] = 0\n        x = batch_x.view(B, self.k * D)\n\n        return x\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(k={self.k})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/utils.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import LayerNorm, Linear, MultiheadAttention, Parameter\n\n\nclass MultiheadAttentionBlock(torch.nn.Module):\n    r\"\"\"The Multihead Attention Block (MAB) from the `\"Set Transformer: A\n    Framework for Attention-based Permutation-Invariant Neural Networks\"\n    <https://arxiv.org/abs/1810.00825>`_ paper.\n\n    .. math::\n\n        \\mathrm{MAB}(\\mathbf{x}, \\mathbf{y}) &= \\mathrm{LayerNorm}(\\mathbf{h} +\n        \\mathbf{W} \\mathbf{h})\n\n        \\mathbf{h} &= \\mathrm{LayerNorm}(\\mathbf{x} +\n        \\mathrm{Multihead}(\\mathbf{x}, \\mathbf{y}, \\mathbf{y}))\n\n    Args:\n        channels (int): Size of each input sample.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        norm (str, optional): If set to :obj:`False`, will not apply layer\n            normalization. (default: :obj:`True`)\n        dropout (float, optional): Dropout probability of attention weights.\n            (default: :obj:`0`)\n        device (torch.device, optional): The device of the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, channels: int, heads: int = 1, layer_norm: bool = True,\n                 dropout: float = 0.0, device: Optional[torch.device] = None):\n        super().__init__()\n\n        self.channels = channels\n        self.heads = heads\n        self.dropout = dropout\n\n        self.attn = MultiheadAttention(\n            channels,\n            heads,\n            batch_first=True,\n            dropout=dropout,\n            device=device,\n        )\n        self.lin = Linear(channels, channels, device=device)\n        self.layer_norm1 = LayerNorm(channels,\n                                     device=device) if layer_norm else None\n        self.layer_norm2 = LayerNorm(channels,\n                                     device=device) if layer_norm else None\n\n    def reset_parameters(self):\n        self.attn._reset_parameters()\n        self.lin.reset_parameters()\n        if self.layer_norm1 is not None:\n            self.layer_norm1.reset_parameters()\n        if self.layer_norm2 is not None:\n            self.layer_norm2.reset_parameters()\n\n    def forward(self, x: Tensor, y: Tensor, x_mask: Optional[Tensor] = None,\n                y_mask: Optional[Tensor] = None) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        if y_mask is not None:\n            y_mask = ~y_mask\n\n        out, _ = self.attn(x, y, y, y_mask, need_weights=False)\n\n        if x_mask is not None:\n            out[~x_mask] = 0.\n\n        out = out + x\n\n        if self.layer_norm1 is not None:\n            out = self.layer_norm1(out)\n\n        out = out + self.lin(out).relu()\n\n        if self.layer_norm2 is not None:\n            out = self.layer_norm2(out)\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.channels}, '\n                f'heads={self.heads}, '\n                f'layer_norm={self.layer_norm1 is not None}, '\n                f'dropout={self.dropout})')\n\n\nclass SetAttentionBlock(torch.nn.Module):\n    r\"\"\"The Set Attention Block (SAB) from the `\"Set Transformer: A\n    Framework for Attention-based Permutation-Invariant Neural Networks\"\n    <https://arxiv.org/abs/1810.00825>`_ paper.\n\n    .. math::\n\n        \\mathrm{SAB}(\\mathbf{X}) = \\mathrm{MAB}(\\mathbf{x}, \\mathbf{y})\n\n    Args:\n        channels (int): Size of each input sample.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        norm (str, optional): If set to :obj:`False`, will not apply layer\n            normalization. (default: :obj:`True`)\n        dropout (float, optional): Dropout probability of attention weights.\n            (default: :obj:`0`)\n    \"\"\"\n    def __init__(self, channels: int, heads: int = 1, layer_norm: bool = True,\n                 dropout: float = 0.0):\n        super().__init__()\n        self.mab = MultiheadAttentionBlock(channels, heads, layer_norm,\n                                           dropout)\n\n    def reset_parameters(self):\n        self.mab.reset_parameters()\n\n    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:\n        return self.mab(x, x, mask, mask)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.mab.channels}, '\n                f'heads={self.mab.heads}, '\n                f'layer_norm={self.mab.layer_norm1 is not None}, '\n                f'dropout={self.mab.dropout})')\n\n\nclass InducedSetAttentionBlock(torch.nn.Module):\n    r\"\"\"The Induced Set Attention Block (SAB) from the `\"Set Transformer: A\n    Framework for Attention-based Permutation-Invariant Neural Networks\"\n    <https://arxiv.org/abs/1810.00825>`_ paper.\n\n    .. math::\n\n        \\mathrm{ISAB}(\\mathbf{X}) &= \\mathrm{MAB}(\\mathbf{x}, \\mathbf{h})\n\n        \\mathbf{h} &= \\mathrm{MAB}(\\mathbf{I}, \\mathbf{x})\n\n    where :math:`\\mathbf{I}` denotes :obj:`num_induced_points` learnable\n    vectors.\n\n    Args:\n        channels (int): Size of each input sample.\n        num_induced_points (int): Number of induced points.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        norm (str, optional): If set to :obj:`False`, will not apply layer\n            normalization. (default: :obj:`True`)\n        dropout (float, optional): Dropout probability of attention weights.\n            (default: :obj:`0`)\n    \"\"\"\n    def __init__(self, channels: int, num_induced_points: int, heads: int = 1,\n                 layer_norm: bool = True, dropout: float = 0.0):\n        super().__init__()\n        self.ind = Parameter(torch.empty(1, num_induced_points, channels))\n        self.mab1 = MultiheadAttentionBlock(channels, heads, layer_norm,\n                                            dropout)\n        self.mab2 = MultiheadAttentionBlock(channels, heads, layer_norm,\n                                            dropout)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.ind)\n        self.mab1.reset_parameters()\n        self.mab2.reset_parameters()\n\n    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:\n        h = self.mab1(self.ind.expand(x.size(0), -1, -1), x, y_mask=mask)\n        return self.mab2(x, h, x_mask=mask)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.ind.size(2)}, '\n                f'num_induced_points={self.ind.size(1)}, '\n                f'heads={self.mab1.heads}, '\n                f'layer_norm={self.mab1.layer_norm1 is not None}, '\n                f'dropout={self.mab1.dropout})')\n\n\nclass PoolingByMultiheadAttention(torch.nn.Module):\n    r\"\"\"The Pooling by Multihead Attention (PMA) layer from the `\"Set\n    Transformer: A Framework for Attention-based Permutation-Invariant Neural\n    Networks\" <https://arxiv.org/abs/1810.00825>`_ paper.\n\n    .. math::\n\n        \\mathrm{PMA}(\\mathbf{X}) = \\mathrm{MAB}(\\mathbf{S}, \\mathbf{x})\n\n    where :math:`\\mathbf{S}` denotes :obj:`num_seed_points` learnable vectors.\n\n    Args:\n        channels (int): Size of each input sample.\n        num_seed_points (int, optional): Number of seed points.\n            (default: :obj:`1`)\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        norm (str, optional): If set to :obj:`False`, will not apply layer\n            normalization. (default: :obj:`True`)\n        dropout (float, optional): Dropout probability of attention weights.\n            (default: :obj:`0`)\n    \"\"\"\n    def __init__(self, channels: int, num_seed_points: int = 1, heads: int = 1,\n                 layer_norm: bool = True, dropout: float = 0.0):\n        super().__init__()\n        self.lin = Linear(channels, channels)\n        self.seed = Parameter(torch.empty(1, num_seed_points, channels))\n        self.mab = MultiheadAttentionBlock(channels, heads, layer_norm,\n                                           dropout)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.lin.reset_parameters()\n        torch.nn.init.xavier_uniform_(self.seed)\n        self.mab.reset_parameters()\n\n    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:\n        x = self.lin(x).relu()\n        return self.mab(self.seed.expand(x.size(0), -1, -1), x, y_mask=mask)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.seed.size(2)}, '\n                f'num_seed_points={self.seed.size(1)}, '\n                f'heads={self.mab.heads}, '\n                f'layer_norm={self.mab.layer_norm1 is not None}, '\n                f'dropout={self.mab.dropout})')\n"
  },
  {
    "path": "torch_geometric/nn/aggr/variance_preserving.py",
    "content": "from typing import Optional\n\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.utils import degree\nfrom torch_geometric.utils._scatter import broadcast\n\n\nclass VariancePreservingAggregation(Aggregation):\n    r\"\"\"Performs the Variance Preserving Aggregation (VPA) from the `\"GNN-VPA:\n    A Variance-Preserving Aggregation Strategy for Graph Neural Networks\"\n    <https://arxiv.org/abs/2403.04747>`_ paper.\n\n    .. math::\n        \\mathrm{vpa}(\\mathcal{X}) = \\frac{1}{\\sqrt{|\\mathcal{X}|}}\n        \\sum_{\\mathbf{x}_i \\in \\mathcal{X}} \\mathbf{x}_i\n    \"\"\"\n    def forward(self, x: Tensor, index: Optional[Tensor] = None,\n                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,\n                dim: int = -2) -> Tensor:\n\n        out = self.reduce(x, index, ptr, dim_size, dim, reduce='sum')\n\n        if ptr is not None:\n            count = ptr.diff().to(out.dtype)\n        else:\n            count = degree(index, dim_size, dtype=out.dtype)\n\n        count = count.sqrt().clamp(min=1.0)\n        count = broadcast(count, ref=out, dim=dim)\n\n        return out / count\n"
  },
  {
    "path": "torch_geometric/nn/attention/__init__.py",
    "content": "from .performer import PerformerAttention\nfrom .qformer import QFormer\nfrom .sgformer import SGFormerAttention\nfrom .polynormer import PolynormerAttention\n\n__all__ = classes = [\n    'PerformerAttention',\n    'QFormer',\n    'SGFormerAttention',\n    'PolynormerAttention',\n]\n"
  },
  {
    "path": "torch_geometric/nn/attention/performer.py",
    "content": "import math\nfrom typing import Callable, Optional\n\nimport torch\nfrom torch import Tensor\n\n\ndef _orthogonal_matrix(dim: int) -> Tensor:\n    r\"\"\"Get an orthogonal matrix by applying QR decomposition.\"\"\"\n    # Random matrix from normal distribution\n    mat = torch.randn((dim, dim))\n    # QR decomposition to two orthogonal matrices\n    q, _ = torch.linalg.qr(mat.cpu(), mode='reduced')\n    return q.t()\n\n\ndef orthogonal_matrix(num_rows: int, num_cols: int) -> Tensor:\n    r\"\"\"Generate an orthogonal matrix with `num_rows` rows\n    and `num_cols` columns.\n    \"\"\"\n    num_full_blocks = int(num_rows / num_cols)\n    blocks = []\n    for _ in range(num_full_blocks):\n        q = _orthogonal_matrix(num_cols)\n        blocks.append(q)\n    remain_rows = num_rows - num_full_blocks * num_cols\n    if remain_rows > 0:\n        q = _orthogonal_matrix(num_cols)\n        blocks.append(q[:remain_rows])\n    mat = torch.cat(blocks)\n    # multiplier = torch.randn((num_rows, num_cols)).norm(dim=1)\n    # scaler = torch.diag(multiplier)\n    # mat = scaler @ mat\n    return mat\n\n\ndef linear_attention(q: Tensor, k: Tensor, v: Tensor) -> Tensor:\n    r\"\"\"Efficient attention mechanism from the\n    `\"Rethinking Attention with Performers\"\n    <https://arxiv.org/abs/2009.14794>`_ paper.\n\n    .. math::\n        \\mathbf{\\hat{D}}^{-1}(\\mathbf{Q}'((\\mathbf{K}')^{\\top} \\mathbf{V}))\n\n    \"\"\"\n    D_inv = 1.0 / (q @ k.sum(dim=-2).unsqueeze(-1))\n    kv = k.transpose(-2, -1) @ v\n    qkv = q @ kv\n    out = torch.einsum('...L,...Ld->...Ld', D_inv.squeeze(-1), qkv)\n    return out\n\n\ndef generalized_kernel(\n        x: Tensor,\n        mat: Tensor,\n        kernel: Callable = torch.nn.ReLU(),\n        epsilon: float = 0.001,\n) -> Tensor:\n    batch_size, num_heads = x.size()[:2]\n    projection = mat.t().expand(batch_size, num_heads, -1, -1)\n    x = x @ projection\n    out = kernel(x) + epsilon\n    return out\n\n\nclass PerformerProjection(torch.nn.Module):\n    r\"\"\"The fast attention that uses a projection matrix\n    from the `\"Rethinking Attention with Performers\"\n    <https://arxiv.org/abs/2009.14794>`_ paper. This class\n    projects :math:`\\mathbf{Q}` and :math:`\\mathbf{K}` matrices\n    with specified kernel.\n\n    Args:\n        num_cols (int): Projection matrix number of columns.\n        kernel (Callable, optional): Kernels for generalized attention.\n            If not specified, `ReLU` kernel will be used.\n            (default: :obj:`torch.nn.ReLU()`)\n    \"\"\"\n    def __init__(self, num_cols: int, kernel: Callable = torch.nn.ReLU()):\n        super().__init__()\n        num_rows = int(num_cols * math.log(num_cols))\n        self.num_rows = num_rows\n        self.num_cols = num_cols\n        # Generate an orthogonal projection matrix\n        # with the shape (num_rows, num_cols)\n        projection_matrix = orthogonal_matrix(self.num_rows, self.num_cols)\n        self.register_buffer('projection_matrix', projection_matrix)\n        assert kernel is not None\n        self.kernel = kernel\n\n    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:\n        q = generalized_kernel(q, self.projection_matrix, self.kernel)\n        k = generalized_kernel(k, self.projection_matrix, self.kernel)\n        out = linear_attention(q, k, v)\n        return out\n\n\nclass PerformerAttention(torch.nn.Module):\n    r\"\"\"The linear scaled attention mechanism from the\n    `\"Rethinking Attention with Performers\"\n    <https://arxiv.org/abs/2009.14794>`_ paper.\n\n    Args:\n        channels (int): Size of each input sample.\n        heads (int, optional): Number of parallel attention heads.\n        head_channels (int, optional): Size of each attention head.\n            (default: :obj:`64.`)\n        kernel (Callable, optional): Kernels for generalized attention.\n            If not specified, `ReLU` kernel will be used.\n            (default: :obj:`torch.nn.ReLU()`)\n        qkv_bias (bool, optional): If specified, add bias to query, key\n            and value in the self attention. (default: :obj:`False`)\n        attn_out_bias (bool, optional): If specified, add bias to the\n            attention output. (default: :obj:`True`)\n        dropout (float, optional): Dropout probability of the final\n            attention output. (default: :obj:`0.0`)\n\n    \"\"\"\n    def __init__(\n        self,\n        channels: int,\n        heads: int,\n        head_channels: int = 64,\n        kernel: Callable = torch.nn.ReLU(),\n        qkv_bias: bool = False,\n        attn_out_bias: bool = True,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n        assert channels % heads == 0\n        if head_channels is None:\n            head_channels = channels // heads\n\n        self.heads = heads\n        self.head_channels = head_channels\n        self.kernel = kernel\n        self.fast_attn = PerformerProjection(head_channels, kernel)\n\n        inner_channels = head_channels * heads\n        self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)\n        self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)\n        self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)\n        self.attn_out = torch.nn.Linear(inner_channels, channels,\n                                        bias=attn_out_bias)\n        self.dropout = torch.nn.Dropout(dropout)\n\n    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Node feature tensor\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n                batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n                each graph, and feature dimension :math:`F`.\n            mask (torch.Tensor, optional): Mask matrix\n                :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n                the valid nodes for each graph. (default: :obj:`None`)\n        \"\"\"\n        B, N, *_ = x.shape\n        q, k, v = self.q(x), self.k(x), self.v(x)\n        # Reshape and permute q, k and v to proper shape\n        # (B, N, num_heads * head_channels) to (b, num_heads, n, head_channels)\n        q, k, v = map(\n            lambda t: t.reshape(B, N, self.heads, self.head_channels).permute(\n                0, 2, 1, 3), (q, k, v))\n        if mask is not None:\n            mask = mask[:, None, :, None]\n            v.masked_fill_(~mask, 0.)\n        out = self.fast_attn(q, k, v)\n        out = out.permute(0, 2, 1, 3).reshape(B, N, -1)\n        out = self.attn_out(out)\n        out = self.dropout(out)\n        return out\n\n    @torch.no_grad()\n    def redraw_projection_matrix(self):\n        r\"\"\"As described in the paper, periodically redraw\n        examples to improve overall approximation of attention.\n        \"\"\"\n        num_rows = self.fast_attn.num_rows\n        num_cols = self.fast_attn.num_cols\n        projection_matrix = orthogonal_matrix(num_rows, num_cols)\n        self.fast_attn.projection_matrix.copy_(projection_matrix)\n        del projection_matrix\n\n    def _reset_parameters(self):\n        self.q.reset_parameters()\n        self.k.reset_parameters()\n        self.v.reset_parameters()\n        self.attn_out.reset_parameters()\n        self.redraw_projection_matrix()\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'heads={self.heads}, '\n                f'head_channels={self.head_channels} '\n                f'kernel={self.kernel})')\n"
  },
  {
    "path": "torch_geometric/nn/attention/polynormer.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\n\nclass PolynormerAttention(torch.nn.Module):\n    r\"\"\"The polynomial-expressive attention mechanism from the\n    `\"Polynormer: Polynomial-Expressive Graph Transformer in Linear Time\"\n    <https://arxiv.org/abs/2403.01232>`_ paper.\n\n    Args:\n        channels (int): Size of each input sample.\n        heads (int, optional): Number of parallel attention heads.\n        head_channels (int, optional): Size of each attention head.\n            (default: :obj:`64.`)\n        beta (float, optional): Polynormer beta initialization.\n            (default: :obj:`0.9`)\n        qkv_bias (bool, optional): If specified, add bias to query, key\n            and value in the self attention. (default: :obj:`False`)\n        qk_shared (bool optional): Whether weight of query and key are shared.\n            (default: :obj:`True`)\n        dropout (float, optional): Dropout probability of the final\n            attention output. (default: :obj:`0.0`)\n    \"\"\"\n    def __init__(\n        self,\n        channels: int,\n        heads: int,\n        head_channels: int = 64,\n        beta: float = 0.9,\n        qkv_bias: bool = False,\n        qk_shared: bool = True,\n        dropout: float = 0.0,\n    ) -> None:\n        super().__init__()\n\n        self.head_channels = head_channels\n        self.heads = heads\n        self.beta = beta\n        self.qk_shared = qk_shared\n\n        inner_channels = heads * head_channels\n        self.h_lins = torch.nn.Linear(channels, inner_channels)\n        if not self.qk_shared:\n            self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)\n        self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)\n        self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)\n        self.lns = torch.nn.LayerNorm(inner_channels)\n        self.lin_out = torch.nn.Linear(inner_channels, inner_channels)\n        self.dropout = torch.nn.Dropout(dropout)\n\n    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Node feature tensor\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n                batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n                each graph, and feature dimension :math:`F`.\n            mask (torch.Tensor, optional): Mask matrix\n                :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n                the valid nodes for each graph. (default: :obj:`None`)\n        \"\"\"\n        B, N, *_ = x.shape\n        h = self.h_lins(x)\n        k = self.k(x).sigmoid().view(B, N, self.head_channels, self.heads)\n        if self.qk_shared:\n            q = k\n        else:\n            q = F.sigmoid(self.q(x)).view(B, N, self.head_channels, self.heads)\n        v = self.v(x).view(B, N, self.head_channels, self.heads)\n\n        if mask is not None:\n            mask = mask[:, :, None, None]\n            v.masked_fill_(~mask, 0.)\n\n        # numerator\n        kv = torch.einsum('bndh, bnmh -> bdmh', k, v)\n        num = torch.einsum('bndh, bdmh -> bnmh', q, kv)\n\n        # denominator\n        k_sum = torch.einsum('bndh -> bdh', k)\n        den = torch.einsum('bndh, bdh -> bnh', q, k_sum).unsqueeze(2)\n\n        # linear global attention based on kernel trick\n        x = (num / (den + 1e-6)).reshape(B, N, -1)\n        x = self.lns(x) * (h + self.beta)\n        x = F.relu(self.lin_out(x))\n        x = self.dropout(x)\n\n        return x\n\n    def reset_parameters(self) -> None:\n        self.h_lins.reset_parameters()\n        if not self.qk_shared:\n            self.q.reset_parameters()\n        self.k.reset_parameters()\n        self.v.reset_parameters()\n        self.lns.reset_parameters()\n        self.lin_out.reset_parameters()\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'heads={self.heads}, '\n                f'head_channels={self.head_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/attention/qformer.py",
    "content": "from typing import Callable\n\nimport torch\n\n\nclass QFormer(torch.nn.Module):\n    r\"\"\"The Querying Transformer (Q-Former) from\n    `\"BLIP-2: Bootstrapping Language-Image Pre-training\n    with Frozen Image Encoders and Large Language Models\"\n    <https://arxiv.org/pdf/2301.12597>`_ paper.\n\n    Args:\n        input_dim (int): The number of features in the input.\n        hidden_dim (int): The dimension of the fnn in the encoder layer.\n        output_dim (int): The final output dimension.\n        num_heads (int): The number of multi-attention-heads.\n        num_layers (int): The number of sub-encoder-layers in the encoder.\n        dropout (int): The dropout value in each encoder layer.\n\n\n    .. note::\n        This is a simplified version of the original Q-Former implementation.\n    \"\"\"\n    def __init__(\n            self,\n            input_dim: int,\n            hidden_dim: int,\n            output_dim: int,\n            num_heads: int,\n            num_layers: int,\n            dropout: float = 0.0,\n            activation: Callable = torch.nn.ReLU(),\n    ) -> None:\n\n        super().__init__()\n        self.num_layers = num_layers\n        self.num_heads = num_heads\n\n        self.layer_norm = torch.nn.LayerNorm(input_dim)\n        self.encoder_layer = torch.nn.TransformerEncoderLayer(\n            d_model=input_dim,\n            nhead=num_heads,\n            dim_feedforward=hidden_dim,\n            dropout=dropout,\n            activation=activation,\n            batch_first=True,\n        )\n        self.encoder = torch.nn.TransformerEncoder(\n            self.encoder_layer,\n            num_layers=num_layers,\n        )\n        self.project = torch.nn.Linear(input_dim, output_dim)\n\n    def forward(self, x: torch.Tensor) -> torch.Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Input sequence to the encoder layer.\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n                batch-size :math:`B`, sequence length :math:`N`,\n                and feature dimension :math:`F`.\n        \"\"\"\n        x = self.layer_norm(x)\n        x = self.encoder(x)\n        out = self.project(x)\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'num_heads={self.num_heads}, '\n                f'num_layers={self.num_layers})')\n"
  },
  {
    "path": "torch_geometric/nn/attention/sgformer.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\n\nclass SGFormerAttention(torch.nn.Module):\n    r\"\"\"The simple global attention mechanism from the\n    `\"SGFormer: Simplifying and Empowering Transformers for\n    Large-Graph Representations\"\n    <https://arxiv.org/abs/2306.10759>`_ paper.\n\n    Args:\n        channels (int): Size of each input sample.\n        heads (int, optional): Number of parallel attention heads.\n            (default: :obj:`1.`)\n        head_channels (int, optional): Size of each attention head.\n            (default: :obj:`64.`)\n        qkv_bias (bool, optional): If specified, add bias to query, key\n            and value in the self attention. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        channels: int,\n        heads: int = 1,\n        head_channels: int = 64,\n        qkv_bias: bool = False,\n    ) -> None:\n        super().__init__()\n        assert channels % heads == 0\n        if head_channels is None:\n            head_channels = channels // heads\n\n        self.heads = heads\n        self.head_channels = head_channels\n\n        inner_channels = head_channels * heads\n        self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)\n        self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)\n        self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)\n\n    def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Node feature tensor\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n                batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n                each graph, and feature dimension :math:`F`.\n            mask (torch.Tensor, optional): Mask matrix\n                :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n                the valid nodes for each graph. (default: :obj:`None`)\n        \"\"\"\n        B, N, *_ = x.shape\n        qs, ks, vs = self.q(x), self.k(x), self.v(x)\n        # reshape and permute q, k and v to proper shape\n        # (b, n, num_heads * head_channels) to (b, n, num_heads, head_channels)\n        qs, ks, vs = map(\n            lambda t: t.reshape(B, N, self.heads, self.head_channels),\n            (qs, ks, vs))\n\n        if mask is not None:\n            mask = mask[:, :, None, None]\n            vs.masked_fill_(~mask, 0.)\n        # replace 0's with epsilon\n        epsilon = 1e-6\n        qs[qs == 0] = epsilon\n        ks[ks == 0] = epsilon\n        # normalize input, shape not changed\n        qs, ks = map(\n            lambda t: t / torch.linalg.norm(t, ord=2, dim=-1, keepdim=True),\n            (qs, ks))\n\n        # numerator\n        kvs = torch.einsum(\"blhm,blhd->bhmd\", ks, vs)\n        attention_num = torch.einsum(\"bnhm,bhmd->bnhd\", qs, kvs)\n        attention_num += N * vs\n\n        # denominator\n        all_ones = torch.ones([B, N]).to(ks.device)\n        ks_sum = torch.einsum(\"blhm,bl->bhm\", ks, all_ones)\n        attention_normalizer = torch.einsum(\"bnhm,bhm->bnh\", qs, ks_sum)\n        # attentive aggregated results\n        attention_normalizer = torch.unsqueeze(attention_normalizer,\n                                               len(attention_normalizer.shape))\n        attention_normalizer += torch.ones_like(attention_normalizer) * N\n        attn_output = attention_num / attention_normalizer\n\n        return attn_output.mean(dim=2)\n\n    def reset_parameters(self):\n        self.q.reset_parameters()\n        self.k.reset_parameters()\n        self.v.reset_parameters()\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'heads={self.heads}, '\n                f'head_channels={self.head_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/__init__.py",
    "content": "from .message_passing import MessagePassing\nfrom .simple_conv import SimpleConv\nfrom .gcn_conv import GCNConv\nfrom .cheb_conv import ChebConv\nfrom .sage_conv import SAGEConv\nfrom .cugraph.sage_conv import CuGraphSAGEConv\nfrom .graph_conv import GraphConv\nfrom .gravnet_conv import GravNetConv\nfrom .gated_graph_conv import GatedGraphConv\nfrom .res_gated_graph_conv import ResGatedGraphConv\nfrom .gat_conv import GATConv\nfrom .cugraph.gat_conv import CuGraphGATConv\nfrom .fused_gat_conv import FusedGATConv\nfrom .gatv2_conv import GATv2Conv\nfrom .transformer_conv import TransformerConv\nfrom .agnn_conv import AGNNConv\nfrom .tag_conv import TAGConv\nfrom .gin_conv import GINConv, GINEConv\nfrom .arma_conv import ARMAConv\nfrom .sg_conv import SGConv\nfrom .appnp import APPNP\nfrom .mf_conv import MFConv\nfrom .rgcn_conv import RGCNConv, FastRGCNConv\nfrom .cugraph.rgcn_conv import CuGraphRGCNConv\nfrom .rgat_conv import RGATConv\nfrom .signed_conv import SignedConv\nfrom .dna_conv import DNAConv\nfrom .point_conv import PointNetConv\nfrom .gmm_conv import GMMConv\nfrom .spline_conv import SplineConv\nfrom .nn_conv import NNConv\nfrom .cg_conv import CGConv\nfrom .edge_conv import EdgeConv, DynamicEdgeConv\nfrom .x_conv import XConv\nfrom .ppf_conv import PPFConv\nfrom .feast_conv import FeaStConv\nfrom .point_transformer_conv import PointTransformerConv\nfrom .hypergraph_conv import HypergraphConv\nfrom .le_conv import LEConv\nfrom .pna_conv import PNAConv\nfrom .cluster_gcn_conv import ClusterGCNConv\nfrom .gen_conv import GENConv\nfrom .gcn2_conv import GCN2Conv\nfrom .pan_conv import PANConv\nfrom .wl_conv import WLConv\nfrom .wl_conv_continuous import WLConvContinuous\nfrom .film_conv import FiLMConv\nfrom .supergat_conv import SuperGATConv\nfrom .fa_conv import FAConv\nfrom .eg_conv import EGConv\nfrom .pdn_conv import PDNConv\nfrom .general_conv import GeneralConv\nfrom .hgt_conv import HGTConv\nfrom .heat_conv import HEATConv\nfrom .hetero_conv import HeteroConv\nfrom .han_conv import HANConv\nfrom .lg_conv import LGConv\nfrom .ssg_conv import SSGConv\nfrom .point_gnn_conv import PointGNNConv\nfrom .gps_conv import GPSConv\nfrom .antisymmetric_conv import AntiSymmetricConv\nfrom .dir_gnn_conv import DirGNNConv\nfrom .mixhop_conv import MixHopConv\nfrom .meshcnn_conv import MeshCNNConv\n\nimport torch_geometric.nn.conv.utils  # noqa\n\n__all__ = [\n    'MessagePassing',\n    'SimpleConv',\n    'GCNConv',\n    'ChebConv',\n    'SAGEConv',\n    'CuGraphSAGEConv',\n    'GraphConv',\n    'GravNetConv',\n    'GatedGraphConv',\n    'ResGatedGraphConv',\n    'GATConv',\n    'CuGraphGATConv',\n    'FusedGATConv',\n    'GATv2Conv',\n    'TransformerConv',\n    'AGNNConv',\n    'TAGConv',\n    'GINConv',\n    'GINEConv',\n    'ARMAConv',\n    'SGConv',\n    'SSGConv',\n    'APPNP',\n    'MFConv',\n    'RGCNConv',\n    'FastRGCNConv',\n    'CuGraphRGCNConv',\n    'RGATConv',\n    'SignedConv',\n    'DNAConv',\n    'PointNetConv',\n    'GMMConv',\n    'SplineConv',\n    'NNConv',\n    'CGConv',\n    'EdgeConv',\n    'DynamicEdgeConv',\n    'XConv',\n    'PPFConv',\n    'FeaStConv',\n    'PointTransformerConv',\n    'HypergraphConv',\n    'LEConv',\n    'PNAConv',\n    'ClusterGCNConv',\n    'GENConv',\n    'GCN2Conv',\n    'PANConv',\n    'WLConv',\n    'WLConvContinuous',\n    'FiLMConv',\n    'SuperGATConv',\n    'FAConv',\n    'EGConv',\n    'PDNConv',\n    'GeneralConv',\n    'HGTConv',\n    'HEATConv',\n    'HeteroConv',\n    'HANConv',\n    'LGConv',\n    'PointGNNConv',\n    'GPSConv',\n    'AntiSymmetricConv',\n    'DirGNNConv',\n    'MixHopConv',\n    'MeshCNNConv',\n]\n\nclasses = __all__\n\nECConv = NNConv\nPointConv = PointNetConv\n"
  },
  {
    "path": "torch_geometric/nn/conv/agnn_conv.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse\nfrom torch_geometric.utils import add_self_loops, remove_self_loops, softmax\n\n\nclass AGNNConv(MessagePassing):\n    r\"\"\"The graph attentional propagation layer from the\n    `\"Attention-based Graph Neural Network for Semi-Supervised Learning\"\n    <https://arxiv.org/abs/1803.03735>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = \\mathbf{P} \\mathbf{X},\n\n    where the propagation matrix :math:`\\mathbf{P}` is computed as\n\n    .. math::\n        P_{i,j} = \\frac{\\exp( \\beta \\cdot \\cos(\\mathbf{x}_i, \\mathbf{x}_j))}\n        {\\sum_{k \\in \\mathcal{N}(i)\\cup \\{ i \\}} \\exp( \\beta \\cdot\n        \\cos(\\mathbf{x}_i, \\mathbf{x}_k))}\n\n    with trainable parameter :math:`\\beta`.\n\n    Args:\n        requires_grad (bool, optional): If set to :obj:`False`, :math:`\\beta`\n            will not be trainable. (default: :obj:`True`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F)`,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F)`\n    \"\"\"\n    def __init__(self, requires_grad: bool = True, add_self_loops: bool = True,\n                 **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.requires_grad = requires_grad\n        self.add_self_loops = add_self_loops\n\n        if requires_grad:\n            self.beta = Parameter(torch.empty(1))\n        else:\n            self.register_buffer('beta', torch.ones(1))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        if self.requires_grad:\n            self.beta.data.fill_(1)\n\n    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n        if self.add_self_loops:\n            if isinstance(edge_index, Tensor):\n                edge_index, _ = remove_self_loops(edge_index)\n                edge_index, _ = add_self_loops(edge_index,\n                                               num_nodes=x.size(self.node_dim))\n            elif isinstance(edge_index, SparseTensor):\n                edge_index = torch_sparse.set_diag(edge_index)\n\n        x_norm = F.normalize(x, p=2., dim=-1)\n\n        # propagate_type: (x: Tensor, x_norm: Tensor)\n        return self.propagate(edge_index, x=x, x_norm=x_norm)\n\n    def message(self, x_j: Tensor, x_norm_i: Tensor, x_norm_j: Tensor,\n                index: Tensor, ptr: OptTensor,\n                size_i: Optional[int]) -> Tensor:\n        alpha = self.beta * (x_norm_i * x_norm_j).sum(dim=-1)\n        alpha = softmax(alpha, index, ptr, size_i)\n        return x_j * alpha.view(-1, 1)\n"
  },
  {
    "path": "torch_geometric/nn/conv/antisymmetric_conv.py",
    "content": "import math\nfrom typing import Any, Callable, Dict, Optional, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import GCNConv, MessagePassing\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.nn.resolver import activation_resolver\nfrom torch_geometric.typing import Adj\n\n\nclass AntiSymmetricConv(torch.nn.Module):\n    r\"\"\"The anti-symmetric graph convolutional operator from the\n    `\"Anti-Symmetric DGN: a stable architecture for Deep Graph Networks\"\n    <https://openreview.net/forum?id=J3Y7cgZOOS>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{x}_i + \\epsilon \\cdot \\sigma \\left(\n            (\\mathbf{W}-\\mathbf{W}^T-\\gamma \\mathbf{I}) \\mathbf{x}_i +\n            \\Phi(\\mathbf{X}, \\mathcal{N}_i) + \\mathbf{b}\\right),\n\n    where :math:`\\Phi(\\mathbf{X}, \\mathcal{N}_i)` denotes a\n    :class:`~torch.nn.conv.MessagePassing` layer.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        phi (MessagePassing, optional): The message passing module\n            :math:`\\Phi`. If set to :obj:`None`, will use a\n            :class:`~torch_geometric.nn.conv.GCNConv` layer as default.\n            (default: :obj:`None`)\n        num_iters (int, optional): The number of times the anti-symmetric deep\n            graph network operator is called. (default: :obj:`1`)\n        epsilon (float, optional): The discretization step size\n            :math:`\\epsilon`. (default: :obj:`0.1`)\n        gamma (float, optional): The strength of the diffusion :math:`\\gamma`.\n            It regulates the stability of the method. (default: :obj:`0.1`)\n        act (str, optional): The non-linear activation function :math:`\\sigma`,\n            *e.g.*, :obj:`\"tanh\"` or :obj:`\"relu\"`. (default: :class:`\"tanh\"`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{in})`\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        phi: Optional[MessagePassing] = None,\n        num_iters: int = 1,\n        epsilon: float = 0.1,\n        gamma: float = 0.1,\n        act: Union[str, Callable, None] = 'tanh',\n        act_kwargs: Optional[Dict[str, Any]] = None,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.num_iters = num_iters\n        self.gamma = gamma\n        self.epsilon = epsilon\n        self.act = activation_resolver(act, **(act_kwargs or {}))\n\n        if phi is None:\n            phi = GCNConv(in_channels, in_channels, bias=False)\n\n        self.W = Parameter(torch.empty(in_channels, in_channels))\n        self.register_buffer('eye', torch.eye(in_channels))\n        self.phi = phi\n\n        if bias:\n            self.bias = Parameter(torch.empty(in_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        torch.nn.init.kaiming_uniform_(self.W, a=math.sqrt(5))\n        self.phi.reset_parameters()\n        zeros(self.bias)\n\n    def forward(self, x: Tensor, edge_index: Adj, *args, **kwargs) -> Tensor:\n        r\"\"\"Runs the forward pass of the module.\"\"\"\n        antisymmetric_W = self.W - self.W.t() - self.gamma * self.eye\n\n        for _ in range(self.num_iters):\n            h = self.phi(x, edge_index, *args, **kwargs)\n            h = x @ antisymmetric_W.t() + h\n\n            if self.bias is not None:\n                h += self.bias\n\n            if self.act is not None:\n                h = self.act(h)\n\n            x = x + self.epsilon * h\n\n        return x\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'{self.in_channels}, '\n                f'phi={self.phi}, '\n                f'num_iters={self.num_iters}, '\n                f'epsilon={self.epsilon}, '\n                f'gamma={self.gamma})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/appnp.py",
    "content": "from typing import Optional\n\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor\nfrom torch_geometric.utils import is_torch_sparse_tensor, spmm, to_edge_index\nfrom torch_geometric.utils.sparse import set_sparse_value\n\n\nclass APPNP(MessagePassing):\n    r\"\"\"The approximate personalized propagation of neural predictions layer\n    from the `\"Predict then Propagate: Graph Neural Networks meet Personalized\n    PageRank\" <https://arxiv.org/abs/1810.05997>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{(0)} &= \\mathbf{X}\n\n        \\mathbf{X}^{(k)} &= (1 - \\alpha) \\mathbf{\\hat{D}}^{-1/2}\n        \\mathbf{\\hat{A}} \\mathbf{\\hat{D}}^{-1/2} \\mathbf{X}^{(k-1)} + \\alpha\n        \\mathbf{X}^{(0)}\n\n        \\mathbf{X}^{\\prime} &= \\mathbf{X}^{(K)},\n\n    where :math:`\\mathbf{\\hat{A}} = \\mathbf{A} + \\mathbf{I}` denotes the\n    adjacency matrix with inserted self-loops and\n    :math:`\\hat{D}_{ii} = \\sum_{j=0} \\hat{A}_{ij}` its diagonal degree matrix.\n    The adjacency matrix can include other values than :obj:`1` representing\n    edge weights via the optional :obj:`edge_weight` tensor.\n\n    Args:\n        K (int): Number of iterations :math:`K`.\n        alpha (float): Teleport probability :math:`\\alpha`.\n        dropout (float, optional): Dropout probability of edges during\n            training. (default: :obj:`0`)\n        cached (bool, optional): If set to :obj:`True`, the layer will cache\n            the computation of :math:`\\mathbf{\\hat{D}}^{-1/2} \\mathbf{\\hat{A}}\n            \\mathbf{\\hat{D}}^{-1/2}` on first execution, and will use the\n            cached version for further executions.\n            This parameter should only be set to :obj:`True` in transductive\n            learning scenarios. (default: :obj:`False`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        normalize (bool, optional): Whether to add self-loops and apply\n            symmetric normalization. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F)`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F)`\n    \"\"\"\n    _cached_edge_index: Optional[OptPairTensor]\n    _cached_adj_t: Optional[SparseTensor]\n\n    def __init__(self, K: int, alpha: float, dropout: float = 0.,\n                 cached: bool = False, add_self_loops: bool = True,\n                 normalize: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n        self.K = K\n        self.alpha = alpha\n        self.dropout = dropout\n        self.cached = cached\n        self.add_self_loops = add_self_loops\n        self.normalize = normalize\n\n        self._cached_edge_index = None\n        self._cached_adj_t = None\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self._cached_edge_index = None\n        self._cached_adj_t = None\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n    ) -> Tensor:\n\n        if self.normalize:\n            if isinstance(edge_index, Tensor):\n                cache = self._cached_edge_index\n                if cache is None:\n                    edge_index, edge_weight = gcn_norm(  # yapf: disable\n                        edge_index, edge_weight, x.size(self.node_dim), False,\n                        self.add_self_loops, self.flow, dtype=x.dtype)\n                    if self.cached:\n                        self._cached_edge_index = (edge_index, edge_weight)\n                else:\n                    edge_index, edge_weight = cache[0], cache[1]\n\n            elif isinstance(edge_index, SparseTensor):\n                cache = self._cached_adj_t\n                if cache is None:\n                    edge_index = gcn_norm(  # yapf: disable\n                        edge_index, edge_weight, x.size(self.node_dim), False,\n                        self.add_self_loops, self.flow, dtype=x.dtype)\n                    if self.cached:\n                        self._cached_adj_t = edge_index\n                else:\n                    edge_index = cache\n\n        h = x\n        for _ in range(self.K):\n            if self.dropout > 0 and self.training:\n                if isinstance(edge_index, Tensor):\n                    if is_torch_sparse_tensor(edge_index):\n                        _, edge_weight = to_edge_index(edge_index)\n                        edge_weight = F.dropout(edge_weight, p=self.dropout)\n                        edge_index = set_sparse_value(edge_index, edge_weight)\n                    else:\n                        assert edge_weight is not None\n                        edge_weight = F.dropout(edge_weight, p=self.dropout)\n                else:\n                    value = edge_index.storage.value()\n                    assert value is not None\n                    value = F.dropout(value, p=self.dropout)\n                    edge_index = edge_index.set_value(value, layout='coo')\n\n            # propagate_type: (x: Tensor, edge_weight: OptTensor)\n            x = self.propagate(edge_index, x=x, edge_weight=edge_weight)\n            x = x * (1 - self.alpha)\n            x = x + self.alpha * h\n\n        return x\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(K={self.K}, alpha={self.alpha})'\n"
  },
  {
    "path": "torch_geometric/nn/conv/arma_conv.py",
    "content": "from typing import Callable, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor, nn\nfrom torch.nn import Parameter, ReLU\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.inits import glorot, zeros\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass ARMAConv(MessagePassing):\n    r\"\"\"The ARMA graph convolutional operator from the `\"Graph Neural Networks\n    with Convolutional ARMA Filters\" <https://arxiv.org/abs/1901.01343>`_\n    paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = \\frac{1}{K} \\sum_{k=1}^K \\mathbf{X}_k^{(T)},\n\n    with :math:`\\mathbf{X}_k^{(T)}` being recursively defined by\n\n    .. math::\n        \\mathbf{X}_k^{(t+1)} = \\sigma \\left( \\mathbf{\\hat{L}}\n        \\mathbf{X}_k^{(t)} \\mathbf{W} + \\mathbf{X}^{(0)} \\mathbf{V} \\right),\n\n    where :math:`\\mathbf{\\hat{L}} = \\mathbf{I} - \\mathbf{L} = \\mathbf{D}^{-1/2}\n    \\mathbf{A} \\mathbf{D}^{-1/2}` denotes the\n    modified Laplacian :math:`\\mathbf{L} = \\mathbf{I} - \\mathbf{D}^{-1/2}\n    \\mathbf{A} \\mathbf{D}^{-1/2}`.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample\n            :math:`\\mathbf{x}^{(t+1)}`.\n        num_stacks (int, optional): Number of parallel stacks :math:`K`.\n            (default: :obj:`1`).\n        num_layers (int, optional): Number of layers :math:`T`.\n            (default: :obj:`1`)\n        act (callable, optional): Activation function :math:`\\sigma`.\n            (default: :meth:`torch.nn.ReLU()`)\n        shared_weights (int, optional): If set to :obj:`True` the layers in\n            each stack will share the same parameters. (default: :obj:`False`)\n        dropout (float, optional): Dropout probability of the skip connection.\n            (default: :obj:`0.`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int,\n                 num_stacks: int = 1, num_layers: int = 1,\n                 shared_weights: bool = False,\n                 act: Optional[Callable] = ReLU(), dropout: float = 0.,\n                 bias: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_stacks = num_stacks\n        self.num_layers = num_layers\n        self.act = act\n        self.shared_weights = shared_weights\n        self.dropout = dropout\n\n        K, T, F_in, F_out = num_stacks, num_layers, in_channels, out_channels\n        T = 1 if self.shared_weights else T\n\n        self.weight = Parameter(torch.empty(max(1, T - 1), K, F_out, F_out))\n        if in_channels > 0:\n            self.init_weight = Parameter(torch.empty(K, F_in, F_out))\n            self.root_weight = Parameter(torch.empty(T, K, F_in, F_out))\n        else:\n            self.init_weight = torch.nn.parameter.UninitializedParameter()\n            self.root_weight = torch.nn.parameter.UninitializedParameter()\n            self._hook = self.register_forward_pre_hook(\n                self.initialize_parameters)\n\n        if bias:\n            self.bias = Parameter(torch.empty(T, K, 1, F_out))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        glorot(self.weight)\n        if not isinstance(self.init_weight, torch.nn.UninitializedParameter):\n            glorot(self.init_weight)\n            glorot(self.root_weight)\n        zeros(self.bias)\n\n    def forward(self, x: Tensor, edge_index: Adj,\n                edge_weight: OptTensor = None) -> Tensor:\n\n        if isinstance(edge_index, Tensor):\n            edge_index, edge_weight = gcn_norm(  # yapf: disable\n                edge_index, edge_weight, x.size(self.node_dim),\n                add_self_loops=False, flow=self.flow, dtype=x.dtype)\n\n        elif isinstance(edge_index, SparseTensor):\n            edge_index = gcn_norm(  # yapf: disable\n                edge_index, edge_weight, x.size(self.node_dim),\n                add_self_loops=False, flow=self.flow, dtype=x.dtype)\n\n        x = x.unsqueeze(-3)\n        out = x\n        for t in range(self.num_layers):\n            if t == 0:\n                out = out @ self.init_weight\n            else:\n                out = out @ self.weight[0 if self.shared_weights else t - 1]\n\n            # propagate_type: (x: Tensor, edge_weight: OptTensor)\n            out = self.propagate(edge_index, x=out, edge_weight=edge_weight)\n\n            root = F.dropout(x, p=self.dropout, training=self.training)\n            root = root @ self.root_weight[0 if self.shared_weights else t]\n            out = out + root\n\n            if self.bias is not None:\n                out = out + self.bias[0 if self.shared_weights else t]\n\n            if self.act is not None:\n                out = self.act(out)\n\n        return out.mean(dim=-3)\n\n    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:\n        return edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    @torch.no_grad()\n    def initialize_parameters(self, module, input):\n        if isinstance(self.init_weight, nn.parameter.UninitializedParameter):\n            F_in, F_out = input[0].size(-1), self.out_channels\n            T, K = self.weight.size(0) + 1, self.weight.size(1)\n            self.init_weight.materialize((K, F_in, F_out))\n            self.root_weight.materialize((T, K, F_in, F_out))\n            glorot(self.init_weight)\n            glorot(self.root_weight)\n\n        module._hook.remove()\n        delattr(module, '_hook')\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, num_stacks={self.num_stacks}, '\n                f'num_layers={self.num_layers})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/cg_conv.py",
    "content": "from typing import Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import BatchNorm1d, Linear\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.typing import Adj, OptTensor, PairTensor\n\n\nclass CGConv(MessagePassing):\n    r\"\"\"The crystal graph convolutional operator from the\n    `\"Crystal Graph Convolutional Neural Networks for an\n    Accurate and Interpretable Prediction of Material Properties\"\n    <https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.120.145301>`_\n    paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{x}_i + \\sum_{j \\in \\mathcal{N}(i)}\n        \\sigma \\left( \\mathbf{z}_{i,j} \\mathbf{W}_f + \\mathbf{b}_f \\right)\n        \\odot g \\left( \\mathbf{z}_{i,j} \\mathbf{W}_s + \\mathbf{b}_s  \\right)\n\n    where :math:`\\mathbf{z}_{i,j} = [ \\mathbf{x}_i, \\mathbf{x}_j,\n    \\mathbf{e}_{i,j} ]` denotes the concatenation of central node features,\n    neighboring node features and edge features.\n    In addition, :math:`\\sigma` and :math:`g` denote the sigmoid and softplus\n    functions, respectively.\n\n    Args:\n        channels (int or tuple): Size of each input sample. A tuple\n            corresponds to the sizes of source and target dimensionalities.\n        dim (int, optional): Edge feature dimensionality. (default: :obj:`0`)\n        aggr (str, optional): The aggregation operator to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"add\"`)\n        batch_norm (bool, optional): If set to :obj:`True`, will make use of\n            batch normalization. (default: :obj:`False`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F)` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F)` or\n          :math:`(|\\mathcal{V_t}|, F_{t})` if bipartite\n    \"\"\"\n    def __init__(self, channels: Union[int, Tuple[int, int]], dim: int = 0,\n                 aggr: str = 'add', batch_norm: bool = False,\n                 bias: bool = True, **kwargs):\n        super().__init__(aggr=aggr, **kwargs)\n        self.channels = channels\n        self.dim = dim\n        self.batch_norm = batch_norm\n\n        if isinstance(channels, int):\n            channels = (channels, channels)\n\n        self.lin_f = Linear(sum(channels) + dim, channels[1], bias=bias)\n        self.lin_s = Linear(sum(channels) + dim, channels[1], bias=bias)\n        if batch_norm:\n            self.bn = BatchNorm1d(channels[1])\n        else:\n            self.bn = None\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin_f.reset_parameters()\n        self.lin_s.reset_parameters()\n        if self.bn is not None:\n            self.bn.reset_parameters()\n\n    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,\n                edge_attr: OptTensor = None) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: PairTensor, edge_attr: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)\n        out = out if self.bn is None else self.bn(out)\n        out = out + x[1]\n        return out\n\n    def message(self, x_i, x_j, edge_attr: OptTensor) -> Tensor:\n        if edge_attr is None:\n            z = torch.cat([x_i, x_j], dim=-1)\n        else:\n            z = torch.cat([x_i, x_j, edge_attr], dim=-1)\n        return self.lin_f(z).sigmoid() * F.softplus(self.lin_s(z))\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.channels}, dim={self.dim})'\n"
  },
  {
    "path": "torch_geometric/nn/conv/cheb_conv.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import get_laplacian\n\n\nclass ChebConv(MessagePassing):\n    r\"\"\"The chebyshev spectral graph convolutional operator from the\n    `\"Convolutional Neural Networks on Graphs with Fast Localized Spectral\n    Filtering\" <https://arxiv.org/abs/1606.09375>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = \\sum_{k=1}^{K} \\mathbf{Z}^{(k)} \\cdot\n        \\mathbf{\\Theta}^{(k)}\n\n    where :math:`\\mathbf{Z}^{(k)}` is computed recursively by\n\n    .. math::\n        \\mathbf{Z}^{(1)} &= \\mathbf{X}\n\n        \\mathbf{Z}^{(2)} &= \\mathbf{\\hat{L}} \\cdot \\mathbf{X}\n\n        \\mathbf{Z}^{(k)} &= 2 \\cdot \\mathbf{\\hat{L}} \\cdot\n        \\mathbf{Z}^{(k-1)} - \\mathbf{Z}^{(k-2)}\n\n    and :math:`\\mathbf{\\hat{L}}` denotes the scaled and normalized Laplacian\n    :math:`\\frac{2\\mathbf{L}}{\\lambda_{\\max}} - \\mathbf{I}`.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        K (int): Chebyshev filter size :math:`K`.\n        normalization (str, optional): The normalization scheme for the graph\n            Laplacian (default: :obj:`\"sym\"`):\n\n            1. :obj:`None`: No normalization\n            :math:`\\mathbf{L} = \\mathbf{D} - \\mathbf{A}`\n\n            2. :obj:`\"sym\"`: Symmetric normalization\n            :math:`\\mathbf{L} = \\mathbf{I} - \\mathbf{D}^{-1/2} \\mathbf{A}\n            \\mathbf{D}^{-1/2}`\n\n            3. :obj:`\"rw\"`: Random-walk normalization\n            :math:`\\mathbf{L} = \\mathbf{I} - \\mathbf{D}^{-1} \\mathbf{A}`\n\n            :obj:`\\lambda_max` should be a :class:`torch.Tensor` of size\n            :obj:`[num_graphs]` in a mini-batch scenario and a\n            scalar/zero-dimensional tensor when operating on single graphs.\n            You can pre-compute :obj:`lambda_max` via the\n            :class:`torch_geometric.transforms.LaplacianLambdaMax` transform.\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*,\n          batch vector :math:`(|\\mathcal{V}|)` *(optional)*,\n          maximum :obj:`lambda` value :math:`(|\\mathcal{G}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        K: int,\n        normalization: Optional[str] = 'sym',\n        bias: bool = True,\n        **kwargs,\n    ):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        assert K > 0\n        assert normalization in [None, 'sym', 'rw'], 'Invalid normalization'\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.normalization = normalization\n        self.lins = torch.nn.ModuleList([\n            Linear(in_channels, out_channels, bias=False,\n                   weight_initializer='glorot') for _ in range(K)\n        ])\n\n        if bias:\n            self.bias = Parameter(Tensor(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        for lin in self.lins:\n            lin.reset_parameters()\n        zeros(self.bias)\n\n    def __norm__(\n        self,\n        edge_index: Tensor,\n        num_nodes: Optional[int],\n        edge_weight: OptTensor,\n        normalization: Optional[str],\n        lambda_max: OptTensor = None,\n        dtype: Optional[int] = None,\n        batch: OptTensor = None,\n    ):\n        edge_index, edge_weight = get_laplacian(edge_index, edge_weight,\n                                                normalization, dtype,\n                                                num_nodes)\n        assert edge_weight is not None\n\n        if lambda_max is None:\n            lambda_max = 2.0 * edge_weight.max()\n        elif not isinstance(lambda_max, Tensor):\n            lambda_max = torch.tensor(lambda_max, dtype=dtype,\n                                      device=edge_index.device)\n        assert lambda_max is not None\n\n        if batch is not None and lambda_max.numel() > 1:\n            lambda_max = lambda_max[batch[edge_index[0]]]\n\n        edge_weight = (2.0 * edge_weight) / lambda_max\n        edge_weight.masked_fill_(edge_weight == float('inf'), 0)\n\n        loop_mask = edge_index[0] == edge_index[1]\n        edge_weight[loop_mask] -= 1\n\n        return edge_index, edge_weight\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        edge_weight: OptTensor = None,\n        batch: OptTensor = None,\n        lambda_max: OptTensor = None,\n    ) -> Tensor:\n\n        edge_index, norm = self.__norm__(\n            edge_index,\n            x.size(self.node_dim),\n            edge_weight,\n            self.normalization,\n            lambda_max,\n            dtype=x.dtype,\n            batch=batch,\n        )\n\n        Tx_0 = x\n        Tx_1 = x  # Dummy.\n        out = self.lins[0](Tx_0)\n\n        # propagate_type: (x: Tensor, norm: Tensor)\n        if len(self.lins) > 1:\n            Tx_1 = self.propagate(edge_index, x=x, norm=norm)\n            out = out + self.lins[1](Tx_1)\n\n        for lin in self.lins[2:]:\n            Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm)\n            Tx_2 = 2. * Tx_2 - Tx_0\n            out = out + lin.forward(Tx_2)\n            Tx_0, Tx_1 = Tx_1, Tx_2\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, norm: Tensor) -> Tensor:\n        return norm.view(-1, 1) * x_j\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, K={len(self.lins)}, '\n                f'normalization={self.normalization})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/cluster_gcn_conv.py",
    "content": "import torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse\nfrom torch_geometric.utils import (\n    add_self_loops,\n    degree,\n    is_torch_sparse_tensor,\n    remove_self_loops,\n    spmm,\n    to_edge_index,\n)\nfrom torch_geometric.utils.sparse import set_sparse_value\n\n\nclass ClusterGCNConv(MessagePassing):\n    r\"\"\"The ClusterGCN graph convolutional operator from the\n    `\"Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph\n    Convolutional Networks\" <https://arxiv.org/abs/1905.07953>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = \\left( \\mathbf{\\hat{A}} + \\lambda \\cdot\n        \\textrm{diag}(\\mathbf{\\hat{A}}) \\right) \\mathbf{X} \\mathbf{W}_1 +\n        \\mathbf{X} \\mathbf{W}_2\n\n    where :math:`\\mathbf{\\hat{A}} = {(\\mathbf{D} + \\mathbf{I})}^{-1}(\\mathbf{A}\n    + \\mathbf{I})`.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        diag_lambda (float, optional): Diagonal enhancement value\n            :math:`\\lambda`. (default: :obj:`0.`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int,\n                 diag_lambda: float = 0., add_self_loops: bool = True,\n                 bias: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.diag_lambda = diag_lambda\n        self.add_self_loops = add_self_loops\n\n        self.lin_out = Linear(in_channels, out_channels, bias=bias,\n                              weight_initializer='glorot')\n        self.lin_root = Linear(in_channels, out_channels, bias=False,\n                               weight_initializer='glorot')\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin_out.reset_parameters()\n        self.lin_root.reset_parameters()\n\n    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n        num_nodes = x.size(self.node_dim)\n        edge_weight: OptTensor = None\n\n        if isinstance(edge_index, SparseTensor):\n            assert edge_index.size(0) == edge_index.size(1)\n\n            if self.add_self_loops:\n                edge_index = torch_sparse.set_diag(edge_index)\n\n            col, row, _ = edge_index.coo()  # Transposed.\n            deg_inv = 1. / torch_sparse.sum(edge_index, dim=1).clamp_(1.)\n\n            edge_weight = deg_inv[col]\n            edge_weight[row == col] += self.diag_lambda * deg_inv\n            edge_index = edge_index.set_value(edge_weight, layout='coo')\n\n        elif is_torch_sparse_tensor(edge_index):\n            assert edge_index.size(0) == edge_index.size(1)\n\n            if edge_index.layout == torch.sparse_csc:\n                raise NotImplementedError(\"Sparse CSC matrices are not yet \"\n                                          \"supported in 'gcn_norm'\")\n\n            if self.add_self_loops:\n                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)\n\n            col_and_row, value = to_edge_index(edge_index)\n            col, row = col_and_row[0], col_and_row[1]\n            deg_inv = 1. / degree(col, num_nodes=edge_index.size(0)).clamp_(1.)\n\n            edge_weight = deg_inv[col]\n            edge_weight[row == col] += self.diag_lambda * deg_inv\n\n            edge_index = set_sparse_value(edge_index, edge_weight)\n\n        else:\n            if self.add_self_loops:\n                edge_index, _ = remove_self_loops(edge_index)\n                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)\n\n            row, col = edge_index[0], edge_index[1]\n            deg_inv = 1. / degree(col, num_nodes=num_nodes).clamp_(1.)\n\n            edge_weight = deg_inv[col]\n            edge_weight[row == col] += self.diag_lambda * deg_inv\n\n        # propagate_type: (x: Tensor, edge_weight: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)\n        out = self.lin_out(out) + self.lin_root(x)\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:\n        return edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, diag_lambda={self.diag_lambda})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/collect.jinja",
    "content": "from typing import List, NamedTuple, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.index import ptr2index\nfrom torch_geometric.utils import is_torch_sparse_tensor\nfrom torch_geometric.typing import SparseTensor\n\n\nclass CollectArgs(NamedTuple):\n{%- if collect_param_dict|length > 0 %}\n{%- for param in collect_param_dict.values() %}\n    {{param.name}}: {{param.type_repr}}\n{%- endfor %}\n{%- else %}\n    pass\n{%- endif %}\n\n\ndef {{collect_name}}(\n    self,\n    edge_index: Union[Tensor, SparseTensor],\n{%- for param in signature.param_dict.values() %}\n    {{param.name}}: {{param.type_repr}},\n{%- endfor %}\n    size: List[Optional[int]],\n) -> CollectArgs:\n\n    i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)\n\n    # Collect special arguments:\n    if isinstance(edge_index, Tensor):\n        if is_torch_sparse_tensor(edge_index):\n{%- if 'edge_index' in collect_param_dict %}\n            raise ValueError(\"Cannot collect 'edge_indices' for sparse matrices\")\n{%- endif %}\n            adj_t = edge_index\n            if adj_t.layout == torch.sparse_coo:\n                edge_index_i = adj_t.indices()[0]\n                edge_index_j = adj_t.indices()[1]\n                ptr = None\n            elif adj_t.layout == torch.sparse_csr:\n                ptr = adj_t.crow_indices()\n                edge_index_j = adj_t.col_indices()\n                edge_index_i = ptr2index(ptr, output_size=edge_index_j.numel())\n            else:\n                raise ValueError(f\"Received invalid layout '{adj_t.layout}'\")\n\n{%- if 'edge_weight' in collect_param_dict %}\n            if edge_weight is None:\n                edge_weight = adj_t.values()\n{%- elif 'edge_attr' in collect_param_dict %}\n            if edge_attr is None:\n                _value = adj_t.values()\n                edge_attr = None if _value.dim() == 1 else _value\n{%- elif 'edge_type' in collect_param_dict %}\n            if edge_type is None:\n                edge_type = adj_t.values()\n{%- endif %}\n\n        else:\n{%- if 'adj_t' in collect_param_dict %}\n            raise ValueError(\"Cannot collect 'adj_t' for edge indices\")\n{%- endif %}\n            edge_index_i = edge_index[i]\n            edge_index_j = edge_index[j]\n\n            ptr = None\n            if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n                if i == 0 and edge_index.is_sorted_by_row:\n                  (ptr, _), _ = edge_index.get_csr()\n                elif i == 1 and edge_index.is_sorted_by_col:\n                  (ptr, _), _ = edge_index.get_csc()\n\n    elif isinstance(edge_index, SparseTensor):\n{%- if 'edge_index' in collect_param_dict %}\n        raise ValueError(\"Cannot collect 'edge_indices' for sparse matrices\")\n{%- endif %}\n        adj_t = edge_index\n        edge_index_i, edge_index_j, _value = adj_t.coo()\n        ptr, _, _ = adj_t.csr()\n\n{%- if 'edge_weight' in collect_param_dict %}\n        if edge_weight is None:\n            edge_weight = _value\n{%- elif 'edge_attr' in collect_param_dict %}\n        if edge_attr is None:\n            edge_attr = None if _value is None or _value.dim() == 1 else _value\n{%- elif 'edge_type' in collect_param_dict %}\n        if edge_type is None:\n            edge_type = _value\n{%- endif %}\n\n    else:\n        raise NotImplementedError\n\n{%- if 'edge_weight' in collect_param_dict and\n    collect_param_dict['edge_weight'].type_repr.endswith('Tensor') %}\n    if torch.jit.is_scripting():\n        assert edge_weight is not None\n{%- elif 'edge_attr' in collect_param_dict and\n    collect_param_dict['edge_attr'].type_repr.endswith('Tensor') %}\n    if torch.jit.is_scripting():\n        assert edge_attr is not None\n{%- elif 'edge_type' in collect_param_dict and\n    collect_param_dict['edge_type'].type_repr.endswith('Tensor') %}\n    if torch.jit.is_scripting():\n        assert edge_type is not None\n{%- endif %}\n\n    # Collect user-defined arguments:\n{%- for name in collect_param_dict %}\n{%- if (name.endswith('_i') or name.endswith('_j')) and\n        name not in ['edge_index_i', 'edge_index_j', 'size_i', 'size_j'] %}\n    # ({{loop.index}}) - Collect `{{name}}`:\n    if isinstance({{name[:-2]}}, (tuple, list)):\n        assert len({{name[:-2]}}) == 2\n        _{{name[:-2]}}_0, _{{name[:-2]}}_1 = {{name[:-2]}}[0], {{name[:-2]}}[1]\n        if isinstance(_{{name[:-2]}}_0, Tensor):\n            self._set_size(size, 0, _{{name[:-2]}}_0)\n{%- if name.endswith('_j') %}\n            {{name}} = self._index_select(_{{name[:-2]}}_0, edge_index_{{name[-1]}})\n        else:\n            {{name}} = None\n{%- endif %}\n        if isinstance(_{{name[:-2]}}_1, Tensor):\n            self._set_size(size, 1, _{{name[:-2]}}_1)\n{%- if name.endswith('_i') %}\n            {{name}} = self._index_select(_{{name[:-2]}}_1, edge_index_{{name[-1]}})\n        else:\n            {{name}} = None\n{%- endif %}\n    elif isinstance({{name[:-2]}}, Tensor):\n        self._set_size(size, {{name[-1]}}, {{name[:-2]}})\n        {{name}} = self._index_select({{name[:-2]}}, edge_index_{{name[-1]}})\n    else:\n        {{name}} = None\n{%- endif %}\n{%- endfor %}\n\n    # Collect default arguments:\n{%- for name, param in collect_param_dict.items() %}\n{%- if name not in signature.param_dict and\n       not name.endswith('_i') and\n       not name.endswith('_j') and\n       name not in ['edge_index', 'adj_t', 'size', 'ptr', 'index', 'dim_size'] and\n       '_empty' not in param.default.__name__ %}\n    {{name}} = {{param.default}}\n{%- endif %}\n{%- endfor %}\n\n    index = edge_index_i\n    size_i = size[i] if size[i] is not None else size[j]\n    size_j = size[j] if size[j] is not None else size[i]\n    dim_size = size_i\n\n    return CollectArgs(\n{%- for name in collect_param_dict %}\n        {{name}},\n{%- endfor %}\n    )\n"
  },
  {
    "path": "torch_geometric/nn/conv/cugraph/__init__.py",
    "content": "from .base import CuGraphModule\nfrom .sage_conv import CuGraphSAGEConv\nfrom .gat_conv import CuGraphGATConv\nfrom .rgcn_conv import CuGraphRGCNConv\n\n__all__ = [\n    'CuGraphModule',\n    'CuGraphSAGEConv',\n    'CuGraphGATConv',\n    'CuGraphRGCNConv',\n]\n"
  },
  {
    "path": "torch_geometric/nn/conv/cugraph/base.py",
    "content": "from typing import Any, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex\n\ntry:  # pragma: no cover\n    LEGACY_MODE = False\n    from pylibcugraphops.pytorch import CSC, HeteroCSC\n    HAS_PYLIBCUGRAPHOPS = True\nexcept ImportError:\n    HAS_PYLIBCUGRAPHOPS = False\n    try:  # pragma: no cover\n        from pylibcugraphops import (\n            make_fg_csr,\n            make_fg_csr_hg,\n            make_mfg_csr,\n            make_mfg_csr_hg,\n        )\n        LEGACY_MODE = True\n    except ImportError:\n        pass\n\n\nclass CuGraphModule(torch.nn.Module):  # pragma: no cover\n    r\"\"\"An abstract base class for implementing :obj:`cugraph`-based message\n    passing layers.\n    \"\"\"\n    def __init__(self):\n        super().__init__()\n\n        if not HAS_PYLIBCUGRAPHOPS and not LEGACY_MODE:\n            raise ModuleNotFoundError(f\"'{self.__class__.__name__}' requires \"\n                                      f\"'pylibcugraphops>=23.02'\")\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n\n    def get_cugraph(\n        self,\n        edge_index: EdgeIndex,\n        max_num_neighbors: Optional[int] = None,\n    ) -> Any:\n        r\"\"\"Constructs a :obj:`cugraph` graph object from CSC representation.\n        Supports both bipartite and non-bipartite graphs.\n\n        Args:\n            edge_index (EdgeIndex): The edge indices.\n            max_num_neighbors (int, optional): The maximum number of neighbors\n                of a target node. It is only effective when operating in a\n                bipartite graph. When not given, will be computed on-the-fly,\n                leading to slightly worse performance. (default: :obj:`None`)\n        \"\"\"\n        if not isinstance(edge_index, EdgeIndex):\n            raise ValueError(f\"'edge_index' needs to be of type 'EdgeIndex' \"\n                             f\"(got {type(edge_index)})\")\n\n        edge_index = edge_index.sort_by('col')[0]\n        num_src_nodes = edge_index.get_sparse_size(0)\n        (colptr, row), _ = edge_index.get_csc()\n\n        if not row.is_cuda:\n            raise RuntimeError(f\"'{self.__class__.__name__}' requires GPU-\"\n                               f\"based processing (got CPU tensor)\")\n\n        if num_src_nodes != colptr.numel() - 1:  # Bipartite graph:\n            if max_num_neighbors is None:\n                max_num_neighbors = int((colptr[1:] - colptr[:-1]).max())\n\n            if LEGACY_MODE:\n                dst_nodes = torch.arange(colptr.numel() - 1, device=row.device)\n                return make_mfg_csr(dst_nodes, colptr, row, max_num_neighbors,\n                                    num_src_nodes)\n\n            return CSC(colptr, row, num_src_nodes,\n                       dst_max_in_degree=max_num_neighbors)\n\n        if LEGACY_MODE:\n            return make_fg_csr(colptr, row)\n\n        return CSC(colptr, row, num_src_nodes=num_src_nodes)\n\n    def get_typed_cugraph(\n        self,\n        edge_index: EdgeIndex,\n        edge_type: Tensor,\n        num_edge_types: Optional[int] = None,\n        max_num_neighbors: Optional[int] = None,\n    ) -> Any:\n        r\"\"\"Constructs a typed :obj:`cugraph` graph object from a CSC\n        representation where each edge corresponds to a given edge type.\n        Supports both bipartite and non-bipartite graphs.\n\n        Args:\n            edge_index (EdgeIndex): The edge indices.\n            edge_type (torch.Tensor): The edge type.\n            num_edge_types (int, optional): The maximum number of edge types.\n                When not given, will be computed on-the-fly, leading to\n                slightly worse performance. (default: :obj:`None`)\n            max_num_neighbors (int, optional): The maximum number of neighbors\n                of a target node. It is only effective when operating in a\n                bipartite graph. When not given, will be computed on-the-fly,\n                leading to slightly worse performance. (default: :obj:`None`)\n        \"\"\"\n        if num_edge_types is None:\n            num_edge_types = int(edge_type.max()) + 1\n\n        if not isinstance(edge_index, EdgeIndex):\n            raise ValueError(f\"'edge_index' needs to be of type 'EdgeIndex' \"\n                             f\"(got {type(edge_index)})\")\n\n        edge_index, perm = edge_index.sort_by('col')\n        edge_type = edge_type[perm]\n        num_src_nodes = edge_index.get_sparse_size(0)\n        (colptr, row), _ = edge_index.get_csc()\n\n        edge_type = edge_type.int()\n\n        if num_src_nodes != colptr.numel() - 1:  # Bipartite graph:\n            if max_num_neighbors is None:\n                max_num_neighbors = int((colptr[1:] - colptr[:-1]).max())\n\n            if LEGACY_MODE:\n                dst_nodes = torch.arange(colptr.numel() - 1, device=row.device)\n                return make_mfg_csr_hg(dst_nodes, colptr, row,\n                                       max_num_neighbors, num_src_nodes,\n                                       n_node_types=0,\n                                       n_edge_types=num_edge_types,\n                                       out_node_types=None, in_node_types=None,\n                                       edge_types=edge_type)\n\n            return HeteroCSC(colptr, row, edge_type, num_src_nodes,\n                             num_edge_types,\n                             dst_max_in_degree=max_num_neighbors)\n\n        if LEGACY_MODE:\n            return make_fg_csr_hg(colptr, row, n_node_types=0,\n                                  n_edge_types=num_edge_types, node_types=None,\n                                  edge_types=edge_type)\n\n        return HeteroCSC(colptr, row, edge_type, num_src_nodes, num_edge_types)\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: EdgeIndex,\n        max_num_neighbors: Optional[int] = None,\n    ) -> Tensor:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor): The node features.\n            edge_index (EdgeIndex): The edge indices.\n            max_num_neighbors (int, optional): The maximum number of neighbors\n                of a target node. It is only effective when operating in a\n                bipartite graph. When not given, the value will be computed\n                on-the-fly, leading to slightly worse performance.\n                (default: :obj:`None`)\n        \"\"\"\n        raise NotImplementedError\n"
  },
  {
    "path": "torch_geometric/nn/conv/cugraph/gat_conv.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Linear, Parameter\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.nn.conv.cugraph import CuGraphModule\nfrom torch_geometric.nn.conv.cugraph.base import LEGACY_MODE\nfrom torch_geometric.nn.inits import zeros\n\ntry:\n    if LEGACY_MODE:\n        from pylibcugraphops.torch.autograd import mha_gat_n2n as GATConvAgg\n    else:\n        from pylibcugraphops.pytorch.operators import mha_gat_n2n as GATConvAgg\nexcept ImportError:\n    pass\n\n\nclass CuGraphGATConv(CuGraphModule):  # pragma: no cover\n    r\"\"\"The graph attentional operator from the `\"Graph Attention Networks\"\n    <https://arxiv.org/abs/1710.10903>`_ paper.\n\n    :class:`CuGraphGATConv` is an optimized version of\n    :class:`~torch_geometric.nn.conv.GATConv` based on the :obj:`cugraph-ops`\n    package that fuses message passing computation for accelerated execution\n    and lower memory footprint.\n\n    See :ref:`install-cugraph` for how to set up :obj:`cugraph-ops`.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        heads: int = 1,\n        concat: bool = True,\n        negative_slope: float = 0.2,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.concat = concat\n        self.negative_slope = negative_slope\n\n        self.lin = Linear(in_channels, heads * out_channels, bias=False)\n        self.att = Parameter(torch.empty(2 * heads * out_channels))\n\n        if bias and concat:\n            self.bias = Parameter(torch.empty(heads * out_channels))\n        elif bias and not concat:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.lin.reset_parameters()\n        gain = torch.nn.init.calculate_gain('relu')\n        torch.nn.init.xavier_normal_(\n            self.att.view(2, self.heads, self.out_channels), gain=gain)\n        zeros(self.bias)\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: EdgeIndex,\n        edge_attr: Tensor,\n        max_num_neighbors: Optional[int] = None,\n    ) -> Tensor:\n        graph = self.get_cugraph(edge_index, max_num_neighbors)\n\n        x = self.lin(x)\n\n        if LEGACY_MODE:\n            out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',\n                             self.negative_slope, False, self.concat,\n                             edge_feat=edge_attr)\n        else:\n            out = GATConvAgg(x, self.att, graph, self.heads, 'LeakyReLU',\n                             self.negative_slope, self.concat,\n                             edge_feat=edge_attr)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/cugraph/rgcn_conv.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.nn.conv.cugraph import CuGraphModule\nfrom torch_geometric.nn.conv.cugraph.base import LEGACY_MODE\nfrom torch_geometric.nn.inits import glorot, zeros\n\ntry:\n    if LEGACY_MODE:\n        from pylibcugraphops.torch.autograd import \\\n            agg_hg_basis_n2n_post as RGCNConvAgg\n    else:\n        from pylibcugraphops.pytorch.operators import \\\n            agg_hg_basis_n2n_post as RGCNConvAgg\nexcept ImportError:\n    pass\n\n\nclass CuGraphRGCNConv(CuGraphModule):  # pragma: no cover\n    r\"\"\"The relational graph convolutional operator from the `\"Modeling\n    Relational Data with Graph Convolutional Networks\"\n    <https://arxiv.org/abs/1703.06103>`_ paper.\n\n    :class:`CuGraphRGCNConv` is an optimized version of\n    :class:`~torch_geometric.nn.conv.RGCNConv` based on the :obj:`cugraph-ops`\n    package that fuses message passing computation for accelerated execution\n    and lower memory footprint.\n\n    See :ref:`install-cugraph` for how to set up :obj:`cugraph-ops`.\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int, num_relations: int,\n                 num_bases: Optional[int] = None, aggr: str = 'mean',\n                 root_weight: bool = True, bias: bool = True):\n        super().__init__()\n\n        if aggr not in ['sum', 'add', 'mean']:\n            raise ValueError(f\"Aggregation function must be either 'mean' \"\n                             f\"or 'sum' (got '{aggr}')\")\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_relations = num_relations\n        self.num_bases = num_bases\n        self.aggr = aggr\n        self.root_weight = root_weight\n\n        dim_root_weight = 1 if root_weight else 0\n\n        if num_bases is not None:\n            self.weight = Parameter(\n                torch.empty(num_bases + dim_root_weight, in_channels,\n                            out_channels))\n            self.comp = Parameter(torch.empty(num_relations, num_bases))\n        else:\n            self.weight = Parameter(\n                torch.empty(num_relations + dim_root_weight, in_channels,\n                            out_channels))\n            self.register_parameter('comp', None)\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        end = -1 if self.root_weight else None\n        glorot(self.weight[:end])\n        glorot(self.comp)\n        if self.root_weight:\n            glorot(self.weight[-1])\n        zeros(self.bias)\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: EdgeIndex,\n        edge_type: Tensor,\n        max_num_neighbors: Optional[int] = None,\n    ) -> Tensor:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor): The node features.\n            edge_index (EdgeIndex): The edge indices.\n            edge_type (torch.Tensor): The edge type.\n            max_num_neighbors (int, optional): The maximum number of neighbors\n                of a target node. It is only effective when operating in a\n                bipartite graph.. When not given, the value will be computed\n                on-the-fly, leading to slightly worse performance.\n                (default: :obj:`None`)\n        \"\"\"\n        graph = self.get_typed_cugraph(edge_index, edge_type,\n                                       self.num_relations, max_num_neighbors)\n\n        out = RGCNConvAgg(x, self.comp, graph, concat_own=self.root_weight,\n                          norm_by_out_degree=bool(self.aggr == 'mean'))\n\n        out = out @ self.weight.view(-1, self.out_channels)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, num_relations={self.num_relations})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/cugraph/sage_conv.py",
    "content": "from typing import Optional\n\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Linear\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.nn.conv.cugraph import CuGraphModule\nfrom torch_geometric.nn.conv.cugraph.base import LEGACY_MODE\n\ntry:\n    if LEGACY_MODE:\n        from pylibcugraphops.torch.autograd import \\\n            agg_concat_n2n as SAGEConvAgg\n    else:\n        from pylibcugraphops.pytorch.operators import \\\n            agg_concat_n2n as SAGEConvAgg\nexcept ImportError:\n    pass\n\n\nclass CuGraphSAGEConv(CuGraphModule):  # pragma: no cover\n    r\"\"\"The GraphSAGE operator from the `\"Inductive Representation Learning on\n    Large Graphs\" <https://arxiv.org/abs/1706.02216>`_ paper.\n\n    :class:`CuGraphSAGEConv` is an optimized version of\n    :class:`~torch_geometric.nn.conv.SAGEConv` based on the :obj:`cugraph-ops`\n    package that fuses message passing computation for accelerated execution\n    and lower memory footprint.\n\n    See :ref:`install-cugraph` for how to set up :obj:`cugraph-ops`.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        aggr: str = 'mean',\n        normalize: bool = False,\n        root_weight: bool = True,\n        project: bool = False,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        if aggr not in ['mean', 'sum', 'min', 'max']:\n            raise ValueError(f\"Aggregation function must be either 'mean', \"\n                             f\"'sum', 'min' or 'max' (got '{aggr}')\")\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.aggr = aggr\n        self.normalize = normalize\n        self.root_weight = root_weight\n        self.project = project\n\n        if self.project:\n            self.pre_lin = Linear(in_channels, in_channels, bias=True)\n\n        if self.root_weight:\n            self.lin = Linear(2 * in_channels, out_channels, bias=bias)\n        else:\n            self.lin = Linear(in_channels, out_channels, bias=bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        if self.project:\n            self.pre_lin.reset_parameters()\n        self.lin.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: EdgeIndex,\n        max_num_neighbors: Optional[int] = None,\n    ) -> Tensor:\n        graph = self.get_cugraph(edge_index, max_num_neighbors)\n\n        if self.project:\n            x = self.pre_lin(x).relu()\n\n        out = SAGEConvAgg(x, graph, self.aggr)\n\n        if self.root_weight:\n            out = self.lin(out)\n        else:\n            out = self.lin(out[:, :self.in_channels])\n\n        if self.normalize:\n            out = F.normalize(out, p=2., dim=-1)\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, aggr={self.aggr})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/dir_gnn_conv.py",
    "content": "import copy\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\n\n\nclass DirGNNConv(torch.nn.Module):\n    r\"\"\"A generic wrapper for computing graph convolution on directed\n    graphs as described in the `\"Edge Directionality Improves Learning on\n    Heterophilic Graphs\" <https://arxiv.org/abs/2305.10498>`_ paper.\n    :class:`DirGNNConv` will pass messages both from source nodes to target\n    nodes and from target nodes to source nodes.\n\n    Args:\n        conv (MessagePassing): The underlying\n            :class:`~torch_geometric.nn.conv.MessagePassing` layer to use.\n        alpha (float, optional): The alpha coefficient used to weight the\n            aggregations of in- and out-edges as part of a convex combination.\n            (default: :obj:`0.5`)\n        root_weight (bool, optional): If set to :obj:`True`, the layer will add\n            transformed root node features to the output.\n            (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        conv: MessagePassing,\n        alpha: float = 0.5,\n        root_weight: bool = True,\n    ):\n        super().__init__()\n\n        self.alpha = alpha\n        self.root_weight = root_weight\n\n        self.conv_in = copy.deepcopy(conv)\n        self.conv_out = copy.deepcopy(conv)\n\n        if hasattr(conv, 'add_self_loops'):\n            self.conv_in.add_self_loops = False\n            self.conv_out.add_self_loops = False\n        if hasattr(conv, 'root_weight'):\n            self.conv_in.root_weight = False\n            self.conv_out.root_weight = False\n\n        if root_weight:\n            self.lin = torch.nn.Linear(conv.in_channels, conv.out_channels)\n        else:\n            self.lin = None\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.conv_in.reset_parameters()\n        self.conv_out.reset_parameters()\n        if self.lin is not None:\n            self.lin.reset_parameters()\n\n    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        x_in = self.conv_in(x, edge_index)\n        x_out = self.conv_out(x, edge_index.flip([0]))\n\n        out = self.alpha * x_out + (1 - self.alpha) * x_in\n\n        if self.root_weight:\n            out = out + self.lin(x)\n\n        return out\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.conv_in}, alpha={self.alpha})'\n"
  },
  {
    "path": "torch_geometric/nn/conv/dna_conv.py",
    "content": "import math\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.inits import kaiming_uniform, uniform\nfrom torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor\n\n\nclass Linear(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, groups=1, bias=True):\n        super().__init__()\n        assert in_channels % groups == 0 and out_channels % groups == 0\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.groups = groups\n\n        self.weight = Parameter(\n            torch.empty(groups, in_channels // groups, out_channels // groups))\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        kaiming_uniform(self.weight, fan=self.weight.size(1), a=math.sqrt(5))\n        uniform(self.weight.size(1), self.bias)\n\n    def forward(self, src):\n        # Input: [*, in_channels]\n        # Output: [*, out_channels]\n\n        if self.groups > 1:\n            size = src.size()[:-1]\n            src = src.view(-1, self.groups, self.in_channels // self.groups)\n            src = src.transpose(0, 1).contiguous()\n            out = torch.matmul(src, self.weight)\n            out = out.transpose(1, 0).contiguous()\n            out = out.view(size + (self.out_channels, ))\n        else:\n            out = torch.matmul(src, self.weight.squeeze(0))\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def __repr__(self) -> str:  # pragma: no cover\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, groups={self.groups})')\n\n\ndef restricted_softmax(src, dim: int = -1, margin: float = 0.):\n    src_max = torch.clamp(src.max(dim=dim, keepdim=True)[0], min=0.)\n    out = (src - src_max).exp()\n    out = out / (out.sum(dim=dim, keepdim=True) + (margin - src_max).exp())\n    return out\n\n\nclass Attention(torch.nn.Module):\n    def __init__(self, dropout=0):\n        super().__init__()\n        self.dropout = dropout\n\n    def forward(self, query, key, value):\n        return self.compute_attention(query, key, value)\n\n    def compute_attention(self, query, key, value):\n        # query: [*, query_entries, dim_k]\n        # key: [*, key_entries, dim_k]\n        # value: [*, key_entries, dim_v]\n        # Output: [*, query_entries, dim_v]\n\n        assert query.dim() == key.dim() == value.dim() >= 2\n        assert query.size(-1) == key.size(-1)\n        assert key.size(-2) == value.size(-2)\n\n        # Score: [*, query_entries, key_entries]\n        score = torch.matmul(query, key.transpose(-2, -1))\n        score = score / math.sqrt(key.size(-1))\n        score = restricted_softmax(score, dim=-1)\n        score = F.dropout(score, p=self.dropout, training=self.training)\n\n        return torch.matmul(score, value)\n\n    def __repr__(self) -> str:  # pragma: no cover\n        return f'{self.__class__.__name__}(dropout={self.dropout})'\n\n\nclass MultiHead(Attention):\n    def __init__(self, in_channels, out_channels, heads=1, groups=1, dropout=0,\n                 bias=True):\n        super().__init__(dropout)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.groups = groups\n        self.bias = bias\n\n        assert in_channels % heads == 0 and out_channels % heads == 0\n        assert in_channels % groups == 0 and out_channels % groups == 0\n        assert max(groups, self.heads) % min(groups, self.heads) == 0\n\n        self.lin_q = Linear(in_channels, out_channels, groups, bias)\n        self.lin_k = Linear(in_channels, out_channels, groups, bias)\n        self.lin_v = Linear(in_channels, out_channels, groups, bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.lin_q.reset_parameters()\n        self.lin_k.reset_parameters()\n        self.lin_v.reset_parameters()\n\n    def forward(self, query, key, value):\n        # query: [*, query_entries, in_channels]\n        # key: [*, key_entries, in_channels]\n        # value: [*, key_entries, in_channels]\n        # Output: [*, query_entries, out_channels]\n\n        assert query.dim() == key.dim() == value.dim() >= 2\n        assert query.size(-1) == key.size(-1) == value.size(-1)\n        assert key.size(-2) == value.size(-2)\n\n        query = self.lin_q(query)\n        key = self.lin_k(key)\n        value = self.lin_v(value)\n\n        # query: [*, heads, query_entries, out_channels // heads]\n        # key: [*, heads, key_entries, out_channels // heads]\n        # value: [*, heads, key_entries, out_channels // heads]\n        size = query.size()[:-2]\n        out_channels_per_head = self.out_channels // self.heads\n\n        query_size = size + (query.size(-2), self.heads, out_channels_per_head)\n        query = query.view(query_size).transpose(-2, -3)\n\n        key_size = size + (key.size(-2), self.heads, out_channels_per_head)\n        key = key.view(key_size).transpose(-2, -3)\n\n        value_size = size + (value.size(-2), self.heads, out_channels_per_head)\n        value = value.view(value_size).transpose(-2, -3)\n\n        # Output: [*, heads, query_entries, out_channels // heads]\n        out = self.compute_attention(query, key, value)\n        # Output: [*, query_entries, heads, out_channels // heads]\n        out = out.transpose(-3, -2).contiguous()\n        # Output: [*, query_entries, out_channels]\n        out = out.view(size + (query.size(-2), self.out_channels))\n\n        return out\n\n    def __repr__(self) -> str:  # pragma: no cover\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads}, '\n                f'groups={self.groups}, dropout={self.dropout}, '\n                f'bias={self.bias})')\n\n\nclass DNAConv(MessagePassing):\n    r\"\"\"The dynamic neighborhood aggregation operator from the `\"Just Jump:\n    Towards Dynamic Neighborhood Aggregation in Graph Neural Networks\"\n    <https://arxiv.org/abs/1904.04849>`_ paper.\n\n    .. math::\n        \\mathbf{x}_v^{(t)} = h_{\\mathbf{\\Theta}}^{(t)} \\left( \\mathbf{x}_{v\n        \\leftarrow v}^{(t)}, \\left\\{ \\mathbf{x}_{v \\leftarrow w}^{(t)} : w \\in\n        \\mathcal{N}(v) \\right\\} \\right)\n\n    based on (multi-head) dot-product attention\n\n    .. math::\n        \\mathbf{x}_{v \\leftarrow w}^{(t)} = \\textrm{Attention} \\left(\n        \\mathbf{x}^{(t-1)}_v \\, \\mathbf{\\Theta}_Q^{(t)}, [\\mathbf{x}_w^{(1)},\n        \\ldots, \\mathbf{x}_w^{(t-1)}] \\, \\mathbf{\\Theta}_K^{(t)}, \\,\n        [\\mathbf{x}_w^{(1)}, \\ldots, \\mathbf{x}_w^{(t-1)}] \\,\n        \\mathbf{\\Theta}_V^{(t)} \\right)\n\n    with :math:`\\mathbf{\\Theta}_Q^{(t)}, \\mathbf{\\Theta}_K^{(t)},\n    \\mathbf{\\Theta}_V^{(t)}` denoting (grouped) projection matrices for query,\n    key and value information, respectively.\n    :math:`h^{(t)}_{\\mathbf{\\Theta}}` is implemented as a non-trainable\n    version of :class:`torch_geometric.nn.conv.GCNConv`.\n\n    .. note::\n        In contrast to other layers, this operator expects node features as\n        shape :obj:`[num_nodes, num_layers, channels]`.\n\n    Args:\n        channels (int): Size of each input/output sample.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        groups (int, optional): Number of groups to use for all linear\n            projections. (default: :obj:`1`)\n        dropout (float, optional): Dropout probability of attention\n            coefficients. (default: :obj:`0.`)\n        cached (bool, optional): If set to :obj:`True`, the layer will cache\n            the computation of :math:`\\mathbf{\\hat{D}}^{-1/2} \\mathbf{\\hat{A}}\n            \\mathbf{\\hat{D}}^{-1/2}` on first execution, and will use the\n            cached version for further executions.\n            This parameter should only be set to :obj:`True` in transductive\n            learning scenarios. (default: :obj:`False`)\n        normalize (bool, optional): Whether to add self-loops and apply\n            symmetric normalization. (default: :obj:`True`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, L, F)` where :math:`L` is the\n          number of layers,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F)`\n    \"\"\"\n\n    _cached_edge_index: Optional[OptPairTensor]\n    _cached_adj_t: Optional[SparseTensor]\n\n    def __init__(self, channels: int, heads: int = 1, groups: int = 1,\n                 dropout: float = 0., cached: bool = False,\n                 normalize: bool = True, add_self_loops: bool = True,\n                 bias: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(node_dim=0, **kwargs)\n\n        self.bias = bias\n        self.cached = cached\n        self.normalize = normalize\n        self.add_self_loops = add_self_loops\n\n        self._cached_edge_index = None\n        self._cached_adj_t = None\n\n        self.multi_head = MultiHead(channels, channels, heads, groups, dropout,\n                                    bias)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.multi_head.reset_parameters()\n        self._cached_edge_index = None\n        self._cached_adj_t = None\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n    ) -> Tensor:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor): The input node features of shape\n                :obj:`[num_nodes, num_layers, channels]`.\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            edge_weight (torch.Tensor, optional): The edge weights.\n                (default: :obj:`None`)\n        \"\"\"\n        if x.dim() != 3:\n            raise ValueError('Feature shape must be [num_nodes, num_layers, '\n                             'channels].')\n\n        if self.normalize:\n            if isinstance(edge_index, Tensor):\n                cache = self._cached_edge_index\n                if cache is None:\n                    edge_index, edge_weight = gcn_norm(  # yapf: disable\n                        edge_index, edge_weight, x.size(self.node_dim), False,\n                        self.add_self_loops, self.flow, dtype=x.dtype)\n                    if self.cached:\n                        self._cached_edge_index = (edge_index, edge_weight)\n                else:\n                    edge_index, edge_weight = cache[0], cache[1]\n\n            elif isinstance(edge_index, SparseTensor):\n                cache = self._cached_adj_t\n                if cache is None:\n                    edge_index = gcn_norm(  # yapf: disable\n                        edge_index, edge_weight, x.size(self.node_dim), False,\n                        self.add_self_loops, self.flow, dtype=x.dtype)\n                    if self.cached:\n                        self._cached_adj_t = edge_index\n                else:\n                    edge_index = cache\n\n        # propagate_type: (x: Tensor, edge_weight: OptTensor)\n        return self.propagate(edge_index, x=x, edge_weight=edge_weight)\n\n    def message(self, x_i: Tensor, x_j: Tensor, edge_weight: Tensor) -> Tensor:\n        x_i = x_i[:, -1:]  # [num_edges, 1, channels]\n        out = self.multi_head(x_i, x_j, x_j)  # [num_edges, 1, channels]\n        return edge_weight.view(-1, 1) * out.squeeze(1)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.multi_head.in_channels}, '\n                f'heads={self.multi_head.heads}, '\n                f'groups={self.multi_head.groups})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/edge_conv.py",
    "content": "from typing import Callable, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor\n\nif torch_geometric.typing.WITH_TORCH_CLUSTER:\n    from torch_cluster import knn\nelse:\n    knn = None\n\n\nclass EdgeConv(MessagePassing):\n    r\"\"\"The edge convolutional operator from the `\"Dynamic Graph CNN for\n    Learning on Point Clouds\" <https://arxiv.org/abs/1801.07829>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\sum_{j \\in \\mathcal{N}(i)}\n        h_{\\mathbf{\\Theta}}(\\mathbf{x}_i \\, \\Vert \\,\n        \\mathbf{x}_j - \\mathbf{x}_i),\n\n    where :math:`h_{\\mathbf{\\Theta}}` denotes a neural network, *.i.e.* a MLP.\n\n    Args:\n        nn (torch.nn.Module): A neural network :math:`h_{\\mathbf{\\Theta}}` that\n            maps pair-wise concatenated node features :obj:`x` of shape\n            :obj:`[-1, 2 * in_channels]` to shape :obj:`[-1, out_channels]`,\n            *e.g.*, defined by :class:`torch.nn.Sequential`.\n        aggr (str, optional): The aggregation scheme to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"max\"`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V}|, F_{in}), (|\\mathcal{V}|, F_{in}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, nn: Callable, aggr: str = 'max', **kwargs):\n        super().__init__(aggr=aggr, **kwargs)\n        self.nn = nn\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        reset(self.nn)\n\n    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: PairTensor)\n        return self.propagate(edge_index, x=x)\n\n    def message(self, x_i: Tensor, x_j: Tensor) -> Tensor:\n        return self.nn(torch.cat([x_i, x_j - x_i], dim=-1))\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(nn={self.nn})'\n\n\nclass DynamicEdgeConv(MessagePassing):\n    r\"\"\"The dynamic edge convolutional operator from the `\"Dynamic Graph CNN\n    for Learning on Point Clouds\" <https://arxiv.org/abs/1801.07829>`_ paper\n    (see :class:`torch_geometric.nn.conv.EdgeConv`), where the graph is\n    dynamically constructed using nearest neighbors in the feature space.\n\n    Args:\n        nn (torch.nn.Module): A neural network :math:`h_{\\mathbf{\\Theta}}` that\n            maps pair-wise concatenated node features :obj:`x` of shape\n            `:obj:`[-1, 2 * in_channels]` to shape :obj:`[-1, out_channels]`,\n            *e.g.* defined by :class:`torch.nn.Sequential`.\n        k (int): Number of nearest neighbors.\n        aggr (str, optional): The aggregation scheme to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"max\"`)\n        num_workers (int): Number of workers to use for k-NN computation.\n            Has no effect in case :obj:`batch` is not :obj:`None`, or the input\n            lies on the GPU. (default: :obj:`1`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V}|, F_{in}), (|\\mathcal{V}|, F_{in}))`\n          if bipartite,\n          batch vector :math:`(|\\mathcal{V}|)` or\n          :math:`((|\\mathcal{V}|), (|\\mathcal{V}|))`\n          if bipartite *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, nn: Callable, k: int, aggr: str = 'max',\n                 num_workers: int = 1, **kwargs):\n        super().__init__(aggr=aggr, flow='source_to_target', **kwargs)\n\n        if knn is None:\n            raise ImportError('`DynamicEdgeConv` requires `torch-cluster`.')\n\n        self.nn = nn\n        self.k = k\n        self.num_workers = num_workers\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        reset(self.nn)\n\n    def forward(\n        self,\n        x: Union[Tensor, PairTensor],\n        batch: Union[OptTensor, Optional[PairTensor]] = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        if x[0].dim() != 2:\n            raise ValueError(\"Static graphs not supported in DynamicEdgeConv\")\n\n        b: PairOptTensor = (None, None)\n        if isinstance(batch, Tensor):\n            b = (batch, batch)\n        elif isinstance(batch, tuple):\n            assert batch is not None\n            b = (batch[0], batch[1])\n\n        edge_index = knn(x[0], x[1], self.k, b[0], b[1]).flip([0])\n\n        # propagate_type: (x: PairTensor)\n        return self.propagate(edge_index, x=x)\n\n    def message(self, x_i: Tensor, x_j: Tensor) -> Tensor:\n        return self.nn(torch.cat([x_i, x_j - x_i], dim=-1))\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(nn={self.nn}, k={self.k})'\n"
  },
  {
    "path": "torch_geometric/nn/conv/edge_updater.jinja",
    "content": "import typing\nfrom typing import Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling\nfrom torch_geometric.utils import is_sparse\nfrom torch_geometric.typing import Size, SparseTensor\n{% for module in modules %}\nfrom {{module}} import *\n{%- endfor %}\n\n\n{% include \"collect.jinja\" %}\n\n\ndef edge_updater(\n    self,\n    edge_index: Union[Tensor, SparseTensor],\n{%- for param in signature.param_dict.values() %}\n    {{param.name}}: {{param.type_repr}},\n{%- endfor %}\n    size: Size = None,\n) -> {{signature.return_type_repr}}:\n\n    mutable_size = self._check_input(edge_index, size)\n\n    kwargs = self.{{collect_name}}(\n        edge_index,\n{%- for name in signature.param_dict %}\n        {{name}},\n{%- endfor %}\n        mutable_size,\n    )\n\n    # Begin Edge Update Forward Pre Hook #######################################\n    if not torch.jit.is_scripting() and not is_compiling():\n        for hook in self._edge_update_forward_pre_hooks.values():\n            hook_kwargs = dict(\n{%- for name in collect_param_dict %}\n                {{name}}=kwargs.{{name}},\n{%- endfor %}\n            )\n            res = hook(self, (edge_index, size, hook_kwargs))\n            if res is not None:\n                edge_index, size, hook_kwargs = res\n                kwargs = CollectArgs(\n{%- for name in collect_param_dict %}\n                    {{name}}=hook_kwargs['{{name}}'],\n{%- endfor %}\n                )\n    # End Edge Update Forward Pre Hook #########################################\n\n    out = self.edge_update(\n{%- for name in collect_param_dict %}\n        {{name}}=kwargs.{{name}},\n{%- endfor %}\n    )\n\n    # Begin Edge Update Forward Hook ###########################################\n    if not torch.jit.is_scripting() and not is_compiling():\n        for hook in self._edge_update_forward_hooks.values():\n            hook_kwargs = dict(\n{%- for name in collect_param_dict %}\n                {{name}}=kwargs.{{name}},\n{%- endfor %}\n            )\n            res = hook(self, (edge_index, size, hook_kwargs), out)\n            out = res if res is not None else out\n    # End Edge Update Forward Hook #############################################\n\n    return out\n"
  },
  {
    "path": "torch_geometric/nn/conv/eg_conv.py",
    "content": "from typing import List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse\nfrom torch_geometric.utils import add_remaining_self_loops, scatter, spmm\n\n\nclass EGConv(MessagePassing):\n    r\"\"\"The Efficient Graph Convolution from the `\"Adaptive Filters and\n    Aggregator Fusion for Efficient Graph Convolutions\"\n    <https://arxiv.org/abs/2104.01481>`_ paper.\n\n    Its node-wise formulation is given by:\n\n    .. math::\n        \\mathbf{x}_i^{\\prime} = {\\LARGE ||}_{h=1}^H \\sum_{\\oplus \\in\n        \\mathcal{A}} \\sum_{b = 1}^B w_{i, h, \\oplus, b} \\;\n        \\underset{j \\in \\mathcal{N}(i) \\cup \\{i\\}}{\\bigoplus}\n        \\mathbf{W}_b \\mathbf{x}_{j}\n\n    with :math:`\\mathbf{W}_b` denoting a basis weight,\n    :math:`\\oplus` denoting an aggregator, and :math:`w` denoting per-vertex\n    weighting coefficients across different heads, bases and aggregators.\n\n    EGC retains :math:`\\mathcal{O}(|\\mathcal{V}|)` memory usage, making it a\n    sensible alternative to :class:`~torch_geometric.nn.conv.GCNConv`,\n    :class:`~torch_geometric.nn.conv.SAGEConv` or\n    :class:`~torch_geometric.nn.conv.GINConv`.\n\n    .. note::\n        For an example of using :obj:`EGConv`, see `examples/egc.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/\n        examples/egc.py>`_.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        aggregators (List[str], optional): Aggregators to be used.\n            Supported aggregators are :obj:`\"sum\"`, :obj:`\"mean\"`,\n            :obj:`\"symnorm\"`, :obj:`\"max\"`, :obj:`\"min\"`, :obj:`\"std\"`,\n            :obj:`\"var\"`.\n            Multiple aggregators can be used to improve the performance.\n            (default: :obj:`[\"symnorm\"]`)\n        num_heads (int, optional): Number of heads :math:`H` to use. Must have\n            :obj:`out_channels % num_heads == 0`. It is recommended to set\n            :obj:`num_heads >= num_bases`. (default: :obj:`8`)\n        num_bases (int, optional): Number of basis weights :math:`B` to use.\n            (default: :obj:`4`)\n        cached (bool, optional): If set to :obj:`True`, the layer will cache\n            the computation of the edge index with added self loops on first\n            execution, along with caching the calculation of the symmetric\n            normalized edge weights if the :obj:`\"symnorm\"` aggregator is\n            being used. This parameter should only be set to :obj:`True` in\n            transductive learning scenarios. (default: :obj:`False`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n\n    _cached_edge_index: Optional[Tuple[Tensor, OptTensor]]\n    _cached_adj_t: Optional[SparseTensor]\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        aggregators: Optional[List[str]] = None,\n        num_heads: int = 8,\n        num_bases: int = 4,\n        cached: bool = False,\n        add_self_loops: bool = True,\n        bias: bool = True,\n        **kwargs,\n    ):\n        super().__init__(node_dim=0, **kwargs)\n\n        if out_channels % num_heads != 0:\n            raise ValueError(f\"'out_channels' (got {out_channels}) must be \"\n                             f\"divisible by the number of heads \"\n                             f\"(got {num_heads})\")\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_heads = num_heads\n        self.num_bases = num_bases\n        self.cached = cached\n        self.add_self_loops = add_self_loops\n        self.aggregators = aggregators or ['symnorm']\n\n        for a in self.aggregators:\n            if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']:\n                raise ValueError(f\"Unsupported aggregator: '{a}'\")\n\n        self.bases_lin = Linear(in_channels,\n                                (out_channels // num_heads) * num_bases,\n                                bias=False, weight_initializer='glorot')\n        self.comb_lin = Linear(in_channels,\n                               num_heads * num_bases * len(self.aggregators))\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.bases_lin.reset_parameters()\n        self.comb_lin.reset_parameters()\n        zeros(self.bias)\n        self._cached_adj_t = None\n        self._cached_edge_index = None\n\n    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n        symnorm_weight: OptTensor = None\n        if \"symnorm\" in self.aggregators:\n            if isinstance(edge_index, Tensor):\n                cache = self._cached_edge_index\n                if cache is None:\n                    edge_index, symnorm_weight = gcn_norm(  # yapf: disable\n                        edge_index, None, num_nodes=x.size(self.node_dim),\n                        improved=False, add_self_loops=self.add_self_loops,\n                        flow=self.flow, dtype=x.dtype)\n                    if self.cached:\n                        self._cached_edge_index = (edge_index, symnorm_weight)\n                else:\n                    edge_index, symnorm_weight = cache\n\n            elif isinstance(edge_index, SparseTensor):\n                cache = self._cached_adj_t\n                if cache is None:\n                    edge_index = gcn_norm(  # yapf: disable\n                        edge_index, None, num_nodes=x.size(self.node_dim),\n                        improved=False, add_self_loops=self.add_self_loops,\n                        flow=self.flow, dtype=x.dtype)\n                    if self.cached:\n                        self._cached_adj_t = edge_index\n                else:\n                    edge_index = cache\n\n        elif self.add_self_loops:\n            if isinstance(edge_index, Tensor):\n                cache = self._cached_edge_index\n                if self.cached and cache is not None:\n                    edge_index = cache[0]\n                else:\n                    edge_index, _ = add_remaining_self_loops(edge_index)\n                    if self.cached:\n                        self._cached_edge_index = (edge_index, None)\n\n            elif isinstance(edge_index, SparseTensor):\n                cache = self._cached_adj_t\n                if self.cached and cache is not None:\n                    edge_index = cache\n                else:\n                    edge_index = torch_sparse.fill_diag(edge_index, 1.0)\n                    if self.cached:\n                        self._cached_adj_t = edge_index\n\n        # [num_nodes, (out_channels // num_heads) * num_bases]\n        bases = self.bases_lin(x)\n        # [num_nodes, num_heads * num_bases * num_aggrs]\n        weightings = self.comb_lin(x)\n\n        # [num_nodes, num_aggregators, (out_channels // num_heads) * num_bases]\n        # propagate_type: (x: Tensor, symnorm_weight: OptTensor)\n        aggregated = self.propagate(edge_index, x=bases,\n                                    symnorm_weight=symnorm_weight)\n\n        weightings = weightings.view(-1, self.num_heads,\n                                     self.num_bases * len(self.aggregators))\n        aggregated = aggregated.view(\n            -1,\n            len(self.aggregators) * self.num_bases,\n            self.out_channels // self.num_heads,\n        )\n\n        # [num_nodes, num_heads, out_channels // num_heads]\n        out = torch.matmul(weightings, aggregated)\n        out = out.view(-1, self.out_channels)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor) -> Tensor:\n        return x_j\n\n    def aggregate(self, inputs: Tensor, index: Tensor,\n                  dim_size: Optional[int] = None,\n                  symnorm_weight: OptTensor = None) -> Tensor:\n\n        outs = []\n        for aggr in self.aggregators:\n            if aggr == 'symnorm':\n                assert symnorm_weight is not None\n                out = scatter(inputs * symnorm_weight.view(-1, 1), index, 0,\n                              dim_size, reduce='sum')\n            elif aggr == 'var' or aggr == 'std':\n                mean = scatter(inputs, index, 0, dim_size, reduce='mean')\n                mean_squares = scatter(inputs * inputs, index, 0, dim_size,\n                                       reduce='mean')\n                out = mean_squares - mean * mean\n                if aggr == 'std':\n                    out = out.clamp(min=1e-5).sqrt()\n            else:\n                out = scatter(inputs, index, 0, dim_size, reduce=aggr)\n\n            outs.append(out)\n\n        return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0]\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        adj_t_2 = adj_t\n        if len(self.aggregators) > 1 and 'symnorm' in self.aggregators:\n            if isinstance(adj_t, SparseTensor):\n                adj_t_2 = adj_t.set_value(None)\n            else:\n                adj_t_2 = adj_t.clone()\n                adj_t_2.values().fill_(1.0)\n\n        outs = []\n        for aggr in self.aggregators:\n            if aggr == 'symnorm':\n                out = spmm(adj_t, x, reduce='sum')\n            elif aggr in ['var', 'std']:\n                mean = spmm(adj_t_2, x, reduce='mean')\n                mean_sq = spmm(adj_t_2, x * x, reduce='mean')\n                out = mean_sq - mean * mean\n                if aggr == 'std':\n                    out = torch.sqrt(out.relu_() + 1e-5)\n            else:\n                out = spmm(adj_t_2, x, reduce=aggr)\n\n            outs.append(out)\n\n        return torch.stack(outs, dim=1) if len(outs) > 1 else outs[0]\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, aggregators={self.aggregators})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/fa_conv.py",
    "content": "import typing\nfrom typing import Optional, Tuple, Union\n\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import PairTensor  # noqa\nfrom torch_geometric.typing import (\n    Adj,\n    NoneType,\n    OptPairTensor,\n    OptTensor,\n    SparseTensor,\n)\nfrom torch_geometric.utils import is_torch_sparse_tensor\nfrom torch_geometric.utils.sparse import set_sparse_value\n\nif typing.TYPE_CHECKING:\n    from typing import overload\nelse:\n    from torch.jit import _overload_method as overload\n\n\nclass FAConv(MessagePassing):\n    r\"\"\"The Frequency Adaptive Graph Convolution operator from the\n    `\"Beyond Low-Frequency Information in Graph Convolutional Networks\"\n    <https://arxiv.org/abs/2101.00797>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i= \\epsilon \\cdot \\mathbf{x}^{(0)}_i +\n        \\sum_{j \\in \\mathcal{N}(i)} \\frac{\\alpha_{i,j}}{\\sqrt{d_i d_j}}\n        \\mathbf{x}_{j}\n\n    where :math:`\\mathbf{x}^{(0)}_i` and :math:`d_i` denote the initial feature\n    representation and node degree of node :math:`i`, respectively.\n    The attention coefficients :math:`\\alpha_{i,j}` are computed as\n\n    .. math::\n        \\mathbf{\\alpha}_{i,j} = \\textrm{tanh}(\\mathbf{a}^{\\top}[\\mathbf{x}_i,\n        \\mathbf{x}_j])\n\n    based on the trainable parameter vector :math:`\\mathbf{a}`.\n\n    Args:\n        channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        eps (float, optional): :math:`\\epsilon`-value. (default: :obj:`0.1`)\n        dropout (float, optional): Dropout probability of the normalized\n            coefficients which exposes each node to a stochastically\n            sampled neighborhood during training. (default: :obj:`0`).\n        cached (bool, optional): If set to :obj:`True`, the layer will cache\n            the computation of :math:`\\sqrt{d_i d_j}` on first execution, and\n            will use the cached version for further executions.\n            This parameter should only be set to :obj:`True` in transductive\n            learning scenarios. (default: :obj:`False`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        normalize (bool, optional): Whether to add self-loops (if\n            :obj:`add_self_loops` is :obj:`True`) and compute\n            symmetric normalization coefficients on the fly.\n            If set to :obj:`False`, :obj:`edge_weight` needs to be provided in\n            the layer's :meth:`forward` method. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F)`,\n          initial node features :math:`(|\\mathcal{V}|, F)`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F)` or\n          :math:`((|\\mathcal{V}|, F), ((2, |\\mathcal{E}|),\n          (|\\mathcal{E}|)))` if :obj:`return_attention_weights=True`\n    \"\"\"\n    _cached_edge_index: Optional[OptPairTensor]\n    _cached_adj_t: Optional[SparseTensor]\n    _alpha: OptTensor\n\n    def __init__(self, channels: int, eps: float = 0.1, dropout: float = 0.0,\n                 cached: bool = False, add_self_loops: bool = True,\n                 normalize: bool = True, **kwargs):\n\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.channels = channels\n        self.eps = eps\n        self.dropout = dropout\n        self.cached = cached\n        self.add_self_loops = add_self_loops\n        self.normalize = normalize\n\n        self._cached_edge_index = None\n        self._cached_adj_t = None\n        self._alpha = None\n\n        self.att_l = Linear(channels, 1, bias=False)\n        self.att_r = Linear(channels, 1, bias=False)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.att_l.reset_parameters()\n        self.att_r.reset_parameters()\n        self._cached_edge_index = None\n        self._cached_adj_t = None\n\n    @overload\n    def forward(\n        self,\n        x: Tensor,\n        x_0: Tensor,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n        return_attention_weights: NoneType = None,\n    ) -> Tensor:\n        pass\n\n    @overload\n    def forward(  # noqa: F811\n        self,\n        x: Tensor,\n        x_0: Tensor,\n        edge_index: Tensor,\n        edge_weight: OptTensor = None,\n        return_attention_weights: bool = None,\n    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:\n        pass\n\n    @overload\n    def forward(  # noqa: F811\n        self,\n        x: Tensor,\n        x_0: Tensor,\n        edge_index: SparseTensor,\n        edge_weight: OptTensor = None,\n        return_attention_weights: bool = None,\n    ) -> Tuple[Tensor, SparseTensor]:\n        pass\n\n    def forward(  # noqa: F811\n        self,\n        x: Tensor,\n        x_0: Tensor,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n        return_attention_weights: Optional[bool] = None,\n    ) -> Union[\n            Tensor,\n            Tuple[Tensor, Tuple[Tensor, Tensor]],\n            Tuple[Tensor, SparseTensor],\n    ]:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor): The node features.\n            x_0 (torch.Tensor): The initial input node features.\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            edge_weight (torch.Tensor, optional): The edge weights.\n                (default: :obj:`None`)\n            return_attention_weights (bool, optional):\n                Will additionally return the tuple\n                :obj:`(edge_index, attention_weights)` whenever it is set to\n                a value, regardless of its actual value\n                (might be `True` or `False`), holding the computed attention\n                weights for each edge.\n                (default: :obj:`None`)\n        \"\"\"\n        if self.normalize:\n            if isinstance(edge_index, Tensor):\n                assert edge_weight is None\n                cache = self._cached_edge_index\n                if cache is None:\n                    edge_index, edge_weight = gcn_norm(  # yapf: disable\n                        edge_index, None, x.size(self.node_dim), False,\n                        self.add_self_loops, self.flow, dtype=x.dtype)\n                    if self.cached:\n                        self._cached_edge_index = (edge_index, edge_weight)\n                else:\n                    edge_index, edge_weight = cache[0], cache[1]\n\n            elif isinstance(edge_index, SparseTensor):\n                assert not edge_index.has_value()\n                cache = self._cached_adj_t\n                if cache is None:\n                    edge_index = gcn_norm(  # yapf: disable\n                        edge_index, None, x.size(self.node_dim), False,\n                        self.add_self_loops, self.flow, dtype=x.dtype)\n                    if self.cached:\n                        self._cached_adj_t = edge_index\n                else:\n                    edge_index = cache\n        else:\n            if isinstance(edge_index,\n                          Tensor) and not is_torch_sparse_tensor(edge_index):\n                assert edge_weight is not None\n            elif isinstance(edge_index, SparseTensor):\n                assert edge_index.has_value()\n\n        alpha_l = self.att_l(x)\n        alpha_r = self.att_r(x)\n\n        # propagate_type: (x: Tensor, alpha: PairTensor,\n        #                  edge_weight: OptTensor)\n        out = self.propagate(edge_index, x=x, alpha=(alpha_l, alpha_r),\n                             edge_weight=edge_weight)\n\n        alpha = self._alpha\n        self._alpha = None\n\n        if self.eps != 0.0:\n            out = out + self.eps * x_0\n\n        if isinstance(return_attention_weights, bool):\n            assert alpha is not None\n            if isinstance(edge_index, Tensor):\n                if is_torch_sparse_tensor(edge_index):\n                    # TODO TorchScript requires to return a tuple\n                    adj = set_sparse_value(edge_index, alpha)\n                    return out, (adj, alpha)\n                else:\n                    return out, (edge_index, alpha)\n            elif isinstance(edge_index, SparseTensor):\n                return out, edge_index.set_value(alpha, layout='coo')\n        else:\n            return out\n\n    def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: Tensor,\n                edge_weight: OptTensor) -> Tensor:\n        assert edge_weight is not None\n        alpha = (alpha_j + alpha_i).tanh().squeeze(-1)\n        self._alpha = alpha\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n        return x_j * (alpha * edge_weight).view(-1, 1)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.channels}, eps={self.eps})'\n"
  },
  {
    "path": "torch_geometric/nn/conv/feast_conv.py",
    "content": "from typing import Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import normal\nfrom torch_geometric.typing import Adj, PairTensor, SparseTensor, torch_sparse\nfrom torch_geometric.utils import add_self_loops, remove_self_loops\n\n\nclass FeaStConv(MessagePassing):\n    r\"\"\"The (translation-invariant) feature-steered convolutional operator from\n    the `\"FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis\"\n    <https://arxiv.org/abs/1706.05206>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\frac{1}{|\\mathcal{N}(i)|}\n        \\sum_{j \\in \\mathcal{N}(i)} \\sum_{h=1}^H\n        q_h(\\mathbf{x}_i, \\mathbf{x}_j) \\mathbf{W}_h \\mathbf{x}_j\n\n    with :math:`q_h(\\mathbf{x}_i, \\mathbf{x}_j) = \\mathrm{softmax}_j\n    (\\mathbf{u}_h^{\\top} (\\mathbf{x}_j - \\mathbf{x}_i) + c_h)`, where :math:`H`\n    denotes the number of attention heads, and :math:`\\mathbf{W}_h`,\n    :math:`\\mathbf{u}_h` and :math:`c_h` are trainable parameters.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        heads (int, optional): Number of attention heads :math:`H`.\n            (default: :obj:`1`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{in}), (|\\mathcal{V_t}|, F_{in}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V_t}|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int, heads: int = 1,\n                 add_self_loops: bool = True, bias: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'mean')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.add_self_loops = add_self_loops\n\n        self.lin = Linear(in_channels, heads * out_channels, bias=False,\n                          weight_initializer='uniform')\n        self.u = Linear(in_channels, heads, bias=False,\n                        weight_initializer='uniform')\n        self.c = Parameter(torch.empty(heads))\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin.reset_parameters()\n        self.u.reset_parameters()\n        normal(self.c, mean=0, std=0.1)\n        normal(self.bias, mean=0, std=0.1)\n\n    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        if self.add_self_loops:\n            if isinstance(edge_index, Tensor):\n                edge_index, _ = remove_self_loops(edge_index)\n                edge_index, _ = add_self_loops(edge_index,\n                                               num_nodes=x[1].size(0))\n            elif isinstance(edge_index, SparseTensor):\n                edge_index = torch_sparse.set_diag(edge_index)\n\n        # propagate_type: (x: PairTensor)\n        out = self.propagate(edge_index, x=x)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_i: Tensor, x_j: Tensor) -> Tensor:\n        q = self.u(x_j - x_i) + self.c  # Translation invariance.\n        q = F.softmax(q, dim=1)\n        x_j = self.lin(x_j).view(x_j.size(0), self.heads, -1)\n        return (x_j * q.view(-1, self.heads, 1)).sum(dim=1)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/film_conv.py",
    "content": "import copy\nfrom typing import Callable, Optional, Tuple, Union\n\nfrom torch import Tensor\nfrom torch.nn import ModuleList, ReLU\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.typing import (\n    Adj,\n    OptTensor,\n    PairTensor,\n    SparseTensor,\n    torch_sparse,\n)\n\n\nclass FiLMConv(MessagePassing):\n    r\"\"\"The FiLM graph convolutional operator from the\n    `\"GNN-FiLM: Graph Neural Networks with Feature-wise Linear Modulation\"\n    <https://arxiv.org/abs/1906.12192>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\sum_{r \\in \\mathcal{R}}\n        \\sum_{j \\in \\mathcal{N}(i)} \\sigma \\left(\n        \\boldsymbol{\\gamma}_{r,i} \\odot \\mathbf{W}_r \\mathbf{x}_j +\n        \\boldsymbol{\\beta}_{r,i} \\right)\n\n    where :math:`\\boldsymbol{\\beta}_{r,i}, \\boldsymbol{\\gamma}_{r,i} =\n    g(\\mathbf{x}_i)` with :math:`g` being a single linear layer by default.\n    Self-loops are automatically added to the input graph and represented as\n    its own relation type.\n\n    .. note::\n\n        For an example of using FiLM, see `examples/gcn.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        film.py>`_.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        num_relations (int, optional): Number of relations. (default: :obj:`1`)\n        nn (torch.nn.Module, optional): The neural network :math:`g` that\n            maps node features :obj:`x_i` of shape\n            :obj:`[-1, in_channels]` to shape :obj:`[-1, 2 * out_channels]`.\n            If set to :obj:`None`, :math:`g` will be implemented as a single\n            linear layer. (default: :obj:`None`)\n        act (callable, optional): Activation function :math:`\\sigma`.\n            (default: :meth:`torch.nn.ReLU()`)\n        aggr (str, optional): The aggregation scheme to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"mean\"`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge types :math:`(|\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V_t}|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(\n            self,\n            in_channels: Union[int, Tuple[int, int]],\n            out_channels: int,\n            num_relations: int = 1,\n            nn: Optional[Callable] = None,\n            act: Optional[Callable] = ReLU(),\n            aggr: str = 'mean',\n            **kwargs,\n    ):\n        super().__init__(aggr=aggr, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_relations = max(num_relations, 1)\n        self.act = act\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        self.lins = ModuleList()\n        self.films = ModuleList()\n        for _ in range(num_relations):\n            self.lins.append(Linear(in_channels[0], out_channels, bias=False))\n            if nn is None:\n                film = Linear(in_channels[1], 2 * out_channels)\n            else:\n                film = copy.deepcopy(nn)\n            self.films.append(film)\n\n        self.lin_skip = Linear(in_channels[1], self.out_channels, bias=False)\n        if nn is None:\n            self.film_skip = Linear(in_channels[1], 2 * self.out_channels,\n                                    bias=False)\n        else:\n            self.film_skip = copy.deepcopy(nn)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        for lin, film in zip(self.lins, self.films):\n            lin.reset_parameters()\n            reset(film)\n        self.lin_skip.reset_parameters()\n        reset(self.film_skip)\n\n    def forward(\n        self,\n        x: Union[Tensor, PairTensor],\n        edge_index: Adj,\n        edge_type: OptTensor = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        beta, gamma = self.film_skip(x[1]).split(self.out_channels, dim=-1)\n        out = gamma * self.lin_skip(x[1]) + beta\n        if self.act is not None:\n            out = self.act(out)\n\n        # propagate_type: (x: Tensor, beta: Tensor, gamma: Tensor)\n        if self.num_relations <= 1:\n            beta, gamma = self.films[0](x[1]).split(self.out_channels, dim=-1)\n            out = out + self.propagate(edge_index, x=self.lins[0](x[0]),\n                                       beta=beta, gamma=gamma)\n        else:\n            for i, (lin, film) in enumerate(zip(self.lins, self.films)):\n                beta, gamma = film(x[1]).split(self.out_channels, dim=-1)\n                if isinstance(edge_index, SparseTensor):\n                    _edge_type = edge_index.storage.value()\n                    assert _edge_type is not None\n                    mask = _edge_type == i\n                    adj_t = torch_sparse.masked_select_nnz(\n                        edge_index, mask, layout='coo')\n                    out = out + self.propagate(adj_t, x=lin(x[0]), beta=beta,\n                                               gamma=gamma)\n                else:\n                    assert edge_type is not None\n                    mask = edge_type == i\n                    out = out + self.propagate(edge_index[:, mask], x=lin(\n                        x[0]), beta=beta, gamma=gamma)\n\n        return out\n\n    def message(self, x_j: Tensor, beta_i: Tensor, gamma_i: Tensor) -> Tensor:\n        out = gamma_i * x_j + beta_i\n        if self.act is not None:\n            out = self.act(out)\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, num_relations={self.num_relations})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/fused_gat_conv.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.index import index2ptr\nfrom torch_geometric.nn.conv import GATConv\nfrom torch_geometric.utils import sort_edge_index\n\n\nclass FusedGATConv(GATConv):  # pragma: no cover\n    r\"\"\"The fused graph attention operator from the\n    `\"Understanding GNN Computational Graph: A Coordinated Computation, IO, and\n    Memory Perspective\"\n    <https://proceedings.mlsys.org/paper/2022/file/\n    9a1158154dfa42caddbd0694a4e9bdc8-Paper.pdf>`_ paper.\n\n    :class:`FusedGATConv` is an optimized version of\n    :class:`~torch_geometric.nn.conv.GATConv` based on the :obj:`dgNN` package\n    that fuses message passing computation for accelerated execution and lower\n    memory footprint.\n\n    .. note::\n\n        This implementation is based on the :obj:`dgNN` package.\n        See `here <https://github.com/dgSPARSE/dgNN>`__ for instructions on how\n        to install.\n    \"\"\"\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n        if self.add_self_loops:\n            raise ValueError(f\"'{self.__class__.__name__}' does not support \"\n                             f\"adding self-loops. Please add them manually \"\n                             f\"in a pre-processing step and set \"\n                             f\"`add_self_loops=False`.\")\n\n        if self.edge_dim is not None:\n            raise ValueError(f\"'{self.__class__.__name__}' does not support \"\n                             f\"edge features. Set `edge_dim=None` in order \"\n                             f\"to proceed.\")\n\n        from dgNN.operators import GATConvFuse\n        self.op = GATConvFuse\n\n    @staticmethod\n    def to_graph_format(\n        edge_index: Tensor,\n        size: Optional[Tuple[int, int]] = None,\n    ) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tensor]:\n        r\"\"\"Converts an :obj:`edge_index` representation of a graph to the\n        desired input format of :class:`FusedGATConv`.\n\n        Args:\n            edge_index (torch.Tensor): The edge indices.\n            size ((int, int), optional): The shape of :obj:`edge_index` in each\n                dimension. (default: :obj:`None`)\n        \"\"\"\n        edge_index = edge_index.to(torch.int)\n\n        edge_index = sort_edge_index(edge_index, sort_by_row=True)\n        rowptr = index2ptr(edge_index[0], size=size[0] if size else None)\n        col = edge_index[1]\n\n        device = edge_index.device\n        perm = torch.arange(edge_index.size(1), dtype=torch.int, device=device)\n        edge_index, perm = sort_edge_index(edge_index, perm, sort_by_row=False)\n        row = edge_index[0]\n        colptr = index2ptr(edge_index[1], size=size[1] if size else None)\n\n        return (rowptr, col), (row, colptr), perm\n\n    def forward(\n        self,\n        x: Tensor,\n        csr: Tuple[Tensor, Tensor],\n        csc: Tuple[Tensor, Tensor],\n        perm: Tensor,\n    ) -> Tensor:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor): The node features.\n            csr ((torch.Tensor, torch.Tensor)): A tuple containing the CSR\n                representation of a graph, given as a tuple of\n                :obj:`(rowptr, col)`.\n            csc ((torch.Tensor, torch.Tensor)): A tuple containing the CSC\n                representation of a graph, given as a tuple of\n                :obj:`(row, colptr)`.\n            perm (torch.Tensor): Permutation tensor to map the CSR\n                representation to the CSC representation.\n\n        .. note::\n\n            Use the\n            :meth:`~torch_geometric.nn.conv.FusedGATConv.to_graph_format`\n            method to obtain the :obj:`(csr, csc, perm)` graph format from an\n            existing :obj:`edge_index` representation.\n        \"\"\"\n        H, C = self.heads, self.out_channels\n\n        assert x.dim() == 2, \"Static graphs not supported in 'GATConv'\"\n        x = self.lin_src(x).view(-1, H, C)\n\n        alpha_src = (x * self.att_src).sum(dim=-1)\n        alpha_dst = (x * self.att_dst).sum(dim=-1)\n\n        dropout = self.dropout if self.training else 0.0\n\n        (rowptr, col), (row, colptr) = csr, csc\n        out = self.op(alpha_dst, alpha_src, rowptr, col, colptr, row, perm,\n                      self.negative_slope, x, dropout)\n\n        if self.concat:\n            out = out.view(-1, self.heads * self.out_channels)\n        else:\n            out = out.mean(dim=1)\n\n        if self.bias is not None:\n            out += self.bias\n\n        return out\n"
  },
  {
    "path": "torch_geometric/nn/conv/gat_conv.py",
    "content": "import typing\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import glorot, zeros\nfrom torch_geometric.typing import (\n    Adj,\n    NoneType,\n    OptPairTensor,\n    OptTensor,\n    Size,\n    SparseTensor,\n    torch_sparse,\n)\nfrom torch_geometric.utils import (\n    add_self_loops,\n    is_torch_sparse_tensor,\n    remove_self_loops,\n    softmax,\n)\nfrom torch_geometric.utils.sparse import set_sparse_value\n\nif typing.TYPE_CHECKING:\n    from typing import overload\nelse:\n    from torch.jit import _overload_method as overload\n\n\nclass GATConv(MessagePassing):\n    r\"\"\"The graph attentional operator from the `\"Graph Attention Networks\"\n    <https://arxiv.org/abs/1710.10903>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\sum_{j \\in \\mathcal{N}(i) \\cup \\{ i \\}}\n        \\alpha_{i,j}\\mathbf{\\Theta}_t\\mathbf{x}_{j},\n\n    where the attention coefficients :math:`\\alpha_{i,j}` are computed as\n\n    .. math::\n        \\alpha_{i,j} =\n        \\frac{\n        \\exp\\left(\\mathrm{LeakyReLU}\\left(\n        \\mathbf{a}^{\\top}_{s} \\mathbf{\\Theta}_{s}\\mathbf{x}_i\n        + \\mathbf{a}^{\\top}_{t} \\mathbf{\\Theta}_{t}\\mathbf{x}_j\n        \\right)\\right)}\n        {\\sum_{k \\in \\mathcal{N}(i) \\cup \\{ i \\}}\n        \\exp\\left(\\mathrm{LeakyReLU}\\left(\n        \\mathbf{a}^{\\top}_{s} \\mathbf{\\Theta}_{s}\\mathbf{x}_i\n        + \\mathbf{a}^{\\top}_{t}\\mathbf{\\Theta}_{t}\\mathbf{x}_k\n        \\right)\\right)}.\n\n    If the graph has multi-dimensional edge features :math:`\\mathbf{e}_{i,j}`,\n    the attention coefficients :math:`\\alpha_{i,j}` are computed as\n\n    .. math::\n        \\alpha_{i,j} =\n        \\frac{\n        \\exp\\left(\\mathrm{LeakyReLU}\\left(\n        \\mathbf{a}^{\\top}_{s} \\mathbf{\\Theta}_{s}\\mathbf{x}_i\n        + \\mathbf{a}^{\\top}_{t} \\mathbf{\\Theta}_{t}\\mathbf{x}_j\n        + \\mathbf{a}^{\\top}_{e} \\mathbf{\\Theta}_{e} \\mathbf{e}_{i,j}\n        \\right)\\right)}\n        {\\sum_{k \\in \\mathcal{N}(i) \\cup \\{ i \\}}\n        \\exp\\left(\\mathrm{LeakyReLU}\\left(\n        \\mathbf{a}^{\\top}_{s} \\mathbf{\\Theta}_{s}\\mathbf{x}_i\n        + \\mathbf{a}^{\\top}_{t} \\mathbf{\\Theta}_{t}\\mathbf{x}_k\n        + \\mathbf{a}^{\\top}_{e} \\mathbf{\\Theta}_{e} \\mathbf{e}_{i,k}\n        \\right)\\right)}.\n\n    If the graph is not bipartite, :math:`\\mathbf{\\Theta}_{s} =\n    \\mathbf{\\Theta}_{t}`.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities in case of a bipartite graph.\n        out_channels (int): Size of each output sample.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        concat (bool, optional): If set to :obj:`False`, the multi-head\n            attentions are averaged instead of concatenated.\n            (default: :obj:`True`)\n        negative_slope (float, optional): LeakyReLU angle of the negative\n            slope. (default: :obj:`0.2`)\n        dropout (float, optional): Dropout probability of the normalized\n            attention coefficients which exposes each node to a stochastically\n            sampled neighborhood during training. (default: :obj:`0`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        edge_dim (int, optional): Edge feature dimensionality (in case\n            there are any). (default: :obj:`None`)\n        fill_value (float or torch.Tensor or str, optional): The way to\n            generate edge features of self-loops (in case\n            :obj:`edge_dim != None`).\n            If given as :obj:`float` or :class:`torch.Tensor`, edge features of\n            self-loops will be directly given by :obj:`fill_value`.\n            If given as :obj:`str`, edge features of self-loops are computed by\n            aggregating all features of edges that point to the specific node,\n            according to a reduce operation. (:obj:`\"add\"`, :obj:`\"mean\"`,\n            :obj:`\"min\"`, :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`\"mean\"`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        residual (bool, optional): If set to :obj:`True`, the layer will add\n            a learnable skip-connection. (default: :obj:`False`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, H * F_{out})` or\n          :math:`((|\\mathcal{V}_t|, H * F_{out})` if bipartite.\n          If :obj:`return_attention_weights=True`, then\n          :math:`((|\\mathcal{V}|, H * F_{out}),\n          ((2, |\\mathcal{E}|), (|\\mathcal{E}|, H)))`\n          or :math:`((|\\mathcal{V_t}|, H * F_{out}), ((2, |\\mathcal{E}|),\n          (|\\mathcal{E}|, H)))` if bipartite\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: int,\n        heads: int = 1,\n        concat: bool = True,\n        negative_slope: float = 0.2,\n        dropout: float = 0.0,\n        add_self_loops: bool = True,\n        edge_dim: Optional[int] = None,\n        fill_value: Union[float, Tensor, str] = 'mean',\n        bias: bool = True,\n        residual: bool = False,\n        **kwargs,\n    ):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(node_dim=0, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.concat = concat\n        self.negative_slope = negative_slope\n        self.dropout = dropout\n        self.add_self_loops = add_self_loops\n        self.edge_dim = edge_dim\n        self.fill_value = fill_value\n        self.residual = residual\n\n        # In case we are operating in bipartite graphs, we apply separate\n        # transformations 'lin_src' and 'lin_dst' to source and target nodes:\n        self.lin = self.lin_src = self.lin_dst = None\n        if isinstance(in_channels, int):\n            self.lin = Linear(in_channels, heads * out_channels, bias=False,\n                              weight_initializer='glorot')\n        else:\n            self.lin_src = Linear(in_channels[0], heads * out_channels, False,\n                                  weight_initializer='glorot')\n            self.lin_dst = Linear(in_channels[1], heads * out_channels, False,\n                                  weight_initializer='glorot')\n\n        # The learnable parameters to compute attention coefficients:\n        self.att_src = Parameter(torch.empty(1, heads, out_channels))\n        self.att_dst = Parameter(torch.empty(1, heads, out_channels))\n\n        if edge_dim is not None:\n            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False,\n                                   weight_initializer='glorot')\n            self.att_edge = Parameter(torch.empty(1, heads, out_channels))\n        else:\n            self.lin_edge = None\n            self.register_parameter('att_edge', None)\n\n        # The number of output channels:\n        total_out_channels = out_channels * (heads if concat else 1)\n\n        if residual:\n            self.res = Linear(\n                in_channels\n                if isinstance(in_channels, int) else in_channels[1],\n                total_out_channels,\n                bias=False,\n                weight_initializer='glorot',\n            )\n        else:\n            self.register_parameter('res', None)\n\n        if bias:\n            self.bias = Parameter(torch.empty(total_out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        if self.lin is not None:\n            self.lin.reset_parameters()\n        if self.lin_src is not None:\n            self.lin_src.reset_parameters()\n        if self.lin_dst is not None:\n            self.lin_dst.reset_parameters()\n        if self.lin_edge is not None:\n            self.lin_edge.reset_parameters()\n        if self.res is not None:\n            self.res.reset_parameters()\n        glorot(self.att_src)\n        glorot(self.att_dst)\n        glorot(self.att_edge)\n        zeros(self.bias)\n\n    @overload\n    def forward(\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        edge_attr: OptTensor = None,\n        size: Size = None,\n        return_attention_weights: NoneType = None,\n    ) -> Tensor:\n        pass\n\n    @overload\n    def forward(  # noqa: F811\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Tensor,\n        edge_attr: OptTensor = None,\n        size: Size = None,\n        return_attention_weights: bool = None,\n    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:\n        pass\n\n    @overload\n    def forward(  # noqa: F811\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: SparseTensor,\n        edge_attr: OptTensor = None,\n        size: Size = None,\n        return_attention_weights: bool = None,\n    ) -> Tuple[Tensor, SparseTensor]:\n        pass\n\n    def forward(  # noqa: F811\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        edge_attr: OptTensor = None,\n        size: Size = None,\n        return_attention_weights: Optional[bool] = None,\n    ) -> Union[\n            Tensor,\n            Tuple[Tensor, Tuple[Tensor, Tensor]],\n            Tuple[Tensor, SparseTensor],\n    ]:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node\n                features.\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            edge_attr (torch.Tensor, optional): The edge features.\n                (default: :obj:`None`)\n            size ((int, int), optional): The shape of the adjacency matrix.\n                (default: :obj:`None`)\n            return_attention_weights (bool, optional):\n                Will additionally return the tuple\n                :obj:`(edge_index, attention_weights)` whenever it is set to\n                a value, regardless of its actual value\n                (might be `True` or `False`), holding the computed attention\n                weights for each edge.\n                (default: :obj:`None`)\n        \"\"\"\n        H, C = self.heads, self.out_channels\n\n        res: Optional[Tensor] = None\n\n        # We first transform the input node features. If a tuple is passed, we\n        # transform source and target node features via separate weights:\n        if isinstance(x, Tensor):\n            assert x.dim() == 2, \"Static graphs not supported in 'GATConv'\"\n\n            if self.res is not None:\n                res = self.res(x)\n\n            if self.lin is not None:\n                x_src = x_dst = self.lin(x).view(-1, H, C)\n            else:\n                # If the module is initialized as bipartite, transform source\n                # and destination node features separately:\n                assert self.lin_src is not None and self.lin_dst is not None\n                x_src = self.lin_src(x).view(-1, H, C)\n                x_dst = self.lin_dst(x).view(-1, H, C)\n\n        else:  # Tuple of source and target node features:\n            x_src, x_dst = x\n            assert x_src.dim() == 2, \"Static graphs not supported in 'GATConv'\"\n\n            if x_dst is not None and self.res is not None:\n                res = self.res(x_dst)\n\n            if self.lin is not None:\n                # If the module is initialized as non-bipartite, we expect that\n                # source and destination node features have the same shape and\n                # that they their transformations are shared:\n                x_src = self.lin(x_src).view(-1, H, C)\n                if x_dst is not None:\n                    x_dst = self.lin(x_dst).view(-1, H, C)\n            else:\n                assert self.lin_src is not None and self.lin_dst is not None\n\n                x_src = self.lin_src(x_src).view(-1, H, C)\n                if x_dst is not None:\n                    x_dst = self.lin_dst(x_dst).view(-1, H, C)\n\n        x = (x_src, x_dst)\n\n        # Next, we compute node-level attention coefficients, both for source\n        # and target nodes (if present):\n        alpha_src = (x_src * self.att_src).sum(dim=-1)\n        alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1)\n        alpha = (alpha_src, alpha_dst)\n\n        if self.add_self_loops:\n            if isinstance(edge_index, Tensor):\n                # We only want to add self-loops for nodes that appear both as\n                # source and target nodes:\n                num_nodes = x_src.size(0)\n                if x_dst is not None:\n                    num_nodes = min(num_nodes, x_dst.size(0))\n                num_nodes = min(size) if size is not None else num_nodes\n                edge_index, edge_attr = remove_self_loops(\n                    edge_index, edge_attr)\n                edge_index, edge_attr = add_self_loops(\n                    edge_index, edge_attr, fill_value=self.fill_value,\n                    num_nodes=num_nodes)\n            elif isinstance(edge_index, SparseTensor):\n                if self.edge_dim is None:\n                    edge_index = torch_sparse.set_diag(edge_index)\n                else:\n                    raise NotImplementedError(\n                        \"The usage of 'edge_attr' and 'add_self_loops' \"\n                        \"simultaneously is currently not yet supported for \"\n                        \"'edge_index' in a 'SparseTensor' form\")\n\n        # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor)\n        alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr,\n                                  size=size)\n\n        # propagate_type: (x: OptPairTensor, alpha: Tensor)\n        out = self.propagate(edge_index, x=x, alpha=alpha, size=size)\n\n        if self.concat:\n            out = out.view(-1, self.heads * self.out_channels)\n        else:\n            out = out.mean(dim=1)\n\n        if res is not None:\n            out = out + res\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        if return_attention_weights:\n            if isinstance(edge_index, Tensor):\n                if is_torch_sparse_tensor(edge_index):\n                    # TODO TorchScript requires to return a tuple\n                    adj = set_sparse_value(edge_index, alpha)\n                    return out, (adj, alpha)\n                else:\n                    return out, (edge_index, alpha)\n            elif isinstance(edge_index, SparseTensor):\n                return out, edge_index.set_value(alpha, layout='coo')\n\n        return out\n\n    def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor,\n                    edge_attr: OptTensor, index: Tensor, ptr: OptTensor,\n                    dim_size: Optional[int]) -> Tensor:\n        # Given edge-level attention coefficients for source and target nodes,\n        # we simply need to sum them up to \"emulate\" concatenation:\n        alpha = alpha_j if alpha_i is None else alpha_j + alpha_i\n        if index.numel() == 0:\n            return alpha\n        if edge_attr is not None and self.lin_edge is not None:\n            if edge_attr.dim() == 1:\n                edge_attr = edge_attr.view(-1, 1)\n            edge_attr = self.lin_edge(edge_attr)\n            edge_attr = edge_attr.view(-1, self.heads, self.out_channels)\n            alpha_edge = (edge_attr * self.att_edge).sum(dim=-1)\n            alpha = alpha + alpha_edge\n\n        alpha = F.leaky_relu(alpha, self.negative_slope)\n        alpha = softmax(alpha, index, ptr, dim_size)\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n        return alpha\n\n    def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:\n        return alpha.unsqueeze(-1) * x_j\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/gated_graph_conv.py",
    "content": "import torch\nfrom torch import Tensor\nfrom torch.nn import Parameter as Param\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.inits import uniform\nfrom torch_geometric.typing import Adj, OptTensor\nfrom torch_geometric.utils import spmm\n\n\nclass GatedGraphConv(MessagePassing):\n    r\"\"\"The gated graph convolution operator from the `\"Gated Graph Sequence\n    Neural Networks\" <https://arxiv.org/abs/1511.05493>`_ paper.\n\n    .. math::\n        \\mathbf{h}_i^{(0)} &= \\mathbf{x}_i \\, \\Vert \\, \\mathbf{0}\n\n        \\mathbf{m}_i^{(l+1)} &= \\sum_{j \\in \\mathcal{N}(i)} e_{j,i} \\cdot\n        \\mathbf{\\Theta} \\cdot \\mathbf{h}_j^{(l)}\n\n        \\mathbf{h}_i^{(l+1)} &= \\textrm{GRU} (\\mathbf{m}_i^{(l+1)},\n        \\mathbf{h}_i^{(l)})\n\n    up to representation :math:`\\mathbf{h}_i^{(L)}`.\n    The number of input channels of :math:`\\mathbf{x}_i` needs to be less or\n    equal than :obj:`out_channels`.\n    :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target\n    node :obj:`i` (default: :obj:`1`)\n\n    Args:\n        out_channels (int): Size of each output sample.\n        num_layers (int): The sequence length :math:`L`.\n        aggr (str, optional): The aggregation scheme to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"add\"`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n\n    \"\"\"\n    def __init__(self, out_channels: int, num_layers: int, aggr: str = 'add',\n                 bias: bool = True, **kwargs):\n        super().__init__(aggr=aggr, **kwargs)\n\n        self.out_channels = out_channels\n        self.num_layers = num_layers\n\n        self.weight = Param(Tensor(num_layers, out_channels, out_channels))\n        self.rnn = torch.nn.GRUCell(out_channels, out_channels, bias=bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        uniform(self.out_channels, self.weight)\n        self.rnn.reset_parameters()\n\n    def forward(self, x: Tensor, edge_index: Adj,\n                edge_weight: OptTensor = None) -> Tensor:\n\n        if x.size(-1) > self.out_channels:\n            raise ValueError('The number of input channels is not allowed to '\n                             'be larger than the number of output channels')\n\n        if x.size(-1) < self.out_channels:\n            zero = x.new_zeros(x.size(0), self.out_channels - x.size(-1))\n            x = torch.cat([x, zero], dim=1)\n\n        for i in range(self.num_layers):\n            m = torch.matmul(x, self.weight[i])\n            # propagate_type: (x: Tensor, edge_weight: OptTensor)\n            m = self.propagate(edge_index, x=m, edge_weight=edge_weight)\n            x = self.rnn(m, x)\n\n        return x\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor):\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.out_channels}, '\n                f'num_layers={self.num_layers})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/gatv2_conv.py",
    "content": "import typing\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import glorot, zeros\nfrom torch_geometric.typing import (\n    Adj,\n    NoneType,\n    OptTensor,\n    PairTensor,\n    SparseTensor,\n    torch_sparse,\n)\nfrom torch_geometric.utils import (\n    add_self_loops,\n    is_torch_sparse_tensor,\n    remove_self_loops,\n    softmax,\n)\nfrom torch_geometric.utils.sparse import set_sparse_value\n\nif typing.TYPE_CHECKING:\n    from typing import overload\nelse:\n    from torch.jit import _overload_method as overload\n\n\nclass GATv2Conv(MessagePassing):\n    r\"\"\"The GATv2 operator from the `\"How Attentive are Graph Attention\n    Networks?\" <https://arxiv.org/abs/2105.14491>`_ paper, which fixes the\n    static attention problem of the standard\n    :class:`~torch_geometric.conv.GATConv` layer.\n    Since the linear layers in the standard GAT are applied right after each\n    other, the ranking of attended nodes is unconditioned on the query node.\n    In contrast, in :class:`GATv2`, every node can attend to any other node.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\sum_{j \\in \\mathcal{N}(i) \\cup \\{ i \\}}\n        \\alpha_{i,j}\\mathbf{\\Theta}_{t}\\mathbf{x}_{j},\n\n    where the attention coefficients :math:`\\alpha_{i,j}` are computed as\n\n    .. math::\n        \\alpha_{i,j} =\n        \\frac{\n        \\exp\\left(\\mathbf{a}^{\\top}\\mathrm{LeakyReLU}\\left(\n        \\mathbf{\\Theta}_{s} \\mathbf{x}_i + \\mathbf{\\Theta}_{t} \\mathbf{x}_j\n        \\right)\\right)}\n        {\\sum_{k \\in \\mathcal{N}(i) \\cup \\{ i \\}}\n        \\exp\\left(\\mathbf{a}^{\\top}\\mathrm{LeakyReLU}\\left(\n        \\mathbf{\\Theta}_{s} \\mathbf{x}_i + \\mathbf{\\Theta}_{t} \\mathbf{x}_k\n        \\right)\\right)}.\n\n    If the graph has multi-dimensional edge features :math:`\\mathbf{e}_{i,j}`,\n    the attention coefficients :math:`\\alpha_{i,j}` are computed as\n\n    .. math::\n        \\alpha_{i,j} =\n        \\frac{\n        \\exp\\left(\\mathbf{a}^{\\top}\\mathrm{LeakyReLU}\\left(\n        \\mathbf{\\Theta}_{s} \\mathbf{x}_i\n        + \\mathbf{\\Theta}_{t} \\mathbf{x}_j\n        + \\mathbf{\\Theta}_{e} \\mathbf{e}_{i,j}\n        \\right)\\right)}\n        {\\sum_{k \\in \\mathcal{N}(i) \\cup \\{ i \\}}\n        \\exp\\left(\\mathbf{a}^{\\top}\\mathrm{LeakyReLU}\\left(\n        \\mathbf{\\Theta}_{s} \\mathbf{x}_i\n        + \\mathbf{\\Theta}_{t} \\mathbf{x}_k\n        + \\mathbf{\\Theta}_{e} \\mathbf{e}_{i,k}]\n        \\right)\\right)}.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities in case of a bipartite graph.\n        out_channels (int): Size of each output sample.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        concat (bool, optional): If set to :obj:`False`, the multi-head\n            attentions are averaged instead of concatenated.\n            (default: :obj:`True`)\n        negative_slope (float, optional): LeakyReLU angle of the negative\n            slope. (default: :obj:`0.2`)\n        dropout (float, optional): Dropout probability of the normalized\n            attention coefficients which exposes each node to a stochastically\n            sampled neighborhood during training. (default: :obj:`0`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        edge_dim (int, optional): Edge feature dimensionality (in case\n            there are any). (default: :obj:`None`)\n        fill_value (float or torch.Tensor or str, optional): The way to\n            generate edge features of self-loops\n            (in case :obj:`edge_dim != None`).\n            If given as :obj:`float` or :class:`torch.Tensor`, edge features of\n            self-loops will be directly given by :obj:`fill_value`.\n            If given as :obj:`str`, edge features of self-loops are computed by\n            aggregating all features of edges that point to the specific node,\n            according to a reduce operation. (:obj:`\"add\"`, :obj:`\"mean\"`,\n            :obj:`\"min\"`, :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`\"mean\"`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        share_weights (bool, optional): If set to :obj:`True`, the same matrix\n            will be applied to the source and the target node of every edge,\n            *i.e.* :math:`\\mathbf{\\Theta}_{s} = \\mathbf{\\Theta}_{t}`.\n            (default: :obj:`False`)\n        residual (bool, optional): If set to :obj:`True`, the layer will add\n            a learnable skip-connection. (default: :obj:`False`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, H * F_{out})` or\n          :math:`((|\\mathcal{V}_t|, H * F_{out})` if bipartite.\n          If :obj:`return_attention_weights=True`, then\n          :math:`((|\\mathcal{V}|, H * F_{out}),\n          ((2, |\\mathcal{E}|), (|\\mathcal{E}|, H)))`\n          or :math:`((|\\mathcal{V_t}|, H * F_{out}), ((2, |\\mathcal{E}|),\n          (|\\mathcal{E}|, H)))` if bipartite\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: int,\n        heads: int = 1,\n        concat: bool = True,\n        negative_slope: float = 0.2,\n        dropout: float = 0.0,\n        add_self_loops: bool = True,\n        edge_dim: Optional[int] = None,\n        fill_value: Union[float, Tensor, str] = 'mean',\n        bias: bool = True,\n        share_weights: bool = False,\n        residual: bool = False,\n        **kwargs,\n    ):\n        super().__init__(node_dim=0, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.concat = concat\n        self.negative_slope = negative_slope\n        self.dropout = dropout\n        self.add_self_loops = add_self_loops\n        self.edge_dim = edge_dim\n        self.fill_value = fill_value\n        self.residual = residual\n        self.share_weights = share_weights\n\n        if isinstance(in_channels, int):\n            self.lin_l = Linear(in_channels, heads * out_channels, bias=bias,\n                                weight_initializer='glorot')\n            if share_weights:\n                self.lin_r = self.lin_l\n            else:\n                self.lin_r = Linear(in_channels, heads * out_channels,\n                                    bias=bias, weight_initializer='glorot')\n        else:\n            self.lin_l = Linear(in_channels[0], heads * out_channels,\n                                bias=bias, weight_initializer='glorot')\n            if share_weights:\n                self.lin_r = self.lin_l\n            else:\n                self.lin_r = Linear(in_channels[1], heads * out_channels,\n                                    bias=bias, weight_initializer='glorot')\n\n        self.att = Parameter(torch.empty(1, heads, out_channels))\n\n        if edge_dim is not None:\n            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False,\n                                   weight_initializer='glorot')\n        else:\n            self.lin_edge = None\n\n        # The number of output channels:\n        total_out_channels = out_channels * (heads if concat else 1)\n\n        if residual:\n            self.res = Linear(\n                in_channels\n                if isinstance(in_channels, int) else in_channels[1],\n                total_out_channels,\n                bias=False,\n                weight_initializer='glorot',\n            )\n        else:\n            self.register_parameter('res', None)\n\n        if bias:\n            self.bias = Parameter(torch.empty(total_out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin_l.reset_parameters()\n        self.lin_r.reset_parameters()\n        if self.lin_edge is not None:\n            self.lin_edge.reset_parameters()\n        if self.res is not None:\n            self.res.reset_parameters()\n        glorot(self.att)\n        zeros(self.bias)\n\n    @overload\n    def forward(\n        self,\n        x: Union[Tensor, PairTensor],\n        edge_index: Adj,\n        edge_attr: OptTensor = None,\n        return_attention_weights: NoneType = None,\n    ) -> Tensor:\n        pass\n\n    @overload\n    def forward(  # noqa: F811\n        self,\n        x: Union[Tensor, PairTensor],\n        edge_index: Tensor,\n        edge_attr: OptTensor = None,\n        return_attention_weights: bool = None,\n    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:\n        pass\n\n    @overload\n    def forward(  # noqa: F811\n        self,\n        x: Union[Tensor, PairTensor],\n        edge_index: SparseTensor,\n        edge_attr: OptTensor = None,\n        return_attention_weights: bool = None,\n    ) -> Tuple[Tensor, SparseTensor]:\n        pass\n\n    def forward(  # noqa: F811\n        self,\n        x: Union[Tensor, PairTensor],\n        edge_index: Adj,\n        edge_attr: OptTensor = None,\n        return_attention_weights: Optional[bool] = None,\n    ) -> Union[\n            Tensor,\n            Tuple[Tensor, Tuple[Tensor, Tensor]],\n            Tuple[Tensor, SparseTensor],\n    ]:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node\n                features.\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            edge_attr (torch.Tensor, optional): The edge features.\n                (default: :obj:`None`)\n            return_attention_weights (bool, optional):\n                Will additionally return the tuple\n                :obj:`(edge_index, attention_weights)` whenever it is set to\n                a value, regardless of its actual value\n                (might be `True` or `False`), holding the computed attention\n                weights for each edge.\n                (default: :obj:`None`)\n        \"\"\"\n        H, C = self.heads, self.out_channels\n\n        res: Optional[Tensor] = None\n\n        x_l: OptTensor = None\n        x_r: OptTensor = None\n        if isinstance(x, Tensor):\n            assert x.dim() == 2\n\n            if self.res is not None:\n                res = self.res(x)\n\n            x_l = self.lin_l(x).view(-1, H, C)\n            if self.share_weights:\n                x_r = x_l\n            else:\n                x_r = self.lin_r(x).view(-1, H, C)\n        else:\n            x_l, x_r = x[0], x[1]\n            assert x[0].dim() == 2\n\n            if x_r is not None and self.res is not None:\n                res = self.res(x_r)\n\n            x_l = self.lin_l(x_l).view(-1, H, C)\n            if x_r is not None:\n                x_r = self.lin_r(x_r).view(-1, H, C)\n\n        assert x_l is not None\n        assert x_r is not None\n\n        if self.add_self_loops:\n            if isinstance(edge_index, Tensor):\n                num_nodes = x_l.size(0)\n                if x_r is not None:\n                    num_nodes = min(num_nodes, x_r.size(0))\n                edge_index, edge_attr = remove_self_loops(\n                    edge_index, edge_attr)\n                edge_index, edge_attr = add_self_loops(\n                    edge_index, edge_attr, fill_value=self.fill_value,\n                    num_nodes=num_nodes)\n            elif isinstance(edge_index, SparseTensor):\n                if self.edge_dim is None:\n                    edge_index = torch_sparse.set_diag(edge_index)\n                else:\n                    raise NotImplementedError(\n                        \"The usage of 'edge_attr' and 'add_self_loops' \"\n                        \"simultaneously is currently not yet supported for \"\n                        \"'edge_index' in a 'SparseTensor' form\")\n\n        # edge_updater_type: (x: PairTensor, edge_attr: OptTensor)\n        alpha = self.edge_updater(edge_index, x=(x_l, x_r),\n                                  edge_attr=edge_attr)\n\n        # propagate_type: (x: PairTensor, alpha: Tensor)\n        out = self.propagate(edge_index, x=(x_l, x_r), alpha=alpha)\n\n        if self.concat:\n            out = out.view(-1, self.heads * self.out_channels)\n        else:\n            out = out.mean(dim=1)\n\n        if res is not None:\n            out = out + res\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        if return_attention_weights:\n            if isinstance(edge_index, Tensor):\n                if is_torch_sparse_tensor(edge_index):\n                    # TODO TorchScript requires to return a tuple\n                    adj = set_sparse_value(edge_index, alpha)\n                    return out, (adj, alpha)\n                else:\n                    return out, (edge_index, alpha)\n            elif isinstance(edge_index, SparseTensor):\n                return out, edge_index.set_value(alpha, layout='coo')\n\n        return out\n\n    def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: OptTensor,\n                    index: Tensor, ptr: OptTensor,\n                    dim_size: Optional[int]) -> Tensor:\n        x = x_i + x_j\n\n        if edge_attr is not None:\n            if edge_attr.dim() == 1:\n                edge_attr = edge_attr.view(-1, 1)\n            assert self.lin_edge is not None\n            edge_attr = self.lin_edge(edge_attr)\n            edge_attr = edge_attr.view(-1, self.heads, self.out_channels)\n            x = x + edge_attr\n\n        x = F.leaky_relu(x, self.negative_slope)\n        alpha = (x * self.att).sum(dim=-1)\n        alpha = softmax(alpha, index, ptr, dim_size)\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n        return alpha\n\n    def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:\n        return x_j * alpha.unsqueeze(-1)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/gcn2_conv.py",
    "content": "from math import log\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.inits import glorot\nfrom torch_geometric.typing import Adj, OptPairTensor, OptTensor, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass GCN2Conv(MessagePassing):\n    r\"\"\"The graph convolutional operator with initial residual connections and\n    identity mapping (GCNII) from the `\"Simple and Deep Graph Convolutional\n    Networks\" <https://arxiv.org/abs/2007.02133>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = \\left( (1 - \\alpha) \\mathbf{\\hat{P}}\\mathbf{X} +\n        \\alpha \\mathbf{X^{(0)}}\\right) \\left( (1 - \\beta) \\mathbf{I} + \\beta\n        \\mathbf{\\Theta} \\right)\n\n    with :math:`\\mathbf{\\hat{P}} = \\mathbf{\\hat{D}}^{-1/2} \\mathbf{\\hat{A}}\n    \\mathbf{\\hat{D}}^{-1/2}`, where\n    :math:`\\mathbf{\\hat{A}} = \\mathbf{A} + \\mathbf{I}` denotes the adjacency\n    matrix with inserted self-loops and\n    :math:`\\hat{D}_{ii} = \\sum_{j=0} \\hat{A}_{ij}` its diagonal degree matrix,\n    and :math:`\\mathbf{X}^{(0)}` being the initial feature representation.\n    Here, :math:`\\alpha` models the strength of the initial residual\n    connection, while :math:`\\beta` models the strength of the identity\n    mapping.\n    The adjacency matrix can include other values than :obj:`1` representing\n    edge weights via the optional :obj:`edge_weight` tensor.\n\n    Args:\n        channels (int): Size of each input and output sample.\n        alpha (float): The strength of the initial residual connection\n            :math:`\\alpha`.\n        theta (float, optional): The hyperparameter :math:`\\theta` to compute\n            the strength of the identity mapping\n            :math:`\\beta = \\log \\left( \\frac{\\theta}{\\ell} + 1 \\right)`.\n            (default: :obj:`None`)\n        layer (int, optional): The layer :math:`\\ell` in which this module is\n            executed. (default: :obj:`None`)\n        shared_weights (bool, optional): If set to :obj:`False`, will use\n            different weight matrices for the smoothed representation and the\n            initial residual (\"GCNII*\"). (default: :obj:`True`)\n        cached (bool, optional): If set to :obj:`True`, the layer will cache\n            the computation of :math:`\\mathbf{\\hat{D}}^{-1/2} \\mathbf{\\hat{A}}\n            \\mathbf{\\hat{D}}^{-1/2}` on first execution, and will use the\n            cached version for further executions.\n            This parameter should only be set to :obj:`True` in transductive\n            learning scenarios. (default: :obj:`False`)\n        normalize (bool, optional): Whether to add self-loops and apply\n            symmetric normalization. (default: :obj:`True`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F)`,\n          initial node features :math:`(|\\mathcal{V}|, F)`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F)`\n    \"\"\"\n    _cached_edge_index: Optional[OptPairTensor]\n    _cached_adj_t: Optional[SparseTensor]\n\n    def __init__(self, channels: int, alpha: float, theta: float = None,\n                 layer: int = None, shared_weights: bool = True,\n                 cached: bool = False, add_self_loops: bool = True,\n                 normalize: bool = True, **kwargs):\n\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.channels = channels\n        self.alpha = alpha\n        self.beta = 1.\n        if theta is not None or layer is not None:\n            assert theta is not None and layer is not None\n            self.beta = log(theta / layer + 1)\n        self.cached = cached\n        self.normalize = normalize\n        self.add_self_loops = add_self_loops\n\n        self._cached_edge_index = None\n        self._cached_adj_t = None\n\n        self.weight1 = Parameter(torch.empty(channels, channels))\n\n        if shared_weights:\n            self.register_parameter('weight2', None)\n        else:\n            self.weight2 = Parameter(torch.empty(channels, channels))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        glorot(self.weight1)\n        glorot(self.weight2)\n        self._cached_edge_index = None\n        self._cached_adj_t = None\n\n    def forward(self, x: Tensor, x_0: Tensor, edge_index: Adj,\n                edge_weight: OptTensor = None) -> Tensor:\n\n        if self.normalize:\n            if isinstance(edge_index, Tensor):\n                cache = self._cached_edge_index\n                if cache is None:\n                    edge_index, edge_weight = gcn_norm(  # yapf: disable\n                        edge_index, edge_weight, x.size(self.node_dim), False,\n                        self.add_self_loops, self.flow, dtype=x.dtype)\n                    if self.cached:\n                        self._cached_edge_index = (edge_index, edge_weight)\n                else:\n                    edge_index, edge_weight = cache[0], cache[1]\n\n            elif isinstance(edge_index, SparseTensor):\n                cache = self._cached_adj_t\n                if cache is None:\n                    edge_index = gcn_norm(  # yapf: disable\n                        edge_index, edge_weight, x.size(self.node_dim), False,\n                        self.add_self_loops, self.flow, dtype=x.dtype)\n                    if self.cached:\n                        self._cached_adj_t = edge_index\n                else:\n                    edge_index = cache\n\n        # propagate_type: (x: Tensor, edge_weight: OptTensor)\n        x = self.propagate(edge_index, x=x, edge_weight=edge_weight)\n\n        x.mul_(1 - self.alpha)\n        x_0 = self.alpha * x_0[:x.size(0)]\n\n        if self.weight2 is None:\n            out = x.add_(x_0)\n            out = torch.addmm(out, out, self.weight1, beta=1. - self.beta,\n                              alpha=self.beta)\n        else:\n            out = torch.addmm(x, x, self.weight1, beta=1. - self.beta,\n                              alpha=self.beta)\n            out = out + torch.addmm(x_0, x_0, self.weight2,\n                                    beta=1. - self.beta, alpha=self.beta)\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.channels}, '\n                f'alpha={self.alpha}, beta={self.beta})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/gcn_conv.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.typing import (\n    Adj,\n    OptPairTensor,\n    OptTensor,\n    SparseTensor,\n    torch_sparse,\n)\nfrom torch_geometric.utils import add_remaining_self_loops\nfrom torch_geometric.utils import add_self_loops as add_self_loops_fn\nfrom torch_geometric.utils import (\n    is_torch_sparse_tensor,\n    scatter,\n    spmm,\n    to_edge_index,\n)\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\nfrom torch_geometric.utils.sparse import set_sparse_value\n\n\n@torch.jit._overload\ndef gcn_norm(  # noqa: F811\n        edge_index, edge_weight, num_nodes, improved, add_self_loops, flow,\n        dtype):\n    # type: (Tensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> OptPairTensor  # noqa\n    pass\n\n\n@torch.jit._overload\ndef gcn_norm(  # noqa: F811\n        edge_index, edge_weight, num_nodes, improved, add_self_loops, flow,\n        dtype):\n    # type: (SparseTensor, OptTensor, Optional[int], bool, bool, str, Optional[int]) -> SparseTensor  # noqa\n    pass\n\n\ndef gcn_norm(  # noqa: F811\n    edge_index: Adj,\n    edge_weight: OptTensor = None,\n    num_nodes: Optional[int] = None,\n    improved: bool = False,\n    add_self_loops: bool = True,\n    flow: str = \"source_to_target\",\n    dtype: Optional[torch.dtype] = None,\n):\n    fill_value = 2. if improved else 1.\n\n    if isinstance(edge_index, SparseTensor):\n        assert edge_index.size(0) == edge_index.size(1)\n\n        adj_t = edge_index\n\n        if not adj_t.has_value():\n            adj_t = adj_t.fill_value(1., dtype=dtype)\n        if add_self_loops:\n            adj_t = torch_sparse.fill_diag(adj_t, fill_value)\n\n        deg = torch_sparse.sum(adj_t, dim=1)\n        deg_inv_sqrt = deg.pow_(-0.5)\n        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)\n        adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(-1, 1))\n        adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(1, -1))\n\n        return adj_t\n\n    if is_torch_sparse_tensor(edge_index):\n        assert edge_index.size(0) == edge_index.size(1)\n\n        if edge_index.layout == torch.sparse_csc:\n            raise NotImplementedError(\"Sparse CSC matrices are not yet \"\n                                      \"supported in 'gcn_norm'\")\n\n        adj_t = edge_index\n        if add_self_loops:\n            adj_t, _ = add_self_loops_fn(adj_t, None, fill_value, num_nodes)\n\n        edge_index, value = to_edge_index(adj_t)\n        col, row = edge_index[0], edge_index[1]\n\n        deg = scatter(value, col, 0, dim_size=num_nodes, reduce='sum')\n        deg_inv_sqrt = deg.pow_(-0.5)\n        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)\n        value = deg_inv_sqrt[row] * value * deg_inv_sqrt[col]\n\n        return set_sparse_value(adj_t, value), None\n\n    assert flow in ['source_to_target', 'target_to_source']\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    if add_self_loops:\n        edge_index, edge_weight = add_remaining_self_loops(\n            edge_index, edge_weight, fill_value, num_nodes)\n\n    if edge_weight is None:\n        edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,\n                                 device=edge_index.device)\n\n    row, col = edge_index[0], edge_index[1]\n    idx = col if flow == 'source_to_target' else row\n    deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')\n    deg_inv_sqrt = deg.pow_(-0.5)\n    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)\n    edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]\n\n    return edge_index, edge_weight\n\n\nclass GCNConv(MessagePassing):\n    r\"\"\"The graph convolutional operator from the `\"Semi-supervised\n    Classification with Graph Convolutional Networks\"\n    <https://arxiv.org/abs/1609.02907>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = \\mathbf{\\hat{D}}^{-1/2} \\mathbf{\\hat{A}}\n        \\mathbf{\\hat{D}}^{-1/2} \\mathbf{X} \\mathbf{\\Theta},\n\n    where :math:`\\mathbf{\\hat{A}} = \\mathbf{A} + \\mathbf{I}` denotes the\n    adjacency matrix with inserted self-loops and\n    :math:`\\hat{D}_{ii} = \\sum_{j=0} \\hat{A}_{ij}` its diagonal degree matrix.\n    The adjacency matrix can include other values than :obj:`1` representing\n    edge weights via the optional :obj:`edge_weight` tensor.\n\n    Its node-wise formulation is given by:\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{\\Theta}^{\\top} \\sum_{j \\in\n        \\mathcal{N}(i) \\cup \\{ i \\}} \\frac{e_{j,i}}{\\sqrt{\\hat{d}_j\n        \\hat{d}_i}} \\mathbf{x}_j\n\n    with :math:`\\hat{d}_i = 1 + \\sum_{j \\in \\mathcal{N}(i)} e_{j,i}`, where\n    :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target\n    node :obj:`i` (default: :obj:`1.0`)\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        improved (bool, optional): If set to :obj:`True`, the layer computes\n            :math:`\\mathbf{\\hat{A}}` as :math:`\\mathbf{A} + 2\\mathbf{I}`.\n            (default: :obj:`False`)\n        cached (bool, optional): If set to :obj:`True`, the layer will cache\n            the computation of :math:`\\mathbf{\\hat{D}}^{-1/2} \\mathbf{\\hat{A}}\n            \\mathbf{\\hat{D}}^{-1/2}` on first execution, and will use the\n            cached version for further executions.\n            This parameter should only be set to :obj:`True` in transductive\n            learning scenarios. (default: :obj:`False`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. By default, self-loops will be added\n            in case :obj:`normalize` is set to :obj:`True`, and not added\n            otherwise. (default: :obj:`None`)\n        normalize (bool, optional): Whether to add self-loops and compute\n            symmetric normalization coefficients on-the-fly.\n            (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n          or sparse matrix :math:`(|\\mathcal{V}|, |\\mathcal{V}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    _cached_edge_index: Optional[OptPairTensor]\n    _cached_adj_t: Optional[SparseTensor]\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        improved: bool = False,\n        cached: bool = False,\n        add_self_loops: Optional[bool] = None,\n        normalize: bool = True,\n        bias: bool = True,\n        **kwargs,\n    ):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        if add_self_loops is None:\n            add_self_loops = normalize\n\n        if add_self_loops and not normalize:\n            raise ValueError(f\"'{self.__class__.__name__}' does not support \"\n                             f\"adding self-loops to the graph when no \"\n                             f\"on-the-fly normalization is applied\")\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.improved = improved\n        self.cached = cached\n        self.add_self_loops = add_self_loops\n        self.normalize = normalize\n\n        self._cached_edge_index = None\n        self._cached_adj_t = None\n\n        self.lin = Linear(in_channels, out_channels, bias=False,\n                          weight_initializer='glorot')\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin.reset_parameters()\n        zeros(self.bias)\n        self._cached_edge_index = None\n        self._cached_adj_t = None\n\n    def forward(self, x: Tensor, edge_index: Adj,\n                edge_weight: OptTensor = None) -> Tensor:\n\n        if isinstance(x, (tuple, list)):\n            raise ValueError(f\"'{self.__class__.__name__}' received a tuple \"\n                             f\"of node features as input while this layer \"\n                             f\"does not support bipartite message passing. \"\n                             f\"Please try other layers such as 'SAGEConv' or \"\n                             f\"'GraphConv' instead\")\n\n        if self.normalize:\n            if isinstance(edge_index, Tensor):\n                cache = self._cached_edge_index\n                if cache is None:\n                    edge_index, edge_weight = gcn_norm(  # yapf: disable\n                        edge_index, edge_weight, x.size(self.node_dim),\n                        self.improved, self.add_self_loops, self.flow, x.dtype)\n                    if self.cached:\n                        self._cached_edge_index = (edge_index, edge_weight)\n                else:\n                    edge_index, edge_weight = cache[0], cache[1]\n\n            elif isinstance(edge_index, SparseTensor):\n                cache = self._cached_adj_t\n                if cache is None:\n                    edge_index = gcn_norm(  # yapf: disable\n                        edge_index, edge_weight, x.size(self.node_dim),\n                        self.improved, self.add_self_loops, self.flow, x.dtype)\n                    if self.cached:\n                        self._cached_adj_t = edge_index\n                else:\n                    edge_index = cache\n\n        x = self.lin(x)\n\n        # propagate_type: (x: Tensor, edge_weight: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n"
  },
  {
    "path": "torch_geometric/nn/conv/gen_conv.py",
    "content": "from typing import List, Optional, Tuple, Union\n\nfrom torch import Tensor\nfrom torch.nn import (\n    BatchNorm1d,\n    Dropout,\n    InstanceNorm1d,\n    LayerNorm,\n    ReLU,\n    Sequential,\n)\n\nfrom torch_geometric.nn.aggr import Aggregation, MultiAggregation\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.nn.norm import MessageNorm\nfrom torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n\n\nclass MLP(Sequential):\n    def __init__(self, channels: List[int], norm: Optional[str] = None,\n                 bias: bool = True, dropout: float = 0.):\n        m = []\n        for i in range(1, len(channels)):\n            m.append(Linear(channels[i - 1], channels[i], bias=bias))\n\n            if i < len(channels) - 1:\n                if norm and norm == 'batch':\n                    m.append(BatchNorm1d(channels[i], affine=True))\n                elif norm and norm == 'layer':\n                    m.append(LayerNorm(channels[i], elementwise_affine=True))\n                elif norm and norm == 'instance':\n                    m.append(InstanceNorm1d(channels[i], affine=False))\n                elif norm:\n                    raise NotImplementedError(\n                        f'Normalization layer \"{norm}\" not supported.')\n                m.append(ReLU())\n                m.append(Dropout(dropout))\n\n        super().__init__(*m)\n\n\nclass GENConv(MessagePassing):\n    r\"\"\"The GENeralized Graph Convolution (GENConv) from the `\"DeeperGCN: All\n    You Need to Train Deeper GCNs\" <https://arxiv.org/abs/2006.07739>`_ paper.\n\n    :class:`GENConv` supports both :math:`\\textrm{softmax}` (see\n    :class:`~torch_geometric.nn.aggr.SoftmaxAggregation`) and\n    :math:`\\textrm{powermean}` (see\n    :class:`~torch_geometric.nn.aggr.PowerMeanAggregation`) aggregation.\n    Its message construction is given by:\n\n    .. math::\n        \\mathbf{x}_i^{\\prime} = \\mathrm{MLP} \\left( \\mathbf{x}_i +\n        \\mathrm{AGG} \\left( \\left\\{\n        \\mathrm{ReLU} \\left( \\mathbf{x}_j + \\mathbf{e_{ji}} \\right) +\\epsilon\n        : j \\in \\mathcal{N}(i) \\right\\} \\right)\n        \\right)\n\n    .. note::\n\n        For an example of using :obj:`GENConv`, see\n        `examples/ogbn_proteins_deepgcn.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        ogbn_proteins_deepgcn.py>`_.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        aggr (str or Aggregation, optional): The aggregation scheme to use.\n            Any aggregation of :obj:`torch_geometric.nn.aggr` can be used,\n            (:obj:`\"softmax\"`, :obj:`\"powermean\"`, :obj:`\"add\"`, :obj:`\"mean\"`,\n            :obj:`max`). (default: :obj:`\"softmax\"`)\n        t (float, optional): Initial inverse temperature for softmax\n            aggregation. (default: :obj:`1.0`)\n        learn_t (bool, optional): If set to :obj:`True`, will learn the value\n            :obj:`t` for softmax aggregation dynamically.\n            (default: :obj:`False`)\n        p (float, optional): Initial power for power mean aggregation.\n            (default: :obj:`1.0`)\n        learn_p (bool, optional): If set to :obj:`True`, will learn the value\n            :obj:`p` for power mean aggregation dynamically.\n            (default: :obj:`False`)\n        msg_norm (bool, optional): If set to :obj:`True`, will use message\n            normalization. (default: :obj:`False`)\n        learn_msg_scale (bool, optional): If set to :obj:`True`, will learn the\n            scaling factor of message normalization. (default: :obj:`False`)\n        norm (str, optional): Norm layer of MLP layers (:obj:`\"batch\"`,\n            :obj:`\"layer\"`, :obj:`\"instance\"`) (default: :obj:`batch`)\n        num_layers (int, optional): The number of MLP layers.\n            (default: :obj:`2`)\n        expansion (int, optional): The expansion factor of hidden channels in\n            MLP layers. (default: :obj:`2`)\n        eps (float, optional): The epsilon value of the message construction\n            function. (default: :obj:`1e-7`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        edge_dim (int, optional): Edge feature dimensionality. If set to\n            :obj:`None`, Edge feature dimensionality is expected to match\n            the `out_channels`. Other-wise, edge features are linearly\n            transformed to match `out_channels` of node feature dimensionality.\n            (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.GenMessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge attributes :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: int,\n        aggr: Optional[Union[str, List[str], Aggregation]] = 'softmax',\n        t: float = 1.0,\n        learn_t: bool = False,\n        p: float = 1.0,\n        learn_p: bool = False,\n        msg_norm: bool = False,\n        learn_msg_scale: bool = False,\n        norm: str = 'batch',\n        num_layers: int = 2,\n        expansion: int = 2,\n        eps: float = 1e-7,\n        bias: bool = False,\n        edge_dim: Optional[int] = None,\n        **kwargs,\n    ):\n\n        # Backward compatibility:\n        semi_grad = True if aggr == 'softmax_sg' else False\n        aggr = 'softmax' if aggr == 'softmax_sg' else aggr\n        aggr = 'powermean' if aggr == 'power' else aggr\n\n        # Override args of aggregator if `aggr_kwargs` is specified\n        if 'aggr_kwargs' not in kwargs:\n            if aggr == 'softmax':\n                kwargs['aggr_kwargs'] = dict(t=t, learn=learn_t,\n                                             semi_grad=semi_grad)\n            elif aggr == 'powermean':\n                kwargs['aggr_kwargs'] = dict(p=p, learn=learn_p)\n\n        super().__init__(aggr=aggr, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.eps = eps\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        if in_channels[0] != out_channels:\n            self.lin_src = Linear(in_channels[0], out_channels, bias=bias)\n\n        if edge_dim is not None and edge_dim != out_channels:\n            self.lin_edge = Linear(edge_dim, out_channels, bias=bias)\n\n        if isinstance(self.aggr_module, MultiAggregation):\n            aggr_out_channels = self.aggr_module.get_out_channels(out_channels)\n        else:\n            aggr_out_channels = out_channels\n\n        if aggr_out_channels != out_channels:\n            self.lin_aggr_out = Linear(aggr_out_channels, out_channels,\n                                       bias=bias)\n\n        if in_channels[1] != out_channels:\n            self.lin_dst = Linear(in_channels[1], out_channels, bias=bias)\n\n        channels = [out_channels]\n        for _ in range(num_layers - 1):\n            channels.append(out_channels * expansion)\n        channels.append(out_channels)\n        self.mlp = MLP(channels, norm=norm, bias=bias)\n\n        if msg_norm:\n            self.msg_norm = MessageNorm(learn_msg_scale)\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        reset(self.mlp)\n        if hasattr(self, 'msg_norm'):\n            self.msg_norm.reset_parameters()\n        if hasattr(self, 'lin_src'):\n            self.lin_src.reset_parameters()\n        if hasattr(self, 'lin_edge'):\n            self.lin_edge.reset_parameters()\n        if hasattr(self, 'lin_aggr_out'):\n            self.lin_aggr_out.reset_parameters()\n        if hasattr(self, 'lin_dst'):\n            self.lin_dst.reset_parameters()\n\n    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,\n                edge_attr: OptTensor = None, size: Size = None) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        if hasattr(self, 'lin_src'):\n            x = (self.lin_src(x[0]), x[1])\n\n        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n\n        if hasattr(self, 'lin_aggr_out'):\n            out = self.lin_aggr_out(out)\n\n        if hasattr(self, 'msg_norm'):\n            h = x[1] if x[1] is not None else x[0]\n            assert h is not None\n            out = self.msg_norm(h, out)\n\n        x_dst = x[1]\n        if x_dst is not None:\n            if hasattr(self, 'lin_dst'):\n                x_dst = self.lin_dst(x_dst)\n            out = out + x_dst\n\n        return self.mlp(out)\n\n    def message(self, x_j: Tensor, edge_attr: OptTensor) -> Tensor:\n        if edge_attr is not None and hasattr(self, 'lin_edge'):\n            edge_attr = self.lin_edge(edge_attr)\n\n        if edge_attr is not None:\n            assert x_j.size(-1) == edge_attr.size(-1)\n\n        msg = x_j if edge_attr is None else x_j + edge_attr\n        return msg.relu() + self.eps\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, aggr={self.aggr})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/general_conv.py",
    "content": "from typing import Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import glorot\nfrom torch_geometric.typing import (\n    Adj,\n    Optional,\n    OptPairTensor,\n    OptTensor,\n    Size,\n)\nfrom torch_geometric.utils import softmax\n\n\nclass GeneralConv(MessagePassing):\n    r\"\"\"A general GNN layer adapted from the `\"Design Space for Graph Neural\n    Networks\" <https://arxiv.org/abs/2011.08843>`_ paper.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        in_edge_channels (int, optional): Size of each input edge.\n            (default: :obj:`None`)\n        aggr (str, optional): The aggregation scheme to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"mean\"`)\n        skip_linear (bool, optional): Whether apply linear function in skip\n            connection. (default: :obj:`False`)\n        directed_msg (bool, optional): If message passing is directed;\n            otherwise, message passing is bi-directed. (default: :obj:`True`)\n        heads (int, optional): Number of message passing ensembles.\n            If :obj:`heads > 1`, the GNN layer will output an ensemble of\n            multiple messages.\n            If attention is used (:obj:`attention=True`), this corresponds to\n            multi-head attention. (default: :obj:`1`)\n        attention (bool, optional): Whether to add attention to message\n            computation. (default: :obj:`False`)\n        attention_type (str, optional): Type of attention: :obj:`\"additive\"`,\n            :obj:`\"dot_product\"`. (default: :obj:`\"additive\"`)\n        l2_normalize (bool, optional): If set to :obj:`True`, output features\n            will be :math:`\\ell_2`-normalized, *i.e.*,\n            :math:`\\frac{\\mathbf{x}^{\\prime}_i}\n            {\\| \\mathbf{x}^{\\prime}_i \\|_2}`.\n            (default: :obj:`False`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge attributes :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: Optional[int],\n        in_edge_channels: Optional[int] = None,\n        aggr: str = \"add\",\n        skip_linear: str = False,\n        directed_msg: bool = True,\n        heads: int = 1,\n        attention: bool = False,\n        attention_type: str = \"additive\",\n        l2_normalize: bool = False,\n        bias: bool = True,\n        **kwargs,\n    ):\n        kwargs.setdefault('aggr', aggr)\n        super().__init__(node_dim=0, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.in_edge_channels = in_edge_channels\n        self.aggr = aggr\n        self.skip_linear = skip_linear\n        self.directed_msg = directed_msg\n        self.heads = heads\n        self.attention = attention\n        self.attention_type = attention_type\n        self.normalize_l2 = l2_normalize\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        if self.directed_msg:\n            self.lin_msg = Linear(in_channels[0], out_channels * self.heads,\n                                  bias=bias)\n        else:\n            self.lin_msg = Linear(in_channels[0], out_channels * self.heads,\n                                  bias=bias)\n            self.lin_msg_i = Linear(in_channels[0], out_channels * self.heads,\n                                    bias=bias)\n\n        if self.skip_linear or self.in_channels != self.out_channels:\n            self.lin_self = Linear(in_channels[1], out_channels, bias=bias)\n        else:\n            self.lin_self = torch.nn.Identity()\n\n        if self.in_edge_channels is not None:\n            self.lin_edge = Linear(in_edge_channels, out_channels * self.heads,\n                                   bias=bias)\n\n        # TODO: A general torch_geometric.nn.AttentionLayer\n        if self.attention:\n            if self.attention_type == 'additive':\n                self.att_msg = Parameter(\n                    torch.empty(1, self.heads, self.out_channels))\n            elif self.attention_type == 'dot_product':\n                scaler = torch.tensor(out_channels, dtype=torch.float).sqrt()\n                self.register_buffer('scaler', scaler)\n            else:\n                raise ValueError(\n                    f\"Attention type '{self.attention_type}' not supported\")\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin_msg.reset_parameters()\n        if hasattr(self.lin_self, 'reset_parameters'):\n            self.lin_self.reset_parameters()\n        if self.in_edge_channels is not None:\n            self.lin_edge.reset_parameters()\n        if self.attention and self.attention_type == 'additive':\n            glorot(self.att_msg)\n\n    def forward(\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        edge_attr: OptTensor = None,\n        size: Size = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x: OptPairTensor = (x, x)\n        x_self = x[1]\n        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n        out = self.propagate(edge_index, x=x, size=size, edge_attr=edge_attr)\n        out = out.mean(dim=1)  # todo: other approach to aggregate heads\n        out = out + self.lin_self(x_self)\n        if self.normalize_l2:\n            out = F.normalize(out, p=2, dim=-1)\n        return out\n\n    def message_basic(self, x_i: Tensor, x_j: Tensor, edge_attr: OptTensor):\n        if self.directed_msg:\n            x_j = self.lin_msg(x_j)\n        else:\n            x_j = self.lin_msg(x_j) + self.lin_msg_i(x_i)\n        if edge_attr is not None:\n            x_j = x_j + self.lin_edge(edge_attr)\n        return x_j\n\n    def message(self, x_i: Tensor, x_j: Tensor, edge_index_i: Tensor,\n                size_i: Tensor, edge_attr: Tensor) -> Tensor:\n        x_j_out = self.message_basic(x_i, x_j, edge_attr)\n        x_j_out = x_j_out.view(-1, self.heads, self.out_channels)\n        if self.attention:\n            if self.attention_type == 'dot_product':\n                x_i_out = self.message_basic(x_j, x_i, edge_attr)\n                x_i_out = x_i_out.view(-1, self.heads, self.out_channels)\n                alpha = (x_i_out * x_j_out).sum(dim=-1) / self.scaler\n            else:\n                alpha = (x_j_out * self.att_msg).sum(dim=-1)\n            alpha = F.leaky_relu(alpha, negative_slope=0.2)\n            alpha = softmax(alpha, edge_index_i, num_nodes=size_i)\n            alpha = alpha.view(-1, self.heads, 1)\n            return x_j_out * alpha\n        else:\n            return x_j_out\n"
  },
  {
    "path": "torch_geometric/nn/conv/gin_conv.py",
    "content": "from typing import Callable, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.typing import (\n    Adj,\n    OptPairTensor,\n    OptTensor,\n    Size,\n    SparseTensor,\n)\nfrom torch_geometric.utils import spmm\n\n\nclass GINConv(MessagePassing):\n    r\"\"\"The graph isomorphism operator from the `\"How Powerful are\n    Graph Neural Networks?\" <https://arxiv.org/abs/1810.00826>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = h_{\\mathbf{\\Theta}} \\left( (1 + \\epsilon) \\cdot\n        \\mathbf{x}_i + \\sum_{j \\in \\mathcal{N}(i)} \\mathbf{x}_j \\right)\n\n    or\n\n    .. math::\n        \\mathbf{X}^{\\prime} = h_{\\mathbf{\\Theta}} \\left( \\left( \\mathbf{A} +\n        (1 + \\epsilon) \\cdot \\mathbf{I} \\right) \\cdot \\mathbf{X} \\right),\n\n    here :math:`h_{\\mathbf{\\Theta}}` denotes a neural network, *.i.e.* an MLP.\n\n    Args:\n        nn (torch.nn.Module): A neural network :math:`h_{\\mathbf{\\Theta}}` that\n            maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to\n            shape :obj:`[-1, out_channels]`, *e.g.*, defined by\n            :class:`torch.nn.Sequential`.\n        eps (float, optional): (Initial) :math:`\\epsilon`-value.\n            (default: :obj:`0.`)\n        train_eps (bool, optional): If set to :obj:`True`, :math:`\\epsilon`\n            will be a trainable parameter. (default: :obj:`False`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,\n                 **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n        self.nn = nn\n        self.initial_eps = eps\n        if train_eps:\n            self.eps = torch.nn.Parameter(torch.empty(1))\n        else:\n            self.register_buffer('eps', torch.empty(1))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        reset(self.nn)\n        self.eps.data.fill_(self.initial_eps)\n\n    def forward(\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        size: Size = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: OptPairTensor)\n        out = self.propagate(edge_index, x=x, size=size)\n\n        x_r = x[1]\n        if x_r is not None:\n            out = out + (1 + self.eps) * x_r\n\n        return self.nn(out)\n\n    def message(self, x_j: Tensor) -> Tensor:\n        return x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:\n        if isinstance(adj_t, SparseTensor):\n            adj_t = adj_t.set_value(None, layout=None)\n        return spmm(adj_t, x[0], reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(nn={self.nn})'\n\n\nclass GINEConv(MessagePassing):\n    r\"\"\"The modified :class:`GINConv` operator from the `\"Strategies for\n    Pre-training Graph Neural Networks\" <https://arxiv.org/abs/1905.12265>`_\n    paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = h_{\\mathbf{\\Theta}} \\left( (1 + \\epsilon) \\cdot\n        \\mathbf{x}_i + \\sum_{j \\in \\mathcal{N}(i)} \\mathrm{ReLU}\n        ( \\mathbf{x}_j + \\mathbf{e}_{j,i} ) \\right)\n\n    that is able to incorporate edge features :math:`\\mathbf{e}_{j,i}` into\n    the aggregation procedure.\n\n    Args:\n        nn (torch.nn.Module): A neural network :math:`h_{\\mathbf{\\Theta}}` that\n            maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to\n            shape :obj:`[-1, out_channels]`, *e.g.*, defined by\n            :class:`torch.nn.Sequential`.\n        eps (float, optional): (Initial) :math:`\\epsilon`-value.\n            (default: :obj:`0.`)\n        train_eps (bool, optional): If set to :obj:`True`, :math:`\\epsilon`\n            will be a trainable parameter. (default: :obj:`False`)\n        edge_dim (int, optional): Edge feature dimensionality. If set to\n            :obj:`None`, node and edge feature dimensionality is expected to\n            match. Other-wise, edge features are linearly transformed to match\n            node feature dimensionality. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, nn: torch.nn.Module, eps: float = 0.,\n                 train_eps: bool = False, edge_dim: Optional[int] = None,\n                 **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n        self.nn = nn\n        self.initial_eps = eps\n        if train_eps:\n            self.eps = torch.nn.Parameter(torch.empty(1))\n        else:\n            self.register_buffer('eps', torch.empty(1))\n        if edge_dim is not None:\n            if isinstance(self.nn, torch.nn.Sequential):\n                nn = self.nn[0]\n            if hasattr(nn, 'in_features'):\n                in_channels = nn.in_features\n            elif hasattr(nn, 'in_channels'):\n                in_channels = nn.in_channels\n            else:\n                raise ValueError(\"Could not infer input channels from `nn`.\")\n            self.lin = Linear(edge_dim, in_channels)\n\n        else:\n            self.lin = None\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        reset(self.nn)\n        self.eps.data.fill_(self.initial_eps)\n        if self.lin is not None:\n            self.lin.reset_parameters()\n\n    def forward(\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        edge_attr: OptTensor = None,\n        size: Size = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n\n        x_r = x[1]\n        if x_r is not None:\n            out = out + (1 + self.eps) * x_r\n\n        return self.nn(out)\n\n    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:\n        if self.lin is None and x_j.size(-1) != edge_attr.size(-1):\n            raise ValueError(\"Node and edge feature dimensionalities do not \"\n                             \"match. Consider setting the 'edge_dim' \"\n                             \"attribute of 'GINEConv'\")\n\n        if self.lin is not None:\n            edge_attr = self.lin(edge_attr)\n\n        return (x_j + edge_attr).relu()\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(nn={self.nn})'\n"
  },
  {
    "path": "torch_geometric/nn/conv/gmm_conv.py",
    "content": "from typing import Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import glorot, zeros\nfrom torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n\n\nclass GMMConv(MessagePassing):\n    r\"\"\"The gaussian mixture model convolutional operator from the `\"Geometric\n    Deep Learning on Graphs and Manifolds using Mixture Model CNNs\"\n    <https://arxiv.org/abs/1611.08402>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\frac{1}{|\\mathcal{N}(i)|}\n        \\sum_{j \\in \\mathcal{N}(i)} \\frac{1}{K} \\sum_{k=1}^K\n        \\mathbf{w}_k(\\mathbf{e}_{i,j}) \\odot \\mathbf{\\Theta}_k \\mathbf{x}_j,\n\n    where\n\n    .. math::\n        \\mathbf{w}_k(\\mathbf{e}) = \\exp \\left( -\\frac{1}{2} {\\left(\n        \\mathbf{e} - \\mathbf{\\mu}_k \\right)}^{\\top} \\Sigma_k^{-1}\n        \\left( \\mathbf{e} - \\mathbf{\\mu}_k \\right) \\right)\n\n    denotes a weighting function based on trainable mean vector\n    :math:`\\mathbf{\\mu}_k` and diagonal covariance matrix\n    :math:`\\mathbf{\\Sigma}_k`.\n\n    .. note::\n\n        The edge attribute :math:`\\mathbf{e}_{ij}` is usually given by\n        :math:`\\mathbf{e}_{ij} = \\mathbf{p}_j - \\mathbf{p}_i`, where\n        :math:`\\mathbf{p}_i` denotes the position of node :math:`i` (see\n        :class:`torch_geometric.transform.Cartesian`).\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        dim (int): Pseudo-coordinate dimensionality.\n        kernel_size (int): Number of kernels :math:`K`.\n        separate_gaussians (bool, optional): If set to :obj:`True`, will\n            learn separate GMMs for every pair of input and output channel,\n            inspired by traditional CNNs. (default: :obj:`False`)\n        aggr (str, optional): The aggregation operator to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"mean\"`)\n        root_weight (bool, optional): If set to :obj:`False`, the layer will\n            not add transformed root node features to the output.\n            (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, in_channels: Union[int, Tuple[int, int]],\n                 out_channels: int, dim: int, kernel_size: int,\n                 separate_gaussians: bool = False, aggr: str = 'mean',\n                 root_weight: bool = True, bias: bool = True, **kwargs):\n        super().__init__(aggr=aggr, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.dim = dim\n        self.kernel_size = kernel_size\n        self.separate_gaussians = separate_gaussians\n        self.root_weight = root_weight\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n        self.rel_in_channels = in_channels[0]\n\n        if in_channels[0] > 0:\n            self.g = Parameter(\n                Tensor(in_channels[0], out_channels * kernel_size))\n\n            if not self.separate_gaussians:\n                self.mu = Parameter(Tensor(kernel_size, dim))\n                self.sigma = Parameter(Tensor(kernel_size, dim))\n            if self.separate_gaussians:\n                self.mu = Parameter(\n                    Tensor(in_channels[0], out_channels, kernel_size, dim))\n                self.sigma = Parameter(\n                    Tensor(in_channels[0], out_channels, kernel_size, dim))\n        else:\n            self.g = torch.nn.parameter.UninitializedParameter()\n            self.mu = torch.nn.parameter.UninitializedParameter()\n            self.sigma = torch.nn.parameter.UninitializedParameter()\n            self._hook = self.register_forward_pre_hook(\n                self.initialize_parameters)\n\n        if root_weight:\n            self.root = Linear(in_channels[1], out_channels, bias=False,\n                               weight_initializer='glorot')\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        if not isinstance(self.g, torch.nn.UninitializedParameter):\n            glorot(self.g)\n            glorot(self.mu)\n            glorot(self.sigma)\n        if self.root_weight:\n            self.root.reset_parameters()\n        zeros(self.bias)\n\n    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,\n                edge_attr: OptTensor = None, size: Size = None):\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n        if not self.separate_gaussians:\n            out: OptPairTensor = (torch.matmul(x[0], self.g), x[1])\n            out = self.propagate(edge_index, x=out, edge_attr=edge_attr,\n                                 size=size)\n        else:\n            out = self.propagate(edge_index, x=x, edge_attr=edge_attr,\n                                 size=size)\n\n        x_r = x[1]\n        if x_r is not None and self.root is not None:\n            out = out + self.root(x_r)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:\n        EPS = 1e-15\n        F, M = self.rel_in_channels, self.out_channels\n        (E, D), K = edge_attr.size(), self.kernel_size\n\n        if not self.separate_gaussians:\n            gaussian = -0.5 * (edge_attr.view(E, 1, D) -\n                               self.mu.view(1, K, D)).pow(2)\n            gaussian = gaussian / (EPS + self.sigma.view(1, K, D).pow(2))\n            gaussian = torch.exp(gaussian.sum(dim=-1))  # [E, K]\n\n            return (x_j.view(E, K, M) * gaussian.view(E, K, 1)).sum(dim=-2)\n\n        else:\n            gaussian = -0.5 * (edge_attr.view(E, 1, 1, 1, D) -\n                               self.mu.view(1, F, M, K, D)).pow(2)\n            gaussian = gaussian / (EPS + self.sigma.view(1, F, M, K, D).pow(2))\n            gaussian = torch.exp(gaussian.sum(dim=-1))  # [E, F, M, K]\n\n            gaussian = gaussian * self.g.view(1, F, M, K)\n            gaussian = gaussian.sum(dim=-1)  # [E, F, M]\n\n            return (x_j.view(E, F, 1) * gaussian).sum(dim=-2)  # [E, M]\n\n    @torch.no_grad()\n    def initialize_parameters(self, module, input):\n        if isinstance(self.g, torch.nn.parameter.UninitializedParameter):\n            x = input[0][0] if isinstance(input, tuple) else input[0]\n            in_channels = x.size(-1)\n            out_channels, kernel_size = self.out_channels, self.kernel_size\n            self.g.materialize((in_channels, out_channels * kernel_size))\n            if not self.separate_gaussians:\n                self.mu.materialize((kernel_size, self.dim))\n                self.sigma.materialize((kernel_size, self.dim))\n            else:\n                self.mu.materialize(\n                    (in_channels, out_channels, kernel_size, self.dim))\n                self.sigma.materialize(\n                    (in_channels, out_channels, kernel_size, self.dim))\n            glorot(self.g)\n            glorot(self.mu)\n            glorot(self.sigma)\n\n        module._hook.remove()\n        delattr(module, '_hook')\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, dim={self.dim})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/gps_conv.py",
    "content": "import inspect\nfrom typing import Any, Dict, Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Dropout, Linear, Sequential\n\nfrom torch_geometric.nn.attention import PerformerAttention\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.nn.resolver import (\n    activation_resolver,\n    normalization_resolver,\n)\nfrom torch_geometric.typing import Adj\nfrom torch_geometric.utils import to_dense_batch\n\n\nclass GPSConv(torch.nn.Module):\n    r\"\"\"The general, powerful, scalable (GPS) graph transformer layer from the\n    `\"Recipe for a General, Powerful, Scalable Graph Transformer\"\n    <https://arxiv.org/abs/2205.12454>`_ paper.\n\n    The GPS layer is based on a 3-part recipe:\n\n    1. Inclusion of positional (PE) and structural encodings (SE) to the input\n       features (done in a pre-processing step via\n       :class:`torch_geometric.transforms`).\n    2. A local message passing layer (MPNN) that operates on the input graph.\n    3. A global attention layer that operates on the entire graph.\n\n    .. note::\n\n        For an example of using :class:`GPSConv`, see\n        `examples/graph_gps.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        graph_gps.py>`_.\n\n    Args:\n        channels (int): Size of each input sample.\n        conv (MessagePassing, optional): The local message passing layer.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        dropout (float, optional): Dropout probability of intermediate\n            embeddings. (default: :obj:`0.`)\n        act (str or Callable, optional): The non-linear activation function to\n            use. (default: :obj:`\"relu\"`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        norm (str or Callable, optional): The normalization function to\n            use. (default: :obj:`\"batch_norm\"`)\n        norm_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective normalization function defined by :obj:`norm`.\n            (default: :obj:`None`)\n        attn_type (str): Global attention type, :obj:`multihead` or\n            :obj:`performer`. (default: :obj:`multihead`)\n        attn_kwargs (Dict[str, Any], optional): Arguments passed to the\n            attention layer. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        channels: int,\n        conv: Optional[MessagePassing],\n        heads: int = 1,\n        dropout: float = 0.0,\n        act: str = 'relu',\n        act_kwargs: Optional[Dict[str, Any]] = None,\n        norm: Optional[str] = 'batch_norm',\n        norm_kwargs: Optional[Dict[str, Any]] = None,\n        attn_type: str = 'multihead',\n        attn_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        super().__init__()\n\n        self.channels = channels\n        self.conv = conv\n        self.heads = heads\n        self.dropout = dropout\n        self.attn_type = attn_type\n\n        attn_kwargs = attn_kwargs or {}\n        if attn_type == 'multihead':\n            self.attn = torch.nn.MultiheadAttention(\n                channels,\n                heads,\n                batch_first=True,\n                **attn_kwargs,\n            )\n        elif attn_type == 'performer':\n            self.attn = PerformerAttention(\n                channels=channels,\n                heads=heads,\n                **attn_kwargs,\n            )\n        else:\n            # TODO: Support BigBird\n            raise ValueError(f'{attn_type} is not supported')\n\n        self.mlp = Sequential(\n            Linear(channels, channels * 2),\n            activation_resolver(act, **(act_kwargs or {})),\n            Dropout(dropout),\n            Linear(channels * 2, channels),\n            Dropout(dropout),\n        )\n\n        norm_kwargs = norm_kwargs or {}\n        self.norm1 = normalization_resolver(norm, channels, **norm_kwargs)\n        self.norm2 = normalization_resolver(norm, channels, **norm_kwargs)\n        self.norm3 = normalization_resolver(norm, channels, **norm_kwargs)\n\n        self.norm_with_batch = False\n        if self.norm1 is not None:\n            signature = inspect.signature(self.norm1.forward)\n            self.norm_with_batch = 'batch' in signature.parameters\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        if self.conv is not None:\n            self.conv.reset_parameters()\n        self.attn._reset_parameters()\n        reset(self.mlp)\n        if self.norm1 is not None:\n            self.norm1.reset_parameters()\n        if self.norm2 is not None:\n            self.norm2.reset_parameters()\n        if self.norm3 is not None:\n            self.norm3.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Adj,\n        batch: Optional[torch.Tensor] = None,\n        **kwargs,\n    ) -> Tensor:\n        r\"\"\"Runs the forward pass of the module.\"\"\"\n        hs = []\n        if self.conv is not None:  # Local MPNN.\n            h = self.conv(x, edge_index, **kwargs)\n            h = F.dropout(h, p=self.dropout, training=self.training)\n            h = h + x\n            if self.norm1 is not None:\n                if self.norm_with_batch:\n                    h = self.norm1(h, batch=batch)\n                else:\n                    h = self.norm1(h)\n            hs.append(h)\n\n        # Global attention transformer-style model.\n        h, mask = to_dense_batch(x, batch)\n\n        if isinstance(self.attn, torch.nn.MultiheadAttention):\n            h, _ = self.attn(h, h, h, key_padding_mask=~mask,\n                             need_weights=False)\n        elif isinstance(self.attn, PerformerAttention):\n            h = self.attn(h, mask=mask)\n\n        h = h[mask]\n        h = F.dropout(h, p=self.dropout, training=self.training)\n        h = h + x  # Residual connection.\n        if self.norm2 is not None:\n            if self.norm_with_batch:\n                h = self.norm2(h, batch=batch)\n            else:\n                h = self.norm2(h)\n        hs.append(h)\n\n        out = sum(hs)  # Combine local and global outputs.\n\n        out = out + self.mlp(out)\n        if self.norm3 is not None:\n            if self.norm_with_batch:\n                out = self.norm3(out, batch=batch)\n            else:\n                out = self.norm3(out)\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.channels}, '\n                f'conv={self.conv}, heads={self.heads}, '\n                f'attn_type={self.attn_type})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/graph_conv.py",
    "content": "from typing import Final, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\nfrom torch_geometric.utils import spmm\n\n\nclass GraphConv(MessagePassing):\n    r\"\"\"The graph neural network operator from the `\"Weisfeiler and Leman Go\n    Neural: Higher-order Graph Neural Networks\"\n    <https://arxiv.org/abs/1810.02244>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{W}_1 \\mathbf{x}_i + \\mathbf{W}_2\n        \\sum_{j \\in \\mathcal{N}(i)} e_{j,i} \\cdot \\mathbf{x}_j\n\n    where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to\n    target node :obj:`i` (default: :obj:`1`)\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        aggr (str, optional): The aggregation scheme to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"add\"`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    SUPPORTS_FUSED_EDGE_INDEX: Final[bool] = True\n\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: int,\n        aggr: str = 'add',\n        bias: bool = True,\n        **kwargs,\n    ):\n        super().__init__(aggr=aggr, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        self.lin_rel = Linear(in_channels[0], out_channels, bias=bias)\n        self.lin_root = Linear(in_channels[1], out_channels, bias=False)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin_rel.reset_parameters()\n        self.lin_root.reset_parameters()\n\n    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,\n                edge_weight: OptTensor = None, size: Size = None) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,\n                             size=size)\n        out = self.lin_rel(out)\n\n        x_r = x[1]\n        if x_r is not None:\n            out = out + self.lin_root(x_r)\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(\n        self,\n        edge_index: Adj,\n        x: OptPairTensor,\n        edge_weight: OptTensor,\n    ) -> Tensor:\n\n        if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n            return edge_index.matmul(\n                other=x[0],\n                input_value=edge_weight,\n                reduce=self.aggr,\n                transpose=True,\n            )\n\n        return spmm(edge_index, x[0], reduce=self.aggr)\n"
  },
  {
    "path": "torch_geometric/nn/conv/gravnet_conv.py",
    "content": "import warnings\nfrom typing import Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import OptPairTensor  # noqa\nfrom torch_geometric.typing import OptTensor, PairOptTensor, PairTensor\n\nif torch_geometric.typing.WITH_TORCH_CLUSTER:\n    from torch_cluster import knn\nelse:\n    knn = None\n\n\nclass GravNetConv(MessagePassing):\n    r\"\"\"The GravNet operator from the `\"Learning Representations of Irregular\n    Particle-detector Geometry with Distance-weighted Graph\n    Networks\" <https://arxiv.org/abs/1902.07987>`_ paper, where the graph is\n    dynamically constructed using nearest neighbors.\n    The neighbors are constructed in a learnable low-dimensional projection of\n    the feature space.\n    A second projection of the input feature space is then propagated from the\n    neighbors to each vertex using distance weights that are derived by\n    applying a Gaussian function to the distances.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): The number of output channels.\n        space_dimensions (int): The dimensionality of the space used to\n           construct the neighbors; referred to as :math:`S` in the paper.\n        propagate_dimensions (int): The number of features to be propagated\n           between the vertices; referred to as :math:`F_{\\textrm{LR}}` in the\n           paper.\n        k (int): The number of nearest neighbors.\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{in}), (|\\mathcal{V_t}|, F_{in}))`\n          if bipartite,\n          batch vector :math:`(|\\mathcal{V}|)` or\n          :math:`((|\\mathcal{V}_s|), (|\\mathcal{V}_t|))` if bipartite\n          *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int,\n                 space_dimensions: int, propagate_dimensions: int, k: int,\n                 num_workers: Optional[int] = None, **kwargs):\n        super().__init__(aggr=['mean', 'max'], flow='source_to_target',\n                         **kwargs)\n\n        if knn is None:\n            raise ImportError('`GravNetConv` requires `torch-cluster`.')\n\n        if num_workers is not None:\n            warnings.warn(\n                \"'num_workers' attribute in '{self.__class__.__name__}' is \"\n                \"deprecated and will be removed in a future release\",\n                stacklevel=2)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.k = k\n\n        self.lin_s = Linear(in_channels, space_dimensions)\n        self.lin_h = Linear(in_channels, propagate_dimensions)\n\n        self.lin_out1 = Linear(in_channels, out_channels, bias=False)\n        self.lin_out2 = Linear(2 * propagate_dimensions, out_channels)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin_s.reset_parameters()\n        self.lin_h.reset_parameters()\n        self.lin_out1.reset_parameters()\n        self.lin_out2.reset_parameters()\n\n    def forward(\n        self,\n        x: Union[Tensor, PairTensor],\n        batch: Union[OptTensor, Optional[PairTensor]] = None,\n    ) -> Tensor:\n\n        is_bipartite: bool = True\n        if isinstance(x, Tensor):\n            x = (x, x)\n            is_bipartite = False\n\n        if x[0].dim() != 2:\n            raise ValueError(\"Static graphs not supported in 'GravNetConv'\")\n\n        b: PairOptTensor = (None, None)\n        if isinstance(batch, Tensor):\n            b = (batch, batch)\n        elif isinstance(batch, tuple):\n            assert batch is not None\n            b = (batch[0], batch[1])\n\n        h_l: Tensor = self.lin_h(x[0])\n\n        s_l: Tensor = self.lin_s(x[0])\n        s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l\n\n        edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0])\n\n        edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1)\n        edge_weight = torch.exp(-10. * edge_weight)  # 10 gives a better spread\n\n        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n        out = self.propagate(edge_index, x=(h_l, None),\n                             edge_weight=edge_weight,\n                             size=(s_l.size(0), s_r.size(0)))\n\n        return self.lin_out1(x[1]) + self.lin_out2(out)\n\n    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:\n        return x_j * edge_weight.unsqueeze(1)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, k={self.k})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/han_conv.py",
    "content": "from typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor, nn\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense import Linear\nfrom torch_geometric.nn.inits import glorot, reset\nfrom torch_geometric.typing import PairTensor  # noqa\nfrom torch_geometric.typing import Adj, EdgeType, Metadata, NodeType, OptTensor\nfrom torch_geometric.utils import softmax\n\n\ndef group(\n    xs: List[Tensor],\n    q: nn.Parameter,\n    k_lin: nn.Module,\n) -> Tuple[OptTensor, OptTensor]:\n\n    if len(xs) == 0:\n        return None, None\n    else:\n        num_edge_types = len(xs)\n        out = torch.stack(xs)\n        if out.numel() == 0:\n            return out.view(0, out.size(-1)), None\n        attn_score = (q * torch.tanh(k_lin(out)).mean(1)).sum(-1)\n        attn = F.softmax(attn_score, dim=0)\n        out = torch.sum(attn.view(num_edge_types, 1, -1) * out, dim=0)\n        return out, attn\n\n\nclass HANConv(MessagePassing):\n    r\"\"\"The Heterogenous Graph Attention Operator from the\n    `\"Heterogenous Graph Attention Network\"\n    <https://arxiv.org/abs/1903.07293>`_ paper.\n\n    .. note::\n\n        For an example of using HANConv, see `examples/hetero/han_imdb.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        hetero/han_imdb.py>`_.\n\n    Args:\n        in_channels (int or Dict[str, int]): Size of each input sample of every\n            node type, or :obj:`-1` to derive the size from the first input(s)\n            to the forward method.\n        out_channels (int): Size of each output sample.\n        metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata\n            of the heterogeneous graph, *i.e.* its node and edge types given\n            by a list of strings and a list of string triplets, respectively.\n            See :meth:`torch_geometric.data.HeteroData.metadata` for more\n            information.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        negative_slope (float, optional): LeakyReLU angle of the negative\n            slope. (default: :obj:`0.2`)\n        dropout (float, optional): Dropout probability of the normalized\n            attention coefficients which exposes each node to a stochastically\n            sampled neighborhood during training. (default: :obj:`0`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Dict[str, int]],\n        out_channels: int,\n        metadata: Metadata,\n        heads: int = 1,\n        negative_slope=0.2,\n        dropout: float = 0.0,\n        **kwargs,\n    ):\n        super().__init__(aggr='add', node_dim=0, **kwargs)\n\n        if not isinstance(in_channels, dict):\n            in_channels = {node_type: in_channels for node_type in metadata[0]}\n\n        self.heads = heads\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.negative_slope = negative_slope\n        self.metadata = metadata\n        self.dropout = dropout\n        self.k_lin = nn.Linear(out_channels, out_channels)\n        self.q = nn.Parameter(torch.empty(1, out_channels))\n\n        self.proj = nn.ModuleDict()\n        for node_type, in_channels in self.in_channels.items():\n            self.proj[node_type] = Linear(in_channels, out_channels)\n\n        self.lin_src = nn.ParameterDict()\n        self.lin_dst = nn.ParameterDict()\n        dim = out_channels // heads\n        for edge_type in metadata[1]:\n            edge_type = '__'.join(edge_type)\n            self.lin_src[edge_type] = nn.Parameter(torch.empty(1, heads, dim))\n            self.lin_dst[edge_type] = nn.Parameter(torch.empty(1, heads, dim))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        reset(self.proj)\n        glorot(self.lin_src)\n        glorot(self.lin_dst)\n        self.k_lin.reset_parameters()\n        glorot(self.q)\n\n    def forward(\n        self,\n        x_dict: Dict[NodeType, Tensor],\n        edge_index_dict: Dict[EdgeType, Adj],\n        return_semantic_attention_weights: bool = False,\n    ) -> Union[Dict[NodeType, OptTensor], Tuple[Dict[NodeType, OptTensor],\n                                                Dict[NodeType, OptTensor]]]:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x_dict (Dict[str, torch.Tensor]): A dictionary holding node feature\n                information for each individual node type.\n            edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A\n                dictionary holding graph connectivity information for each\n                individual edge type, either as a :class:`torch.Tensor` of\n                shape :obj:`[2, num_edges]` or a\n                :class:`torch_sparse.SparseTensor`.\n            return_semantic_attention_weights (bool, optional): If set to\n                :obj:`True`, will additionally return the semantic-level\n                attention weights for each destination node type.\n                (default: :obj:`False`)\n        \"\"\"\n        H, D = self.heads, self.out_channels // self.heads\n        x_node_dict, out_dict = {}, {}\n\n        # Iterate over node types:\n        for node_type, x in x_dict.items():\n            x_node_dict[node_type] = self.proj[node_type](x).view(-1, H, D)\n            out_dict[node_type] = []\n\n        # Iterate over edge types:\n        for edge_type, edge_index in edge_index_dict.items():\n            src_type, _, dst_type = edge_type\n            edge_type = '__'.join(edge_type)\n            lin_src = self.lin_src[edge_type]\n            lin_dst = self.lin_dst[edge_type]\n            x_src = x_node_dict[src_type]\n            x_dst = x_node_dict[dst_type]\n            alpha_src = (x_src * lin_src).sum(dim=-1)\n            alpha_dst = (x_dst * lin_dst).sum(dim=-1)\n            # propagate_type: (x: PairTensor, alpha: PairTensor)\n            out = self.propagate(edge_index, x=(x_src, x_dst),\n                                 alpha=(alpha_src, alpha_dst))\n\n            out = F.relu(out)\n            out_dict[dst_type].append(out)\n\n        # iterate over node types:\n        semantic_attn_dict = {}\n        for node_type, outs in out_dict.items():\n            out, attn = group(outs, self.q, self.k_lin)\n            out_dict[node_type] = out\n            semantic_attn_dict[node_type] = attn\n\n        if return_semantic_attention_weights:\n            return out_dict, semantic_attn_dict\n\n        return out_dict\n\n    def message(self, x_j: Tensor, alpha_i: Tensor, alpha_j: Tensor,\n                index: Tensor, ptr: Optional[Tensor],\n                size_i: Optional[int]) -> Tensor:\n\n        alpha = alpha_j + alpha_i\n        alpha = F.leaky_relu(alpha, self.negative_slope)\n        alpha = softmax(alpha, index, ptr, size_i)\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n        out = x_j * alpha.view(-1, self.heads, 1)\n        return out.view(-1, self.out_channels)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.out_channels}, '\n                f'heads={self.heads})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/heat_conv.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Embedding\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import HeteroLinear, Linear\nfrom torch_geometric.typing import Adj, OptTensor\nfrom torch_geometric.utils import softmax\n\n\nclass HEATConv(MessagePassing):\n    r\"\"\"The heterogeneous edge-enhanced graph attentional operator from the\n    `\"Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent\n    Trajectory Prediction\" <https://arxiv.org/abs/2106.07161>`_ paper.\n\n    :class:`HEATConv` enhances :class:`~torch_geometric.nn.conv.GATConv` by:\n\n    1. type-specific transformations of nodes of different types\n    2. edge type and edge feature incorporation, in which edges are assumed to\n       have different types but contain the same kind of attributes\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        num_node_types (int): The number of node types.\n        num_edge_types (int): The number of edge types.\n        edge_type_emb_dim (int): The embedding size of edge types.\n        edge_dim (int): Edge feature dimensionality.\n        edge_attr_emb_dim (int): The embedding size of edge features.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        concat (bool, optional): If set to :obj:`False`, the multi-head\n            attentions are averaged instead of concatenated.\n            (default: :obj:`True`)\n        negative_slope (float, optional): LeakyReLU angle of the negative\n            slope. (default: :obj:`0.2`)\n        dropout (float, optional): Dropout probability of the normalized\n            attention coefficients which exposes each node to a stochastically\n            sampled neighborhood during training. (default: :obj:`0`)\n        root_weight (bool, optional): If set to :obj:`False`, the layer will\n            not add transformed root node features to the output.\n            (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          node types :math:`(|\\mathcal{V}|)`,\n          edge types :math:`(|\\mathcal{E}|)`,\n          edge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int,\n                 num_node_types: int, num_edge_types: int,\n                 edge_type_emb_dim: int, edge_dim: int, edge_attr_emb_dim: int,\n                 heads: int = 1, concat: bool = True,\n                 negative_slope: float = 0.2, dropout: float = 0.0,\n                 root_weight: bool = True, bias: bool = True, **kwargs):\n\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(node_dim=0, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.concat = concat\n        self.negative_slope = negative_slope\n        self.dropout = dropout\n        self.root_weight = root_weight\n\n        self.hetero_lin = HeteroLinear(in_channels, out_channels,\n                                       num_node_types, bias=bias)\n\n        self.edge_type_emb = Embedding(num_edge_types, edge_type_emb_dim)\n        self.edge_attr_emb = Linear(edge_dim, edge_attr_emb_dim, bias=False)\n\n        self.att = Linear(\n            2 * out_channels + edge_type_emb_dim + edge_attr_emb_dim,\n            self.heads, bias=False)\n\n        self.lin = Linear(out_channels + edge_attr_emb_dim, out_channels,\n                          bias=bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.hetero_lin.reset_parameters()\n        self.edge_type_emb.reset_parameters()\n        self.edge_attr_emb.reset_parameters()\n        self.att.reset_parameters()\n        self.lin.reset_parameters()\n\n    def forward(self, x: Tensor, edge_index: Adj, node_type: Tensor,\n                edge_type: Tensor, edge_attr: OptTensor = None) -> Tensor:\n\n        x = self.hetero_lin(x, node_type)\n\n        edge_type_emb = F.leaky_relu(self.edge_type_emb(edge_type),\n                                     self.negative_slope)\n\n        # propagate_type: (x: Tensor, edge_type_emb: Tensor,\n        #                  edge_attr: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_type_emb=edge_type_emb,\n                             edge_attr=edge_attr)\n\n        if self.concat:\n            if self.root_weight:\n                out = out + x.view(-1, 1, self.out_channels)\n            out = out.view(-1, self.heads * self.out_channels)\n        else:\n            out = out.mean(dim=1)\n            if self.root_weight:\n                out = out + x\n\n        return out\n\n    def message(self, x_i: Tensor, x_j: Tensor, edge_type_emb: Tensor,\n                edge_attr: Tensor, index: Tensor, ptr: OptTensor,\n                size_i: Optional[int]) -> Tensor:\n\n        edge_attr = F.leaky_relu(self.edge_attr_emb(edge_attr),\n                                 self.negative_slope)\n\n        alpha = torch.cat([x_i, x_j, edge_type_emb, edge_attr], dim=-1)\n        alpha = F.leaky_relu(self.att(alpha), self.negative_slope)\n        alpha = softmax(alpha, index, ptr, size_i)\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n\n        out = self.lin(torch.cat([x_j, edge_attr], dim=-1)).unsqueeze(-2)\n        return out * alpha.unsqueeze(-1)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/hetero_conv.py",
    "content": "import warnings\nfrom typing import Dict, List, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.module_dict import ModuleDict\nfrom torch_geometric.typing import EdgeType, NodeType\nfrom torch_geometric.utils.hetero import check_add_self_loops\n\n\ndef group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]:\n    if len(xs) == 0:\n        return None\n    elif aggr is None:\n        return torch.stack(xs, dim=1)\n    elif len(xs) == 1:\n        return xs[0]\n    elif aggr == \"cat\":\n        return torch.cat(xs, dim=-1)\n    else:\n        out = torch.stack(xs, dim=0)\n        out = getattr(torch, aggr)(out, dim=0)\n        out = out[0] if isinstance(out, tuple) else out\n        return out\n\n\nclass HeteroConv(torch.nn.Module):\n    r\"\"\"A generic wrapper for computing graph convolution on heterogeneous\n    graphs.\n    This layer will pass messages from source nodes to target nodes based on\n    the bipartite GNN layer given for a specific edge type.\n    If multiple relations point to the same destination, their results will be\n    aggregated according to :attr:`aggr`.\n    In comparison to :meth:`torch_geometric.nn.to_hetero`, this layer is\n    especially useful if you want to apply different message passing modules\n    for different edge types.\n\n    .. code-block:: python\n\n        hetero_conv = HeteroConv({\n            ('paper', 'cites', 'paper'): GCNConv(-1, 64),\n            ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64),\n            ('paper', 'written_by', 'author'): GATConv((-1, -1), 64),\n        }, aggr='sum')\n\n        out_dict = hetero_conv(x_dict, edge_index_dict)\n\n        print(list(out_dict.keys()))\n        >>> ['paper', 'author']\n\n    Args:\n        convs (Dict[Tuple[str, str, str], MessagePassing]): A dictionary\n            holding a bipartite\n            :class:`~torch_geometric.nn.conv.MessagePassing` layer for each\n            individual edge type.\n        aggr (str, optional): The aggregation scheme to use for grouping node\n            embeddings generated by different relations\n            (:obj:`\"sum\"`, :obj:`\"mean\"`, :obj:`\"min\"`, :obj:`\"max\"`,\n            :obj:`\"cat\"`, :obj:`None`). (default: :obj:`\"sum\"`)\n    \"\"\"\n    def __init__(\n        self,\n        convs: Dict[EdgeType, MessagePassing],\n        aggr: Optional[str] = \"sum\",\n    ):\n        super().__init__()\n\n        for edge_type, module in convs.items():\n            check_add_self_loops(module, [edge_type])\n\n        src_node_types = {key[0] for key in convs.keys()}\n        dst_node_types = {key[-1] for key in convs.keys()}\n        if len(src_node_types - dst_node_types) > 0:\n            warnings.warn(\n                f\"There exist node types ({src_node_types - dst_node_types}) \"\n                f\"whose representations do not get updated during message \"\n                f\"passing as they do not occur as destination type in any \"\n                f\"edge type. This may lead to unexpected behavior.\",\n                stacklevel=2)\n\n        self.convs = ModuleDict(convs)\n        self.aggr = aggr\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for conv in self.convs.values():\n            conv.reset_parameters()\n\n    def forward(\n        self,\n        *args_dict,\n        **kwargs_dict,\n    ) -> Dict[NodeType, Tensor]:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x_dict (Dict[str, torch.Tensor]): A dictionary holding node feature\n                information for each individual node type.\n            edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A\n                dictionary holding graph connectivity information for each\n                individual edge type, either as a :class:`torch.Tensor` of\n                shape :obj:`[2, num_edges]` or a\n                :class:`torch_sparse.SparseTensor`.\n            *args_dict (optional): Additional forward arguments of individual\n                :class:`torch_geometric.nn.conv.MessagePassing` layers.\n            **kwargs_dict (optional): Additional forward arguments of\n                individual :class:`torch_geometric.nn.conv.MessagePassing`\n                layers.\n                For example, if a specific GNN layer at edge type\n                :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a\n                forward argument, then you can pass them to\n                :meth:`~torch_geometric.nn.conv.HeteroConv.forward` via\n                :obj:`edge_attr_dict = { edge_type: edge_attr }`.\n        \"\"\"\n        out_dict: Dict[str, List[Tensor]] = {}\n\n        for edge_type, conv in self.convs.items():\n            src, rel, dst = edge_type\n\n            has_edge_level_arg = False\n\n            args = []\n            for value_dict in args_dict:\n                if edge_type in value_dict:\n                    has_edge_level_arg = True\n                    args.append(value_dict[edge_type])\n                elif src == dst and src in value_dict:\n                    args.append(value_dict[src])\n                elif src in value_dict or dst in value_dict:\n                    args.append((\n                        value_dict.get(src, None),\n                        value_dict.get(dst, None),\n                    ))\n\n            kwargs = {}\n            for arg, value_dict in kwargs_dict.items():\n                if not arg.endswith('_dict'):\n                    raise ValueError(\n                        f\"Keyword arguments in '{self.__class__.__name__}' \"\n                        f\"need to end with '_dict' (got '{arg}')\")\n\n                arg = arg[:-5]  # `{*}_dict`\n                if edge_type in value_dict:\n                    has_edge_level_arg = True\n                    kwargs[arg] = value_dict[edge_type]\n                elif src == dst and src in value_dict:\n                    kwargs[arg] = value_dict[src]\n                elif src in value_dict or dst in value_dict:\n                    kwargs[arg] = (\n                        value_dict.get(src, None),\n                        value_dict.get(dst, None),\n                    )\n\n            if not has_edge_level_arg:\n                continue\n\n            out = conv(*args, **kwargs)\n\n            if dst not in out_dict:\n                out_dict[dst] = [out]\n            else:\n                out_dict[dst].append(out)\n\n        for key, value in out_dict.items():\n            out_dict[key] = group(value, self.aggr)\n\n        return out_dict\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(num_relations={len(self.convs)})'\n"
  },
  {
    "path": "torch_geometric/nn/conv/hgt_conv.py",
    "content": "import math\nfrom typing import Dict, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense import HeteroDictLinear, HeteroLinear\nfrom torch_geometric.nn.inits import ones\nfrom torch_geometric.nn.parameter_dict import ParameterDict\nfrom torch_geometric.typing import Adj, EdgeType, Metadata, NodeType\nfrom torch_geometric.utils import softmax\nfrom torch_geometric.utils.hetero import construct_bipartite_edge_index\n\n\nclass HGTConv(MessagePassing):\n    r\"\"\"The Heterogeneous Graph Transformer (HGT) operator from the\n    `\"Heterogeneous Graph Transformer\" <https://arxiv.org/abs/2003.01332>`_\n    paper.\n\n    .. note::\n\n        For an example of using HGT, see `examples/hetero/hgt_dblp.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        hetero/hgt_dblp.py>`_.\n\n    Args:\n        in_channels (int or Dict[str, int]): Size of each input sample of every\n            node type, or :obj:`-1` to derive the size from the first input(s)\n            to the forward method.\n        out_channels (int): Size of each output sample.\n        metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata\n            of the heterogeneous graph, *i.e.* its node and edge types given\n            by a list of strings and a list of string triplets, respectively.\n            See :meth:`torch_geometric.data.HeteroData.metadata` for more\n            information.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Dict[str, int]],\n        out_channels: int,\n        metadata: Metadata,\n        heads: int = 1,\n        **kwargs,\n    ):\n        super().__init__(aggr='add', node_dim=0, **kwargs)\n\n        if out_channels % heads != 0:\n            raise ValueError(f\"'out_channels' (got {out_channels}) must be \"\n                             f\"divisible by the number of heads (got {heads})\")\n\n        if not isinstance(in_channels, dict):\n            in_channels = {node_type: in_channels for node_type in metadata[0]}\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.node_types = metadata[0]\n        self.edge_types = metadata[1]\n        self.edge_types_map = {\n            edge_type: i\n            for i, edge_type in enumerate(metadata[1])\n        }\n\n        self.dst_node_types = {key[-1] for key in self.edge_types}\n\n        self.kqv_lin = HeteroDictLinear(self.in_channels,\n                                        self.out_channels * 3)\n\n        self.out_lin = HeteroDictLinear(self.out_channels, self.out_channels,\n                                        types=self.node_types)\n\n        dim = out_channels // heads\n        num_types = heads * len(self.edge_types)\n\n        self.k_rel = HeteroLinear(dim, dim, num_types, bias=False,\n                                  is_sorted=True)\n        self.v_rel = HeteroLinear(dim, dim, num_types, bias=False,\n                                  is_sorted=True)\n\n        self.skip = ParameterDict({\n            node_type: Parameter(torch.empty(1))\n            for node_type in self.node_types\n        })\n\n        self.p_rel = ParameterDict()\n        for edge_type in self.edge_types:\n            edge_type = '__'.join(edge_type)\n            self.p_rel[edge_type] = Parameter(torch.empty(1, heads))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.kqv_lin.reset_parameters()\n        self.out_lin.reset_parameters()\n        self.k_rel.reset_parameters()\n        self.v_rel.reset_parameters()\n        ones(self.skip)\n        ones(self.p_rel)\n\n    def _cat(self, x_dict: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, int]]:\n        \"\"\"Concatenates a dictionary of features.\"\"\"\n        cumsum = 0\n        outs: List[Tensor] = []\n        offset: Dict[str, int] = {}\n        for key, x in x_dict.items():\n            outs.append(x)\n            offset[key] = cumsum\n            cumsum += x.size(0)\n        return torch.cat(outs, dim=0), offset\n\n    def _construct_src_node_feat(\n        self, k_dict: Dict[str, Tensor], v_dict: Dict[str, Tensor],\n        edge_index_dict: Dict[EdgeType, Adj]\n    ) -> Tuple[Tensor, Tensor, Dict[EdgeType, int]]:\n        \"\"\"Constructs the source node representations.\"\"\"\n        cumsum = 0\n        num_edge_types = len(self.edge_types)\n        H, D = self.heads, self.out_channels // self.heads\n\n        # Flatten into a single tensor with shape [num_edge_types * heads, D]:\n        ks: List[Tensor] = []\n        vs: List[Tensor] = []\n        type_list: List[Tensor] = []\n        offset: Dict[EdgeType] = {}\n        for edge_type in edge_index_dict.keys():\n            src = edge_type[0]\n            N = k_dict[src].size(0)\n            offset[edge_type] = cumsum\n            cumsum += N\n\n            # construct type_vec for curr edge_type with shape [H, D]\n            edge_type_offset = self.edge_types_map[edge_type]\n            type_vec = torch.arange(H, dtype=torch.long).view(-1, 1).repeat(\n                1, N) * num_edge_types + edge_type_offset\n\n            type_list.append(type_vec)\n            ks.append(k_dict[src])\n            vs.append(v_dict[src])\n\n        ks = torch.cat(ks, dim=0).transpose(0, 1).reshape(-1, D)\n        vs = torch.cat(vs, dim=0).transpose(0, 1).reshape(-1, D)\n        type_vec = torch.cat(type_list, dim=1).flatten()\n\n        k = self.k_rel(ks, type_vec).view(H, -1, D).transpose(0, 1)\n        v = self.v_rel(vs, type_vec).view(H, -1, D).transpose(0, 1)\n\n        return k, v, offset\n\n    def forward(\n        self,\n        x_dict: Dict[NodeType, Tensor],\n        edge_index_dict: Dict[EdgeType, Adj]  # Support both.\n    ) -> Dict[NodeType, Optional[Tensor]]:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x_dict (Dict[str, torch.Tensor]): A dictionary holding input node\n                features  for each individual node type.\n            edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A\n                dictionary holding graph connectivity information for each\n                individual edge type, either as a :class:`torch.Tensor` of\n                shape :obj:`[2, num_edges]` or a\n                :class:`torch_sparse.SparseTensor`.\n\n        :rtype: :obj:`Dict[str, Optional[torch.Tensor]]` - The output node\n            embeddings for each node type.\n            In case a node type does not receive any message, its output will\n            be set to :obj:`None`.\n        \"\"\"\n        F = self.out_channels\n        H = self.heads\n        D = F // H\n\n        k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {}\n\n        # Compute K, Q, V over node types:\n        kqv_dict = self.kqv_lin(x_dict)\n        for key, val in kqv_dict.items():\n            k, q, v = torch.tensor_split(val, 3, dim=1)\n            k_dict[key] = k.view(-1, H, D)\n            q_dict[key] = q.view(-1, H, D)\n            v_dict[key] = v.view(-1, H, D)\n\n        q, dst_offset = self._cat(q_dict)\n        k, v, src_offset = self._construct_src_node_feat(\n            k_dict, v_dict, edge_index_dict)\n\n        edge_index, edge_attr = construct_bipartite_edge_index(\n            edge_index_dict, src_offset, dst_offset, edge_attr_dict=self.p_rel,\n            num_nodes=k.size(0))\n\n        out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr)\n\n        # Reconstruct output node embeddings dict:\n        for node_type, start_offset in dst_offset.items():\n            end_offset = start_offset + q_dict[node_type].size(0)\n            if node_type in self.dst_node_types:\n                out_dict[node_type] = out[start_offset:end_offset]\n\n        # Transform output node embeddings:\n        a_dict = self.out_lin({\n            k:\n            torch.nn.functional.gelu(v) if v is not None else v\n            for k, v in out_dict.items()\n        })\n\n        # Iterate over node types:\n        for node_type, out in out_dict.items():\n            out = a_dict[node_type]\n\n            if out.size(-1) == x_dict[node_type].size(-1):\n                alpha = self.skip[node_type].sigmoid()\n                out = alpha * out + (1 - alpha) * x_dict[node_type]\n            out_dict[node_type] = out\n\n        return out_dict\n\n    def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, edge_attr: Tensor,\n                index: Tensor, ptr: Optional[Tensor],\n                size_i: Optional[int]) -> Tensor:\n        alpha = (q_i * k_j).sum(dim=-1) * edge_attr\n        alpha = alpha / math.sqrt(q_i.size(-1))\n        alpha = softmax(alpha, index, ptr, size_i)\n        out = v_j * alpha.view(-1, self.heads, 1)\n        return out.view(-1, self.out_channels)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(-1, {self.out_channels}, '\n                f'heads={self.heads})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/hypergraph_conv.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.experimental import disable_dynamic_shapes\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import glorot, zeros\nfrom torch_geometric.utils import scatter, softmax\n\n\nclass HypergraphConv(MessagePassing):\n    r\"\"\"The hypergraph convolutional operator from the `\"Hypergraph Convolution\n    and Hypergraph Attention\" <https://arxiv.org/abs/1901.08150>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = \\mathbf{D}^{-1} \\mathbf{H} \\mathbf{W}\n        \\mathbf{B}^{-1} \\mathbf{H}^{\\top} \\mathbf{X} \\mathbf{\\Theta}\n\n    where :math:`\\mathbf{H} \\in {\\{ 0, 1 \\}}^{N \\times M}` is the incidence\n    matrix, :math:`\\mathbf{W} \\in \\mathbb{R}^M` is the diagonal hyperedge\n    weight matrix, and\n    :math:`\\mathbf{D}` and :math:`\\mathbf{B}` are the corresponding degree\n    matrices.\n\n    For example, in the hypergraph scenario\n    :math:`\\mathcal{G} = (\\mathcal{V}, \\mathcal{E})` with\n    :math:`\\mathcal{V} = \\{ 0, 1, 2, 3 \\}` and\n    :math:`\\mathcal{E} = \\{ \\{ 0, 1, 2 \\}, \\{ 1, 2, 3 \\} \\}`, the\n    :obj:`hyperedge_index` is represented as:\n\n    .. code-block:: python\n\n        hyperedge_index = torch.tensor([\n            [0, 1, 2, 1, 2, 3],\n            [0, 0, 0, 1, 1, 1],\n        ])\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        use_attention (bool, optional): If set to :obj:`True`, attention\n            will be added to this layer. (default: :obj:`False`)\n        attention_mode (str, optional): The mode on how to compute attention.\n            If set to :obj:`\"node\"`, will compute attention scores of nodes\n            within all nodes belonging to the same hyperedge.\n            If set to :obj:`\"edge\"`, will compute attention scores of nodes\n            across all edges holding this node belongs to.\n            (default: :obj:`\"node\"`)\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        concat (bool, optional): If set to :obj:`False`, the multi-head\n            attentions are averaged instead of concatenated.\n            (default: :obj:`True`)\n        negative_slope (float, optional): LeakyReLU angle of the negative\n            slope. (default: :obj:`0.2`)\n        dropout (float, optional): Dropout probability of the normalized\n            attention coefficients which exposes each node to a stochastically\n            sampled neighborhood during training. (default: :obj:`0`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          hyperedge indices :math:`(|\\mathcal{V}|, |\\mathcal{E}|)`,\n          hyperedge weights :math:`(|\\mathcal{E}|)` *(optional)*\n          hyperedge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        use_attention: bool = False,\n        attention_mode: str = 'node',\n        heads: int = 1,\n        concat: bool = True,\n        negative_slope: float = 0.2,\n        dropout: float = 0,\n        bias: bool = True,\n        **kwargs,\n    ):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(flow='source_to_target', node_dim=0, **kwargs)\n\n        assert attention_mode in ['node', 'edge']\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.use_attention = use_attention\n        self.attention_mode = attention_mode\n\n        if self.use_attention:\n            self.heads = heads\n            self.concat = concat\n            self.negative_slope = negative_slope\n            self.dropout = dropout\n            self.lin = Linear(in_channels, heads * out_channels, bias=False,\n                              weight_initializer='glorot')\n            self.att = Parameter(torch.empty(1, heads, 2 * out_channels))\n        else:\n            self.heads = 1\n            self.concat = True\n            self.lin = Linear(in_channels, out_channels, bias=False,\n                              weight_initializer='glorot')\n\n        if bias and concat:\n            self.bias = Parameter(torch.empty(heads * out_channels))\n        elif bias and not concat:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin.reset_parameters()\n        if self.use_attention:\n            glorot(self.att)\n        zeros(self.bias)\n\n    @disable_dynamic_shapes(required_args=['num_edges'])\n    def forward(self, x: Tensor, hyperedge_index: Tensor,\n                hyperedge_weight: Optional[Tensor] = None,\n                hyperedge_attr: Optional[Tensor] = None,\n                num_edges: Optional[int] = None) -> Tensor:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor): Node feature matrix\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}`.\n            hyperedge_index (torch.Tensor): The hyperedge indices, *i.e.*\n                the sparse incidence matrix\n                :math:`\\mathbf{H} \\in {\\{ 0, 1 \\}}^{N \\times M}` mapping from\n                nodes to edges.\n            hyperedge_weight (torch.Tensor, optional): Hyperedge weights\n                :math:`\\mathbf{W} \\in \\mathbb{R}^M`. (default: :obj:`None`)\n            hyperedge_attr (torch.Tensor, optional): Hyperedge feature matrix\n                in :math:`\\mathbb{R}^{M \\times F}`.\n                These features only need to get passed in case\n                :obj:`use_attention=True`. (default: :obj:`None`)\n            num_edges (int, optional) : The number of edges :math:`M`.\n                (default: :obj:`None`)\n        \"\"\"\n        num_nodes = x.size(0)\n\n        if num_edges is None:\n            num_edges = 0\n            if hyperedge_index.numel() > 0:\n                num_edges = int(hyperedge_index[1].max()) + 1\n\n        if hyperedge_weight is None:\n            hyperedge_weight = x.new_ones(num_edges)\n\n        x = self.lin(x)\n\n        alpha = None\n        if self.use_attention:\n            assert hyperedge_attr is not None\n            x = x.view(-1, self.heads, self.out_channels)\n            hyperedge_attr = self.lin(hyperedge_attr)\n            hyperedge_attr = hyperedge_attr.view(-1, self.heads,\n                                                 self.out_channels)\n            x_i = x[hyperedge_index[0]]\n            x_j = hyperedge_attr[hyperedge_index[1]]\n            alpha = (torch.cat([x_i, x_j], dim=-1) * self.att).sum(dim=-1)\n            alpha = F.leaky_relu(alpha, self.negative_slope)\n            if self.attention_mode == 'node':\n                alpha = softmax(alpha, hyperedge_index[1], num_nodes=num_edges)\n            else:\n                alpha = softmax(alpha, hyperedge_index[0], num_nodes=num_nodes)\n            alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n\n        D = scatter(hyperedge_weight[hyperedge_index[1]], hyperedge_index[0],\n                    dim=0, dim_size=num_nodes, reduce='sum')\n        D = 1.0 / D\n        D[D == float(\"inf\")] = 0\n\n        B = scatter(x.new_ones(hyperedge_index.size(1)), hyperedge_index[1],\n                    dim=0, dim_size=num_edges, reduce='sum')\n        B = 1.0 / B\n        B[B == float(\"inf\")] = 0\n\n        out = self.propagate(hyperedge_index, x=x, norm=B, alpha=alpha,\n                             size=(num_nodes, num_edges))\n        out = self.propagate(hyperedge_index.flip([0]), x=out, norm=D,\n                             alpha=alpha, size=(num_edges, num_nodes))\n\n        if self.concat is True:\n            out = out.view(-1, self.heads * self.out_channels)\n        else:\n            out = out.mean(dim=1)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, norm_i: Tensor, alpha: Tensor) -> Tensor:\n        H, F = self.heads, self.out_channels\n\n        out = norm_i.view(-1, 1, 1) * x_j.view(-1, H, F)\n\n        if alpha is not None:\n            out = alpha.view(-1, self.heads, 1) * out\n\n        return out\n"
  },
  {
    "path": "torch_geometric/nn/conv/le_conv.py",
    "content": "from typing import Tuple, Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import Adj, OptTensor, PairTensor\n\n\nclass LEConv(MessagePassing):\n    r\"\"\"The local extremum graph neural network operator from the\n    `\"ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph\n    Representations\" <https://arxiv.org/abs/1911.07979>`_ paper.\n\n    :class:`LEConv` finds the importance of nodes with respect to their\n    neighbors using the difference operator:\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{x}_i \\cdot \\mathbf{\\Theta}_1 +\n        \\sum_{j \\in \\mathcal{N}(i)} e_{j,i} \\cdot\n        (\\mathbf{\\Theta}_2 \\mathbf{x}_i - \\mathbf{\\Theta}_3 \\mathbf{x}_j)\n\n    where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to\n    target node :obj:`i` (default: :obj:`1`)\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        bias (bool, optional): If set to :obj:`False`, the layer will\n            not learn an additive bias. (default: :obj:`True`).\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, in_channels: Union[int, Tuple[int, int]],\n                 out_channels: int, bias: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        self.lin1 = Linear(in_channels[0], out_channels, bias=bias)\n        self.lin2 = Linear(in_channels[1], out_channels, bias=False)\n        self.lin3 = Linear(in_channels[1], out_channels, bias=bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin1.reset_parameters()\n        self.lin2.reset_parameters()\n        self.lin3.reset_parameters()\n\n    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,\n                edge_weight: OptTensor = None) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        a = self.lin1(x[0])\n        b = self.lin2(x[1])\n\n        # propagate_type: (a: Tensor, b: Tensor, edge_weight: OptTensor)\n        out = self.propagate(edge_index, a=a, b=b, edge_weight=edge_weight)\n\n        return out + self.lin3(x[1])\n\n    def message(self, a_j: Tensor, b_i: Tensor,\n                edge_weight: OptTensor) -> Tensor:\n        out = a_j - b_i\n        return out if edge_weight is None else out * edge_weight.view(-1, 1)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/lg_conv.py",
    "content": "from torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass LGConv(MessagePassing):\n    r\"\"\"The Light Graph Convolution (LGC) operator from the `\"LightGCN:\n    Simplifying and Powering Graph Convolution Network for Recommendation\"\n    <https://arxiv.org/abs/2002.02126>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\sum_{j \\in \\mathcal{N}(i)}\n        \\frac{e_{j,i}}{\\sqrt{\\deg(i)\\deg(j)}} \\mathbf{x}_j\n\n    Args:\n        normalize (bool, optional): If set to :obj:`False`, output features\n            will not be normalized via symmetric normalization.\n            (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F)`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F)`\n    \"\"\"\n    def __init__(self, normalize: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n        self.normalize = normalize\n\n    def forward(self, x: Tensor, edge_index: Adj,\n                edge_weight: OptTensor = None) -> Tensor:\n\n        if self.normalize and isinstance(edge_index, Tensor):\n            out = gcn_norm(edge_index, edge_weight, x.size(self.node_dim),\n                           add_self_loops=False, flow=self.flow, dtype=x.dtype)\n            edge_index, edge_weight = out\n        elif self.normalize and isinstance(edge_index, SparseTensor):\n            edge_index = gcn_norm(edge_index, None, x.size(self.node_dim),\n                                  add_self_loops=False, flow=self.flow,\n                                  dtype=x.dtype)\n\n        # propagate_type: (x: Tensor, edge_weight: OptTensor)\n        return self.propagate(edge_index, x=x, edge_weight=edge_weight)\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n"
  },
  {
    "path": "torch_geometric/nn/conv/meshcnn_conv.py",
    "content": "# The below is to suppress the warning on torch.nn.conv.MeshCNNConv::update\n# pyright: reportIncompatibleMethodOverride=false\nimport warnings\nfrom typing import Optional\n\nimport torch\nfrom torch.nn import Linear, Module, ModuleList\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.typing import Tensor\n\n\nclass MeshCNNConv(MessagePassing):\n    r\"\"\"The convolutional layer introduced by the paper\n    `\"MeshCNN: A Network With An Edge\" <https://arxiv.org/abs/1809.05910>`_.\n\n    Recall that, given a set of categories :math:`C`,\n    MeshCNN is a function that takes as its input\n    a triangular mesh\n    :math:`\\mathcal{m} = (V, F) \\in \\mathbb{R}^{|V| \\times 3} \\times\n    \\{0,...,|V|-1\\}^{3 \\times |F|}`, and returns as its output\n    a :math:`|C|`-dimensional vector, whose :math:`i` th component denotes\n    the probability of the input mesh belonging to category :math:`c_i \\in C`.\n\n    Let :math:`X^{(k)} \\in \\mathbb{R}^{|E| \\times \\text{Dim-Out}(k)}`\n    denote the output value of the prior (e.g. :math:`k` th )\n    layer of our neural network. The :math:`i` th row of :math:`X^{(k)}` is a\n    :math:`\\text{Dim-Out}(k)`-dimensional vector that represents the features\n    computed by the :math:`k` th layer for edge :math:`e_i` of the input mesh\n    :math:`\\mathcal{m}`. Let :math:`A \\in \\{0, ..., |E|-1\\}^{2 \\times 4*|E|}`\n    denote the *edge adjacency* matrix of our input mesh :math:`\\mathcal{m}`.\n    The :math:`j` th column of :math:`A` returns a pair of indices\n    :math:`k,l \\in \\{0,...,|E|-1\\}`, which means that edge\n    :math:`e_k` is adjacent to edge :math:`e_l`\n    in our input mesh :math:`\\mathcal{m}`.\n    The definition of edge adjacency in a triangular\n    mesh is illustrated in Figure 1.\n    In a triangular\n    mesh, each edge :math:`e_i` is expected to be adjacent to exactly :math:`4`\n    neighboring edges, hence the number of columns of :math:`A`: :math:`4*|E|`.\n    We write *the neighborhood* of edge :math:`e_i` as\n    :math:`\\mathcal{N}(i) = (a(i), b(i), c(i), d(i))` where\n\n    1. :math:`a(i)` denotes the index of the *first* counter-clockwise\n    edge of the face *above* :math:`e_i`.\n\n    2. :math:`b(i)` denotes the index of the *second* counter-clockwise\n    edge of the face *above* :math:`e_i`.\n\n    3. :math:`c(i)` denotes the index of the *first* counter-clockwise edge\n    of the face *below* :math:`e_i`.\n\n    4. :math:`d(i)` denotes the index of the *second*\n    counter-clockwise edge of the face *below* :math:`e_i`.\n\n    .. figure:: ../_figures/meshcnn_edge_adjacency.svg\n        :align: center\n        :width: 80%\n\n        **Figure 1:** The neighbors of edge :math:`\\mathbf{e_1}`\n        are :math:`\\mathbf{e_2}, \\mathbf{e_3}, \\mathbf{e_4}` and\n        :math:`\\mathbf{e_5}`, respectively.\n        We write this as\n        :math:`\\mathcal{N}(1) = (a(1), b(1), c(1), d(1)) = (2, 3, 4, 5)`\n\n\n    Because of this ordering constraint, :obj:`MeshCNNConv` **requires\n    that the columns of** :math:`A`\n    **be ordered in the following way**:\n\n    .. math::\n        &A[:,0] = (0, \\text{The index of the \"a\" edge for edge } 0) \\\\\n        &A[:,1] = (0, \\text{The index of the \"b\" edge for edge } 0) \\\\\n        &A[:,2] = (0, \\text{The index of the \"c\" edge for edge } 0) \\\\\n        &A[:,3] = (0, \\text{The index of the \"d\" edge for edge } 0) \\\\\n        \\vdots \\\\\n        &A[:,4*|E|-4] =\n            \\bigl(|E|-1,\n                a\\bigl(|E|-1\\bigr)\\bigr) \\\\\n        &A[:,4*|E|-3] =\n            \\bigl(|E|-1,\n                b\\bigl(|E|-1\\bigr)\\bigr) \\\\\n        &A[:,4*|E|-2] =\n            \\bigl(|E|-1,\n                c\\bigl(|E|-1\\bigr)\\bigr) \\\\\n        &A[:,4*|E|-1] =\n            \\bigl(|E|-1,\n                d\\bigl(|E|-1\\bigr)\\bigr)\n\n\n    Stated a bit more compactly, for every edge :math:`e_i` in the input mesh,\n    :math:`A`, should have the following entries\n\n    .. math::\n        A[:, 4*i] &= (i, a(i)) \\\\\n        A[:, 4*i + 1] &= (i, b(i)) \\\\\n        A[:, 4*i + 2] &= (i, c(i)) \\\\\n        A[:, 4*i + 3] &= (i, d(i))\n\n    To summarize so far, we have defined 3 things:\n\n    1. The activation of the prior (e.g. :math:`k` th) layer,\n    :math:`X^{(k)} \\in \\mathbb{R}^{|E| \\times \\text{Dim-Out}(k)}`\n\n    2. The edge adjacency matrix and the definition of edge adjacency.\n    :math:`A \\in \\{0,...,|E|-1\\}^{2 \\times 4*|E|}`\n\n    3. The ways the columns of :math:`A` must be ordered.\n\n\n\n    We are now finally able to define the :obj:`MeshCNNConv` class/layer.\n    In the following definition\n    we assume :obj:`MeshCNNConv` is at the :math:`k+1` th layer of our\n    neural network.\n\n    The :obj:`MeshCNNConv` layer is a function,\n\n    .. math::\n        \\text{MeshCNNConv}^{(k+1)}(X^{(k)}, A) = X^{(k+1)},\n\n    that, given the prior layer's output\n    :math:`X^{(k)} \\in \\mathbb{R}^{|E| \\times \\text{Dim-Out}(k)}`\n    and the edge adjacency matrix :math:`A`\n    of the input mesh (graph) :math:`\\mathcal{m}` ,\n    returns a new edge feature tensor\n    :math:`X^{(k+1)} \\in \\mathbb{R}^{|E| \\times \\text{Dim-Out}(k+1)}`,\n    where the :math:`i` th row of :math:`X^{(k+1)}`, denoted by\n    :math:`x^{(k+1)}_i`,\n    represents the :math:`\\text{Dim-Out}(k+1)`-dimensional feature vector\n    of edge :math:`e_i`, **and is defined as follows**:\n\n    .. math::\n        x^{(k+1)}_i &= W^{(k+1)}_0 x^{(k)}_i \\\\\n        &+ W^{(k+1)}_1 \\bigl| x^{(k)}_{a(i)} - x^{(k)}_{c(i)} \\bigr| \\\\\n        &+ W^{(k+1)}_2 \\bigl( x^{(k)}_{a(i)} + x^{(k)}_{c(i)} \\bigr) \\\\\n        &+ W^{(k+1)}_3 \\bigl| x^{(k)}_{b(i)} - x^{(k)}_{d(i)} \\bigr| \\\\\n        &+ W^{(k+1)}_4 \\bigl( x^{(k)}_{b(i)} + x^{(k)}_{d(i)} \\bigr).\n\n    :math:`W_0^{(k+1)},W_1^{(k+1)},W_2^{(k+1)},W_3^{(k+1)}, W_4^{(k+1)}\n    \\in \\mathbb{R}^{\\text{Dim-Out}(k+1) \\times \\text{Dim-Out}(k)}`\n    are trainable linear functions (i.e. \"the weights\" of this layer).\n    :math:`x_i` is the :math:`\\text{Dim-Out}(k)`-dimensional feature of\n    edge :math:`e_i` vector computed by the prior (e.g. :math:`k`) th layer.\n    :math:`x^{(k)}_{a(i)}, x^{(k)}_{b(i)}, x^{(k)}_{c(i)}`, and\n    :math:`x^{(k)}_{d(i)}` are the :math:`\\text{Dim-Out}(k)`-feature vectors,\n    computed in the :math:`k` th layer, that are associated with the :math:`4`\n    neighboring edges of :math:`e_i`.\n\n\n    Args:\n        in_channels (int): Corresponds to :math:`\\text{Dim-Out}(k)`\n            in the above overview. This\n            represents the output dimension of the prior layer. For the given\n            input mesh :math:`\\mathcal{m} = (V, F)`, the prior layer is\n            expected to output a\n            :math:`X \\in \\mathbb{R}^{|E| \\times \\textit{in_channels}}`\n            feature matrix.\n            Assuming the instance of this class\n            is situated at layer :math:`k+1`, we write that\n            :math:`X^{(k)} \\in \\mathbb{R}^{|E| \\times \\textit{in_channels}}`.\n        out_channels (int): Corresponds to :math:`\\text{Dim-Out}(k+1)` in the\n            above overview. This represents the output dimension of this layer.\n            Assuming the instance of this class\n            is situated at layer :math:`k+1`, we write that\n            :math:`X^{(k+1)}\n            \\in \\mathbb{R}^{|E| \\times \\textit{out_channels}}`.\n        kernels (torch.nn.ModuleList, optional): A list of length of 5,\n            where each\n            element is a :class:`torch.nn.module` (i.e a neural network),\n            that each MUST take as input a vector\n            of dimension :`obj:in_channels` and return a vector of dimension\n            :obj:`out_channels`. In particular,\n            `obj:kernels[0]` is :math:`W^{(k+1)}_0` in the above overview\n            (see :obj:`MeshCNNConv`), `obj:kernels[1]` is :math:`W^{(k+1)}_1`,\n            `obj:kernels[2]` is :math:`W^{(k+1)}_2`,\n            `obj:kernels[3]` is :math:`W^{(k+1)}_3`\n            `obj:kernels[4]` is :math:`W^{(k+1)}_4`.\n            Note that this input is optional, in which case\n            each of the 5 elements in the kernels will be a linear\n            neural network :class:`torch.nn.modules.Linear`\n            correctly configured to take as input\n            :attr:`in_channels`-dimensional vectors and return\n            a vector of dimensions :attr:`out_channels`.\n\n    Discussion:\n        The key difference that separates :obj:`MeshCNNConv` from a traditional\n        message passing graph neural network is that :obj:`MeshCNNConv`\n        requires the set of neighbors for a node\n        :math:`\\mathcal{N}(u) = (v_1, v_2, ...)`\n        to *be an ordered set* (i.e. a tuple). In\n        fact, :obj:`MeshCNNConv` goes further, requiring\n        that :math:`\\mathcal{N}(u)` always return a set of size :math:`4`.\n        This is different to most message passing graph neural networks,\n        which assume that :math:`\\mathcal{N}(u) = \\{v_1, v_2, ...\\}` returns an\n        ordered set. This lends :obj:`MeshCNNConv` more expressive power,\n        at the cost of no longer being permutation invariant to\n        :math:`\\mathbb{S}_4`. Put more plainly, in tradition message passing\n        GNNs, the network is *unable* to distinguish one neighboring node\n        from another.\n        In contrast, in :obj:`MeshCNNConv`, each of the 4 neighbors has a\n        \"role\", either the \"a\", \"b\", \"c\", or \"d\" neighbor. We encode this fact\n        by requiring that :math:`\\mathcal{N}` return the 4-tuple,\n        where the first component is the \"a\" neighbor, and so on.\n\n        To summarize this comparison, it may re-define\n        :obj:`MeshCNNConv` in terms of :math:`\\text{UPDATE}` and\n        :math:`\\text{AGGREGATE}`\n        functions, which is a general way to define a traditional GNN layer.\n        If we let :math:`x_i^{(k+1)}`\n        denote the output of a GNN layer for node :math:`i` at\n        layer :math:`k+1`, and let\n        :math:`\\mathcal{N}(i)` denote the set of nodes adjacent\n        to node :math:`i`,\n        then we can describe the :math:`k+1` th layer as traditional GNN\n        as\n\n        .. math::\n            x_i^{(k+1)} = \\text{UPDATE}^{(k+1)}\\bigl(x^{(k)}_i,\n            \\text{AGGREGATE}^{(k+1)}\\bigl(\\mathcal{N}(i)\\bigr)\\bigr).\n\n        Here, :math:`\\text{UPDATE}^{(k+1)}` is a function of :math:`2`\n        :math:`\\text{Dim-Out}(k)`-dimensional vectors, and returns a\n        :math:`\\text{Dim-Out}(k+1)`-dimensional vector.\n        :math:`\\text{AGGREGATE}^{(k+1)}` function\n        is a function of a *unordered set*\n        of nodes that are neighbors of node :math:`i`, as defined by\n        :math:`\\mathcal{N}(i)`. Usually the size of this set varies across\n        different nodes :math:`i`, and one of the most basic examples\n        of such a function is the \"sum aggregation\", defined as\n        :math:`\\text{AGGREGATE}^{(k+1)}(\\mathcal{N}(i)) =\n        \\sum_{j \\in \\mathcal{N}(i)} x^{(k)}_j`.\n        See\n        :class:`SumAggregation <torch_geometric.nn.aggr.basic.SumAggregation>`\n        for more.\n\n        In contrast, while :obj:`MeshCNNConv` 's :math:`\\text{UPDATE}`\n        function follows\n        a tradition GNN, its :math:`\\text{AGGREGATE}` is a function of a tuple\n        (i.e. an ordered set) of neighbors\n        rather than a unordered set of neighbors.\n        In particular, while the :math:`\\text{UPDATE}`\n        function of :obj:`MeshCNNConv` for :math:`e_i` is\n\n        .. math::\n            x_i^{(k+1)} = \\text{UPDATE}^{(k+1)}(x_i^{(k)}, s_i^{(k+1)})\n            = W_0^{(k+1)}x_i^{(k)} + s_i^{(k+1)},\n\n        in contrast, :obj:`MeshCNNConv` 's :math:`\\text{AGGREGATE}` function is\n\n        .. math::\n            s_i^{(k+1)} = \\text{AGGREGATE}^{(k+1)}(A, B, C, D)\n            &= W_1^{(k+1)}\\bigl|A - C \\bigr| \\\\\n            &= W_2^{(k+1)}\\bigl(A + C \\bigr) \\\\\n            &= W_3^{(k+1)}\\bigl|B - D \\bigr| \\\\\n            &= W_4^{(k+1)}\\bigl(B + D \\bigr),\n\n        where :math:`A=x_{a(i)}^{(k)}, B=x_{b(i)}^{(k)}, C=x_{c(i)}^{(k)},`\n        and :math:`D=x_{d(i)}^{(k)}`.\n\n        ..\n\n            The :math:`i` th row of\n            :math:`V \\in \\mathbb{R}^{|V| \\times 3}`\n            holds the cartesian :math:`xyz`\n            coordinates for node :math:`v_i` in the mesh, and the :math:`j` th\n            column in :math:`F \\in \\{1,...,|V|\\}^{3 \\times |V|}`\n            holds the :math:`3` indices\n            :math:`(k,l,m)` that correspond to the :math:`3` nodes\n            :math:`(v_k, v_l, v_m)` that construct face :math:`j` of the mesh.\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int,\n                 kernels: Optional[ModuleList] = None):\n        super().__init__(aggr='add')\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        if kernels is None:\n            self.kernels = ModuleList(\n                [Linear(in_channels, out_channels) for _ in range(5)])\n\n        else:\n            # ensures kernels is properly formed, otherwise throws\n            # the appropriate error.\n            self._assert_kernels(kernels)\n            self.kernels = kernels\n\n    def forward(self, x: Tensor, edge_index: Tensor):\n        r\"\"\"Forward pass.\n\n        Args:\n            x(torch.Tensor): :math:`X^{(k)} \\in\n                \\mathbb{R}^{|E| \\times \\textit{in_channels}}`.\n                The edge feature tensor returned by the prior layer\n                (e.g. :math:`k`). The tensor is of shape\n                :math:`|E| \\times \\text{Dim-Out}(k)`, or equivalently,\n                :obj:`(|E|, self.in_channels)`.\n\n            edge_index(torch.Tensor):\n                :math:`A \\in \\{0,...,|E|-1\\}^{2 \\times 4*|E|}`.\n                The edge adjacency tensor of the networks input mesh\n                :math:`\\mathcal{m} = (V, F)`. The edge adjacency tensor\n                **MUST** have the following form:\n\n                .. math::\n                    &A[:,0] = (0,\n                        \\text{The index of the \"a\" edge for edge } 0) \\\\\n                    &A[:,1] = (0,\n                        \\text{The index of the \"b\" edge for edge } 0) \\\\\n                    &A[:,2] = (0,\n                        \\text{The index of the \"c\" edge for edge } 0) \\\\\n                    &A[:,3] = (0,\n                        \\text{The index of the \"d\" edge for edge } 0) \\\\\n                    \\vdots \\\\\n                    &A[:,4*|E|-4] =\n                        \\bigl(|E|-1,\n                            a\\bigl(|E|-1\\bigr)\\bigr) \\\\\n                    &A[:,4*|E|-3] =\n                        \\bigl(|E|-1,\n                            b\\bigl(|E|-1\\bigr)\\bigr) \\\\\n                    &A[:,4*|E|-2] =\n                        \\bigl(|E|-1,\n                            c\\bigl(|E|-1\\bigr)\\bigr) \\\\\n                    &A[:,4*|E|-1] =\n                        \\bigl(|E|-1,\n                            d\\bigl(|E|-1\\bigr)\\bigr)\n\n                See :obj:`MeshCNNConv` for what\n                \"index of the 'a'(b,c,d) edge for edge i\" means, and also\n                for the general definition of edge adjacency in MeshCNN.\n                These definitions are also provided in the\n                `paper <https://arxiv.org/abs/1809.05910>`_ itself.\n\n        Returns:\n           torch.Tensor:\n           :math:`X^{(k+1)} \\in \\mathbb{R}^{|E| \\times \\textit{out_channels}}`.\n           The edge feature tensor for this (e.g. the :math:`k+1` th) layer.\n           The :math:`i` th row of :math:`X^{(k+1)}` is computed according\n           to the formula\n\n            .. math::\n                x^{(k+1)}_i &= W^{(k+1)}_0 x^{(k)}_i \\\\\n                &+ W^{(k+1)}_1 \\bigl| x^{(k)}_{a(i)} - x^{(k)}_{c(i)} \\bigr| \\\\\n                &+ W^{(k+1)}_2 \\bigl( x^{(k)}_{a(i)} + x^{(k)}_{c(i)} \\bigr) \\\\\n                &+ W^{(k+1)}_3 \\bigl| x^{(k)}_{b(i)} - x^{(k)}_{d(i)} \\bigr| \\\\\n                &+ W^{(k+1)}_4 \\bigl( x^{(k)}_{b(i)} + x^{(k)}_{d(i)} \\bigr),\n\n            where :math:`W_0^{(k+1)},W_1^{(k+1)},\n            W_2^{(k+1)},W_3^{(k+1)}, W_4^{(k+1)}\n            \\in \\mathbb{R}^{\\text{Dim-Out}(k+1) \\times \\text{Dim-Out}(k)}`\n            are the trainable linear functions (i.e. the trainable\n            \"weights\") of this layer, and\n            :math:`x^{(k)}_{a(i)}, x^{(k)}_{b(i)}, x^{(k)}_{c(i)}`,\n            :math:`x^{(k)}_{d(i)}` are the\n            :math:`\\text{Dim-Out}(k)`-dimensional edge feature vectors\n            computed by the prior (:math:`k` th) layer,\n            that are associated with the :math:`4`\n            neighboring edges of :math:`e_i`.\n\n        \"\"\"\n        return self.propagate(edge_index, x=x)\n\n    def message(self, x_j: Tensor) -> Tensor:\n        r\"\"\"The messaging passing step of :obj:`MeshCNNConv`.\n\n\n        Args:\n          x_j: A :obj:`[4*|E|, num_node_features]` tensor.\n          Its ith row holds the value\n            stored by the source node in the previous layer of edge i.\n\n        Returns:\n            A :obj:`[|E|, num_node_features]` tensor,\n            whose ith row will be the value\n            that the target node of edge i will receive.\n        \"\"\"\n        # The following variables names are taken from the paper\n        # MeshCNN computes the features associated with edge\n        # e by (|a - c|, a + c, |b - c|, b + c), where a, b, c, d are the\n        # neighboring edges of e, a being the 1 edge of the upper face,\n        # b being the second edge of the upper face, c being the first edge\n        # of the lower face,\n        # and d being the second edge of the lower face of the input Mesh\n\n        # TODO: It is unclear  if view is faster. If it is not,\n        # then we should prefer the strided method commented out below\n\n        E4, in_channels = x_j.size()  # E4 = 4|E|, i.e. num edges in line graph\n        # Option 1\n        n_a = x_j[0::4]  # shape: |E| x in_channels\n        n_b = x_j[1::4]  # shape: |E| x in_channels\n        n_c = x_j[2::4]  # shape: |E| x in_channels\n        n_d = x_j[3::4]  # shape: |E| x in_channels\n        m = torch.empty(E4, self.out_channels)\n        m[0::4] = self.kernels[1].forward(torch.abs(n_a - n_c))\n        m[1::4] = self.kernels[2].forward(n_a + n_c)\n        m[2::4] = self.kernels[3].forward(torch.abs(n_b - n_d))\n        m[3::4] = self.kernels[4].forward(n_b + n_d)\n        return m\n\n        # Option 2\n        # E4, in_channels = x_j.size()\n        # E = E4 // 4\n        # x_j = x_j.view(E, 4, in_channels)  # shape: (|E| x 4 x in_channels)\n        # n_a, n_b, n_c, n_d = x_j.unbind(\n        #     dim=1)  # shape: (4 x |E| x in_channels)\n        # m = torch.stack(\n        #     [\n        #         (n_a - n_c).abs(),  # shape: |E| x in_channels\n        #         n_a + n_c,\n        #         (n_b - n_d).abs(),\n        #         n_b + n_d,\n        #     ],\n        #     dim=1)  # shape: (|E| x 4 x in_channels)\n        # m.view(E4, in_channels)  # shape 4*|E| x in_channels\n        # return m\n\n    def update(self, inputs: Tensor, x: Tensor) -> Tensor:\n        r\"\"\"The UPDATE step, in reference to the UPDATE and AGGREGATE\n        formulation of message passing convolution.\n\n        Args:\n           inputs(torch.Tensor): The :attr:`in_channels`-dimensional vector\n            returned by aggregate.\n           x(torch.Tensor): :math:`X^{(k)}`. The original inputs to this layer.\n\n        Returns:\n            torch.Tensor: :math:`X^{(k+1)}`. The output of this layer, which\n            has shape :obj:`(|E|, out_channels)`.\n        \"\"\"\n        return self.kernels[0].forward(x) + inputs\n\n    def _assert_kernels(self, kernels: ModuleList):\n        r\"\"\"Ensures that :obj:`kernels` is a list of 5 :obj:`torch.nn.Module`\n        modules (i.e. networks). In addition, it also ensures that each network\n        takes in input of dimension :attr:`in_channels`, and returns output\n        of dimension :attr:`out_channels`.\n        This method throws an error otherwise.\n\n        .. warn::\n            This method throws an error if :obj:`kernels` is\n            not valid. (Otherwise this method returns nothing)\n\n        \"\"\"\n        assert isinstance(kernels, ModuleList), \\\n            f\"Parameter 'kernels' must be a \\\n            torch.nn.module.ModuleList with 5 members, but we got \\\n            {type(kernels)}.\"\n\n        assert len(kernels) == 5, \"Parameter 'kernels' must be a \\\n            torch.nn.module.ModuleList of with exactly 5 members\"\n\n        for i, network in enumerate(kernels):\n            assert isinstance(network, Module), \\\n                f\"kernels[{i}] must be torch.nn.Module, got \\\n                {type(network)}\"\n            if not hasattr(network, \"in_channels\") and \\\n                    not hasattr(network, \"in_features\"):\n                warnings.warn(\n                    f\"kernel[{i}] does not have attribute 'in_channels' nor \"\n                    f\"'out_features'. The network must take as input a \"\n                    f\"{self.in_channels}-dimensional tensor.\", stacklevel=2)\n            else:\n                input_dimension = getattr(network, \"in_channels\",\n                                          network.in_features)\n                assert input_dimension == self.in_channels, f\"The input \\\n                dimension of the neural network in kernel[{i}] must \\\n                be \\\n                equal to 'in_channels', but input_dimension = \\\n                {input_dimension}, and \\\n                self.in_channels={self.in_channels}.\"\n\n            if not hasattr(network, \"out_channels\") and \\\n                    not hasattr(network, \"out_features\"):\n                warnings.warn(\n                    f\"kernel[{i}] does not have attribute 'in_channels' nor \"\n                    f\"'out_features'. The network must take as input a \"\n                    f\"{self.in_channels}-dimensional tensor.\", stacklevel=2)\n            else:\n                output_dimension = getattr(network, \"out_channels\",\n                                           network.out_features)\n                assert output_dimension == self.out_channels, f\"The output \\\n                    dimension of the neural network in kernel[{i}] must \\\n                    be \\\n                    equal to 'out_channels', but out_dimension = \\\n                    {output_dimension}, and \\\n                    self.out_channels={self.out_channels}.\"\n"
  },
  {
    "path": "torch_geometric/nn/conv/message_passing.py",
    "content": "import os.path as osp\nimport warnings\nfrom abc import abstractmethod\nfrom inspect import Parameter\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    Final,\n    List,\n    Optional,\n    OrderedDict,\n    Set,\n    Tuple,\n    Union,\n)\n\nimport torch\nfrom torch import Tensor\nfrom torch.utils.hooks import RemovableHandle\n\nfrom torch_geometric import EdgeIndex, is_compiling\nfrom torch_geometric.index import ptr2index\nfrom torch_geometric.inspector import Inspector, Signature\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver\nfrom torch_geometric.template import module_from_template\nfrom torch_geometric.typing import Adj, Size, SparseTensor\nfrom torch_geometric.utils import (\n    is_sparse,\n    is_torch_sparse_tensor,\n    to_edge_index,\n)\n\nFUSE_AGGRS = {'add', 'sum', 'mean', 'min', 'max'}\nHookDict = OrderedDict[int, Callable]\n\n\nclass MessagePassing(torch.nn.Module):\n    r\"\"\"Base class for creating message passing layers.\n\n    Message passing layers follow the form\n\n    .. math::\n        \\mathbf{x}_i^{\\prime} = \\gamma_{\\mathbf{\\Theta}} \\left( \\mathbf{x}_i,\n        \\bigoplus_{j \\in \\mathcal{N}(i)} \\, \\phi_{\\mathbf{\\Theta}}\n        \\left(\\mathbf{x}_i, \\mathbf{x}_j,\\mathbf{e}_{j,i}\\right) \\right),\n\n    where :math:`\\bigoplus` denotes a differentiable, permutation invariant\n    function, *e.g.*, sum, mean, min, max or mul, and\n    :math:`\\gamma_{\\mathbf{\\Theta}}` and :math:`\\phi_{\\mathbf{\\Theta}}` denote\n    differentiable functions such as MLPs.\n    See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/\n    create_gnn.html>`__ for the accompanying tutorial.\n\n    Args:\n        aggr (str or [str] or Aggregation, optional): The aggregation scheme\n            to use, *e.g.*, :obj:`\"sum\"` :obj:`\"mean\"`, :obj:`\"min\"`,\n            :obj:`\"max\"` or :obj:`\"mul\"`.\n            In addition, can be any\n            :class:`~torch_geometric.nn.aggr.Aggregation` module (or any string\n            that automatically resolves to it).\n            If given as a list, will make use of multiple aggregations in which\n            different outputs will get concatenated in the last dimension.\n            If set to :obj:`None`, the :class:`MessagePassing` instantiation is\n            expected to implement its own aggregation logic via\n            :meth:`aggregate`. (default: :obj:`\"add\"`)\n        aggr_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective aggregation function in case it gets automatically\n            resolved. (default: :obj:`None`)\n        flow (str, optional): The flow direction of message passing\n            (:obj:`\"source_to_target\"` or :obj:`\"target_to_source\"`).\n            (default: :obj:`\"source_to_target\"`)\n        node_dim (int, optional): The axis along which to propagate.\n            (default: :obj:`-2`)\n        decomposed_layers (int, optional): The number of feature decomposition\n            layers, as introduced in the `\"Optimizing Memory Efficiency of\n            Graph Neural Networks on Edge Computing Platforms\"\n            <https://arxiv.org/abs/2104.03058>`_ paper.\n            Feature decomposition reduces the peak memory usage by slicing\n            the feature dimensions into separated feature decomposition layers\n            during GNN aggregation.\n            This method can accelerate GNN execution on CPU-based platforms\n            (*e.g.*, 2-3x speedup on the\n            :class:`~torch_geometric.datasets.Reddit` dataset) for common GNN\n            models such as :class:`~torch_geometric.nn.models.GCN`,\n            :class:`~torch_geometric.nn.models.GraphSAGE`,\n            :class:`~torch_geometric.nn.models.GIN`, etc.\n            However, this method is not applicable to all GNN operators\n            available, in particular for operators in which message computation\n            can not easily be decomposed, *e.g.* in attention-based GNNs.\n            The selection of the optimal value of :obj:`decomposed_layers`\n            depends both on the specific graph dataset and available hardware\n            resources.\n            A value of :obj:`2` is suitable in most cases.\n            Although the peak memory usage is directly associated with the\n            granularity of feature decomposition, the same is not necessarily\n            true for execution speedups. (default: :obj:`1`)\n    \"\"\"\n\n    special_args: Set[str] = {\n        'edge_index', 'adj_t', 'edge_index_i', 'edge_index_j', 'size',\n        'size_i', 'size_j', 'ptr', 'index', 'dim_size'\n    }\n\n    # Supports `message_and_aggregate` via `EdgeIndex`.\n    # TODO Remove once migration is finished.\n    SUPPORTS_FUSED_EDGE_INDEX: Final[bool] = False\n\n    def __init__(\n        self,\n        aggr: Optional[Union[str, List[str], Aggregation]] = 'sum',\n        *,\n        aggr_kwargs: Optional[Dict[str, Any]] = None,\n        flow: str = \"source_to_target\",\n        node_dim: int = -2,\n        decomposed_layers: int = 1,\n    ) -> None:\n        super().__init__()\n\n        if flow not in ['source_to_target', 'target_to_source']:\n            raise ValueError(f\"Expected 'flow' to be either 'source_to_target'\"\n                             f\" or 'target_to_source' (got '{flow}')\")\n\n        # Cast `aggr` into a string representation for backward compatibility:\n        self.aggr: Optional[Union[str, List[str]]]\n        if aggr is None:\n            self.aggr = None\n        elif isinstance(aggr, (str, Aggregation)):\n            self.aggr = str(aggr)\n        elif isinstance(aggr, (tuple, list)):\n            self.aggr = [str(x) for x in aggr]\n\n        self.aggr_module = aggr_resolver(aggr, **(aggr_kwargs or {}))\n        self.flow = flow\n        self.node_dim = node_dim\n\n        # Collect attribute names requested in message passing hooks:\n        self.inspector = Inspector(self.__class__)\n        self.inspector.inspect_signature(self.message)\n        self.inspector.inspect_signature(self.aggregate, exclude=[0, 'aggr'])\n        self.inspector.inspect_signature(self.message_and_aggregate, [0])\n        self.inspector.inspect_signature(self.update, exclude=[0])\n        self.inspector.inspect_signature(self.edge_update)\n\n        self._user_args: List[str] = self.inspector.get_flat_param_names(\n            ['message', 'aggregate', 'update'], exclude=self.special_args)\n        self._fused_user_args: List[str] = self.inspector.get_flat_param_names(\n            ['message_and_aggregate', 'update'], exclude=self.special_args)\n        self._edge_user_args: List[str] = self.inspector.get_param_names(\n            'edge_update', exclude=self.special_args)\n\n        # Support for \"fused\" message passing:\n        self.fuse = self.inspector.implements('message_and_aggregate')\n        if self.aggr is not None:\n            self.fuse &= isinstance(self.aggr, str) and self.aggr in FUSE_AGGRS\n\n        # Hooks:\n        self._propagate_forward_pre_hooks: HookDict = OrderedDict()\n        self._propagate_forward_hooks: HookDict = OrderedDict()\n        self._message_forward_pre_hooks: HookDict = OrderedDict()\n        self._message_forward_hooks: HookDict = OrderedDict()\n        self._aggregate_forward_pre_hooks: HookDict = OrderedDict()\n        self._aggregate_forward_hooks: HookDict = OrderedDict()\n        self._message_and_aggregate_forward_pre_hooks: HookDict = OrderedDict()\n        self._message_and_aggregate_forward_hooks: HookDict = OrderedDict()\n        self._edge_update_forward_pre_hooks: HookDict = OrderedDict()\n        self._edge_update_forward_hooks: HookDict = OrderedDict()\n\n        # Set jittable `propagate` and `edge_updater` function templates:\n        self._set_jittable_templates()\n\n        # Explainability:\n        self._explain: Optional[bool] = None\n        self._edge_mask: Optional[Tensor] = None\n        self._loop_mask: Optional[Tensor] = None\n        self._apply_sigmoid: bool = True\n\n        # Inference Decomposition:\n        self._decomposed_layers = 1\n        self.decomposed_layers = decomposed_layers\n\n    def reset_parameters(self) -> None:\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        if self.aggr_module is not None:\n            self.aggr_module.reset_parameters()\n\n    def __setstate__(self, data: Dict[str, Any]) -> None:\n        self.inspector = data['inspector']\n        self.fuse = data['fuse']\n        self._set_jittable_templates()\n        super().__setstate__(data)\n\n    def __repr__(self) -> str:\n        channels_repr = ''\n        if hasattr(self, 'in_channels') and hasattr(self, 'out_channels'):\n            channels_repr = f'{self.in_channels}, {self.out_channels}'\n        elif hasattr(self, 'channels'):\n            channels_repr = f'{self.channels}'\n        return f'{self.__class__.__name__}({channels_repr})'\n\n    # Utilities ###############################################################\n\n    def _check_input(\n        self,\n        edge_index: Union[Tensor, SparseTensor],\n        size: Optional[Tuple[Optional[int], Optional[int]]],\n    ) -> List[Optional[int]]:\n\n        if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n            return [edge_index.num_rows, edge_index.num_cols]\n\n        if is_sparse(edge_index):\n            if self.flow == 'target_to_source':\n                raise ValueError(\n                    'Flow direction \"target_to_source\" is invalid for '\n                    'message propagation via `torch_sparse.SparseTensor` '\n                    'or `torch.sparse.Tensor`. If you really want to make '\n                    'use of a reverse message passing flow, pass in the '\n                    'transposed sparse tensor to the message passing module, '\n                    'e.g., `adj_t.t()`.')\n\n            if isinstance(edge_index, SparseTensor):\n                return [edge_index.size(1), edge_index.size(0)]\n            return [edge_index.size(1), edge_index.size(0)]\n\n        elif isinstance(edge_index, Tensor):\n            int_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32,\n                          torch.int64)\n\n            if edge_index.dtype not in int_dtypes:\n                raise ValueError(f\"Expected 'edge_index' to be of integer \"\n                                 f\"type (got '{edge_index.dtype}')\")\n            if edge_index.dim() != 2:\n                raise ValueError(f\"Expected 'edge_index' to be two-dimensional\"\n                                 f\" (got {edge_index.dim()} dimensions)\")\n            if not torch.jit.is_tracing() and edge_index.size(0) != 2:\n                raise ValueError(f\"Expected 'edge_index' to have size '2' in \"\n                                 f\"the first dimension (got \"\n                                 f\"'{edge_index.size(0)}')\")\n\n            return list(size) if size is not None else [None, None]\n\n        raise ValueError(\n            '`MessagePassing.propagate` only supports integer tensors of '\n            'shape `[2, num_messages]`, `torch_sparse.SparseTensor` or '\n            '`torch.sparse.Tensor` for argument `edge_index`.')\n\n    def _set_size(\n        self,\n        size: List[Optional[int]],\n        dim: int,\n        src: Tensor,\n    ) -> None:\n        the_size = size[dim]\n        if the_size is None:\n            size[dim] = src.size(self.node_dim)\n        elif the_size != src.size(self.node_dim):\n            raise ValueError(\n                f'Encountered tensor with size {src.size(self.node_dim)} in '\n                f'dimension {self.node_dim}, but expected size {the_size}.')\n\n    def _index_select(self, src: Tensor, index) -> Tensor:\n        if torch.jit.is_scripting() or is_compiling():\n            return src.index_select(self.node_dim, index)\n        else:\n            return self._index_select_safe(src, index)\n\n    def _index_select_safe(self, src: Tensor, index: Tensor) -> Tensor:\n        try:\n            return src.index_select(self.node_dim, index)\n        except (IndexError, RuntimeError) as e:\n            if index.numel() > 0 and index.min() < 0:\n                raise IndexError(\n                    f\"Found negative indices in 'edge_index' (got \"\n                    f\"{index.min().item()}). Please ensure that all \"\n                    f\"indices in 'edge_index' point to valid indices \"\n                    f\"in the interval [0, {src.size(self.node_dim)}) in \"\n                    f\"your node feature matrix and try again.\") from e\n\n            if (index.numel() > 0 and index.max() >= src.size(self.node_dim)):\n                raise IndexError(\n                    f\"Found indices in 'edge_index' that are larger \"\n                    f\"than {src.size(self.node_dim) - 1} (got \"\n                    f\"{index.max().item()}). Please ensure that all \"\n                    f\"indices in 'edge_index' point to valid indices \"\n                    f\"in the interval [0, {src.size(self.node_dim)}) in \"\n                    f\"your node feature matrix and try again.\") from e\n\n            raise e\n\n    def _lift(\n        self,\n        src: Tensor,\n        edge_index: Union[Tensor, SparseTensor],\n        dim: int,\n    ) -> Tensor:\n        if not torch.jit.is_scripting() and is_torch_sparse_tensor(edge_index):\n            assert dim == 0 or dim == 1\n            if edge_index.layout == torch.sparse_coo:\n                index = edge_index._indices()[1 - dim]\n            elif edge_index.layout == torch.sparse_csr:\n                if dim == 0:\n                    index = edge_index.col_indices()\n                else:\n                    index = ptr2index(edge_index.crow_indices())\n            elif edge_index.layout == torch.sparse_csc:\n                if dim == 0:\n                    index = ptr2index(edge_index.ccol_indices())\n                else:\n                    index = edge_index.row_indices()\n            else:\n                raise ValueError(f\"Unsupported sparse tensor layout \"\n                                 f\"(got '{edge_index.layout}')\")\n            return src.index_select(self.node_dim, index)\n\n        elif isinstance(edge_index, Tensor):\n            if torch.jit.is_scripting():  # Try/catch blocks are not supported.\n                index = edge_index[dim]\n                return src.index_select(self.node_dim, index)\n            return self._index_select(src, edge_index[dim])\n\n        elif isinstance(edge_index, SparseTensor):\n            row, col, _ = edge_index.coo()\n            if dim == 0:\n                return src.index_select(self.node_dim, col)\n            elif dim == 1:\n                return src.index_select(self.node_dim, row)\n\n        raise ValueError(\n            '`MessagePassing.propagate` only supports integer tensors of '\n            'shape `[2, num_messages]`, `torch_sparse.SparseTensor` '\n            'or `torch.sparse.Tensor` for argument `edge_index`.')\n\n    def _collect(\n        self,\n        args: Set[str],\n        edge_index: Union[Tensor, SparseTensor],\n        size: List[Optional[int]],\n        kwargs: Dict[str, Any],\n    ) -> Dict[str, Any]:\n\n        i, j = (1, 0) if self.flow == 'source_to_target' else (0, 1)\n\n        out = {}\n        for arg in args:\n            if arg[-2:] not in ['_i', '_j']:\n                out[arg] = kwargs.get(arg, Parameter.empty)\n            else:\n                dim = j if arg[-2:] == '_j' else i\n                data = kwargs.get(arg[:-2], Parameter.empty)\n\n                if isinstance(data, (tuple, list)):\n                    assert len(data) == 2\n                    if isinstance(data[1 - dim], Tensor):\n                        self._set_size(size, 1 - dim, data[1 - dim])\n                    data = data[dim]\n\n                if isinstance(data, Tensor):\n                    self._set_size(size, dim, data)\n                    data = self._lift(data, edge_index, dim)\n\n                out[arg] = data\n\n        if is_torch_sparse_tensor(edge_index):\n            indices, values = to_edge_index(edge_index)\n            out['adj_t'] = edge_index\n            out['edge_index'] = None\n            out['edge_index_i'] = indices[0]\n            out['edge_index_j'] = indices[1]\n            out['ptr'] = None  # TODO Get `rowptr` from CSR representation.\n            if out.get('edge_weight', None) is None:\n                out['edge_weight'] = values\n            if out.get('edge_attr', None) is None:\n                out['edge_attr'] = None if values.dim() == 1 else values\n            if out.get('edge_type', None) is None:\n                out['edge_type'] = values\n\n        elif isinstance(edge_index, Tensor):\n            out['adj_t'] = None\n            out['edge_index'] = edge_index\n            out['edge_index_i'] = edge_index[i]\n            out['edge_index_j'] = edge_index[j]\n\n            out['ptr'] = None\n            if isinstance(edge_index, EdgeIndex):\n                if i == 0 and edge_index.is_sorted_by_row:\n                    (out['ptr'], _), _ = edge_index.get_csr()\n                elif i == 1 and edge_index.is_sorted_by_col:\n                    (out['ptr'], _), _ = edge_index.get_csc()\n\n        elif isinstance(edge_index, SparseTensor):\n            row, col, value = edge_index.coo()\n            rowptr, _, _ = edge_index.csr()\n\n            out['adj_t'] = edge_index\n            out['edge_index'] = None\n            out['edge_index_i'] = row\n            out['edge_index_j'] = col\n            out['ptr'] = rowptr\n            if out.get('edge_weight', None) is None:\n                out['edge_weight'] = value\n            if out.get('edge_attr', None) is None:\n                out['edge_attr'] = value\n            if out.get('edge_type', None) is None:\n                out['edge_type'] = value\n\n        out['index'] = out['edge_index_i']\n        out['size'] = size\n        out['size_i'] = size[i] if size[i] is not None else size[j]\n        out['size_j'] = size[j] if size[j] is not None else size[i]\n        out['dim_size'] = out['size_i']\n\n        return out\n\n    # Message Passing #########################################################\n\n    def forward(self, *args: Any, **kwargs: Any) -> Any:\n        r\"\"\"Runs the forward pass of the module.\"\"\"\n\n    def propagate(\n        self,\n        edge_index: Adj,\n        size: Size = None,\n        **kwargs: Any,\n    ) -> Tensor:\n        r\"\"\"The initial call to start propagating messages.\n\n        Args:\n            edge_index (torch.Tensor or SparseTensor): A :class:`torch.Tensor`,\n                a :class:`torch_sparse.SparseTensor` or a\n                :class:`torch.sparse.Tensor` that defines the underlying\n                graph connectivity/message passing flow.\n                :obj:`edge_index` holds the indices of a general (sparse)\n                assignment matrix of shape :obj:`[N, M]`.\n                If :obj:`edge_index` is a :obj:`torch.Tensor`, its :obj:`dtype`\n                should be :obj:`torch.long` and its shape needs to be defined\n                as :obj:`[2, num_messages]` where messages from nodes in\n                :obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]`\n                (in case :obj:`flow=\"source_to_target\"`).\n                If :obj:`edge_index` is a :class:`torch_sparse.SparseTensor` or\n                a :class:`torch.sparse.Tensor`, its sparse indices\n                :obj:`(row, col)` should relate to :obj:`row = edge_index[1]`\n                and :obj:`col = edge_index[0]`.\n                The major difference between both formats is that we need to\n                input the *transposed* sparse adjacency matrix into\n                :meth:`propagate`.\n            size ((int, int), optional): The size :obj:`(N, M)` of the\n                assignment matrix in case :obj:`edge_index` is a\n                :class:`torch.Tensor`.\n                If set to :obj:`None`, the size will be automatically inferred\n                and assumed to be quadratic.\n                This argument is ignored in case :obj:`edge_index` is a\n                :class:`torch_sparse.SparseTensor` or\n                a :class:`torch.sparse.Tensor`. (default: :obj:`None`)\n            **kwargs: Any additional data which is needed to construct and\n                aggregate messages, and to update node embeddings.\n        \"\"\"\n        decomposed_layers = 1 if self.explain else self.decomposed_layers\n\n        for hook in self._propagate_forward_pre_hooks.values():\n            res = hook(self, (edge_index, size, kwargs))\n            if res is not None:\n                edge_index, size, kwargs = res\n\n        mutable_size = self._check_input(edge_index, size)\n\n        # Run \"fused\" message and aggregation (if applicable).\n        fuse = False\n        if self.fuse and not self.explain:\n            if is_sparse(edge_index):\n                fuse = True\n            elif (not torch.jit.is_scripting()\n                  and isinstance(edge_index, EdgeIndex)):\n                if (self.SUPPORTS_FUSED_EDGE_INDEX\n                        and edge_index.is_sorted_by_col):\n                    fuse = True\n\n        if fuse:\n            coll_dict = self._collect(self._fused_user_args, edge_index,\n                                      mutable_size, kwargs)\n\n            msg_aggr_kwargs = self.inspector.collect_param_data(\n                'message_and_aggregate', coll_dict)\n            for hook in self._message_and_aggregate_forward_pre_hooks.values():\n                res = hook(self, (edge_index, msg_aggr_kwargs))\n                if res is not None:\n                    edge_index, msg_aggr_kwargs = res\n            out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)\n            for hook in self._message_and_aggregate_forward_hooks.values():\n                res = hook(self, (edge_index, msg_aggr_kwargs), out)\n                if res is not None:\n                    out = res\n\n            update_kwargs = self.inspector.collect_param_data(\n                'update', coll_dict)\n            out = self.update(out, **update_kwargs)\n\n        else:  # Otherwise, run both functions in separation.\n            if decomposed_layers > 1:\n                user_args = self._user_args\n                decomp_args = {a[:-2] for a in user_args if a[-2:] == '_j'}\n                decomp_kwargs = {\n                    a: kwargs[a].chunk(decomposed_layers, -1)\n                    for a in decomp_args\n                }\n                decomp_out = []\n\n            for i in range(decomposed_layers):\n                if decomposed_layers > 1:\n                    for arg in decomp_args:\n                        kwargs[arg] = decomp_kwargs[arg][i]\n\n                coll_dict = self._collect(self._user_args, edge_index,\n                                          mutable_size, kwargs)\n\n                msg_kwargs = self.inspector.collect_param_data(\n                    'message', coll_dict)\n                for hook in self._message_forward_pre_hooks.values():\n                    res = hook(self, (msg_kwargs, ))\n                    if res is not None:\n                        msg_kwargs = res[0] if isinstance(res, tuple) else res\n                out = self.message(**msg_kwargs)\n                for hook in self._message_forward_hooks.values():\n                    res = hook(self, (msg_kwargs, ), out)\n                    if res is not None:\n                        out = res\n\n                if self.explain:\n                    explain_msg_kwargs = self.inspector.collect_param_data(\n                        'explain_message', coll_dict)\n                    out = self.explain_message(out, **explain_msg_kwargs)\n\n                aggr_kwargs = self.inspector.collect_param_data(\n                    'aggregate', coll_dict)\n                for hook in self._aggregate_forward_pre_hooks.values():\n                    res = hook(self, (aggr_kwargs, ))\n                    if res is not None:\n                        aggr_kwargs = res[0] if isinstance(res, tuple) else res\n\n                out = self.aggregate(out, **aggr_kwargs)\n\n                for hook in self._aggregate_forward_hooks.values():\n                    res = hook(self, (aggr_kwargs, ), out)\n                    if res is not None:\n                        out = res\n\n                update_kwargs = self.inspector.collect_param_data(\n                    'update', coll_dict)\n                out = self.update(out, **update_kwargs)\n\n                if decomposed_layers > 1:\n                    decomp_out.append(out)\n\n            if decomposed_layers > 1:\n                out = torch.cat(decomp_out, dim=-1)\n\n        for hook in self._propagate_forward_hooks.values():\n            res = hook(self, (edge_index, mutable_size, kwargs), out)\n            if res is not None:\n                out = res\n\n        return out\n\n    def message(self, x_j: Tensor) -> Tensor:\n        r\"\"\"Constructs messages from node :math:`j` to node :math:`i`\n        in analogy to :math:`\\phi_{\\mathbf{\\Theta}}` for each edge in\n        :obj:`edge_index`.\n        This function can take any argument as input which was initially\n        passed to :meth:`propagate`.\n        Furthermore, tensors passed to :meth:`propagate` can be mapped to the\n        respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or\n        :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.\n        \"\"\"\n        return x_j\n\n    def aggregate(\n        self,\n        inputs: Tensor,\n        index: Tensor,\n        ptr: Optional[Tensor] = None,\n        dim_size: Optional[int] = None,\n    ) -> Tensor:\n        r\"\"\"Aggregates messages from neighbors as\n        :math:`\\bigoplus_{j \\in \\mathcal{N}(i)}`.\n\n        Takes in the output of message computation as first argument and any\n        argument which was initially passed to :meth:`propagate`.\n\n        By default, this function will delegate its call to the underlying\n        :class:`~torch_geometric.nn.aggr.Aggregation` module to reduce messages\n        as specified in :meth:`__init__` by the :obj:`aggr` argument.\n        \"\"\"\n        return self.aggr_module(inputs, index, ptr=ptr, dim_size=dim_size,\n                                dim=self.node_dim)\n\n    @abstractmethod\n    def message_and_aggregate(self, edge_index: Adj) -> Tensor:\n        r\"\"\"Fuses computations of :func:`message` and :func:`aggregate` into a\n        single function.\n        If applicable, this saves both time and memory since messages do not\n        explicitly need to be materialized.\n        This function will only gets called in case it is implemented and\n        propagation takes place based on a :obj:`torch_sparse.SparseTensor`\n        or a :obj:`torch.sparse.Tensor`.\n        \"\"\"\n        raise NotImplementedError\n\n    def update(self, inputs: Tensor) -> Tensor:\n        r\"\"\"Updates node embeddings in analogy to\n        :math:`\\gamma_{\\mathbf{\\Theta}}` for each node\n        :math:`i \\in \\mathcal{V}`.\n        Takes in the output of aggregation as first argument and any argument\n        which was initially passed to :meth:`propagate`.\n        \"\"\"\n        return inputs\n\n    # Edge-level Updates ######################################################\n\n    def edge_updater(\n        self,\n        edge_index: Adj,\n        size: Size = None,\n        **kwargs: Any,\n    ) -> Tensor:\n        r\"\"\"The initial call to compute or update features for each edge in the\n        graph.\n\n        Args:\n            edge_index (torch.Tensor or SparseTensor): A :obj:`torch.Tensor`, a\n                :class:`torch_sparse.SparseTensor` or a\n                :class:`torch.sparse.Tensor` that defines the underlying graph\n                connectivity/message passing flow.\n                See :meth:`propagate` for more information.\n            size ((int, int), optional): The size :obj:`(N, M)` of the\n                assignment matrix in case :obj:`edge_index` is a\n                :class:`torch.Tensor`.\n                If set to :obj:`None`, the size will be automatically inferred\n                and assumed to be quadratic.\n                This argument is ignored in case :obj:`edge_index` is a\n                :class:`torch_sparse.SparseTensor` or\n                a :class:`torch.sparse.Tensor`. (default: :obj:`None`)\n            **kwargs: Any additional data which is needed to compute or update\n                features for each edge in the graph.\n        \"\"\"\n        for hook in self._edge_update_forward_pre_hooks.values():\n            res = hook(self, (edge_index, size, kwargs))\n            if res is not None:\n                edge_index, size, kwargs = res\n\n        mutable_size = self._check_input(edge_index, size=None)\n\n        coll_dict = self._collect(self._edge_user_args, edge_index,\n                                  mutable_size, kwargs)\n\n        edge_kwargs = self.inspector.collect_param_data(\n            'edge_update', coll_dict)\n        out = self.edge_update(**edge_kwargs)\n\n        for hook in self._edge_update_forward_hooks.values():\n            res = hook(self, (edge_index, size, kwargs), out)\n            if res is not None:\n                out = res\n\n        return out\n\n    @abstractmethod\n    def edge_update(self) -> Tensor:\n        r\"\"\"Computes or updates features for each edge in the graph.\n        This function can take any argument as input which was initially passed\n        to :meth:`edge_updater`.\n        Furthermore, tensors passed to :meth:`edge_updater` can be mapped to\n        the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or\n        :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`.\n        \"\"\"\n        raise NotImplementedError\n\n    # Inference Decomposition #################################################\n\n    @property\n    def decomposed_layers(self) -> int:\n        return self._decomposed_layers\n\n    @decomposed_layers.setter\n    def decomposed_layers(self, decomposed_layers: int) -> None:\n        if torch.jit.is_scripting():\n            raise ValueError(\"Inference decomposition of message passing \"\n                             \"modules is only supported on the Python module\")\n\n        if decomposed_layers == self._decomposed_layers:\n            return  # Abort early if nothing to do.\n\n        self._decomposed_layers = decomposed_layers\n\n        if decomposed_layers != 1:\n            if hasattr(self.__class__, '_orig_propagate'):\n                self.propagate = self.__class__._orig_propagate.__get__(\n                    self, MessagePassing)\n\n        elif self.explain is None or self.explain is False:\n            if hasattr(self.__class__, '_jinja_propagate'):\n                self.propagate = self.__class__._jinja_propagate.__get__(\n                    self, MessagePassing)\n\n    # Explainability ##########################################################\n\n    @property\n    def explain(self) -> Optional[bool]:\n        return self._explain\n\n    @explain.setter\n    def explain(self, explain: Optional[bool]) -> None:\n        if torch.jit.is_scripting():\n            raise ValueError(\"Explainability of message passing modules \"\n                             \"is only supported on the Python module\")\n\n        if explain == self._explain:\n            return  # Abort early if nothing to do.\n\n        self._explain = explain\n\n        if explain is True:\n            assert self.decomposed_layers == 1\n            self.inspector.remove_signature(self.explain_message)\n            self.inspector.inspect_signature(self.explain_message, exclude=[0])\n            self._user_args = self.inspector.get_flat_param_names(\n                funcs=['message', 'explain_message', 'aggregate', 'update'],\n                exclude=self.special_args,\n            )\n            if hasattr(self.__class__, '_orig_propagate'):\n                self.propagate = self.__class__._orig_propagate.__get__(\n                    self, MessagePassing)\n        else:\n            self._user_args = self.inspector.get_flat_param_names(\n                funcs=['message', 'aggregate', 'update'],\n                exclude=self.special_args,\n            )\n            if self.decomposed_layers == 1:\n                if hasattr(self.__class__, '_jinja_propagate'):\n                    self.propagate = self.__class__._jinja_propagate.__get__(\n                        self, MessagePassing)\n\n    def explain_message(\n        self,\n        inputs: Tensor,\n        dim_size: Optional[int],\n    ) -> Tensor:\n        # NOTE Replace this method in custom explainers per message-passing\n        # layer to customize how messages shall be explained, e.g., via:\n        # conv.explain_message = explain_message.__get__(conv, MessagePassing)\n        # see stackoverflow.com: 394770/override-a-method-at-instance-level\n        edge_mask = self._edge_mask\n\n        if edge_mask is None:\n            raise ValueError(\"Could not find a pre-defined 'edge_mask' \"\n                             \"to explain. Did you forget to initialize it?\")\n\n        if self._apply_sigmoid:\n            edge_mask = edge_mask.sigmoid()\n\n        # Some ops add self-loops to `edge_index`. We need to do the same for\n        # `edge_mask` (but do not train these entries).\n        if inputs.size(self.node_dim) != edge_mask.size(0):\n            assert dim_size is not None\n            edge_mask = edge_mask[self._loop_mask]\n            loop = edge_mask.new_ones(dim_size)\n            edge_mask = torch.cat([edge_mask, loop], dim=0)\n        assert inputs.size(self.node_dim) == edge_mask.size(0)\n\n        size = [1] * inputs.dim()\n        size[self.node_dim] = -1\n        return inputs * edge_mask.view(size)\n\n    # Hooks ###################################################################\n\n    def register_propagate_forward_pre_hook(\n        self,\n        hook: Callable,\n    ) -> RemovableHandle:\n        r\"\"\"Registers a forward pre-hook on the module.\n\n        The hook will be called every time before :meth:`propagate` is invoked.\n        It should have the following signature:\n\n        .. code-block:: python\n\n            hook(module, inputs) -> None or modified input\n\n        The hook can modify the input.\n        Input keyword arguments are passed to the hook as a dictionary in\n        :obj:`inputs[-1]`.\n\n        Returns a :class:`torch.utils.hooks.RemovableHandle` that can be used\n        to remove the added hook by calling :obj:`handle.remove()`.\n        \"\"\"\n        handle = RemovableHandle(self._propagate_forward_pre_hooks)\n        self._propagate_forward_pre_hooks[handle.id] = hook\n        return handle\n\n    def register_propagate_forward_hook(\n        self,\n        hook: Callable,\n    ) -> RemovableHandle:\n        r\"\"\"Registers a forward hook on the module.\n\n        The hook will be called every time after :meth:`propagate` has computed\n        an output.\n        It should have the following signature:\n\n        .. code-block:: python\n\n            hook(module, inputs, output) -> None or modified output\n\n        The hook can modify the output.\n        Input keyword arguments are passed to the hook as a dictionary in\n        :obj:`inputs[-1]`.\n\n        Returns a :class:`torch.utils.hooks.RemovableHandle` that can be used\n        to remove the added hook by calling :obj:`handle.remove()`.\n        \"\"\"\n        handle = RemovableHandle(self._propagate_forward_hooks)\n        self._propagate_forward_hooks[handle.id] = hook\n        return handle\n\n    def register_message_forward_pre_hook(\n        self,\n        hook: Callable,\n    ) -> RemovableHandle:\n        r\"\"\"Registers a forward pre-hook on the module.\n        The hook will be called every time before :meth:`message` is invoked.\n        See :meth:`register_propagate_forward_pre_hook` for more information.\n        \"\"\"\n        handle = RemovableHandle(self._message_forward_pre_hooks)\n        self._message_forward_pre_hooks[handle.id] = hook\n        return handle\n\n    def register_message_forward_hook(self, hook: Callable) -> RemovableHandle:\n        r\"\"\"Registers a forward hook on the module.\n        The hook will be called every time after :meth:`message` has computed\n        an output.\n        See :meth:`register_propagate_forward_hook` for more information.\n        \"\"\"\n        handle = RemovableHandle(self._message_forward_hooks)\n        self._message_forward_hooks[handle.id] = hook\n        return handle\n\n    def register_aggregate_forward_pre_hook(\n        self,\n        hook: Callable,\n    ) -> RemovableHandle:\n        r\"\"\"Registers a forward pre-hook on the module.\n        The hook will be called every time before :meth:`aggregate` is invoked.\n        See :meth:`register_propagate_forward_pre_hook` for more information.\n        \"\"\"\n        handle = RemovableHandle(self._aggregate_forward_pre_hooks)\n        self._aggregate_forward_pre_hooks[handle.id] = hook\n        return handle\n\n    def register_aggregate_forward_hook(\n        self,\n        hook: Callable,\n    ) -> RemovableHandle:\n        r\"\"\"Registers a forward hook on the module.\n        The hook will be called every time after :meth:`aggregate` has computed\n        an output.\n        See :meth:`register_propagate_forward_hook` for more information.\n        \"\"\"\n        handle = RemovableHandle(self._aggregate_forward_hooks)\n        self._aggregate_forward_hooks[handle.id] = hook\n        return handle\n\n    def register_message_and_aggregate_forward_pre_hook(\n        self,\n        hook: Callable,\n    ) -> RemovableHandle:\n        r\"\"\"Registers a forward pre-hook on the module.\n        The hook will be called every time before :meth:`message_and_aggregate`\n        is invoked.\n        See :meth:`register_propagate_forward_pre_hook` for more information.\n        \"\"\"\n        handle = RemovableHandle(self._message_and_aggregate_forward_pre_hooks)\n        self._message_and_aggregate_forward_pre_hooks[handle.id] = hook\n        return handle\n\n    def register_message_and_aggregate_forward_hook(\n        self,\n        hook: Callable,\n    ) -> RemovableHandle:\n        r\"\"\"Registers a forward hook on the module.\n        The hook will be called every time after :meth:`message_and_aggregate`\n        has computed an output.\n        See :meth:`register_propagate_forward_hook` for more information.\n        \"\"\"\n        handle = RemovableHandle(self._message_and_aggregate_forward_hooks)\n        self._message_and_aggregate_forward_hooks[handle.id] = hook\n        return handle\n\n    def register_edge_update_forward_pre_hook(\n        self,\n        hook: Callable,\n    ) -> RemovableHandle:\n        r\"\"\"Registers a forward pre-hook on the module.\n        The hook will be called every time before :meth:`edge_update` is\n        invoked. See :meth:`register_propagate_forward_pre_hook` for more\n        information.\n        \"\"\"\n        handle = RemovableHandle(self._edge_update_forward_pre_hooks)\n        self._edge_update_forward_pre_hooks[handle.id] = hook\n        return handle\n\n    def register_edge_update_forward_hook(\n        self,\n        hook: Callable,\n    ) -> RemovableHandle:\n        r\"\"\"Registers a forward hook on the module.\n        The hook will be called every time after :meth:`edge_update` has\n        computed an output.\n        See :meth:`register_propagate_forward_hook` for more information.\n        \"\"\"\n        handle = RemovableHandle(self._edge_update_forward_hooks)\n        self._edge_update_forward_hooks[handle.id] = hook\n        return handle\n\n    # TorchScript Support #####################################################\n\n    def _set_jittable_templates(self, raise_on_error: bool = False) -> None:\n        root_dir = osp.dirname(osp.realpath(__file__))\n        jinja_prefix = f'{self.__module__}_{self.__class__.__name__}'\n        # Optimize `propagate()` via `*.jinja` templates:\n        if not self.propagate.__module__.startswith(jinja_prefix):\n            try:\n                if ('propagate' in self.__class__.__dict__\n                        and self.__class__.__dict__['propagate']\n                        != MessagePassing.propagate):\n                    raise ValueError(\"Cannot compile custom 'propagate' \"\n                                     \"method\")\n\n                module = module_from_template(\n                    module_name=f'{jinja_prefix}_propagate',\n                    template_path=osp.join(root_dir, 'propagate.jinja'),\n                    tmp_dirname='message_passing',\n                    # Keyword arguments:\n                    modules=self.inspector._modules,\n                    collect_name='collect',\n                    signature=self._get_propagate_signature(),\n                    collect_param_dict=self.inspector.get_flat_param_dict(\n                        ['message', 'aggregate', 'update']),\n                    message_args=self.inspector.get_param_names('message'),\n                    aggregate_args=self.inspector.get_param_names('aggregate'),\n                    message_and_aggregate_args=self.inspector.get_param_names(\n                        'message_and_aggregate'),\n                    update_args=self.inspector.get_param_names('update'),\n                    fuse=self.fuse,\n                )\n\n                self.__class__._orig_propagate = self.__class__.propagate\n                self.__class__._jinja_propagate = module.propagate\n\n                self.__class__.propagate = module.propagate\n                self.__class__.collect = module.collect\n            except Exception as e:  # pragma: no cover\n                if raise_on_error:\n                    raise e\n                self.__class__._orig_propagate = self.__class__.propagate\n                self.__class__._jinja_propagate = self.__class__.propagate\n\n        # Optimize `edge_updater()` via `*.jinja` templates (if implemented):\n        if (self.inspector.implements('edge_update')\n                and not self.edge_updater.__module__.startswith(jinja_prefix)):\n            try:\n                if ('edge_updater' in self.__class__.__dict__\n                        and self.__class__.__dict__['edge_updater']\n                        != MessagePassing.edge_updater):\n                    raise ValueError(\"Cannot compile custom 'edge_updater' \"\n                                     \"method\")\n\n                module = module_from_template(\n                    module_name=f'{jinja_prefix}_edge_updater',\n                    template_path=osp.join(root_dir, 'edge_updater.jinja'),\n                    tmp_dirname='message_passing',\n                    # Keyword arguments:\n                    modules=self.inspector._modules,\n                    collect_name='edge_collect',\n                    signature=self._get_edge_updater_signature(),\n                    collect_param_dict=self.inspector.get_param_dict(\n                        'edge_update'),\n                )\n\n                self.__class__._orig_edge_updater = self.__class__.edge_updater\n                self.__class__._jinja_edge_updater = module.edge_updater\n\n                self.__class__.edge_updater = module.edge_updater\n                self.__class__.edge_collect = module.edge_collect\n            except Exception as e:  # pragma: no cover\n                if raise_on_error:\n                    raise e\n                self.__class__._orig_edge_updater = self.__class__.edge_updater\n                self.__class__._jinja_edge_updater = (\n                    self.__class__.edge_updater)\n\n    def _get_propagate_signature(self) -> Signature:\n        param_dict = self.inspector.get_params_from_method_call(\n            'propagate', exclude=[0, 'edge_index', 'size'])\n        update_signature = self.inspector.get_signature('update')\n\n        return Signature(\n            param_dict=param_dict,\n            return_type=update_signature.return_type,\n            return_type_repr=update_signature.return_type_repr,\n        )\n\n    def _get_edge_updater_signature(self) -> Signature:\n        param_dict = self.inspector.get_params_from_method_call(\n            'edge_updater', exclude=[0, 'edge_index', 'size'])\n        edge_update_signature = self.inspector.get_signature('edge_update')\n\n        return Signature(\n            param_dict=param_dict,\n            return_type=edge_update_signature.return_type,\n            return_type_repr=edge_update_signature.return_type_repr,\n        )\n\n    def jittable(self, typing: Optional[str] = None) -> 'MessagePassing':\n        r\"\"\"Analyzes the :class:`MessagePassing` instance and produces a new\n        jittable module that can be used in combination with\n        :meth:`torch.jit.script`.\n\n        .. note::\n            :meth:`jittable` is deprecated and a no-op from :pyg:`PyG` 2.5\n            onwards.\n        \"\"\"\n        warnings.warn(\n            f\"'{self.__class__.__name__}.jittable' is deprecated \"\n            f\"and a no-op. Please remove its usage.\", stacklevel=2)\n        return self\n"
  },
  {
    "path": "torch_geometric/nn/conv/mf_conv.py",
    "content": "from typing import Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import ModuleList\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import Adj, OptPairTensor, Size, SparseTensor\nfrom torch_geometric.utils import degree, spmm\n\n\nclass MFConv(MessagePassing):\n    r\"\"\"The graph neural network operator from the\n    `\"Convolutional Networks on Graphs for Learning Molecular Fingerprints\"\n    <https://arxiv.org/abs/1509.09292>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{W}^{(\\deg(i))}_1 \\mathbf{x}_i +\n        \\mathbf{W}^{(\\deg(i))}_2 \\sum_{j \\in \\mathcal{N}(i)} \\mathbf{x}_j\n\n    which trains a distinct weight matrix for each possible vertex degree.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        max_degree (int, optional): The maximum node degree to consider when\n            updating weights (default: :obj:`10`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **inputs:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **outputs:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V_t}|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, in_channels: Union[int, Tuple[int, int]],\n                 out_channels: int, max_degree: int = 10, bias=True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.max_degree = max_degree\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        self.lins_l = ModuleList([\n            Linear(in_channels[0], out_channels, bias=bias)\n            for _ in range(max_degree + 1)\n        ])\n\n        self.lins_r = ModuleList([\n            Linear(in_channels[1], out_channels, bias=False)\n            for _ in range(max_degree + 1)\n        ])\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        for lin in self.lins_l:\n            lin.reset_parameters()\n        for lin in self.lins_r:\n            lin.reset_parameters()\n\n    def forward(\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        size: Size = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        x_r = x[1]\n\n        deg = x[0]  # Dummy.\n        if isinstance(edge_index, SparseTensor):\n            deg = edge_index.storage.rowcount()\n        elif isinstance(edge_index, Tensor):\n            i = 1 if self.flow == 'source_to_target' else 0\n            N = x[0].size(self.node_dim)\n            N = size[1] if size is not None else N\n            N = x_r.size(self.node_dim) if x_r is not None else N\n            deg = degree(edge_index[i], N, dtype=torch.long)\n        deg.clamp_(max=self.max_degree)\n\n        # propagate_type: (x: OptPairTensor)\n        h = self.propagate(edge_index, x=x, size=size)\n\n        out = h.new_empty(list(h.size())[:-1] + [self.out_channels])\n        for i, (lin_l, lin_r) in enumerate(zip(self.lins_l, self.lins_r)):\n            idx = (deg == i).nonzero().view(-1)\n            r = lin_l(h.index_select(self.node_dim, idx))\n\n            if x_r is not None:\n                r = r + lin_r(x_r.index_select(self.node_dim, idx))\n\n            out.index_copy_(self.node_dim, idx, r)\n\n        return out\n\n    def message(self, x_j: Tensor) -> Tensor:\n        return x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:\n        if isinstance(adj_t, SparseTensor):\n            adj_t = adj_t.set_value(None, layout=None)\n        return spmm(adj_t, x[0], reduce=self.aggr)\n"
  },
  {
    "path": "torch_geometric/nn/conv/mixhop_conv.py",
    "content": "from typing import List, Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass MixHopConv(MessagePassing):\n    r\"\"\"The Mix-Hop graph convolutional operator from the\n    `\"MixHop: Higher-Order Graph Convolutional Architectures via Sparsified\n    Neighborhood Mixing\" <https://arxiv.org/abs/1905.00067>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime}={\\Bigg\\Vert}_{p\\in P}\n        {\\left( \\mathbf{\\hat{D}}^{-1/2} \\mathbf{\\hat{A}}\n        \\mathbf{\\hat{D}}^{-1/2} \\right)}^p \\mathbf{X} \\mathbf{\\Theta},\n\n    where :math:`\\mathbf{\\hat{A}} = \\mathbf{A} + \\mathbf{I}` denotes the\n    adjacency matrix with inserted self-loops and\n    :math:`\\hat{D}_{ii} = \\sum_{j=0} \\hat{A}_{ij}` its diagonal degree matrix.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        powers (List[int], optional): The powers of the adjacency matrix to\n            use. (default: :obj:`[0, 1, 2]`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:**\n          node features :math:`(|\\mathcal{V}|, |P| \\cdot F_{out})`\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        powers: Optional[List[int]] = None,\n        add_self_loops: bool = True,\n        bias: bool = True,\n        **kwargs,\n    ):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        if powers is None:\n            powers = [0, 1, 2]\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.powers = powers\n        self.add_self_loops = add_self_loops\n\n        self.lins = torch.nn.ModuleList([\n            Linear(in_channels, out_channels, bias=False)\n            if p in powers else torch.nn.Identity()\n            for p in range(max(powers) + 1)\n        ])\n\n        if bias:\n            self.bias = Parameter(torch.empty(len(powers) * out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        for lin in self.lins:\n            if hasattr(lin, 'reset_parameters'):\n                lin.reset_parameters()\n        zeros(self.bias)\n\n    def forward(self, x: Tensor, edge_index: Adj,\n                edge_weight: OptTensor = None) -> Tensor:\n\n        if isinstance(edge_index, Tensor):\n            edge_index, edge_weight = gcn_norm(  # yapf: disable\n                edge_index, edge_weight, x.size(self.node_dim), False,\n                self.add_self_loops, self.flow, x.dtype)\n        elif isinstance(edge_index, SparseTensor):\n            edge_index = gcn_norm(  # yapf: disable\n                edge_index, edge_weight, x.size(self.node_dim), False,\n                self.add_self_loops, self.flow, x.dtype)\n\n        outs = [self.lins[0](x)]\n\n        for lin in self.lins[1:]:\n            # propagate_type: (x: Tensor, edge_weight: OptTensor)\n            x = self.propagate(edge_index, x=x, edge_weight=edge_weight)\n\n            outs.append(lin.forward(x))\n\n        out = torch.cat([outs[p] for p in self.powers], dim=-1)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, powers={self.powers})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/nn_conv.py",
    "content": "from typing import Callable, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import reset, zeros\nfrom torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\n\n\nclass NNConv(MessagePassing):\n    r\"\"\"The continuous kernel-based convolutional operator from the\n    `\"Neural Message Passing for Quantum Chemistry\"\n    <https://arxiv.org/abs/1704.01212>`_ paper.\n\n    This convolution is also known as the edge-conditioned convolution from the\n    `\"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on\n    Graphs\" <https://arxiv.org/abs/1704.02901>`_ paper (see\n    :class:`torch_geometric.nn.conv.ECConv` for an alias):\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{\\Theta} \\mathbf{x}_i +\n        \\sum_{j \\in \\mathcal{N}(i)} \\mathbf{x}_j \\cdot\n        h_{\\mathbf{\\Theta}}(\\mathbf{e}_{i,j}),\n\n    where :math:`h_{\\mathbf{\\Theta}}` denotes a neural network, *.i.e.*\n    a MLP.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        nn (torch.nn.Module): A neural network :math:`h_{\\mathbf{\\Theta}}` that\n            maps edge features :obj:`edge_attr` of shape :obj:`[-1,\n            num_edge_features]` to shape\n            :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by\n            :class:`torch.nn.Sequential`.\n        aggr (str, optional): The aggregation scheme to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"add\"`)\n        root_weight (bool, optional): If set to :obj:`False`, the layer will\n            not add the transformed root node features to the output.\n            (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, in_channels: Union[int, Tuple[int, int]],\n                 out_channels: int, nn: Callable, aggr: str = 'add',\n                 root_weight: bool = True, bias: bool = True, **kwargs):\n        super().__init__(aggr=aggr, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.nn = nn\n        self.root_weight = root_weight\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        self.in_channels_l = in_channels[0]\n\n        if root_weight:\n            self.lin = Linear(in_channels[1], out_channels, bias=False,\n                              weight_initializer='uniform')\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        reset(self.nn)\n        if self.root_weight:\n            self.lin.reset_parameters()\n        zeros(self.bias)\n\n    def forward(\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        edge_attr: OptTensor = None,\n        size: Size = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n\n        x_r = x[1]\n        if x_r is not None and self.root_weight:\n            out = out + self.lin(x_r)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:\n        weight = self.nn(edge_attr)\n        weight = weight.view(-1, self.in_channels_l, self.out_channels)\n        return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, aggr={self.aggr}, nn={self.nn})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/pan_conv.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import Adj, SparseTensor\nfrom torch_geometric.utils import is_torch_sparse_tensor, spmm\n\n\nclass PANConv(MessagePassing):\n    r\"\"\"The path integral based convolutional operator from the\n    `\"Path Integral Based Convolution and Pooling for Graph Neural Networks\"\n    <https://arxiv.org/abs/2006.16811>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = \\mathbf{M} \\mathbf{X} \\mathbf{W}\n\n    where :math:`\\mathbf{M}` denotes the normalized and learned maximal entropy\n    transition (MET) matrix that includes neighbors up to :obj:`filter_size`\n    hops:\n\n    .. math::\n\n        \\mathbf{M} = \\mathbf{Z}^{-1/2} \\sum_{n=0}^L e^{-\\frac{E(n)}{T}}\n        \\mathbf{A}^n \\mathbf{Z}^{-1/2}\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        filter_size (int): The filter size :math:`L`.\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int, filter_size: int,\n                 **kwargs):\n\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.filter_size = filter_size\n\n        self.lin = Linear(in_channels, out_channels)\n        self.weight = Parameter(torch.empty(filter_size + 1))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin.reset_parameters()\n        self.weight.data.fill_(0.5)\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Adj,\n    ) -> Tuple[Tensor, SparseTensor]:\n\n        adj_t: Optional[SparseTensor] = None\n        if isinstance(edge_index, Tensor):\n            if is_torch_sparse_tensor(edge_index):\n                # TODO Handle PyTorch sparse tensor directly.\n                if edge_index.layout == torch.sparse_coo:\n                    adj_t = SparseTensor.from_torch_sparse_coo_tensor(\n                        edge_index)\n                elif edge_index.layout == torch.sparse_csr:\n                    adj_t = SparseTensor.from_torch_sparse_csr_tensor(\n                        edge_index)\n                else:\n                    raise ValueError(f\"Unexpected sparse tensor layout \"\n                                     f\"(got '{edge_index.layout}')\")\n            else:\n                adj_t = SparseTensor(row=edge_index[1], col=edge_index[0],\n                                     sparse_sizes=(x.size(0), x.size(0)))\n\n        elif isinstance(edge_index, SparseTensor):\n            adj_t = edge_index.set_value(None)\n\n        adj_t = self.panentropy(adj_t, dtype=x.dtype)\n\n        deg = adj_t.storage.rowcount().to(x.dtype)\n        deg_inv_sqrt = deg.pow_(-0.5)\n        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0.\n        M = deg_inv_sqrt.view(1, -1) * adj_t * deg_inv_sqrt.view(-1, 1)\n\n        out = self.propagate(M, x=x, edge_weight=None)\n        out = self.lin(out)\n        return out, M\n\n    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:\n        return edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def panentropy(self, adj_t: SparseTensor,\n                   dtype: Optional[int] = None) -> SparseTensor:\n\n        if not adj_t.has_value():\n            adj_t = adj_t.fill_value(1.0)\n\n        tmp = SparseTensor.eye(adj_t.size(0), adj_t.size(1), has_value=True,\n                               dtype=dtype, device=adj_t.device())\n        tmp = tmp.mul_nnz(self.weight[0], layout='coo')\n\n        outs = [tmp]\n        for i in range(1, self.filter_size + 1):\n            tmp = tmp @ adj_t\n            tmp = tmp.mul_nnz(self.weight[i], layout='coo')\n            outs += [tmp]\n\n        row = torch.cat([out.storage.row() for out in outs], dim=0)\n        col = torch.cat([out.storage.col() for out in outs], dim=0)\n        value = torch.cat([out.storage.value() for out in outs], dim=0)\n\n        out = SparseTensor(row=row, col=col, value=value,\n                           sparse_sizes=adj_t.sparse_sizes()).coalesce()\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, filter_size={self.filter_size})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/pdn_conv.py",
    "content": "import torch\nfrom torch import Tensor\nfrom torch.nn import Linear, Parameter, ReLU, Sequential, Sigmoid\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.inits import glorot, zeros\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass PDNConv(MessagePassing):\n    r\"\"\"The pathfinder discovery network convolutional operator from the\n    `\"Pathfinder Discovery Networks for Neural Message Passing\"\n    <https://arxiv.org/abs/2010.12878>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\sum_{j \\in \\mathcal{N}(i) \\cup\n        \\{i\\}}f_{\\Theta}(\\textbf{e}_{(j,i)}) \\cdot f_{\\Omega}(\\mathbf{x}_{j})\n\n    where :math:`z_{i,j}` denotes the edge feature vector from source node\n    :math:`j` to target node :math:`i`, and :math:`\\mathbf{x}_{j}` denotes the\n    node feature vector of node :math:`j`.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        out_channels (int): Size of each output sample.\n        edge_dim (int): Edge feature dimensionality.\n        hidden_channels (int): Hidden edge feature dimensionality.\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        normalize (bool, optional): Whether to add self-loops and compute\n            symmetric normalization coefficients on the fly.\n            (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int, edge_dim: int,\n                 hidden_channels: int, add_self_loops: bool = True,\n                 normalize: bool = True, bias: bool = True, **kwargs):\n\n        kwargs.setdefault(\"aggr\", \"add\")\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.edge_dim = edge_dim\n        self.hidden_channels = hidden_channels\n        self.add_self_loops = add_self_loops\n        self.normalize = normalize\n\n        self.lin = Linear(in_channels, out_channels, bias=False)\n\n        self.mlp = Sequential(\n            Linear(edge_dim, hidden_channels),\n            ReLU(inplace=True),\n            Linear(hidden_channels, 1),\n            Sigmoid(),\n        )\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter(\"bias\", None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        glorot(self.lin.weight)\n        glorot(self.mlp[0].weight)\n        glorot(self.mlp[2].weight)\n        zeros(self.mlp[0].bias)\n        zeros(self.mlp[2].bias)\n        zeros(self.bias)\n\n    def forward(self, x: Tensor, edge_index: Adj,\n                edge_attr: OptTensor = None) -> Tensor:\n\n        if isinstance(edge_index, SparseTensor):\n            edge_attr = edge_index.storage.value()\n\n        if edge_attr is not None:\n            edge_attr = self.mlp(edge_attr).squeeze(-1)\n\n        if isinstance(edge_index, SparseTensor):\n            edge_index = edge_index.set_value(edge_attr, layout='coo')\n\n        if self.normalize:\n            if isinstance(edge_index, Tensor):\n                edge_index, edge_attr = gcn_norm(edge_index, edge_attr,\n                                                 x.size(self.node_dim), False,\n                                                 self.add_self_loops,\n                                                 self.flow, x.dtype)\n            elif isinstance(edge_index, SparseTensor):\n                edge_index = gcn_norm(edge_index, None, x.size(self.node_dim),\n                                      False, self.add_self_loops, self.flow,\n                                      x.dtype)\n\n        x = self.lin(x)\n\n        # propagate_type: (x: Tensor, edge_weight: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_weight=edge_attr)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:\n        return edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self):\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/pna_conv.py",
    "content": "from typing import Any, Callable, Dict, List, Optional, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import ModuleList, Sequential\nfrom torch.utils.data import DataLoader\n\nfrom torch_geometric.nn.aggr import DegreeScalerAggregation\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.nn.resolver import activation_resolver\nfrom torch_geometric.typing import Adj, OptTensor\nfrom torch_geometric.utils import degree\n\n\nclass PNAConv(MessagePassing):\n    r\"\"\"The Principal Neighbourhood Aggregation graph convolution operator\n    from the `\"Principal Neighbourhood Aggregation for Graph Nets\"\n    <https://arxiv.org/abs/2004.05718>`_ paper.\n\n    .. math::\n        \\mathbf{x}_i^{\\prime} = \\gamma_{\\mathbf{\\Theta}} \\left(\n        \\mathbf{x}_i, \\underset{j \\in \\mathcal{N}(i)}{\\bigoplus}\n        h_{\\mathbf{\\Theta}} \\left( \\mathbf{x}_i, \\mathbf{x}_j \\right)\n        \\right)\n\n    with\n\n    .. math::\n        \\bigoplus = \\underbrace{\\begin{bmatrix}\n            1 \\\\\n            S(\\mathbf{D}, \\alpha=1) \\\\\n            S(\\mathbf{D}, \\alpha=-1)\n        \\end{bmatrix} }_{\\text{scalers}}\n        \\otimes \\underbrace{\\begin{bmatrix}\n            \\mu \\\\\n            \\sigma \\\\\n            \\max \\\\\n            \\min\n        \\end{bmatrix}}_{\\text{aggregators}},\n\n    where :math:`\\gamma_{\\mathbf{\\Theta}}` and :math:`h_{\\mathbf{\\Theta}}`\n    denote MLPs.\n\n    .. note::\n\n        For an example of using :obj:`PNAConv`, see `examples/pna.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/\n        examples/pna.py>`_.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        aggregators (List[str]): Set of aggregation function identifiers,\n            namely :obj:`\"sum\"`, :obj:`\"mean\"`, :obj:`\"min\"`, :obj:`\"max\"`,\n            :obj:`\"var\"` and :obj:`\"std\"`.\n        scalers (List[str]): Set of scaling function identifiers, namely\n            :obj:`\"identity\"`, :obj:`\"amplification\"`,\n            :obj:`\"attenuation\"`, :obj:`\"linear\"` and\n            :obj:`\"inverse_linear\"`.\n        deg (torch.Tensor): Histogram of in-degrees of nodes in the training\n            set, used by scalers to normalize.\n        edge_dim (int, optional): Edge feature dimensionality (in case\n            there are any). (default :obj:`None`)\n        towers (int, optional): Number of towers (default: :obj:`1`).\n        pre_layers (int, optional): Number of transformation layers before\n            aggregation (default: :obj:`1`).\n        post_layers (int, optional): Number of transformation layers after\n            aggregation (default: :obj:`1`).\n        divide_input (bool, optional): Whether the input features should\n            be split between towers or not (default: :obj:`False`).\n        act (str or callable, optional): Pre- and post-layer activation\n            function to use. (default: :obj:`\"relu\"`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        train_norm (bool, optional): Whether normalization parameters\n            are trainable. (default: :obj:`False`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge features :math:`(|\\mathcal{E}|, D)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        aggregators: List[str],\n        scalers: List[str],\n        deg: Tensor,\n        edge_dim: Optional[int] = None,\n        towers: int = 1,\n        pre_layers: int = 1,\n        post_layers: int = 1,\n        divide_input: bool = False,\n        act: Union[str, Callable, None] = \"relu\",\n        act_kwargs: Optional[Dict[str, Any]] = None,\n        train_norm: bool = False,\n        **kwargs,\n    ):\n\n        aggr = DegreeScalerAggregation(aggregators, scalers, deg, train_norm)\n        super().__init__(aggr=aggr, node_dim=0, **kwargs)\n\n        if divide_input:\n            assert in_channels % towers == 0\n        assert out_channels % towers == 0\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.edge_dim = edge_dim\n        self.towers = towers\n        self.divide_input = divide_input\n\n        self.F_in = in_channels // towers if divide_input else in_channels\n        self.F_out = self.out_channels // towers\n\n        if self.edge_dim is not None:\n            self.edge_encoder = Linear(edge_dim, self.F_in)\n\n        self.pre_nns = ModuleList()\n        self.post_nns = ModuleList()\n        for _ in range(towers):\n            modules = [Linear((3 if edge_dim else 2) * self.F_in, self.F_in)]\n            for _ in range(pre_layers - 1):\n                modules += [activation_resolver(act, **(act_kwargs or {}))]\n                modules += [Linear(self.F_in, self.F_in)]\n            self.pre_nns.append(Sequential(*modules))\n\n            in_channels = (len(aggregators) * len(scalers) + 1) * self.F_in\n            modules = [Linear(in_channels, self.F_out)]\n            for _ in range(post_layers - 1):\n                modules += [activation_resolver(act, **(act_kwargs or {}))]\n                modules += [Linear(self.F_out, self.F_out)]\n            self.post_nns.append(Sequential(*modules))\n\n        self.lin = Linear(out_channels, out_channels)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        if self.edge_dim is not None:\n            self.edge_encoder.reset_parameters()\n        for nn in self.pre_nns:\n            reset(nn)\n        for nn in self.post_nns:\n            reset(nn)\n        self.lin.reset_parameters()\n\n    def forward(self, x: Tensor, edge_index: Adj,\n                edge_attr: OptTensor = None) -> Tensor:\n\n        if self.divide_input:\n            x = x.view(-1, self.towers, self.F_in)\n        else:\n            x = x.view(-1, 1, self.F_in).repeat(1, self.towers, 1)\n\n        # propagate_type: (x: Tensor, edge_attr: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_attr=edge_attr)\n\n        out = torch.cat([x, out], dim=-1)\n        outs = [nn(out[:, i]) for i, nn in enumerate(self.post_nns)]\n        out = torch.cat(outs, dim=1)\n\n        return self.lin(out)\n\n    def message(self, x_i: Tensor, x_j: Tensor,\n                edge_attr: OptTensor) -> Tensor:\n\n        h: Tensor = x_i  # Dummy.\n        if edge_attr is not None:\n            edge_attr = self.edge_encoder(edge_attr)\n            edge_attr = edge_attr.view(-1, 1, self.F_in)\n            edge_attr = edge_attr.repeat(1, self.towers, 1)\n            h = torch.cat([x_i, x_j, edge_attr], dim=-1)\n        else:\n            h = torch.cat([x_i, x_j], dim=-1)\n\n        hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)]\n        return torch.stack(hs, dim=1)\n\n    def __repr__(self):\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, towers={self.towers}, '\n                f'edge_dim={self.edge_dim})')\n\n    @staticmethod\n    def get_degree_histogram(loader: DataLoader) -> Tensor:\n        r\"\"\"Returns the degree histogram to be used as input for the :obj:`deg`\n        argument in :class:`PNAConv`.\n        \"\"\"\n        deg_histogram = torch.zeros(1, dtype=torch.long)\n        for data in loader:\n            deg = degree(data.edge_index[1], num_nodes=data.num_nodes,\n                         dtype=torch.long)\n            deg_bincount = torch.bincount(deg, minlength=deg_histogram.numel())\n            deg_histogram = deg_histogram.to(deg_bincount.device)\n            if deg_bincount.numel() > deg_histogram.numel():\n                deg_bincount[:deg_histogram.size(0)] += deg_histogram\n                deg_histogram = deg_bincount\n            else:\n                assert deg_bincount.numel() == deg_histogram.numel()\n                deg_histogram += deg_bincount\n\n        return deg_histogram\n"
  },
  {
    "path": "torch_geometric/nn/conv/point_conv.py",
    "content": "from typing import Callable, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.typing import (\n    Adj,\n    OptTensor,\n    PairOptTensor,\n    PairTensor,\n    SparseTensor,\n    torch_sparse,\n)\nfrom torch_geometric.utils import add_self_loops, remove_self_loops\n\n\nclass PointNetConv(MessagePassing):\n    r\"\"\"The PointNet set layer from the `\"PointNet: Deep Learning on Point Sets\n    for 3D Classification and Segmentation\"\n    <https://arxiv.org/abs/1612.00593>`_ and `\"PointNet++: Deep Hierarchical\n    Feature Learning on Point Sets in a Metric Space\"\n    <https://arxiv.org/abs/1706.02413>`_ papers.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\gamma_{\\mathbf{\\Theta}} \\left( \\max_{j \\in\n        \\mathcal{N}(i) \\cup \\{ i \\}} h_{\\mathbf{\\Theta}} ( \\mathbf{x}_j,\n        \\mathbf{p}_j - \\mathbf{p}_i) \\right),\n\n    where :math:`\\gamma_{\\mathbf{\\Theta}}` and :math:`h_{\\mathbf{\\Theta}}`\n    denote neural networks, *i.e.* MLPs, and\n    :math:`\\mathbf{P} \\in \\mathbb{R}^{N \\times D}` defines the position of\n    each point.\n\n    Args:\n        local_nn (torch.nn.Module, optional): A neural network\n            :math:`h_{\\mathbf{\\Theta}}` that maps node features :obj:`x` and\n            relative spatial coordinates :obj:`pos_j - pos_i` of shape\n            :obj:`[-1, in_channels + num_dimensions]` to shape\n            :obj:`[-1, out_channels]`, *e.g.*, defined by\n            :class:`torch.nn.Sequential`. (default: :obj:`None`)\n        global_nn (torch.nn.Module, optional): A neural network\n            :math:`\\gamma_{\\mathbf{\\Theta}}` that maps aggregated node features\n            of shape :obj:`[-1, out_channels]` to shape :obj:`[-1,\n            final_out_channels]`, *e.g.*, defined by\n            :class:`torch.nn.Sequential`. (default: :obj:`None`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          positions :math:`(|\\mathcal{V}|, 3)` or\n          :math:`((|\\mathcal{V_s}|, 3), (|\\mathcal{V_t}|, 3))` if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, local_nn: Optional[Callable] = None,\n                 global_nn: Optional[Callable] = None,\n                 add_self_loops: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'max')\n        super().__init__(**kwargs)\n\n        self.local_nn = local_nn\n        self.global_nn = global_nn\n        self.add_self_loops = add_self_loops\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        reset(self.local_nn)\n        reset(self.global_nn)\n\n    def forward(\n        self,\n        x: Union[OptTensor, PairOptTensor],\n        pos: Union[Tensor, PairTensor],\n        edge_index: Adj,\n    ) -> Tensor:\n\n        if not isinstance(x, tuple):\n            x = (x, None)\n\n        if isinstance(pos, Tensor):\n            pos = (pos, pos)\n\n        if self.add_self_loops:\n            if isinstance(edge_index, Tensor):\n                edge_index, _ = remove_self_loops(edge_index)\n                edge_index, _ = add_self_loops(\n                    edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0)))\n            elif isinstance(edge_index, SparseTensor):\n                edge_index = torch_sparse.set_diag(edge_index)\n\n        # propagate_type: (x: PairOptTensor, pos: PairTensor)\n        out = self.propagate(edge_index, x=x, pos=pos)\n\n        if self.global_nn is not None:\n            out = self.global_nn(out)\n\n        return out\n\n    def message(self, x_j: Optional[Tensor], pos_i: Tensor,\n                pos_j: Tensor) -> Tensor:\n        msg = pos_j - pos_i\n        if x_j is not None:\n            msg = torch.cat([x_j, msg], dim=1)\n        if self.local_nn is not None:\n            msg = self.local_nn(msg)\n        return msg\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(local_nn={self.local_nn}, '\n                f'global_nn={self.global_nn})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/point_gnn_conv.py",
    "content": "import torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.typing import Adj\n\n\nclass PointGNNConv(MessagePassing):\n    r\"\"\"The PointGNN operator from the `\"Point-GNN: Graph Neural Network for\n    3D Object Detection in a Point Cloud\" <https://arxiv.org/abs/2003.01251>`_\n    paper.\n\n    .. math::\n\n        \\Delta \\textrm{pos}_i &= h_{\\mathbf{\\Theta}}(\\mathbf{x}_i)\n\n        \\mathbf{e}_{j,i} &= f_{\\mathbf{\\Theta}}(\\textrm{pos}_j -\n        \\textrm{pos}_i + \\Delta \\textrm{pos}_i, \\mathbf{x}_j)\n\n        \\mathbf{x}^{\\prime}_i &= g_{\\mathbf{\\Theta}}(\\max_{j \\in\n        \\mathcal{N}(i)} \\mathbf{e}_{j,i}) + \\mathbf{x}_i\n\n    The relative position is used in the message passing step to introduce\n    global translation invariance.\n    To also counter shifts in the local neighborhood of the center node, the\n    authors propose to utilize an alignment offset.\n    The graph should be statically constructed using radius-based cutoff.\n\n    Args:\n        mlp_h (torch.nn.Module): A neural network :math:`h_{\\mathbf{\\Theta}}`\n            that maps node features of size :math:`F_{in}` to three-dimensional\n            coordination offsets :math:`\\Delta \\textrm{pos}_i`.\n        mlp_f (torch.nn.Module): A neural network :math:`f_{\\mathbf{\\Theta}}`\n            that computes :math:`\\mathbf{e}_{j,i}` from the features of\n            neighbors of size :math:`F_{in}` and the three-dimensional vector\n            :math:`\\textrm{pos_j} - \\textrm{pos_i} + \\Delta \\textrm{pos}_i`.\n        mlp_g (torch.nn.Module): A neural network :math:`g_{\\mathbf{\\Theta}}`\n            that maps the aggregated edge features back to :math:`F_{in}`.\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          positions :math:`(|\\mathcal{V}|, 3)`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{in})`\n    \"\"\"\n    def __init__(\n        self,\n        mlp_h: torch.nn.Module,\n        mlp_f: torch.nn.Module,\n        mlp_g: torch.nn.Module,\n        **kwargs,\n    ):\n        kwargs.setdefault('aggr', 'max')\n        super().__init__(**kwargs)\n\n        self.mlp_h = mlp_h\n        self.mlp_f = mlp_f\n        self.mlp_g = mlp_g\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        reset(self.mlp_h)\n        reset(self.mlp_f)\n        reset(self.mlp_g)\n\n    def forward(self, x: Tensor, pos: Tensor, edge_index: Adj) -> Tensor:\n        # propagate_type: (x: Tensor, pos: Tensor)\n        out = self.propagate(edge_index, x=x, pos=pos)\n        out = self.mlp_g(out)\n        return x + out\n\n    def message(self, pos_j: Tensor, pos_i: Tensor, x_i: Tensor,\n                x_j: Tensor) -> Tensor:\n        delta = self.mlp_h(x_i)\n        e = torch.cat([pos_j - pos_i + delta, x_j], dim=-1)\n        return self.mlp_f(e)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(\\n'\n                f'  mlp_h={self.mlp_h},\\n'\n                f'  mlp_f={self.mlp_f},\\n'\n                f'  mlp_g={self.mlp_g},\\n'\n                f')')\n"
  },
  {
    "path": "torch_geometric/nn/conv/point_transformer_conv.py",
    "content": "from typing import Callable, Optional, Tuple, Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.typing import (\n    Adj,\n    OptTensor,\n    PairTensor,\n    SparseTensor,\n    torch_sparse,\n)\nfrom torch_geometric.utils import add_self_loops, remove_self_loops, softmax\n\n\nclass PointTransformerConv(MessagePassing):\n    r\"\"\"The Point Transformer layer from the `\"Point Transformer\"\n    <https://arxiv.org/abs/2012.09164>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i =  \\sum_{j \\in\n        \\mathcal{N}(i) \\cup \\{ i \\}} \\alpha_{i,j} \\left(\\mathbf{W}_3\n        \\mathbf{x}_j + \\delta_{ij} \\right),\n\n    where the attention coefficients :math:`\\alpha_{i,j}` and\n    positional embedding :math:`\\delta_{ij}` are computed as\n\n    .. math::\n        \\alpha_{i,j}= \\textrm{softmax} \\left( \\gamma_\\mathbf{\\Theta}\n        (\\mathbf{W}_1 \\mathbf{x}_i - \\mathbf{W}_2 \\mathbf{x}_j +\n        \\delta_{i,j}) \\right)\n\n    and\n\n    .. math::\n        \\delta_{i,j}= h_{\\mathbf{\\Theta}}(\\mathbf{p}_i - \\mathbf{p}_j),\n\n    with :math:`\\gamma_\\mathbf{\\Theta}` and :math:`h_\\mathbf{\\Theta}`\n    denoting neural networks, *i.e.* MLPs, and\n    :math:`\\mathbf{P} \\in \\mathbb{R}^{N \\times D}` defines the position of\n    each point.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        pos_nn (torch.nn.Module, optional): A neural network\n            :math:`h_\\mathbf{\\Theta}` which maps relative spatial coordinates\n            :obj:`pos_j - pos_i` of shape :obj:`[-1, 3]` to shape\n            :obj:`[-1, out_channels]`.\n            Will default to a :class:`torch.nn.Linear` transformation if not\n            further specified. (default: :obj:`None`)\n        attn_nn (torch.nn.Module, optional): A neural network\n            :math:`\\gamma_\\mathbf{\\Theta}` which maps transformed\n            node features of shape :obj:`[-1, out_channels]`\n            to shape :obj:`[-1, out_channels]`. (default: :obj:`None`)\n        add_self_loops (bool, optional) : If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          positions :math:`(|\\mathcal{V}|, 3)` or\n          :math:`((|\\mathcal{V_s}|, 3), (|\\mathcal{V_t}|, 3))` if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, in_channels: Union[int, Tuple[int, int]],\n                 out_channels: int, pos_nn: Optional[Callable] = None,\n                 attn_nn: Optional[Callable] = None,\n                 add_self_loops: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.add_self_loops = add_self_loops\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        self.pos_nn = pos_nn\n        if self.pos_nn is None:\n            self.pos_nn = Linear(3, out_channels)\n\n        self.attn_nn = attn_nn\n        self.lin = Linear(in_channels[0], out_channels, bias=False)\n        self.lin_src = Linear(in_channels[0], out_channels, bias=False)\n        self.lin_dst = Linear(in_channels[1], out_channels, bias=False)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        reset(self.pos_nn)\n        if self.attn_nn is not None:\n            reset(self.attn_nn)\n        self.lin.reset_parameters()\n        self.lin_src.reset_parameters()\n        self.lin_dst.reset_parameters()\n\n    def forward(\n        self,\n        x: Union[Tensor, PairTensor],\n        pos: Union[Tensor, PairTensor],\n        edge_index: Adj,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            alpha = (self.lin_src(x), self.lin_dst(x))\n            x = (self.lin(x), x)\n        else:\n            alpha = (self.lin_src(x[0]), self.lin_dst(x[1]))\n            x = (self.lin(x[0]), x[1])\n\n        if isinstance(pos, Tensor):\n            pos = (pos, pos)\n\n        if self.add_self_loops:\n            if isinstance(edge_index, Tensor):\n                edge_index, _ = remove_self_loops(edge_index)\n                edge_index, _ = add_self_loops(\n                    edge_index, num_nodes=min(pos[0].size(0), pos[1].size(0)))\n            elif isinstance(edge_index, SparseTensor):\n                edge_index = torch_sparse.set_diag(edge_index)\n\n        # propagate_type: (x: PairTensor, pos: PairTensor, alpha: PairTensor)\n        out = self.propagate(edge_index, x=x, pos=pos, alpha=alpha)\n        return out\n\n    def message(self, x_j: Tensor, pos_i: Tensor, pos_j: Tensor,\n                alpha_i: Tensor, alpha_j: Tensor, index: Tensor,\n                ptr: OptTensor, size_i: Optional[int]) -> Tensor:\n\n        delta = self.pos_nn(pos_i - pos_j)\n        alpha = alpha_i - alpha_j + delta\n        if self.attn_nn is not None:\n            alpha = self.attn_nn(alpha)\n        alpha = softmax(alpha, index, ptr, size_i)\n        return alpha * (x_j + delta)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/ppf_conv.py",
    "content": "from typing import Callable, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.typing import (\n    Adj,\n    OptTensor,\n    PairOptTensor,\n    PairTensor,\n    SparseTensor,\n    torch_sparse,\n)\nfrom torch_geometric.utils import add_self_loops, remove_self_loops\n\n\ndef get_angle(v1: Tensor, v2: Tensor) -> Tensor:\n    return torch.atan2(\n        torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1))\n\n\ndef point_pair_features(pos_i: Tensor, pos_j: Tensor, normal_i: Tensor,\n                        normal_j: Tensor) -> Tensor:\n    pseudo = pos_j - pos_i\n    return torch.stack([\n        pseudo.norm(p=2, dim=1),\n        get_angle(normal_i, pseudo),\n        get_angle(normal_j, pseudo),\n        get_angle(normal_i, normal_j)\n    ], dim=1)\n\n\nclass PPFConv(MessagePassing):\n    r\"\"\"The PPFNet operator from the `\"PPFNet: Global Context Aware Local\n    Features for Robust 3D Point Matching\" <https://arxiv.org/abs/1802.02669>`_\n    paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\gamma_{\\mathbf{\\Theta}} \\left( \\max_{j \\in\n        \\mathcal{N}(i) \\cup \\{ i \\}} h_{\\mathbf{\\Theta}} ( \\mathbf{x}_j, \\|\n        \\mathbf{d_{j,i}} \\|, \\angle(\\mathbf{n}_i, \\mathbf{d_{j,i}}),\n        \\angle(\\mathbf{n}_j, \\mathbf{d_{j,i}}), \\angle(\\mathbf{n}_i,\n        \\mathbf{n}_j) \\right)\n\n    where :math:`\\gamma_{\\mathbf{\\Theta}}` and :math:`h_{\\mathbf{\\Theta}}`\n    denote neural networks, *.i.e.* MLPs, which takes in node features and\n    :class:`torch_geometric.transforms.PointPairFeatures`.\n\n    Args:\n        local_nn (torch.nn.Module, optional): A neural network\n            :math:`h_{\\mathbf{\\Theta}}` that maps node features :obj:`x` and\n            relative spatial coordinates :obj:`pos_j - pos_i` of shape\n            :obj:`[-1, in_channels + num_dimensions]` to shape\n            :obj:`[-1, out_channels]`, *e.g.*, defined by\n            :class:`torch.nn.Sequential`. (default: :obj:`None`)\n        global_nn (torch.nn.Module, optional): A neural network\n            :math:`\\gamma_{\\mathbf{\\Theta}}` that maps aggregated node features\n            of shape :obj:`[-1, out_channels]` to shape :obj:`[-1,\n            final_out_channels]`, *e.g.*, defined by\n            :class:`torch.nn.Sequential`. (default: :obj:`None`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          positions :math:`(|\\mathcal{V}|, 3)` or\n          :math:`((|\\mathcal{V_s}|, 3), (|\\mathcal{V_t}|, 3))` if bipartite,\n          point normals :math:`(|\\mathcal{V}, 3)` or\n          :math:`((|\\mathcal{V_s}|, 3), (|\\mathcal{V_t}|, 3))` if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V}_t|, F_{out})` if bipartite\n\n    \"\"\"\n    def __init__(self, local_nn: Optional[Callable] = None,\n                 global_nn: Optional[Callable] = None,\n                 add_self_loops: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'max')\n        super().__init__(**kwargs)\n\n        self.local_nn = local_nn\n        self.global_nn = global_nn\n        self.add_self_loops = add_self_loops\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        reset(self.local_nn)\n        reset(self.global_nn)\n\n    def forward(\n        self,\n        x: Union[OptTensor, PairOptTensor],\n        pos: Union[Tensor, PairTensor],\n        normal: Union[Tensor, PairTensor],\n        edge_index: Adj,\n    ) -> Tensor:\n\n        if not isinstance(x, tuple):\n            x = (x, None)\n\n        if isinstance(pos, Tensor):\n            pos = (pos, pos)\n\n        if isinstance(normal, Tensor):\n            normal = (normal, normal)\n\n        if self.add_self_loops:\n            if isinstance(edge_index, Tensor):\n                edge_index, _ = remove_self_loops(edge_index)\n                edge_index, _ = add_self_loops(edge_index,\n                                               num_nodes=pos[1].size(0))\n            elif isinstance(edge_index, SparseTensor):\n                edge_index = torch_sparse.set_diag(edge_index)\n\n        # propagate_type: (x: PairOptTensor, pos: PairTensor,\n        #                  normal: PairTensor)\n        out = self.propagate(edge_index, x=x, pos=pos, normal=normal)\n\n        if self.global_nn is not None:\n            out = self.global_nn(out)\n\n        return out\n\n    def message(self, x_j: OptTensor, pos_i: Tensor, pos_j: Tensor,\n                normal_i: Tensor, normal_j: Tensor) -> Tensor:\n        msg = point_pair_features(pos_i, pos_j, normal_i, normal_j)\n        if x_j is not None:\n            msg = torch.cat([x_j, msg], dim=1)\n        if self.local_nn is not None:\n            msg = self.local_nn(msg)\n        return msg\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(local_nn={self.local_nn}, '\n                f'global_nn={self.global_nn})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/propagate.jinja",
    "content": "import typing\nfrom typing import Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling\nfrom torch_geometric.utils import is_sparse\nfrom torch_geometric.typing import Size, SparseTensor\n{% for module in modules %}\nfrom {{module}} import *\n{%- endfor %}\n\n\n{% include \"collect.jinja\" %}\n\n\ndef propagate(\n    self,\n    edge_index: Union[Tensor, SparseTensor],\n{%- for param in signature.param_dict.values() %}\n    {{param.name}}: {{param.type_repr}},\n{%- endfor %}\n    size: Size = None,\n) -> {{signature.return_type_repr}}:\n\n    # Begin Propagate Forward Pre Hook #########################################\n    if not torch.jit.is_scripting() and not is_compiling():\n        for hook in self._propagate_forward_pre_hooks.values():\n            hook_kwargs = dict(\n{%- for name in signature.param_dict %}\n                {{name}}={{name}},\n{%- endfor %}\n            )\n            res = hook(self, (edge_index, size, hook_kwargs))\n            if res is not None:\n                edge_index, size, hook_kwargs = res\n{%- for name in signature.param_dict %}\n                {{name}} = hook_kwargs['{{name}}']\n{%- endfor %}\n    # End Propagate Forward Pre Hook ###########################################\n\n    mutable_size = self._check_input(edge_index, size)\n\n    # Run \"fused\" message and aggregation (if applicable).\n    fuse = False\n    if self.fuse:\n        if is_sparse(edge_index):\n            fuse = True\n        elif not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n            if self.SUPPORTS_FUSED_EDGE_INDEX and edge_index.is_sorted_by_col:\n                fuse = True\n\n    if fuse:\n\n{%- if fuse %}\n        # Begin Message and Aggregate Forward Pre Hook #########################\n        if not torch.jit.is_scripting() and not is_compiling():\n            for hook in self._message_and_aggregate_forward_pre_hooks.values():\n                hook_kwargs = dict(\n{%- for name in message_and_aggregate_args %}\n                    {{name}}={{name}},\n{%- endfor %}\n                )\n                res = hook(self, (edge_index, hook_kwargs))\n                if res is not None:\n                    edge_index, hook_kwargs = res\n{%- for name in message_and_aggregate_args %}\n                    {{name}} = hook_kwargs['{{name}}']\n{%- endfor %}\n        # End Message and Aggregate Forward Pre Hook ##########################\n\n        out = self.message_and_aggregate(\n            edge_index,\n{%- for name in message_and_aggregate_args %}\n            {{name}},\n{%- endfor %}\n        )\n\n        # Begin Message and Aggregate Forward Hook #############################\n        if not torch.jit.is_scripting() and not is_compiling():\n            for hook in self._message_and_aggregate_forward_hooks.values():\n                hook_kwargs = dict(\n{%- for name in message_and_aggregate_args %}\n                    {{name}}={{name}},\n{%- endfor %}\n                )\n                res = hook(self, (edge_index, hook_kwargs, ), out)\n                out = res if res is not None else out\n        # End Message and Aggregate Forward Hook ###############################\n\n        out = self.update(\n            out,\n{%- for name in update_args %}\n            {{name}}={{name}},\n{%- endfor %}\n        )\n{%- else %}\n        raise NotImplementedError(\"'message_and_aggregate' not implemented\")\n{%- endif %}\n\n    else:\n\n        kwargs = self.{{collect_name}}(\n            edge_index,\n{%- for name in signature.param_dict %}\n            {{name}},\n{%- endfor %}\n            mutable_size,\n        )\n\n        # Begin Message Forward Pre Hook #######################################\n        if not torch.jit.is_scripting() and not is_compiling():\n            for hook in self._message_forward_pre_hooks.values():\n                hook_kwargs = dict(\n{%- for name in message_args %}\n                    {{name}}=kwargs.{{name}},\n{%- endfor %}\n                )\n                res = hook(self, (hook_kwargs, ))\n                hook_kwargs = res[0] if isinstance(res, tuple) else res\n                if res is not None:\n                    kwargs = CollectArgs(\n{%- for name in collect_param_dict %}\n{%- if name in message_args %}\n                        {{name}}=hook_kwargs['{{name}}'],\n{%- else %}\n                        {{name}}=kwargs.{{name}},\n{%- endif %}\n{%- endfor %}\n                    )\n        # End Message Forward Pre Hook #########################################\n\n        out = self.message(\n{%- for name in message_args %}\n            {{name}}=kwargs.{{name}},\n{%- endfor %}\n        )\n\n        # Begin Message Forward Hook ###########################################\n        if not torch.jit.is_scripting() and not is_compiling():\n            for hook in self._message_forward_hooks.values():\n                hook_kwargs = dict(\n{%- for name in message_args %}\n                    {{name}}=kwargs.{{name}},\n{%- endfor %}\n                )\n                res = hook(self, (hook_kwargs, ), out)\n                out = res if res is not None else out\n        # End Message Forward Hook #############################################\n\n        # Begin Aggregate Forward Pre Hook #####################################\n        if not torch.jit.is_scripting() and not is_compiling():\n            for hook in self._aggregate_forward_pre_hooks.values():\n                hook_kwargs = dict(\n{%- for name in aggregate_args %}\n                    {{name}}=kwargs.{{name}},\n{%- endfor %}\n                )\n                res = hook(self, (hook_kwargs, ))\n                hook_kwargs = res[0] if isinstance(res, tuple) else res\n                if res is not None:\n                    kwargs = CollectArgs(\n{%- for name in collect_param_dict %}\n{%- if name in aggregate_args %}\n                        {{name}}=hook_kwargs['{{name}}'],\n{%- else %}\n                        {{name}}=kwargs.{{name}},\n{%- endif %}\n{%- endfor %}\n                    )\n        # End Aggregate Forward Pre Hook #######################################\n\n        out = self.aggregate(\n            out,\n{%- for name in aggregate_args %}\n            {{name}}=kwargs.{{name}},\n{%- endfor %}\n        )\n\n        # Begin Aggregate Forward Hook #########################################\n        if not torch.jit.is_scripting() and not is_compiling():\n            for hook in self._aggregate_forward_hooks.values():\n                hook_kwargs = dict(\n{%- for name in aggregate_args %}\n                    {{name}}=kwargs.{{name}},\n{%- endfor %}\n                )\n                res = hook(self, (hook_kwargs, ), out)\n                out = res if res is not None else out\n        # End Aggregate Forward Hook ###########################################\n\n        out = self.update(\n            out,\n{%- for name in update_args %}\n            {{name}}=kwargs.{{name}},\n{%- endfor %}\n        )\n\n    # Begin Propagate Forward Hook ############################################\n    if not torch.jit.is_scripting() and not is_compiling():\n        for hook in self._propagate_forward_hooks.values():\n            hook_kwargs = dict(\n{%- for name in signature.param_dict %}\n                {{name}}={{name}},\n{%- endfor %}\n            )\n            res = hook(self, (edge_index, mutable_size, hook_kwargs), out)\n            out = res if res is not None else out\n    # End Propagate Forward Hook ##############################################\n\n    return out\n"
  },
  {
    "path": "torch_geometric/nn/conv/res_gated_graph_conv.py",
    "content": "from typing import Callable, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter, Sigmoid\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.typing import Adj, OptTensor, PairTensor\n\n\nclass ResGatedGraphConv(MessagePassing):\n    r\"\"\"The residual gated graph convolutional operator from the\n    `\"Residual Gated Graph ConvNets\" <https://arxiv.org/abs/1711.07553>`_\n    paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{W}_1 \\mathbf{x}_i +\n        \\sum_{j \\in \\mathcal{N}(i)} \\eta_{i,j} \\odot \\mathbf{W}_2 \\mathbf{x}_j\n\n    where the gate :math:`\\eta_{i,j}` is defined as\n\n    .. math::\n        \\eta_{i,j} = \\sigma(\\mathbf{W}_3 \\mathbf{x}_i + \\mathbf{W}_4\n        \\mathbf{x}_j)\n\n    with :math:`\\sigma` denoting the sigmoid function.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        act (callable, optional): Gating function :math:`\\sigma`.\n            (default: :meth:`torch.nn.Sigmoid()`)\n        edge_dim (int, optional): Edge feature dimensionality (in case\n            there are any). (default: :obj:`None`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        root_weight (bool, optional): If set to :obj:`False`, the layer will\n            not add transformed root node features to the output.\n            (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **inputs:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **outputs:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V_t}|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: int,\n        act: Optional[Callable] = Sigmoid(),\n        edge_dim: Optional[int] = None,\n        root_weight: bool = True,\n        bias: bool = True,\n        **kwargs,\n    ):\n\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.act = act\n        self.edge_dim = edge_dim\n        self.root_weight = root_weight\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        edge_dim = edge_dim if edge_dim is not None else 0\n        self.lin_key = Linear(in_channels[1] + edge_dim, out_channels)\n        self.lin_query = Linear(in_channels[0] + edge_dim, out_channels)\n        self.lin_value = Linear(in_channels[0] + edge_dim, out_channels)\n\n        if root_weight:\n            self.lin_skip = Linear(in_channels[1], out_channels, bias=False)\n        else:\n            self.register_parameter('lin_skip', None)\n\n        if bias:\n            self.bias = Parameter(Tensor(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin_key.reset_parameters()\n        self.lin_query.reset_parameters()\n        self.lin_value.reset_parameters()\n        if self.lin_skip is not None:\n            self.lin_skip.reset_parameters()\n        if self.bias is not None:\n            zeros(self.bias)\n\n    def forward(\n        self,\n        x: Union[Tensor, PairTensor],\n        edge_index: Adj,\n        edge_attr: OptTensor = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # In case edge features are not given, we can compute key, query and\n        # value tensors in node-level space, which is a bit more efficient:\n        if self.edge_dim is None:\n            k = self.lin_key(x[1])\n            q = self.lin_query(x[0])\n            v = self.lin_value(x[0])\n        else:\n            k, q, v = x[1], x[0], x[0]\n\n        # propagate_type: (k: Tensor, q: Tensor, v: Tensor,\n        #                  edge_attr: OptTensor)\n        out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr)\n\n        if self.root_weight:\n            out = out + self.lin_skip(x[1])\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, k_i: Tensor, q_j: Tensor, v_j: Tensor,\n                edge_attr: OptTensor) -> Tensor:\n\n        assert (edge_attr is not None) == (self.edge_dim is not None)\n\n        if edge_attr is not None:\n            k_i = self.lin_key(torch.cat([k_i, edge_attr], dim=-1))\n            q_j = self.lin_query(torch.cat([q_j, edge_attr], dim=-1))\n            v_j = self.lin_value(torch.cat([v_j, edge_attr], dim=-1))\n\n        return self.act(k_i + q_j) * v_j\n"
  },
  {
    "path": "torch_geometric/nn/conv/rgat_conv.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter, ReLU\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import glorot, ones, zeros\nfrom torch_geometric.typing import Adj, OptTensor, Size, SparseTensor\nfrom torch_geometric.utils import is_torch_sparse_tensor, scatter, softmax\nfrom torch_geometric.utils.sparse import set_sparse_value\n\n\nclass RGATConv(MessagePassing):\n    r\"\"\"The relational graph attentional operator from the `\"Relational Graph\n    Attention Networks\" <https://arxiv.org/abs/1904.05811>`_ paper.\n\n    Here, attention logits :math:`\\mathbf{a}^{(r)}_{i,j}` are computed for each\n    relation type :math:`r` with the help of both query and key kernels, *i.e.*\n\n    .. math::\n        \\mathbf{q}^{(r)}_i = \\mathbf{W}_1^{(r)}\\mathbf{x}_{i} \\cdot\n        \\mathbf{Q}^{(r)}\n        \\quad \\textrm{and} \\quad\n        \\mathbf{k}^{(r)}_i = \\mathbf{W}_1^{(r)}\\mathbf{x}_{i} \\cdot\n        \\mathbf{K}^{(r)}.\n\n    Two schemes have been proposed to compute attention logits\n    :math:`\\mathbf{a}^{(r)}_{i,j}` for each relation type :math:`r`:\n\n    **Additive attention**\n\n    .. math::\n        \\mathbf{a}^{(r)}_{i,j} = \\mathrm{LeakyReLU}(\\mathbf{q}^{(r)}_i +\n        \\mathbf{k}^{(r)}_j)\n\n    or **multiplicative attention**\n\n    .. math::\n        \\mathbf{a}^{(r)}_{i,j} = \\mathbf{q}^{(r)}_i \\cdot \\mathbf{k}^{(r)}_j.\n\n    If the graph has multi-dimensional edge features\n    :math:`\\mathbf{e}^{(r)}_{i,j}`, the attention logits\n    :math:`\\mathbf{a}^{(r)}_{i,j}` for each relation type :math:`r` are\n    computed as\n\n    .. math::\n        \\mathbf{a}^{(r)}_{i,j} = \\mathrm{LeakyReLU}(\\mathbf{q}^{(r)}_i +\n        \\mathbf{k}^{(r)}_j + \\mathbf{W}_2^{(r)}\\mathbf{e}^{(r)}_{i,j})\n\n    or\n\n    .. math::\n        \\mathbf{a}^{(r)}_{i,j} = \\mathbf{q}^{(r)}_i \\cdot \\mathbf{k}^{(r)}_j\n        \\cdot \\mathbf{W}_2^{(r)} \\mathbf{e}^{(r)}_{i,j},\n\n    respectively.\n    The attention coefficients :math:`\\alpha^{(r)}_{i,j}` for each relation\n    type :math:`r` are then obtained via two different attention mechanisms:\n    The **within-relation** attention mechanism\n\n    .. math::\n        \\alpha^{(r)}_{i,j} =\n        \\frac{\\exp(\\mathbf{a}^{(r)}_{i,j})}\n        {\\sum_{k \\in \\mathcal{N}_r(i)} \\exp(\\mathbf{a}^{(r)}_{i,k})}\n\n    or the **across-relation** attention mechanism\n\n    .. math::\n        \\alpha^{(r)}_{i,j} =\n        \\frac{\\exp(\\mathbf{a}^{(r)}_{i,j})}\n        {\\sum_{r^{\\prime} \\in \\mathcal{R}}\n        \\sum_{k \\in \\mathcal{N}_{r^{\\prime}}(i)}\n        \\exp(\\mathbf{a}^{(r^{\\prime})}_{i,k})}\n\n    where :math:`\\mathcal{R}` denotes the set of relations, *i.e.* edge types.\n    Edge type needs to be a one-dimensional :obj:`torch.long` tensor which\n    stores a relation identifier :math:`\\in \\{ 0, \\ldots, |\\mathcal{R}| - 1\\}`\n    for each edge.\n\n    To enhance the discriminative power of attention-based GNNs, this layer\n    further implements four different cardinality preservation options as\n    proposed in the `\"Improving Attention Mechanism in Graph Neural Networks\n    via Cardinality Preservation\" <https://arxiv.org/abs/1907.02204>`_ paper:\n\n    .. math::\n        \\text{additive:}~~~\\mathbf{x}^{{\\prime}(r)}_i &=\n        \\sum_{j \\in \\mathcal{N}_r(i)}\n        \\alpha^{(r)}_{i,j} \\mathbf{x}^{(r)}_j + \\mathcal{W} \\odot\n        \\sum_{j \\in \\mathcal{N}_r(i)} \\mathbf{x}^{(r)}_j\n\n        \\text{scaled:}~~~\\mathbf{x}^{{\\prime}(r)}_i &=\n        \\psi(|\\mathcal{N}_r(i)|) \\odot\n        \\sum_{j \\in \\mathcal{N}_r(i)} \\alpha^{(r)}_{i,j} \\mathbf{x}^{(r)}_j\n\n        \\text{f-additive:}~~~\\mathbf{x}^{{\\prime}(r)}_i &=\n        \\sum_{j \\in \\mathcal{N}_r(i)}\n        (\\alpha^{(r)}_{i,j} + 1) \\cdot \\mathbf{x}^{(r)}_j\n\n        \\text{f-scaled:}~~~\\mathbf{x}^{{\\prime}(r)}_i &=\n        |\\mathcal{N}_r(i)| \\odot \\sum_{j \\in \\mathcal{N}_r(i)}\n        \\alpha^{(r)}_{i,j} \\mathbf{x}^{(r)}_j\n\n    * If :obj:`attention_mode=\"additive-self-attention\"` and\n      :obj:`concat=True`, the layer outputs :obj:`heads * out_channels`\n      features for each node.\n\n    * If :obj:`attention_mode=\"multiplicative-self-attention\"` and\n      :obj:`concat=True`, the layer outputs :obj:`heads * dim * out_channels`\n      features for each node.\n\n    * If :obj:`attention_mode=\"additive-self-attention\"` and\n      :obj:`concat=False`, the layer outputs :obj:`out_channels` features for\n      each node.\n\n    * If :obj:`attention_mode=\"multiplicative-self-attention\"` and\n      :obj:`concat=False`, the layer outputs :obj:`dim * out_channels` features\n      for each node.\n\n    Please make sure to set the :obj:`in_channels` argument of the next\n    layer accordingly if more than one instance of this layer is used.\n\n    .. note::\n\n        For an example of using :class:`RGATConv`, see\n        `examples/rgat.py <https://github.com/pyg-team/pytorch_geometric/blob\n        /master/examples/rgat.py>`_.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        out_channels (int): Size of each output sample.\n        num_relations (int): Number of relations.\n        num_bases (int, optional): If set, this layer will use the\n            basis-decomposition regularization scheme where :obj:`num_bases`\n            denotes the number of bases to use. (default: :obj:`None`)\n        num_blocks (int, optional): If set, this layer will use the\n            block-diagonal-decomposition regularization scheme where\n            :obj:`num_blocks` denotes the number of blocks to use.\n            (default: :obj:`None`)\n        mod (str, optional): The cardinality preservation option to use.\n            (:obj:`\"additive\"`, :obj:`\"scaled\"`, :obj:`\"f-additive\"`,\n            :obj:`\"f-scaled\"`, :obj:`None`). (default: :obj:`None`)\n        attention_mechanism (str, optional): The attention mechanism to use\n            (:obj:`\"within-relation\"`, :obj:`\"across-relation\"`).\n            (default: :obj:`\"across-relation\"`)\n        attention_mode (str, optional): The mode to calculate attention logits.\n            (:obj:`\"additive-self-attention\"`,\n            :obj:`\"multiplicative-self-attention\"`).\n            (default: :obj:`\"additive-self-attention\"`)\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        dim (int): Number of dimensions for query and key kernels.\n            (default: :obj:`1`)\n        concat (bool, optional): If set to :obj:`False`, the multi-head\n            attentions are averaged instead of concatenated.\n            (default: :obj:`True`)\n        negative_slope (float, optional): LeakyReLU angle of the negative\n            slope. (default: :obj:`0.2`)\n        dropout (float, optional): Dropout probability of the normalized\n            attention coefficients which exposes each node to a stochastically\n            sampled neighborhood during training. (default: :obj:`0`)\n        edge_dim (int, optional): Edge feature dimensionality (in case there\n            are any). (default: :obj:`None`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not\n            learn an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n    \"\"\"\n\n    _alpha: OptTensor\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        num_relations: int,\n        num_bases: Optional[int] = None,\n        num_blocks: Optional[int] = None,\n        mod: Optional[str] = None,\n        attention_mechanism: str = \"across-relation\",\n        attention_mode: str = \"additive-self-attention\",\n        heads: int = 1,\n        dim: int = 1,\n        concat: bool = True,\n        negative_slope: float = 0.2,\n        dropout: float = 0.0,\n        edge_dim: Optional[int] = None,\n        bias: bool = True,\n        **kwargs,\n    ):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(node_dim=0, **kwargs)\n\n        self.heads = heads\n        self.negative_slope = negative_slope\n        self.dropout = dropout\n        self.mod = mod\n        self.activation = ReLU()\n        self.concat = concat\n        self.attention_mode = attention_mode\n        self.attention_mechanism = attention_mechanism\n        self.dim = dim\n        self.edge_dim = edge_dim\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_relations = num_relations\n        self.num_bases = num_bases\n        self.num_blocks = num_blocks\n\n        mod_types = ['additive', 'scaled', 'f-additive', 'f-scaled']\n\n        if (self.attention_mechanism != \"within-relation\"\n                and self.attention_mechanism != \"across-relation\"):\n            raise ValueError('attention mechanism must either be '\n                             '\"within-relation\" or \"across-relation\"')\n\n        if (self.attention_mode != \"additive-self-attention\"\n                and self.attention_mode != \"multiplicative-self-attention\"):\n            raise ValueError('attention mode must either be '\n                             '\"additive-self-attention\" or '\n                             '\"multiplicative-self-attention\"')\n\n        if self.attention_mode == \"additive-self-attention\" and self.dim > 1:\n            raise ValueError('\"additive-self-attention\" mode cannot be '\n                             'applied when value of d is greater than 1. '\n                             'Use \"multiplicative-self-attention\" instead.')\n\n        if self.dropout > 0.0 and self.mod in mod_types:\n            raise ValueError('mod must be None with dropout value greater '\n                             'than 0 in order to sample attention '\n                             'coefficients stochastically')\n\n        if num_bases is not None and num_blocks is not None:\n            raise ValueError('Can not apply both basis-decomposition and '\n                             'block-diagonal-decomposition at the same time.')\n\n        # The learnable parameters to compute both attention logits and\n        # attention coefficients:\n        self.q = Parameter(\n            torch.empty(self.heads * self.out_channels, self.heads * self.dim))\n        self.k = Parameter(\n            torch.empty(self.heads * self.out_channels, self.heads * self.dim))\n\n        if bias and concat:\n            self.bias = Parameter(\n                torch.empty(self.heads * self.dim * self.out_channels))\n        elif bias and not concat:\n            self.bias = Parameter(torch.empty(self.dim * self.out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        if edge_dim is not None:\n            self.lin_edge = Linear(self.edge_dim,\n                                   self.heads * self.out_channels, bias=False,\n                                   weight_initializer='glorot')\n            self.e = Parameter(\n                torch.empty(self.heads * self.out_channels,\n                            self.heads * self.dim))\n        else:\n            self.lin_edge = None\n            self.register_parameter('e', None)\n\n        if num_bases is not None:\n            self.att = Parameter(\n                torch.empty(self.num_relations, self.num_bases))\n            self.basis = Parameter(\n                torch.empty(self.num_bases, self.in_channels,\n                            self.heads * self.out_channels))\n        elif num_blocks is not None:\n            assert (\n                self.in_channels % self.num_blocks == 0\n                and (self.heads * self.out_channels) % self.num_blocks == 0), (\n                    \"both 'in_channels' and 'heads * out_channels' must be \"\n                    \"multiple of 'num_blocks' used\")\n            self.weight = Parameter(\n                torch.empty(self.num_relations, self.num_blocks,\n                            self.in_channels // self.num_blocks,\n                            (self.heads * self.out_channels) //\n                            self.num_blocks))\n        else:\n            self.weight = Parameter(\n                torch.empty(self.num_relations, self.in_channels,\n                            self.heads * self.out_channels))\n\n        self.w = Parameter(torch.ones(self.out_channels))\n        self.l1 = Parameter(torch.empty(1, self.out_channels))\n        self.b1 = Parameter(torch.empty(1, self.out_channels))\n        self.l2 = Parameter(torch.empty(self.out_channels, self.out_channels))\n        self.b2 = Parameter(torch.empty(1, self.out_channels))\n\n        self._alpha = None\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        if self.num_bases is not None:\n            glorot(self.basis)\n            glorot(self.att)\n        else:\n            glorot(self.weight)\n        glorot(self.q)\n        glorot(self.k)\n        zeros(self.bias)\n        ones(self.l1)\n        zeros(self.b1)\n        torch.full(self.l2.size(), 1 / self.out_channels)\n        zeros(self.b2)\n        if self.lin_edge is not None:\n            glorot(self.lin_edge)\n            glorot(self.e)\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Adj,\n        edge_type: OptTensor = None,\n        edge_attr: OptTensor = None,\n        size: Size = None,\n        return_attention_weights=None,\n    ):\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor): The input node features.\n                Can be either a :obj:`[num_nodes, in_channels]` node feature\n                matrix, or an optional one-dimensional node index tensor (in\n                which case input features are treated as trainable node\n                embeddings).\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            edge_type (torch.Tensor, optional): The one-dimensional relation\n                type/index for each edge in :obj:`edge_index`.\n                Should be only :obj:`None` in case :obj:`edge_index` is of type\n                :class:`torch_sparse.SparseTensor` or\n                :class:`torch.sparse.Tensor`. (default: :obj:`None`)\n            edge_attr (torch.Tensor, optional): The edge features.\n                (default: :obj:`None`)\n            size ((int, int), optional): The shape of the adjacency matrix.\n                (default: :obj:`None`)\n            return_attention_weights (bool, optional):\n                Will additionally return the tuple\n                :obj:`(edge_index, attention_weights)` whenever it is set to\n                a value, regardless of its actual value\n                (might be `True` or `False`), holding the computed attention\n                weights for each edge.\n                (default: :obj:`None`)\n        \"\"\"\n        # propagate_type: (x: Tensor, edge_type: OptTensor,\n        #                  edge_attr: OptTensor)\n        out = self.propagate(edge_index=edge_index, edge_type=edge_type, x=x,\n                             size=size, edge_attr=edge_attr)\n\n        alpha = self._alpha\n        assert alpha is not None\n        self._alpha = None\n\n        if isinstance(return_attention_weights, bool):\n            if isinstance(edge_index, Tensor):\n                if is_torch_sparse_tensor(edge_index):\n                    # TODO TorchScript requires to return a tuple\n                    adj = set_sparse_value(edge_index, alpha)\n                    return out, (adj, alpha)\n                else:\n                    return out, (edge_index, alpha)\n            elif isinstance(edge_index, SparseTensor):\n                return out, edge_index.set_value(alpha, layout='coo')\n        else:\n            return out\n\n    def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor,\n                edge_attr: OptTensor, index: Tensor, ptr: OptTensor,\n                size_i: Optional[int]) -> Tensor:\n\n        if self.num_bases is not None:  # Basis-decomposition =================\n            w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))\n            w = w.view(self.num_relations, self.in_channels,\n                       self.heads * self.out_channels)\n        if self.num_blocks is not None:  # Block-diagonal-decomposition =======\n            if (x_i.dtype == torch.long and x_j.dtype == torch.long\n                    and self.num_blocks is not None):\n                raise ValueError('Block-diagonal decomposition not supported '\n                                 'for non-continuous input features.')\n            w = self.weight\n            x_i = x_i.view(-1, 1, w.size(1), w.size(2))\n            x_j = x_j.view(-1, 1, w.size(1), w.size(2))\n            w = torch.index_select(w, 0, edge_type)\n            outi = torch.einsum('abcd,acde->ace', x_i, w)\n            outi = outi.contiguous().view(-1, self.heads * self.out_channels)\n            outj = torch.einsum('abcd,acde->ace', x_j, w)\n            outj = outj.contiguous().view(-1, self.heads * self.out_channels)\n        else:  # No regularization/Basis-decomposition ========================\n            if self.num_bases is None:\n                w = self.weight\n            w = torch.index_select(w, 0, edge_type)\n            outi = torch.bmm(x_i.unsqueeze(1), w).squeeze(-2)\n            outj = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)\n\n        qi = torch.matmul(outi, self.q)\n        kj = torch.matmul(outj, self.k)\n\n        alpha_edge, alpha = 0, torch.tensor([0])\n        if edge_attr is not None:\n            if edge_attr.dim() == 1:\n                edge_attr = edge_attr.view(-1, 1)\n            assert self.lin_edge is not None, (\n                \"Please set 'edge_dim = edge_attr.size(-1)' while calling the \"\n                \"RGATConv layer\")\n            edge_attributes = self.lin_edge(edge_attr).view(\n                -1, self.heads * self.out_channels)\n            if edge_attributes.size(0) != edge_attr.size(0):\n                edge_attributes = torch.index_select(edge_attributes, 0,\n                                                     edge_type)\n            alpha_edge = torch.matmul(edge_attributes, self.e)\n\n        if self.attention_mode == \"additive-self-attention\":\n            if edge_attr is not None:\n                alpha = torch.add(qi, kj) + alpha_edge\n            else:\n                alpha = torch.add(qi, kj)\n            alpha = F.leaky_relu(alpha, self.negative_slope)\n        elif self.attention_mode == \"multiplicative-self-attention\":\n            if edge_attr is not None:\n                alpha = (qi * kj) * alpha_edge\n            else:\n                alpha = qi * kj\n\n        if self.attention_mechanism == \"within-relation\":\n            across_out = torch.zeros_like(alpha)\n            for r in range(self.num_relations):\n                mask = edge_type == r\n                across_out[mask] = softmax(alpha[mask], index[mask])\n            alpha = across_out\n        elif self.attention_mechanism == \"across-relation\":\n            alpha = softmax(alpha, index, ptr, size_i)\n\n        self._alpha = alpha\n\n        if self.mod == \"additive\":\n            if self.attention_mode == \"additive-self-attention\":\n                ones = torch.ones_like(alpha)\n                h = (outj.view(-1, self.heads, self.out_channels) *\n                     ones.view(-1, self.heads, 1))\n                h = torch.mul(self.w, h)\n\n                return (outj.view(-1, self.heads, self.out_channels) *\n                        alpha.view(-1, self.heads, 1) + h)\n            elif self.attention_mode == \"multiplicative-self-attention\":\n                ones = torch.ones_like(alpha)\n                h = (outj.view(-1, self.heads, 1, self.out_channels) *\n                     ones.view(-1, self.heads, self.dim, 1))\n                h = torch.mul(self.w, h)\n\n                return (outj.view(-1, self.heads, 1, self.out_channels) *\n                        alpha.view(-1, self.heads, self.dim, 1) + h)\n\n        elif self.mod == \"scaled\":\n            if self.attention_mode == \"additive-self-attention\":\n                ones = alpha.new_ones(index.size())\n                degree = scatter(ones, index, dim_size=size_i,\n                                 reduce='sum')[index].unsqueeze(-1)\n                degree = torch.matmul(degree, self.l1) + self.b1\n                degree = self.activation(degree)\n                degree = torch.matmul(degree, self.l2) + self.b2\n\n                return torch.mul(\n                    outj.view(-1, self.heads, self.out_channels) *\n                    alpha.view(-1, self.heads, 1),\n                    degree.view(-1, 1, self.out_channels))\n            elif self.attention_mode == \"multiplicative-self-attention\":\n                ones = alpha.new_ones(index.size())\n                degree = scatter(ones, index, dim_size=size_i,\n                                 reduce='sum')[index].unsqueeze(-1)\n                degree = torch.matmul(degree, self.l1) + self.b1\n                degree = self.activation(degree)\n                degree = torch.matmul(degree, self.l2) + self.b2\n\n                return torch.mul(\n                    outj.view(-1, self.heads, 1, self.out_channels) *\n                    alpha.view(-1, self.heads, self.dim, 1),\n                    degree.view(-1, 1, 1, self.out_channels))\n\n        elif self.mod == \"f-additive\":\n            alpha = torch.where(alpha > 0, alpha + 1, alpha)\n\n        elif self.mod == \"f-scaled\":\n            ones = alpha.new_ones(index.size())\n            degree = scatter(ones, index, dim_size=size_i,\n                             reduce='sum')[index].unsqueeze(-1)\n            alpha = alpha * degree\n\n        elif self.training and self.dropout > 0:\n            alpha = F.dropout(alpha, p=self.dropout, training=True)\n\n        else:\n            alpha = alpha  # original\n\n        if self.attention_mode == \"additive-self-attention\":\n            return alpha.view(-1, self.heads, 1) * outj.view(\n                -1, self.heads, self.out_channels)\n        else:\n            return (alpha.view(-1, self.heads, self.dim, 1) *\n                    outj.view(-1, self.heads, 1, self.out_channels))\n\n    def update(self, aggr_out: Tensor) -> Tensor:\n        if self.attention_mode == \"additive-self-attention\":\n            if self.concat is True:\n                aggr_out = aggr_out.view(-1, self.heads * self.out_channels)\n            else:\n                aggr_out = aggr_out.mean(dim=1)\n\n            if self.bias is not None:\n                aggr_out = aggr_out + self.bias\n\n            return aggr_out\n        else:\n            if self.concat is True:\n                aggr_out = aggr_out.view(\n                    -1, self.heads * self.dim * self.out_channels)\n            else:\n                aggr_out = aggr_out.mean(dim=1)\n                aggr_out = aggr_out.view(-1, self.dim * self.out_channels)\n\n            if self.bias is not None:\n                aggr_out = aggr_out + self.bias\n\n            return aggr_out\n\n    def __repr__(self) -> str:\n        return '{}({}, {}, heads={})'.format(self.__class__.__name__,\n                                             self.in_channels,\n                                             self.out_channels, self.heads)\n"
  },
  {
    "path": "torch_geometric/nn/conv/rgcn_conv.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nimport torch_geometric.backend\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling\nfrom torch_geometric.index import index2ptr\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.inits import glorot, zeros\nfrom torch_geometric.typing import (\n    Adj,\n    OptTensor,\n    SparseTensor,\n    pyg_lib,\n    torch_sparse,\n)\nfrom torch_geometric.utils import index_sort, one_hot, scatter, spmm\n\n\ndef masked_edge_index(edge_index: Adj, edge_mask: Tensor) -> Adj:\n    if isinstance(edge_index, Tensor):\n        return edge_index[:, edge_mask]\n    return torch_sparse.masked_select_nnz(edge_index, edge_mask, layout='coo')\n\n\nclass RGCNConv(MessagePassing):\n    r\"\"\"The relational graph convolutional operator from the `\"Modeling\n    Relational Data with Graph Convolutional Networks\"\n    <https://arxiv.org/abs/1703.06103>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{\\Theta}_{\\textrm{root}} \\cdot\n        \\mathbf{x}_i + \\sum_{r \\in \\mathcal{R}} \\sum_{j \\in \\mathcal{N}_r(i)}\n        \\frac{1}{|\\mathcal{N}_r(i)|} \\mathbf{\\Theta}_r \\cdot \\mathbf{x}_j,\n\n    where :math:`\\mathcal{R}` denotes the set of relations, *i.e.* edge types.\n    Edge type needs to be a one-dimensional :obj:`torch.long` tensor which\n    stores a relation identifier\n    :math:`\\in \\{ 0, \\ldots, |\\mathcal{R}| - 1\\}` for each edge.\n\n    .. note::\n        This implementation is as memory-efficient as possible by iterating\n        over each individual relation type.\n        Therefore, it may result in low GPU utilization in case the graph has a\n        large number of relations.\n        As an alternative approach, :class:`FastRGCNConv` does not iterate over\n        each individual type, but may consume a large amount of memory to\n        compensate.\n        We advise to check out both implementations to see which one fits your\n        needs.\n\n    .. note::\n        :class:`RGCNConv` can use `dynamic shapes\n        <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index\n        .html#work_dynamic_shapes>`_, which means that the shape of the interim\n        tensors can be determined at runtime.\n        If your device doesn't support dynamic shapes, use\n        :class:`FastRGCNConv` instead.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample. A tuple\n            corresponds to the sizes of source and target dimensionalities.\n            In case no input features are given, this argument should\n            correspond to the number of nodes in your graph.\n        out_channels (int): Size of each output sample.\n        num_relations (int): Number of relations.\n        num_bases (int, optional): If set, this layer will use the\n            basis-decomposition regularization scheme where :obj:`num_bases`\n            denotes the number of bases to use. (default: :obj:`None`)\n        num_blocks (int, optional): If set, this layer will use the\n            block-diagonal-decomposition regularization scheme where\n            :obj:`num_blocks` denotes the number of blocks to use.\n            (default: :obj:`None`)\n        aggr (str, optional): The aggregation scheme to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"mean\"`)\n        root_weight (bool, optional): If set to :obj:`False`, the layer will\n            not add transformed root node features to the output.\n            (default: :obj:`True`)\n        is_sorted (bool, optional): If set to :obj:`True`, assumes that\n            :obj:`edge_index` is sorted by :obj:`edge_type`. This avoids\n            internal re-sorting of the data and can improve runtime and memory\n            efficiency. (default: :obj:`False`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: int,\n        num_relations: int,\n        num_bases: Optional[int] = None,\n        num_blocks: Optional[int] = None,\n        aggr: str = 'mean',\n        root_weight: bool = True,\n        is_sorted: bool = False,\n        bias: bool = True,\n        **kwargs,\n    ):\n        kwargs.setdefault('aggr', aggr)\n        super().__init__(node_dim=0, **kwargs)\n\n        if num_bases is not None and num_blocks is not None:\n            raise ValueError('Can not apply both basis-decomposition and '\n                             'block-diagonal-decomposition at the same time.')\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_relations = num_relations\n        self.num_bases = num_bases\n        self.num_blocks = num_blocks\n        self.is_sorted = is_sorted\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n        self.in_channels_l = in_channels[0]\n\n        self._use_segment_matmul_heuristic_output: torch.jit.Attribute(\n            None, Optional[float])\n\n        if num_bases is not None:\n            self.weight = Parameter(\n                torch.empty(num_bases, in_channels[0], out_channels))\n            self.comp = Parameter(torch.empty(num_relations, num_bases))\n\n        elif num_blocks is not None:\n            assert (in_channels[0] % num_blocks == 0\n                    and out_channels % num_blocks == 0)\n            self.weight = Parameter(\n                torch.empty(num_relations, num_blocks,\n                            in_channels[0] // num_blocks,\n                            out_channels // num_blocks))\n            self.register_parameter('comp', None)\n\n        else:\n            self.weight = Parameter(\n                torch.empty(num_relations, in_channels[0], out_channels))\n            self.register_parameter('comp', None)\n\n        if root_weight:\n            self.root = Parameter(torch.empty(in_channels[1], out_channels))\n        else:\n            self.register_parameter('root', None)\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        glorot(self.weight)\n        glorot(self.comp)\n        glorot(self.root)\n        zeros(self.bias)\n\n    def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],\n                edge_index: Adj, edge_type: OptTensor = None):\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor or tuple, optional): The input node features.\n                Can be either a :obj:`[num_nodes, in_channels]` node feature\n                matrix, or an optional one-dimensional node index tensor (in\n                which case input features are treated as trainable node\n                embeddings).\n                Furthermore, :obj:`x` can be of type :obj:`tuple` denoting\n                source and destination node features.\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            edge_type (torch.Tensor, optional): The one-dimensional relation\n                type/index for each edge in :obj:`edge_index`.\n                Should be only :obj:`None` in case :obj:`edge_index` is of type\n                :class:`torch_sparse.SparseTensor`. (default: :obj:`None`)\n        \"\"\"\n        # Convert input features to a pair of node features or node indices.\n        x_l: OptTensor = None\n        if isinstance(x, tuple):\n            x_l = x[0]\n        else:\n            x_l = x\n        if x_l is None:\n            x_l = torch.arange(self.in_channels_l, device=self.weight.device)\n\n        x_r: Tensor = x_l\n        if isinstance(x, tuple):\n            x_r = x[1]\n\n        size = (x_l.size(0), x_r.size(0))\n        if isinstance(edge_index, SparseTensor):\n            edge_type = edge_index.storage.value()\n        assert edge_type is not None\n\n        # propagate_type: (x: Tensor, edge_type_ptr: OptTensor)\n        out = torch.zeros(x_r.size(0), self.out_channels, device=x_r.device)\n\n        weight = self.weight\n        if self.num_bases is not None:  # Basis-decomposition =================\n            weight = (self.comp @ weight.view(self.num_bases, -1)).view(\n                self.num_relations, self.in_channels_l, self.out_channels)\n\n        if self.num_blocks is not None:  # Block-diagonal-decomposition =====\n\n            if not torch.is_floating_point(\n                    x_r) and self.num_blocks is not None:\n                raise ValueError('Block-diagonal decomposition not supported '\n                                 'for non-continuous input features.')\n\n            for i in range(self.num_relations):\n                tmp = masked_edge_index(edge_index, edge_type == i)\n                h = self.propagate(tmp, x=x_l, edge_type_ptr=None, size=size)\n                h = h.view(-1, weight.size(1), weight.size(2))\n                h = torch.einsum('abc,bcd->abd', h, weight[i])\n                out = out + h.contiguous().view(-1, self.out_channels)\n\n        else:  # No regularization/Basis-decomposition ========================\n\n            use_segment_matmul = torch_geometric.backend.use_segment_matmul\n            # If `use_segment_matmul` is not specified, use a simple heuristic\n            # to determine whether `segment_matmul` can speed up computation\n            # given the observed input sizes:\n            if use_segment_matmul is None:\n                segment_count = scatter(torch.ones_like(edge_type), edge_type,\n                                        dim_size=self.num_relations)\n\n                self._use_segment_matmul_heuristic_output = (\n                    torch_geometric.backend.use_segment_matmul_heuristic(\n                        num_segments=self.num_relations,\n                        max_segment_size=int(segment_count.max()),\n                        in_channels=self.weight.size(1),\n                        out_channels=self.weight.size(2),\n                    ))\n\n                assert self._use_segment_matmul_heuristic_output is not None\n                use_segment_matmul = self._use_segment_matmul_heuristic_output\n\n            if (use_segment_matmul and torch_geometric.typing.WITH_SEGMM\n                    and not is_compiling() and self.num_bases is None\n                    and x_l.is_floating_point()\n                    and isinstance(edge_index, Tensor)):\n\n                if not self.is_sorted:\n                    if (edge_type[1:] < edge_type[:-1]).any():\n                        edge_type, perm = index_sort(\n                            edge_type, max_value=self.num_relations)\n                        edge_index = edge_index[:, perm]\n                edge_type_ptr = index2ptr(edge_type, self.num_relations)\n                out = self.propagate(edge_index, x=x_l,\n                                     edge_type_ptr=edge_type_ptr, size=size)\n            else:\n                for i in range(self.num_relations):\n                    tmp = masked_edge_index(edge_index, edge_type == i)\n\n                    if not torch.is_floating_point(x_r):\n                        out = out + self.propagate(\n                            tmp,\n                            x=weight[i, x_l],\n                            edge_type_ptr=None,\n                            size=size,\n                        )\n                    else:\n                        h = self.propagate(tmp, x=x_l, edge_type_ptr=None,\n                                           size=size)\n                        out = out + (h @ weight[i])\n\n        root = self.root\n        if root is not None:\n            if not torch.is_floating_point(x_r):\n                out = out + root[x_r]\n            else:\n                out = out + x_r @ root\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, edge_type_ptr: OptTensor) -> Tensor:\n        if (torch_geometric.typing.WITH_SEGMM and not is_compiling()\n                and edge_type_ptr is not None):\n            # TODO Re-weight according to edge type degree for `aggr=mean`.\n            return pyg_lib.ops.segment_matmul(x_j, edge_type_ptr, self.weight)\n\n        return x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        if isinstance(adj_t, SparseTensor):\n            adj_t = adj_t.set_value(None)\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, num_relations={self.num_relations})')\n\n\nclass FastRGCNConv(RGCNConv):\n    r\"\"\"See :class:`RGCNConv`.\"\"\"\n    def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],\n                edge_index: Adj, edge_type: OptTensor = None):\n\n        self.fuse = False\n        assert self.aggr in ['add', 'sum', 'mean']\n\n        # Convert input features to a pair of node features or node indices.\n        x_l: OptTensor = None\n        if isinstance(x, tuple):\n            x_l = x[0]\n        else:\n            x_l = x\n        if x_l is None:\n            x_l = torch.arange(self.in_channels_l, device=self.weight.device)\n\n        x_r: Tensor = x_l\n        if isinstance(x, tuple):\n            x_r = x[1]\n\n        size = (x_l.size(0), x_r.size(0))\n\n        # propagate_type: (x: Tensor, edge_type: OptTensor)\n        out = self.propagate(edge_index, x=x_l, edge_type=edge_type, size=size)\n\n        root = self.root\n        if root is not None:\n            if not torch.is_floating_point(x_r):\n                out = out + root[x_r]\n            else:\n                out = out + x_r @ root\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, edge_type: Tensor,\n                edge_index_j: Tensor) -> Tensor:\n        weight = self.weight\n        if self.num_bases is not None:  # Basis-decomposition =================\n            weight = (self.comp @ weight.view(self.num_bases, -1)).view(\n                self.num_relations, self.in_channels_l, self.out_channels)\n\n        if self.num_blocks is not None:  # Block-diagonal-decomposition =======\n            if not torch.is_floating_point(x_j):\n                raise ValueError('Block-diagonal decomposition not supported '\n                                 'for non-continuous input features.')\n\n            weight = weight[edge_type].view(-1, weight.size(2), weight.size(3))\n            x_j = x_j.view(-1, 1, weight.size(1))\n            return torch.bmm(x_j, weight).view(-1, self.out_channels)\n\n        else:  # No regularization/Basis-decomposition ========================\n            if not torch.is_floating_point(x_j):\n                weight_index = edge_type * weight.size(1) + edge_index_j\n                return weight.view(-1, self.out_channels)[weight_index]\n\n            return torch.bmm(x_j.unsqueeze(-2), weight[edge_type]).squeeze(-2)\n\n    def aggregate(self, inputs: Tensor, edge_type: Tensor, index: Tensor,\n                  dim_size: Optional[int] = None) -> Tensor:\n\n        # Compute normalization in separation for each `edge_type`.\n        if self.aggr == 'mean':\n            norm = one_hot(edge_type, self.num_relations, dtype=inputs.dtype)\n            norm = scatter(norm, index, dim=0, dim_size=dim_size)[index]\n            norm = torch.gather(norm, 1, edge_type.view(-1, 1))\n            norm = 1. / norm.clamp_(1.)\n            inputs = norm * inputs\n\n        return scatter(inputs, index, dim=self.node_dim, dim_size=dim_size)\n"
  },
  {
    "path": "torch_geometric/nn/conv/sage_conv.py",
    "content": "from typing import List, Optional, Tuple, Union\n\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation, MultiAggregation\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import Adj, OptPairTensor, Size, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass SAGEConv(MessagePassing):\n    r\"\"\"The GraphSAGE operator from the `\"Inductive Representation Learning on\n    Large Graphs\" <https://arxiv.org/abs/1706.02216>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{W}_1 \\mathbf{x}_i + \\mathbf{W}_2 \\cdot\n        \\mathrm{mean}_{j \\in \\mathcal{N(i)}} \\mathbf{x}_j\n\n    If :obj:`project = True`, then :math:`\\mathbf{x}_j` will first get\n    projected via\n\n    .. math::\n        \\mathbf{x}_j \\leftarrow \\sigma ( \\mathbf{W}_3 \\mathbf{x}_j +\n        \\mathbf{b})\n\n    as described in Eq. (3) of the paper.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        aggr (str or Aggregation, optional): The aggregation scheme to use.\n            Any aggregation of :obj:`torch_geometric.nn.aggr` can be used,\n            *e.g.*, :obj:`\"mean\"`, :obj:`\"max\"`, or :obj:`\"lstm\"`.\n            (default: :obj:`\"mean\"`)\n        normalize (bool, optional): If set to :obj:`True`, output features\n            will be :math:`\\ell_2`-normalized, *i.e.*,\n            :math:`\\frac{\\mathbf{x}^{\\prime}_i}\n            {\\| \\mathbf{x}^{\\prime}_i \\|_2}`.\n            (default: :obj:`False`)\n        root_weight (bool, optional): If set to :obj:`False`, the layer will\n            not add transformed root node features to the output.\n            (default: :obj:`True`)\n        project (bool, optional): If set to :obj:`True`, the layer will apply a\n            linear transformation followed by an activation function before\n            aggregation (as described in Eq. (3) of the paper).\n            (default: :obj:`False`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **inputs:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{s}), (|\\mathcal{V_t}|, F_{t}))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **outputs:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V_t}|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: int,\n        aggr: Optional[Union[str, List[str], Aggregation]] = \"mean\",\n        normalize: bool = False,\n        root_weight: bool = True,\n        project: bool = False,\n        bias: bool = True,\n        **kwargs,\n    ):\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.normalize = normalize\n        self.root_weight = root_weight\n        self.project = project\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        if aggr == 'lstm':\n            kwargs.setdefault('aggr_kwargs', {})\n            kwargs['aggr_kwargs'].setdefault('in_channels', in_channels[0])\n            kwargs['aggr_kwargs'].setdefault('out_channels', in_channels[0])\n\n        super().__init__(aggr, **kwargs)\n\n        if self.project:\n            if in_channels[0] <= 0:\n                raise ValueError(f\"'{self.__class__.__name__}' does not \"\n                                 f\"support lazy initialization with \"\n                                 f\"`project=True`\")\n            self.lin = Linear(in_channels[0], in_channels[0], bias=True)\n\n        if isinstance(self.aggr_module, MultiAggregation):\n            aggr_out_channels = self.aggr_module.get_out_channels(\n                in_channels[0])\n        else:\n            aggr_out_channels = in_channels[0]\n\n        self.lin_l = Linear(aggr_out_channels, out_channels, bias=bias)\n        if self.root_weight:\n            self.lin_r = Linear(in_channels[1], out_channels, bias=False)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        if self.project:\n            self.lin.reset_parameters()\n        self.lin_l.reset_parameters()\n        if self.root_weight:\n            self.lin_r.reset_parameters()\n\n    def forward(\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        size: Size = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        if self.project and hasattr(self, 'lin'):\n            x = (self.lin(x[0]).relu(), x[1])\n\n        # propagate_type: (x: OptPairTensor)\n        out = self.propagate(edge_index, x=x, size=size)\n        out = self.lin_l(out)\n\n        x_r = x[1]\n        if self.root_weight and x_r is not None:\n            out = out + self.lin_r(x_r)\n\n        if self.normalize:\n            out = F.normalize(out, p=2., dim=-1)\n\n        return out\n\n    def message(self, x_j: Tensor) -> Tensor:\n        return x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:\n        if isinstance(adj_t, SparseTensor):\n            adj_t = adj_t.set_value(None, layout=None)\n        return spmm(adj_t, x[0], reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, aggr={self.aggr})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/sg_conv.py",
    "content": "from typing import Optional\n\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass SGConv(MessagePassing):\n    r\"\"\"The simple graph convolutional operator from the `\"Simplifying Graph\n    Convolutional Networks\" <https://arxiv.org/abs/1902.07153>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = {\\left(\\mathbf{\\hat{D}}^{-1/2} \\mathbf{\\hat{A}}\n        \\mathbf{\\hat{D}}^{-1/2} \\right)}^K \\mathbf{X} \\mathbf{\\Theta},\n\n    where :math:`\\mathbf{\\hat{A}} = \\mathbf{A} + \\mathbf{I}` denotes the\n    adjacency matrix with inserted self-loops and\n    :math:`\\hat{D}_{ii} = \\sum_{j=0} \\hat{A}_{ij}` its diagonal degree matrix.\n    The adjacency matrix can include other values than :obj:`1` representing\n    edge weights via the optional :obj:`edge_weight` tensor.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        K (int, optional): Number of hops :math:`K`. (default: :obj:`1`)\n        cached (bool, optional): If set to :obj:`True`, the layer will cache\n            the computation of :math:`{\\left(\\mathbf{\\hat{D}}^{-1/2}\n            \\mathbf{\\hat{A}} \\mathbf{\\hat{D}}^{-1/2} \\right)}^K \\mathbf{X}` on\n            first execution, and will use the cached version for further\n            executions.\n            This parameter should only be set to :obj:`True` in transductive\n            learning scenarios. (default: :obj:`False`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:**\n          node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n\n    _cached_x: Optional[Tensor]\n\n    def __init__(self, in_channels: int, out_channels: int, K: int = 1,\n                 cached: bool = False, add_self_loops: bool = True,\n                 bias: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.K = K\n        self.cached = cached\n        self.add_self_loops = add_self_loops\n\n        self._cached_x = None\n\n        self.lin = Linear(in_channels, out_channels, bias=bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin.reset_parameters()\n        self._cached_x = None\n\n    def forward(self, x: Tensor, edge_index: Adj,\n                edge_weight: OptTensor = None) -> Tensor:\n\n        cache = self._cached_x\n        if cache is None:\n            if isinstance(edge_index, Tensor):\n                edge_index, edge_weight = gcn_norm(  # yapf: disable\n                    edge_index, edge_weight, x.size(self.node_dim), False,\n                    self.add_self_loops, self.flow, dtype=x.dtype)\n            elif isinstance(edge_index, SparseTensor):\n                edge_index = gcn_norm(  # yapf: disable\n                    edge_index, edge_weight, x.size(self.node_dim), False,\n                    self.add_self_loops, self.flow, dtype=x.dtype)\n\n            for _ in range(self.K):\n                # propagate_type: (x: Tensor, edge_weight: OptTensor)\n                x = self.propagate(edge_index, x=x, edge_weight=edge_weight)\n                if self.cached:\n                    self._cached_x = x\n        else:\n            x = cache.detach()\n\n        return self.lin(x)\n\n    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:\n        return edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, K={self.K})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/signed_conv.py",
    "content": "from typing import Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import Adj, PairTensor, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass SignedConv(MessagePassing):\n    r\"\"\"The signed graph convolutional operator from the `\"Signed Graph\n    Convolutional Network\" <https://arxiv.org/abs/1808.06354>`_ paper.\n\n    .. math::\n        \\mathbf{x}_v^{(\\textrm{pos})} &= \\mathbf{\\Theta}^{(\\textrm{pos})}\n        \\left[ \\frac{1}{|\\mathcal{N}^{+}(v)|} \\sum_{w \\in \\mathcal{N}^{+}(v)}\n        \\mathbf{x}_w , \\mathbf{x}_v \\right]\n\n        \\mathbf{x}_v^{(\\textrm{neg})} &= \\mathbf{\\Theta}^{(\\textrm{neg})}\n        \\left[ \\frac{1}{|\\mathcal{N}^{-}(v)|} \\sum_{w \\in \\mathcal{N}^{-}(v)}\n        \\mathbf{x}_w , \\mathbf{x}_v \\right]\n\n    if :obj:`first_aggr` is set to :obj:`True`, and\n\n    .. math::\n        \\mathbf{x}_v^{(\\textrm{pos})} &= \\mathbf{\\Theta}^{(\\textrm{pos})}\n        \\left[ \\frac{1}{|\\mathcal{N}^{+}(v)|} \\sum_{w \\in \\mathcal{N}^{+}(v)}\n        \\mathbf{x}_w^{(\\textrm{pos})}, \\frac{1}{|\\mathcal{N}^{-}(v)|}\n        \\sum_{w \\in \\mathcal{N}^{-}(v)} \\mathbf{x}_w^{(\\textrm{neg})},\n        \\mathbf{x}_v^{(\\textrm{pos})} \\right]\n\n        \\mathbf{x}_v^{(\\textrm{neg})} &= \\mathbf{\\Theta}^{(\\textrm{pos})}\n        \\left[ \\frac{1}{|\\mathcal{N}^{+}(v)|} \\sum_{w \\in \\mathcal{N}^{+}(v)}\n        \\mathbf{x}_w^{(\\textrm{neg})}, \\frac{1}{|\\mathcal{N}^{-}(v)|}\n        \\sum_{w \\in \\mathcal{N}^{-}(v)} \\mathbf{x}_w^{(\\textrm{pos})},\n        \\mathbf{x}_v^{(\\textrm{neg})} \\right]\n\n    otherwise.\n    In case :obj:`first_aggr` is :obj:`False`, the layer expects :obj:`x` to be\n    a tensor where :obj:`x[:, :in_channels]` denotes the positive node features\n    :math:`\\mathbf{X}^{(\\textrm{pos})}` and :obj:`x[:, in_channels:]` denotes\n    the negative node features :math:`\\mathbf{X}^{(\\textrm{neg})}`.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        first_aggr (bool): Denotes which aggregation formula to use.\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})` or\n          :math:`((|\\mathcal{V_s}|, F_{in}), (|\\mathcal{V_t}|, F_{in}))`\n          if bipartite,\n          positive edge indices :math:`(2, |\\mathcal{E}^{(+)}|)`,\n          negative edge indices :math:`(2, |\\mathcal{E}^{(-)}|)`\n        - **outputs:** node features :math:`(|\\mathcal{V}|, F_{out})` or\n          :math:`(|\\mathcal{V_t}|, F_{out})` if bipartite\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int, first_aggr: bool,\n                 bias: bool = True, **kwargs):\n\n        kwargs.setdefault('aggr', 'mean')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.first_aggr = first_aggr\n\n        if first_aggr:\n            self.lin_pos_l = Linear(in_channels, out_channels, False)\n            self.lin_pos_r = Linear(in_channels, out_channels, bias)\n            self.lin_neg_l = Linear(in_channels, out_channels, False)\n            self.lin_neg_r = Linear(in_channels, out_channels, bias)\n        else:\n            self.lin_pos_l = Linear(2 * in_channels, out_channels, False)\n            self.lin_pos_r = Linear(in_channels, out_channels, bias)\n            self.lin_neg_l = Linear(2 * in_channels, out_channels, False)\n            self.lin_neg_r = Linear(in_channels, out_channels, bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin_pos_l.reset_parameters()\n        self.lin_pos_r.reset_parameters()\n        self.lin_neg_l.reset_parameters()\n        self.lin_neg_r.reset_parameters()\n\n    def forward(\n        self,\n        x: Union[Tensor, PairTensor],\n        pos_edge_index: Adj,\n        neg_edge_index: Adj,\n    ):\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: PairTensor)\n        if self.first_aggr:\n\n            out_pos = self.propagate(pos_edge_index, x=x)\n            out_pos = self.lin_pos_l(out_pos)\n            out_pos = out_pos + self.lin_pos_r(x[1])\n\n            out_neg = self.propagate(neg_edge_index, x=x)\n            out_neg = self.lin_neg_l(out_neg)\n            out_neg = out_neg + self.lin_neg_r(x[1])\n\n            return torch.cat([out_pos, out_neg], dim=-1)\n\n        else:\n            F_in = self.in_channels\n\n            out_pos1 = self.propagate(pos_edge_index,\n                                      x=(x[0][..., :F_in], x[1][..., :F_in]))\n            out_pos2 = self.propagate(neg_edge_index,\n                                      x=(x[0][..., F_in:], x[1][..., F_in:]))\n            out_pos = torch.cat([out_pos1, out_pos2], dim=-1)\n            out_pos = self.lin_pos_l(out_pos)\n            out_pos = out_pos + self.lin_pos_r(x[1][..., :F_in])\n\n            out_neg1 = self.propagate(pos_edge_index,\n                                      x=(x[0][..., F_in:], x[1][..., F_in:]))\n            out_neg2 = self.propagate(neg_edge_index,\n                                      x=(x[0][..., :F_in], x[1][..., :F_in]))\n            out_neg = torch.cat([out_neg1, out_neg2], dim=-1)\n            out_neg = self.lin_neg_l(out_neg)\n            out_neg = out_neg + self.lin_neg_r(x[1][..., F_in:])\n\n            return torch.cat([out_pos, out_neg], dim=-1)\n\n    def message(self, x_j: Tensor) -> Tensor:\n        return x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: PairTensor) -> Tensor:\n        if isinstance(adj_t, SparseTensor):\n            adj_t = adj_t.set_value(None, layout=None)\n        return spmm(adj_t, x[0], reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, first_aggr={self.first_aggr})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/simple_conv.py",
    "content": "from typing import List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.aggr import Aggregation\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.typing import (\n    Adj,\n    OptPairTensor,\n    OptTensor,\n    Size,\n    SparseTensor,\n    torch_sparse,\n)\nfrom torch_geometric.utils import add_self_loops, spmm\n\n\nclass SimpleConv(MessagePassing):\n    r\"\"\"A simple message passing operator that performs (non-trainable)\n    propagation.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\bigoplus_{j \\in \\mathcal{N(i)}} e_{ji} \\cdot\n        \\mathbf{x}_j\n\n    where :math:`\\bigoplus` defines a custom aggregation scheme.\n\n    Args:\n        aggr (str or [str] or Aggregation, optional): The aggregation scheme\n            to use, *e.g.*, :obj:`\"add\"`, :obj:`\"sum\"` :obj:`\"mean\"`,\n            :obj:`\"min\"`, :obj:`\"max\"` or :obj:`\"mul\"`.\n            In addition, can be any\n            :class:`~torch_geometric.nn.aggr.Aggregation` module (or any string\n            that automatically resolves to it). (default: :obj:`\"sum\"`)\n        combine_root (str, optional): Specifies whether or how to combine the\n            central node representation (one of :obj:`\"sum\"`, :obj:`\"cat\"`,\n            :obj:`\"self_loop\"`, :obj:`None`). (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **inputs:**\n          node features :math:`(|\\mathcal{V}|, F)` or\n          :math:`((|\\mathcal{V_s}|, F), (|\\mathcal{V_t}|, *))`\n          if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **outputs:** node features :math:`(|\\mathcal{V}|, F)` or\n          :math:`(|\\mathcal{V_t}|, F)` if bipartite\n    \"\"\"\n    def __init__(\n        self,\n        aggr: Optional[Union[str, List[str], Aggregation]] = \"sum\",\n        combine_root: Optional[str] = None,\n        **kwargs,\n    ):\n        if combine_root not in ['sum', 'cat', 'self_loop', None]:\n            raise ValueError(f\"Received invalid value for 'combine_root' \"\n                             f\"(got '{combine_root}')\")\n\n        super().__init__(aggr, **kwargs)\n        self.combine_root = combine_root\n\n    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,\n                edge_weight: OptTensor = None, size: Size = None) -> Tensor:\n\n        if self.combine_root is not None:\n            if self.combine_root == 'self_loop':\n                if not isinstance(x, Tensor) or (size is not None\n                                                 and size[0] != size[1]):\n                    raise ValueError(\"Cannot use `combine_root='self_loop'` \"\n                                     \"for bipartite message passing\")\n                if isinstance(edge_index, Tensor):\n                    edge_index, edge_weight = add_self_loops(\n                        edge_index, edge_weight, num_nodes=x.size(0))\n                elif isinstance(edge_index, SparseTensor):\n                    edge_index = torch_sparse.set_diag(edge_index)\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,\n                             size=size)\n\n        x_dst = x[1]\n        if x_dst is not None and self.combine_root is not None:\n            if self.combine_root == 'sum':\n                out = out + x_dst\n            elif self.combine_root == 'cat':\n                out = torch.cat([x_dst, out], dim=-1)\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:\n        assert isinstance(self.aggr, str)\n        return spmm(adj_t, x[0], reduce=self.aggr)\n"
  },
  {
    "path": "torch_geometric/nn/conv/spline_conv.py",
    "content": "import warnings\nfrom typing import List, Tuple, Union\n\nimport torch\nfrom torch import Tensor, nn\nfrom torch.nn import Parameter\n\nimport torch_geometric.typing\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import uniform, zeros\nfrom torch_geometric.typing import Adj, OptPairTensor, OptTensor, Size\nfrom torch_geometric.utils.repeat import repeat\n\nif torch_geometric.typing.WITH_SPLINE:\n    from pyg_lib.ops import spline_basis, spline_weighting\nelse:\n    spline_basis = spline_weighting = None\n\n\nclass SplineConv(MessagePassing):\n    r\"\"\"The spline-based convolutional operator from the `\"SplineCNN: Fast\n    Geometric Deep Learning with Continuous B-Spline Kernels\"\n    <https://arxiv.org/abs/1711.08920>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\frac{1}{|\\mathcal{N}(i)|} \\sum_{j \\in\n        \\mathcal{N}(i)} \\mathbf{x}_j \\cdot\n        h_{\\mathbf{\\Theta}}(\\mathbf{e}_{i,j}),\n\n    where :math:`h_{\\mathbf{\\Theta}}` denotes a kernel function defined\n    over the weighted B-Spline tensor product basis.\n\n    .. note::\n\n        Pseudo-coordinates must lay in the fixed interval :math:`[0, 1]` for\n        this method to work as intended.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        dim (int): Pseudo-coordinate dimensionality.\n        kernel_size (int or [int]): Size of the convolving kernel.\n        is_open_spline (bool or [bool], optional): If set to :obj:`False`, the\n            operator will use a closed B-spline basis in this dimension.\n            (default :obj:`True`)\n        degree (int, optional): B-spline basis degrees. (default: :obj:`1`)\n        aggr (str, optional): The aggregation scheme to use\n            (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"max\"`).\n            (default: :obj:`\"mean\"`)\n        root_weight (bool, optional): If set to :obj:`False`, the layer will\n            not add transformed root node features to the output.\n            (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: int,\n        dim: int,\n        kernel_size: Union[int, List[int]],\n        is_open_spline: bool = True,\n        degree: int = 1,\n        aggr: str = 'mean',\n        root_weight: bool = True,\n        bias: bool = True,\n        **kwargs,\n    ):\n        super().__init__(aggr=aggr, **kwargs)\n\n        if spline_basis is None:\n            raise ImportError(\"'SplineConv' requires 'pyg-lib>=0.6.0'\")\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.dim = dim\n        self.degree = degree\n        self.root_weight = root_weight\n\n        kernel_size = torch.tensor(repeat(kernel_size, dim), dtype=torch.long)\n        self.register_buffer('kernel_size', kernel_size)\n\n        is_open_spline = repeat(is_open_spline, dim)\n        is_open_spline = torch.tensor(is_open_spline, dtype=torch.uint8)\n        self.register_buffer('is_open_spline', is_open_spline)\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        self.K = kernel_size.prod().item()\n\n        if in_channels[0] > 0:\n            self.weight = Parameter(\n                torch.empty(self.K, in_channels[0], out_channels))\n        else:\n            self.weight = torch.nn.parameter.UninitializedParameter()\n            self._hook = self.register_forward_pre_hook(\n                self.initialize_parameters)\n\n        if root_weight:\n            self.lin = Linear(in_channels[1], out_channels, bias=False,\n                              weight_initializer='uniform')\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        if not isinstance(self.weight, nn.UninitializedParameter):\n            size = self.weight.size(0) * self.weight.size(1)\n            uniform(size, self.weight)\n        if self.root_weight:\n            self.lin.reset_parameters()\n        zeros(self.bias)\n\n    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,\n                edge_attr: OptTensor = None, size: Size = None) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        if not x[0].is_cuda:\n            warnings.warn(\n                'We do not recommend using the non-optimized CPU version of '\n                '`SplineConv`. If possible, please move your data to GPU.',\n                stacklevel=2)\n\n        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)\n\n        x_r = x[1]\n        if x_r is not None and self.root_weight:\n            out = out + self.lin(x_r)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:\n        data = spline_basis(edge_attr, self.kernel_size, self.is_open_spline,\n                            self.degree)\n        return spline_weighting(x_j, self.weight, *data)\n\n    @torch.no_grad()\n    def initialize_parameters(self, module, input):\n        if isinstance(self.weight, torch.nn.parameter.UninitializedParameter):\n            x = input[0][0] if isinstance(input, tuple) else input[0]\n            in_channels = x.size(-1)\n            self.weight.materialize((self.K, in_channels, self.out_channels))\n            size = self.weight.size(0) * self.weight.size(1)\n            uniform(size, self.weight)\n        module._hook.remove()\n        delattr(module, '_hook')\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, dim={self.dim})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/ssg_conv.py",
    "content": "from typing import Optional\n\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass SSGConv(MessagePassing):\n    r\"\"\"The simple spectral graph convolutional operator from the\n    `\"Simple Spectral Graph Convolution\"\n    <https://openreview.net/forum?id=CYO5T-YjWZV>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = \\frac{1}{K} \\sum_{k=1}^K\\left((1-\\alpha)\n        {\\left(\\mathbf{\\hat{D}}^{-1/2} \\mathbf{\\hat{A}}\n        \\mathbf{\\hat{D}}^{-1/2} \\right)}^k\n        \\mathbf{X}+\\alpha \\mathbf{X}\\right) \\mathbf{\\Theta},\n\n    where :math:`\\mathbf{\\hat{A}} = \\mathbf{A} + \\mathbf{I}` denotes the\n    adjacency matrix with inserted self-loops and\n    :math:`\\hat{D}_{ii} = \\sum_{j=0} \\hat{A}_{ij}` its diagonal degree matrix.\n    The adjacency matrix can include other values than :obj:`1` representing\n    edge weights via the optional :obj:`edge_weight` tensor.\n    :class:`~torch_geometric.nn.conv.SSGConv` is an improved operator of\n    :class:`~torch_geometric.nn.conv.SGConv` by introducing the :obj:`alpha`\n    parameter to address the oversmoothing issue.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        alpha (float): Teleport probability :math:`\\alpha \\in [0, 1]`.\n        K (int, optional): Number of hops :math:`K`. (default: :obj:`1`)\n        cached (bool, optional): If set to :obj:`True`, the layer will cache\n            the computation of :math:`\\frac{1}{K} \\sum_{k=1}^K\\left((1-\\alpha)\n            {\\left(\\mathbf{\\hat{D}}^{-1/2} \\mathbf{\\hat{A}}\n            \\mathbf{\\hat{D}}^{-1/2} \\right)}^k \\mathbf{X}+\n            \\alpha \\mathbf{X}\\right)` on first execution, and will use the\n            cached version for further executions.\n            This parameter should only be set to :obj:`True` in transductive\n            learning scenarios. (default: :obj:`False`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:**\n          node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n\n    _cached_h: Optional[Tensor]\n\n    def __init__(self, in_channels: int, out_channels: int, alpha: float,\n                 K: int = 1, cached: bool = False, add_self_loops: bool = True,\n                 bias: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.alpha = alpha\n        self.K = K\n        self.cached = cached\n        self.add_self_loops = add_self_loops\n\n        self._cached_h = None\n\n        self.lin = Linear(in_channels, out_channels, bias=bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin.reset_parameters()\n        self._cached_h = None\n\n    def forward(self, x: Tensor, edge_index: Adj,\n                edge_weight: OptTensor = None) -> Tensor:\n\n        cache = self._cached_h\n        if cache is None:\n            if isinstance(edge_index, Tensor):\n                edge_index, edge_weight = gcn_norm(  # yapf: disable\n                    edge_index, edge_weight, x.size(self.node_dim), False,\n                    self.add_self_loops, self.flow, dtype=x.dtype)\n            elif isinstance(edge_index, SparseTensor):\n                edge_index = gcn_norm(  # yapf: disable\n                    edge_index, edge_weight, x.size(self.node_dim), False,\n                    self.add_self_loops, self.flow, dtype=x.dtype)\n\n            h = x * self.alpha\n            for _ in range(self.K):\n                # propagate_type: (x: Tensor, edge_weight: OptTensor)\n                x = self.propagate(edge_index, x=x, edge_weight=edge_weight)\n                h = h + (1 - self.alpha) / self.K * x\n            if self.cached:\n                self._cached_h = h\n        else:\n            h = cache.detach()\n\n        return self.lin(h)\n\n    def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:\n        return edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, K={self.K}, alpha={self.alpha})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/supergat_conv.py",
    "content": "import math\nfrom typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import glorot, zeros\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor, torch_sparse\nfrom torch_geometric.utils import (\n    add_self_loops,\n    batched_negative_sampling,\n    dropout_edge,\n    is_undirected,\n    negative_sampling,\n    remove_self_loops,\n    softmax,\n    to_undirected,\n)\n\n\nclass SuperGATConv(MessagePassing):\n    r\"\"\"The self-supervised graph attentional operator from the `\"How to Find\n    Your Friendly Neighborhood: Graph Attention Design with Self-Supervision\"\n    <https://openreview.net/forum?id=Wi5KUNlqWty>`_ paper.\n\n    .. math::\n\n        \\mathbf{x}^{\\prime}_i = \\alpha_{i,i}\\mathbf{\\Theta}\\mathbf{x}_{i} +\n        \\sum_{j \\in \\mathcal{N}(i)} \\alpha_{i,j}\\mathbf{\\Theta}\\mathbf{x}_{j},\n\n    where the two types of attention :math:`\\alpha_{i,j}^{\\mathrm{MX\\ or\\ SD}}`\n    are computed as:\n\n    .. math::\n\n        \\alpha_{i,j}^{\\mathrm{MX\\ or\\ SD}} &=\n        \\frac{\n        \\exp\\left(\\mathrm{LeakyReLU}\\left(\n            e_{i,j}^{\\mathrm{MX\\ or\\ SD}}\n        \\right)\\right)}\n        {\\sum_{k \\in \\mathcal{N}(i) \\cup \\{ i \\}}\n        \\exp\\left(\\mathrm{LeakyReLU}\\left(\n            e_{i,k}^{\\mathrm{MX\\ or\\ SD}}\n        \\right)\\right)}\n\n        e_{i,j}^{\\mathrm{MX}} &= \\mathbf{a}^{\\top}\n            [\\mathbf{\\Theta}\\mathbf{x}_i \\, \\Vert \\,\n             \\mathbf{\\Theta}\\mathbf{x}_j]\n            \\cdot \\sigma \\left(\n                \\left( \\mathbf{\\Theta}\\mathbf{x}_i \\right)^{\\top}\n                \\mathbf{\\Theta}\\mathbf{x}_j\n            \\right)\n\n        e_{i,j}^{\\mathrm{SD}} &= \\frac{\n            \\left( \\mathbf{\\Theta}\\mathbf{x}_i \\right)^{\\top}\n            \\mathbf{\\Theta}\\mathbf{x}_j\n        }{ \\sqrt{d} }\n\n    The self-supervised task is a link prediction using the attention values\n    as input to predict the likelihood :math:`\\phi_{i,j}^{\\mathrm{MX\\ or\\ SD}}`\n    that an edge exists between nodes:\n\n    .. math::\n\n        \\phi_{i,j}^{\\mathrm{MX}} &= \\sigma \\left(\n            \\left( \\mathbf{\\Theta}\\mathbf{x}_i \\right)^{\\top}\n            \\mathbf{\\Theta}\\mathbf{x}_j\n        \\right)\n\n        \\phi_{i,j}^{\\mathrm{SD}} &= \\sigma \\left(\n            \\frac{\n                \\left( \\mathbf{\\Theta}\\mathbf{x}_i \\right)^{\\top}\n                \\mathbf{\\Theta}\\mathbf{x}_j\n            }{ \\sqrt{d} }\n        \\right)\n\n    .. note::\n\n        For an example of using SuperGAT, see `examples/super_gat.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        super_gat.py>`_.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        concat (bool, optional): If set to :obj:`False`, the multi-head\n            attentions are averaged instead of concatenated.\n            (default: :obj:`True`)\n        negative_slope (float, optional): LeakyReLU angle of the negative\n            slope. (default: :obj:`0.2`)\n        dropout (float, optional): Dropout probability of the normalized\n            attention coefficients which exposes each node to a stochastically\n            sampled neighborhood during training. (default: :obj:`0`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        attention_type (str, optional): Type of attention to use\n            (:obj:`'MX'`, :obj:`'SD'`). (default: :obj:`'MX'`)\n        neg_sample_ratio (float, optional): The ratio of the number of sampled\n            negative edges to the number of positive edges.\n            (default: :obj:`0.5`)\n        edge_sample_ratio (float, optional): The ratio of samples to use for\n            training among the number of training edges. (default: :obj:`1.0`)\n        is_undirected (bool, optional): Whether the input graph is undirected.\n            If not given, will be automatically computed with the input graph\n            when negative sampling is performed. (default: :obj:`False`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          negative edge indices :math:`(2, |\\mathcal{E}^{(-)}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, H * F_{out})`\n    \"\"\"\n    att_x: OptTensor\n    att_y: OptTensor\n\n    def __init__(self, in_channels: int, out_channels: int, heads: int = 1,\n                 concat: bool = True, negative_slope: float = 0.2,\n                 dropout: float = 0.0, add_self_loops: bool = True,\n                 bias: bool = True, attention_type: str = 'MX',\n                 neg_sample_ratio: float = 0.5, edge_sample_ratio: float = 1.0,\n                 is_undirected: bool = False, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(node_dim=0, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.concat = concat\n        self.negative_slope = negative_slope\n        self.dropout = dropout\n        self.add_self_loops = add_self_loops\n        self.attention_type = attention_type\n        self.neg_sample_ratio = neg_sample_ratio\n        self.edge_sample_ratio = edge_sample_ratio\n        self.is_undirected = is_undirected\n\n        assert attention_type in ['MX', 'SD']\n        assert 0.0 < neg_sample_ratio and 0.0 < edge_sample_ratio <= 1.0\n\n        self.lin = Linear(in_channels, heads * out_channels, bias=False,\n                          weight_initializer='glorot')\n\n        if self.attention_type == 'MX':\n            self.att_l = Parameter(torch.empty(1, heads, out_channels))\n            self.att_r = Parameter(torch.empty(1, heads, out_channels))\n        else:  # self.attention_type == 'SD'\n            self.register_parameter('att_l', None)\n            self.register_parameter('att_r', None)\n\n        self.att_x = self.att_y = None  # x/y for self-supervision\n\n        if bias and concat:\n            self.bias = Parameter(torch.empty(heads * out_channels))\n        elif bias and not concat:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin.reset_parameters()\n        glorot(self.att_l)\n        glorot(self.att_r)\n        zeros(self.bias)\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Adj,\n        neg_edge_index: OptTensor = None,\n        batch: OptTensor = None,\n    ) -> Tensor:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor): The input node features.\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            neg_edge_index (torch.Tensor, optional): The negative edges to\n                train against. If not given, uses negative sampling to\n                calculate negative edges. (default: :obj:`None`)\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example.\n                Used when sampling negatives on-the-fly in mini-batch\n                scenarios. (default: :obj:`None`)\n        \"\"\"\n        N, H, C = x.size(0), self.heads, self.out_channels\n\n        if self.add_self_loops:\n            if isinstance(edge_index, SparseTensor):\n                edge_index = torch_sparse.fill_diag(edge_index, 1.)\n            else:\n                edge_index, _ = remove_self_loops(edge_index)\n                edge_index, _ = add_self_loops(edge_index, num_nodes=N)\n\n        x = self.lin(x).view(-1, H, C)\n\n        # propagate_type: (x: Tensor)\n        out = self.propagate(edge_index, x=x)\n\n        if self.training:\n            if isinstance(edge_index, SparseTensor):\n                col, row, _ = edge_index.coo()\n                edge_index = torch.stack([row, col], dim=0)\n            pos_edge_index = self.positive_sampling(edge_index)\n\n            pos_att = self.get_attention(\n                edge_index_i=pos_edge_index[1],\n                x_i=x[pos_edge_index[1]],\n                x_j=x[pos_edge_index[0]],\n                num_nodes=x.size(0),\n                return_logits=True,\n            )\n\n            if neg_edge_index is None:\n                neg_edge_index = self.negative_sampling(edge_index, N, batch)\n\n            neg_att = self.get_attention(\n                edge_index_i=neg_edge_index[1],\n                x_i=x[neg_edge_index[1]],\n                x_j=x[neg_edge_index[0]],\n                num_nodes=x.size(0),\n                return_logits=True,\n            )\n\n            self.att_x = torch.cat([pos_att, neg_att], dim=0)\n            self.att_y = self.att_x.new_zeros(self.att_x.size(0))\n            self.att_y[:pos_edge_index.size(1)] = 1.\n\n        if self.concat is True:\n            out = out.view(-1, self.heads * self.out_channels)\n        else:\n            out = out.mean(dim=1)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, edge_index_i: Tensor, x_i: Tensor, x_j: Tensor,\n                size_i: Optional[int]) -> Tensor:\n        alpha = self.get_attention(edge_index_i, x_i, x_j, num_nodes=size_i)\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n        return x_j * alpha.view(-1, self.heads, 1)\n\n    def negative_sampling(self, edge_index: Tensor, num_nodes: int,\n                          batch: OptTensor = None) -> Tensor:\n\n        num_neg_samples = int(self.neg_sample_ratio * self.edge_sample_ratio *\n                              edge_index.size(1))\n\n        if not self.is_undirected and not is_undirected(\n                edge_index, num_nodes=num_nodes):\n            edge_index = to_undirected(edge_index, num_nodes=num_nodes)\n\n        if batch is None:\n            neg_edge_index = negative_sampling(edge_index, num_nodes,\n                                               num_neg_samples=num_neg_samples)\n        else:\n            neg_edge_index = batched_negative_sampling(\n                edge_index, batch, num_neg_samples=num_neg_samples)\n\n        return neg_edge_index\n\n    def positive_sampling(self, edge_index: Tensor) -> Tensor:\n        pos_edge_index, _ = dropout_edge(edge_index,\n                                         p=1. - self.edge_sample_ratio,\n                                         training=self.training)\n        return pos_edge_index\n\n    def get_attention(self, edge_index_i: Tensor, x_i: Tensor, x_j: Tensor,\n                      num_nodes: Optional[int],\n                      return_logits: bool = False) -> Tensor:\n\n        if self.attention_type == 'MX':\n            logits = (x_i * x_j).sum(dim=-1)\n            if return_logits:\n                return logits\n\n            alpha = (x_j * self.att_l).sum(-1) + (x_i * self.att_r).sum(-1)\n            alpha = alpha * logits.sigmoid()\n\n        else:  # self.attention_type == 'SD'\n            alpha = (x_i * x_j).sum(dim=-1) / math.sqrt(self.out_channels)\n            if return_logits:\n                return alpha\n\n        alpha = F.leaky_relu(alpha, self.negative_slope)\n        alpha = softmax(alpha, edge_index_i, num_nodes=num_nodes)\n        return alpha\n\n    def get_attention_loss(self) -> Tensor:\n        r\"\"\"Computes the self-supervised graph attention loss.\"\"\"\n        if not self.training:\n            return torch.tensor([0], device=self.lin.weight.device)\n\n        return F.binary_cross_entropy_with_logits(\n            self.att_x.mean(dim=-1),\n            self.att_y,\n        )\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads}, '\n                f'type={self.attention_type})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/tag_conv.py",
    "content": "import torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor\nfrom torch_geometric.utils import spmm\n\n\nclass TAGConv(MessagePassing):\n    r\"\"\"The topology adaptive graph convolutional networks operator from the\n    `\"Topology Adaptive Graph Convolutional Networks\"\n    <https://arxiv.org/abs/1710.10370>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} = \\sum_{k=0}^K \\left( \\mathbf{D}^{-1/2} \\mathbf{A}\n        \\mathbf{D}^{-1/2} \\right)^k \\mathbf{X} \\mathbf{W}_{k},\n\n    where :math:`\\mathbf{A}` denotes the adjacency matrix and\n    :math:`D_{ii} = \\sum_{j=0} A_{ij}` its diagonal degree matrix.\n    The adjacency matrix can include other values than :obj:`1` representing\n    edge weights via the optional :obj:`edge_weight` tensor.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        out_channels (int): Size of each output sample.\n        K (int, optional): Number of hops :math:`K`. (default: :obj:`3`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        normalize (bool, optional): Whether to apply symmetric normalization.\n            (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node_features :math:`(|\\mathcal{V}|, F_{in})`,\n          edge_index :math:`(2, |\\mathcal{E}|)`,\n          edge_weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int, K: int = 3,\n                 bias: bool = True, normalize: bool = True, **kwargs):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(**kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.K = K\n        self.normalize = normalize\n\n        self.lins = torch.nn.ModuleList([\n            Linear(in_channels, out_channels, bias=False) for _ in range(K + 1)\n        ])\n\n        if bias:\n            self.bias = torch.nn.Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        for lin in self.lins:\n            lin.reset_parameters()\n        zeros(self.bias)\n\n    def forward(self, x: Tensor, edge_index: Adj,\n                edge_weight: OptTensor = None) -> Tensor:\n\n        if self.normalize:\n            if isinstance(edge_index, Tensor):\n                edge_index, edge_weight = gcn_norm(  # yapf: disable\n                    edge_index, edge_weight, x.size(self.node_dim),\n                    improved=False, add_self_loops=False, flow=self.flow,\n                    dtype=x.dtype)\n\n            elif isinstance(edge_index, SparseTensor):\n                edge_index = gcn_norm(  # yapf: disable\n                    edge_index, edge_weight, x.size(self.node_dim),\n                    add_self_loops=False, flow=self.flow, dtype=x.dtype)\n\n        out = self.lins[0](x)\n        for lin in self.lins[1:]:\n            # propagate_type: (x: Tensor, edge_weight: OptTensor)\n            x = self.propagate(edge_index, x=x, edge_weight=edge_weight)\n            out = out + lin.forward(x)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, K={self.K})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/transformer_conv.py",
    "content": "import math\nimport typing\nfrom typing import Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.typing import (\n    Adj,\n    NoneType,\n    OptTensor,\n    PairTensor,\n    SparseTensor,\n)\nfrom torch_geometric.utils import softmax\n\nif typing.TYPE_CHECKING:\n    from typing import overload\nelse:\n    from torch.jit import _overload_method as overload\n\n\nclass TransformerConv(MessagePassing):\n    r\"\"\"The graph transformer operator from the `\"Masked Label Prediction:\n    Unified Message Passing Model for Semi-Supervised Classification\"\n    <https://arxiv.org/abs/2009.03509>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathbf{W}_1 \\mathbf{x}_i +\n        \\sum_{j \\in \\mathcal{N}(i)} \\alpha_{i,j} \\mathbf{W}_2 \\mathbf{x}_{j},\n\n    where the attention coefficients :math:`\\alpha_{i,j}` are computed via\n    multi-head dot product attention:\n\n    .. math::\n        \\alpha_{i,j} = \\textrm{softmax} \\left(\n        \\frac{(\\mathbf{W}_3\\mathbf{x}_i)^{\\top} (\\mathbf{W}_4\\mathbf{x}_j)}\n        {\\sqrt{d}} \\right)\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        out_channels (int): Size of each output sample.\n        heads (int, optional): Number of multi-head-attentions.\n            (default: :obj:`1`)\n        concat (bool, optional): If set to :obj:`False`, the multi-head\n            attentions are averaged instead of concatenated.\n            (default: :obj:`True`)\n        beta (bool, optional): If set, will combine aggregation and\n            skip information via\n\n            .. math::\n                \\mathbf{x}^{\\prime}_i = \\beta_i \\mathbf{W}_1 \\mathbf{x}_i +\n                (1 - \\beta_i) \\underbrace{\\left(\\sum_{j \\in \\mathcal{N}(i)}\n                \\alpha_{i,j} \\mathbf{W}_2 \\vec{x}_j \\right)}_{=\\mathbf{m}_i}\n\n            with :math:`\\beta_i = \\textrm{sigmoid}(\\mathbf{w}_5^{\\top}\n            [ \\mathbf{W}_1 \\mathbf{x}_i, \\mathbf{m}_i, \\mathbf{W}_1\n            \\mathbf{x}_i - \\mathbf{m}_i ])` (default: :obj:`False`)\n        dropout (float, optional): Dropout probability of the normalized\n            attention coefficients which exposes each node to a stochastically\n            sampled neighborhood during training. (default: :obj:`0`)\n        edge_dim (int, optional): Edge feature dimensionality (in case\n            there are any). Edge features are added to the keys after\n            linear transformation, that is, prior to computing the\n            attention dot product. They are also added to final values\n            after the same linear transformation. The model is:\n\n            .. math::\n                \\mathbf{x}^{\\prime}_i = \\mathbf{W}_1 \\mathbf{x}_i +\n                \\sum_{j \\in \\mathcal{N}(i)} \\alpha_{i,j} \\left(\n                \\mathbf{W}_2 \\mathbf{x}_{j} + \\mathbf{W}_6 \\mathbf{e}_{ij}\n                \\right),\n\n            where the attention coefficients :math:`\\alpha_{i,j}` are now\n            computed via:\n\n            .. math::\n                \\alpha_{i,j} = \\textrm{softmax} \\left(\n                \\frac{(\\mathbf{W}_3\\mathbf{x}_i)^{\\top}\n                (\\mathbf{W}_4\\mathbf{x}_j + \\mathbf{W}_6 \\mathbf{e}_{ij})}\n                {\\sqrt{d}} \\right)\n\n            (default :obj:`None`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        root_weight (bool, optional): If set to :obj:`False`, the layer will\n            not add the transformed root node features to the output and the\n            option  :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n    \"\"\"\n    _alpha: OptTensor\n\n    def __init__(\n        self,\n        in_channels: Union[int, Tuple[int, int]],\n        out_channels: int,\n        heads: int = 1,\n        concat: bool = True,\n        beta: bool = False,\n        dropout: float = 0.,\n        edge_dim: Optional[int] = None,\n        bias: bool = True,\n        root_weight: bool = True,\n        **kwargs,\n    ):\n        kwargs.setdefault('aggr', 'add')\n        super().__init__(node_dim=0, **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.beta = beta and root_weight\n        self.root_weight = root_weight\n        self.concat = concat\n        self.dropout = dropout\n        self.edge_dim = edge_dim\n        self._alpha = None\n\n        if isinstance(in_channels, int):\n            in_channels = (in_channels, in_channels)\n\n        self.lin_key = Linear(in_channels[0], heads * out_channels, bias=bias)\n        self.lin_query = Linear(in_channels[1], heads * out_channels,\n                                bias=bias)\n        self.lin_value = Linear(in_channels[0], heads * out_channels,\n                                bias=bias)\n        if edge_dim is not None:\n            self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False)\n        else:\n            self.lin_edge = self.register_parameter('lin_edge', None)\n\n        if concat:\n            self.lin_skip = Linear(in_channels[1], heads * out_channels,\n                                   bias=bias)\n            if self.beta:\n                self.lin_beta = Linear(3 * heads * out_channels, 1, bias=False)\n            else:\n                self.lin_beta = self.register_parameter('lin_beta', None)\n        else:\n            self.lin_skip = Linear(in_channels[1], out_channels, bias=bias)\n            if self.beta:\n                self.lin_beta = Linear(3 * out_channels, 1, bias=False)\n            else:\n                self.lin_beta = self.register_parameter('lin_beta', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        self.lin_key.reset_parameters()\n        self.lin_query.reset_parameters()\n        self.lin_value.reset_parameters()\n        if self.edge_dim:\n            self.lin_edge.reset_parameters()\n        self.lin_skip.reset_parameters()\n        if self.beta:\n            self.lin_beta.reset_parameters()\n\n    @overload\n    def forward(\n        self,\n        x: Union[Tensor, PairTensor],\n        edge_index: Adj,\n        edge_attr: OptTensor = None,\n        return_attention_weights: NoneType = None,\n    ) -> Tensor:\n        pass\n\n    @overload\n    def forward(  # noqa: F811\n        self,\n        x: Union[Tensor, PairTensor],\n        edge_index: Tensor,\n        edge_attr: OptTensor = None,\n        return_attention_weights: bool = None,\n    ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:\n        pass\n\n    @overload\n    def forward(  # noqa: F811\n        self,\n        x: Union[Tensor, PairTensor],\n        edge_index: SparseTensor,\n        edge_attr: OptTensor = None,\n        return_attention_weights: bool = None,\n    ) -> Tuple[Tensor, SparseTensor]:\n        pass\n\n    def forward(  # noqa: F811\n        self,\n        x: Union[Tensor, PairTensor],\n        edge_index: Adj,\n        edge_attr: OptTensor = None,\n        return_attention_weights: Optional[bool] = None,\n    ) -> Union[\n            Tensor,\n            Tuple[Tensor, Tuple[Tensor, Tensor]],\n            Tuple[Tensor, SparseTensor],\n    ]:\n        r\"\"\"Runs the forward pass of the module.\n\n        Args:\n            x (torch.Tensor or (torch.Tensor, torch.Tensor)): The input node\n                features.\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            edge_attr (torch.Tensor, optional): The edge features.\n                (default: :obj:`None`)\n            return_attention_weights (bool, optional):\n                Will additionally return the tuple\n                :obj:`(edge_index, attention_weights)` whenever it is set to\n                a value, regardless of its actual value\n                (might be `True` or `False`), holding the computed attention\n                weights for each edge.\n                (default: :obj:`None`)\n        \"\"\"\n        H, C = self.heads, self.out_channels\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        query = self.lin_query(x[1]).view(-1, H, C)\n        key = self.lin_key(x[0]).view(-1, H, C)\n        value = self.lin_value(x[0]).view(-1, H, C)\n\n        # propagate_type: (query: Tensor, key:Tensor, value: Tensor,\n        #                  edge_attr: OptTensor)\n        out = self.propagate(edge_index, query=query, key=key, value=value,\n                             edge_attr=edge_attr)\n\n        alpha = self._alpha\n        self._alpha = None\n\n        if self.concat:\n            out = out.view(-1, self.heads * self.out_channels)\n        else:\n            out = out.mean(dim=1)\n\n        if self.root_weight:\n            x_r = self.lin_skip(x[1])\n            if self.lin_beta is not None:\n                beta = self.lin_beta(torch.cat([out, x_r, out - x_r], dim=-1))\n                beta = beta.sigmoid()\n                out = beta * x_r + (1 - beta) * out\n            else:\n                out = out + x_r\n\n        if isinstance(return_attention_weights, bool):\n            assert alpha is not None\n            if isinstance(edge_index, Tensor):\n                return out, (edge_index, alpha)\n            elif isinstance(edge_index, SparseTensor):\n                return out, edge_index.set_value(alpha, layout='coo')\n        else:\n            return out\n\n    def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor,\n                edge_attr: OptTensor, index: Tensor, ptr: OptTensor,\n                size_i: Optional[int]) -> Tensor:\n\n        if self.lin_edge is not None:\n            assert edge_attr is not None\n            edge_attr = self.lin_edge(edge_attr).view(-1, self.heads,\n                                                      self.out_channels)\n            key_j = key_j + edge_attr\n\n        alpha = (query_i * key_j).sum(dim=-1) / math.sqrt(self.out_channels)\n        alpha = softmax(alpha, index, ptr, size_i)\n        self._alpha = alpha\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n\n        out = value_j\n        if edge_attr is not None:\n            out = out + edge_attr\n\n        out = out * alpha.view(-1, self.heads, 1)\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads})')\n"
  },
  {
    "path": "torch_geometric/nn/conv/utils/__init__.py",
    "content": "r\"\"\"GNN utility package.\"\"\"\n\nfrom .cheatsheet import paper_title, paper_link\nfrom .cheatsheet import supports_sparse_tensor\nfrom .cheatsheet import supports_edge_weights\nfrom .cheatsheet import supports_edge_features\nfrom .cheatsheet import supports_bipartite_graphs\nfrom .cheatsheet import supports_static_graphs\nfrom .cheatsheet import supports_lazy_initialization\nfrom .cheatsheet import processes_heterogeneous_graphs\nfrom .cheatsheet import processes_hypergraphs\nfrom .cheatsheet import processes_point_clouds\n\n__all__ = [\n    'paper_title',\n    'paper_link',\n    'supports_sparse_tensor',\n    'supports_edge_weights',\n    'supports_edge_features',\n    'supports_bipartite_graphs',\n    'supports_static_graphs',\n    'supports_lazy_initialization',\n    'processes_heterogeneous_graphs',\n    'processes_hypergraphs',\n    'processes_point_clouds',\n]\n"
  },
  {
    "path": "torch_geometric/nn/conv/utils/cheatsheet.py",
    "content": "import importlib\nimport inspect\nimport re\nfrom typing import Optional\n\n\ndef paper_title(cls: str) -> Optional[str]:\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    match = re.search('`\\\".+?\\\"', inspect.getdoc(cls), flags=re.DOTALL)\n    return None if match is None else match.group().replace('\\n', ' ')[2:-1]\n\n\ndef paper_link(cls: str) -> Optional[str]:\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    match = re.search('<.+?>', inspect.getdoc(cls), flags=re.DOTALL)\n    return None if match is None else match.group().replace('\\n', ' ')[1:-1]\n\n\ndef supports_sparse_tensor(cls: str) -> bool:\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    signature = inspect.signature(cls.forward)\n    return 'SparseTensor' in str(signature)\n\n\ndef supports_edge_weights(cls: str) -> bool:\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    signature = inspect.signature(cls.forward)\n    return 'edge_weight' in str(signature)\n\n\ndef supports_edge_features(cls: str) -> bool:\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    signature = inspect.signature(cls.forward)\n    return 'edge_attr' in str(signature)\n\n\ndef supports_bipartite_graphs(cls: str) -> bool:\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    signature = inspect.signature(cls.forward)\n    return 'Union[torch.Tensor, Tuple[torch.Tensor' in str(signature)\n\n\ndef supports_static_graphs(cls: str) -> bool:\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    return 'node_dim=' not in inspect.getsource(cls.__init__)\n\n\ndef supports_lazy_initialization(cls: str) -> bool:\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    doc = re.sub(' +', ' ', inspect.getdoc(cls).replace('\\n', ' '))\n    match = re.search('or :obj:`-1` to derive the size from the first', doc)\n    return match is not None\n\n\ndef processes_heterogeneous_graphs(cls: str) -> bool:\n    if 'hetero' in cls.lower():\n        return True\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    signature = inspect.signature(cls.forward)\n    return 'edge_index_dict' in str(signature) or 'edge_type' in str(signature)\n\n\ndef processes_hypergraphs(cls: str) -> bool:\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    signature = inspect.signature(cls.forward)\n    return 'hyperedge_index' in str(signature)\n\n\ndef processes_point_clouds(cls: str) -> bool:\n    cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]\n    signature = inspect.signature(cls.forward)\n    return (('edge_index' not in str(signature)\n             and 'csc' not in str(signature)) or 'pos' in str(signature))\n"
  },
  {
    "path": "torch_geometric/nn/conv/wl_conv.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import Adj\nfrom torch_geometric.utils import (\n    degree,\n    is_sparse,\n    scatter,\n    sort_edge_index,\n    to_edge_index,\n)\n\n\nclass WLConv(torch.nn.Module):\n    r\"\"\"The Weisfeiler Lehman (WL) operator from the `\"A Reduction of a Graph\n    to a Canonical Form and an Algebra Arising During this Reduction\"\n    <https://www.iti.zcu.cz/wl2018/pdf/wl_paper_translation.pdf>`_ paper.\n\n    :class:`WLConv` iteratively refines node colorings according to:\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\textrm{hash} \\left( \\mathbf{x}_i, \\{\n        \\mathbf{x}_j \\colon j \\in \\mathcal{N}(i) \\} \\right)\n\n    Shapes:\n        - **input:**\n          node coloring :math:`(|\\mathcal{V}|, F_{in})` *(one-hot encodings)*\n          or :math:`(|\\mathcal{V}|)` *(integer-based)*,\n          edge indices :math:`(2, |\\mathcal{E}|)`\n        - **output:** node coloring :math:`(|\\mathcal{V}|)` *(integer-based)*\n    \"\"\"\n    def __init__(self):\n        super().__init__()\n        self.hashmap = {}\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.hashmap = {}\n\n    @torch.no_grad()\n    def forward(self, x: Tensor, edge_index: Adj) -> Tensor:\n        r\"\"\"Runs the forward pass of the module.\"\"\"\n        if x.dim() > 1:\n            assert (x.sum(dim=-1) == 1).sum() == x.size(0)\n            x = x.argmax(dim=-1)  # one-hot -> integer.\n        assert x.dtype == torch.long\n\n        if is_sparse(edge_index):\n            col_and_row, _ = to_edge_index(edge_index)\n            col = col_and_row[0]\n            row = col_and_row[1]\n        else:\n            edge_index = sort_edge_index(edge_index, num_nodes=x.size(0),\n                                         sort_by_row=False)\n            row, col = edge_index[0], edge_index[1]\n\n        # `col` is sorted, so we can use it to `split` neighbors to groups:\n        deg = degree(col, x.size(0), dtype=torch.long).tolist()\n\n        out = []\n        for node, neighbors in zip(x.tolist(), x[row].split(deg)):\n            idx = hash(tuple([node] + neighbors.sort()[0].tolist()))\n            if idx not in self.hashmap:\n                self.hashmap[idx] = len(self.hashmap)\n            out.append(self.hashmap[idx])\n\n        return torch.tensor(out, device=x.device)\n\n    def histogram(self, x: Tensor, batch: Optional[Tensor] = None,\n                  norm: bool = False) -> Tensor:\n        r\"\"\"Given a node coloring :obj:`x`, computes the color histograms of\n        the respective graphs (separated by :obj:`batch`).\n        \"\"\"\n        if batch is None:\n            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)\n\n        num_colors = len(self.hashmap)\n        batch_size = int(batch.max()) + 1\n\n        index = batch * num_colors + x\n        out = scatter(torch.ones_like(index), index, dim=0,\n                      dim_size=num_colors * batch_size, reduce='sum')\n        out = out.view(batch_size, num_colors)\n\n        if norm:\n            out = out.to(torch.float)\n            out /= out.norm(dim=-1, keepdim=True)\n\n        return out\n"
  },
  {
    "path": "torch_geometric/nn/conv/wl_conv_continuous.py",
    "content": "from typing import Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.typing import (\n    Adj,\n    OptPairTensor,\n    OptTensor,\n    Size,\n    SparseTensor,\n)\nfrom torch_geometric.utils import scatter, spmm\n\n\nclass WLConvContinuous(MessagePassing):\n    r\"\"\"The Weisfeiler Lehman operator from the `\"Wasserstein\n    Weisfeiler-Lehman Graph Kernels\" <https://arxiv.org/abs/1906.01277>`_\n    paper.\n\n    Refinement is done though a degree-scaled mean aggregation and works on\n    nodes with continuous attributes:\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\frac{1}{2}\\big(\\mathbf{x}_i +\n        \\frac{1}{\\textrm{deg}(i)}\n        \\sum_{j \\in \\mathcal{N}(i)} e_{j,i} \\cdot \\mathbf{x}_j \\big)\n\n    where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to\n    target node :obj:`i` (default: :obj:`1`)\n\n    Args:\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MessagePassing`.\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F)` or\n          :math:`((|\\mathcal{V_s}|, F), (|\\mathcal{V_t}|, F))` if bipartite,\n          edge indices :math:`(2, |\\mathcal{E}|)`,\n          edge weights :math:`(|\\mathcal{E}|)` *(optional)*\n        - **output:** node features :math:`(|\\mathcal{V}|, F)` or\n          :math:`(|\\mathcal{V}_t|, F)` if bipartite\n    \"\"\"\n    def __init__(self, **kwargs):\n        super().__init__(aggr='add', **kwargs)\n\n    def forward(\n        self,\n        x: Union[Tensor, OptPairTensor],\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n        size: Size = None,\n    ) -> Tensor:\n\n        if isinstance(x, Tensor):\n            x = (x, x)\n\n        # propagate_type: (x: OptPairTensor, edge_weight: OptTensor)\n        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,\n                             size=size)\n\n        if isinstance(edge_index, SparseTensor):\n            assert edge_weight is None\n            dst_index, _, edge_weight = edge_index.coo()\n        else:\n            dst_index = edge_index[1]\n\n        if edge_weight is None:\n            edge_weight = x[0].new_ones(dst_index.numel())\n\n        deg = scatter(edge_weight, dst_index, 0, out.size(0), reduce='sum')\n        deg_inv = 1. / deg\n        deg_inv.masked_fill_(deg_inv == float('inf'), 0)\n        out = deg_inv.view(-1, 1) * out\n\n        x_dst = x[1]\n        if x_dst is not None:\n            out = 0.5 * (x_dst + out)\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor:\n        return spmm(adj_t, x[0], reduce=self.aggr)\n"
  },
  {
    "path": "torch_geometric/nn/conv/x_conv.py",
    "content": "from math import ceil\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import ELU\nfrom torch.nn import BatchNorm1d as BN\nfrom torch.nn import Conv1d\nfrom torch.nn import Linear as L\nfrom torch.nn import Sequential as S\n\nimport torch_geometric.typing\nfrom torch_geometric.nn import Reshape\nfrom torch_geometric.nn.inits import reset\n\nif torch_geometric.typing.WITH_TORCH_CLUSTER:\n    from torch_cluster import knn_graph\nelse:\n    knn_graph = None\n\n\nclass XConv(torch.nn.Module):\n    r\"\"\"The convolutional operator on :math:`\\mathcal{X}`-transformed points\n    from the `\"PointCNN: Convolution On X-Transformed Points\"\n    <https://arxiv.org/abs/1801.07791>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\mathrm{Conv}\\left(\\mathbf{K},\n        \\gamma_{\\mathbf{\\Theta}}(\\mathbf{P}_i - \\mathbf{p}_i) \\times\n        \\left( h_\\mathbf{\\Theta}(\\mathbf{P}_i - \\mathbf{p}_i) \\, \\Vert \\,\n        \\mathbf{x}_i \\right) \\right),\n\n    where :math:`\\mathbf{K}` and :math:`\\mathbf{P}_i` denote the trainable\n    filter and neighboring point positions of :math:`\\mathbf{x}_i`,\n    respectively.\n    :math:`\\gamma_{\\mathbf{\\Theta}}` and :math:`h_{\\mathbf{\\Theta}}` describe\n    neural networks, *i.e.* MLPs, where :math:`h_{\\mathbf{\\Theta}}`\n    individually lifts each point into a higher-dimensional space, and\n    :math:`\\gamma_{\\mathbf{\\Theta}}` computes the :math:`\\mathcal{X}`-\n    transformation matrix based on *all* points in a neighborhood.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        out_channels (int): Size of each output sample.\n        dim (int): Point cloud dimensionality.\n        kernel_size (int): Size of the convolving kernel, *i.e.* number of\n            neighbors including self-loops.\n        hidden_channels (int, optional): Output size of\n            :math:`h_{\\mathbf{\\Theta}}`, *i.e.* dimensionality of lifted\n            points. If set to :obj:`None`, will be automatically set to\n            :obj:`in_channels / 4`. (default: :obj:`None`)\n        dilation (int, optional): The factor by which the neighborhood is\n            extended, from which :obj:`kernel_size` neighbors are then\n            uniformly sampled. Can be interpreted as the dilation rate of\n            classical convolutional operators. (default: :obj:`1`)\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        num_workers (int): Number of workers to use for k-NN computation.\n            Has no effect in case :obj:`batch` is not :obj:`None`, or the input\n            lies on the GPU. (default: :obj:`1`)\n\n    Shapes:\n        - **input:**\n          node features :math:`(|\\mathcal{V}|, F_{in})`,\n          positions :math:`(|\\mathcal{V}|, D)`,\n          batch vector :math:`(|\\mathcal{V}|)` *(optional)*\n        - **output:**\n          node features :math:`(|\\mathcal{V}|, F_{out})`\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int, dim: int,\n                 kernel_size: int, hidden_channels: Optional[int] = None,\n                 dilation: int = 1, bias: bool = True, num_workers: int = 1):\n        super().__init__()\n\n        if knn_graph is None:\n            raise ImportError('`XConv` requires `torch-cluster`.')\n\n        self.in_channels = in_channels\n        if hidden_channels is None:\n            hidden_channels = in_channels // 4\n        assert hidden_channels > 0\n        self.hidden_channels = hidden_channels\n        self.out_channels = out_channels\n        self.dim = dim\n        self.kernel_size = kernel_size\n        self.dilation = dilation\n        self.num_workers = num_workers\n\n        C_in, C_delta, C_out = in_channels, hidden_channels, out_channels\n        D, K = dim, kernel_size\n\n        self.mlp1 = S(\n            L(dim, C_delta),\n            ELU(),\n            BN(C_delta),\n            L(C_delta, C_delta),\n            ELU(),\n            BN(C_delta),\n            Reshape(-1, K, C_delta),\n        )\n\n        self.mlp2 = S(\n            L(D * K, K**2),\n            ELU(),\n            BN(K**2),\n            Reshape(-1, K, K),\n            Conv1d(K, K**2, K, groups=K),\n            ELU(),\n            BN(K**2),\n            Reshape(-1, K, K),\n            Conv1d(K, K**2, K, groups=K),\n            BN(K**2),\n            Reshape(-1, K, K),\n        )\n\n        C_in = C_in + C_delta\n        depth_multiplier = int(ceil(C_out / C_in))\n        self.conv = S(\n            Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in),\n            Reshape(-1, C_in * depth_multiplier),\n            L(C_in * depth_multiplier, C_out, bias=bias),\n        )\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        reset(self.mlp1)\n        reset(self.mlp2)\n        reset(self.conv)\n\n    def forward(self, x: Tensor, pos: Tensor, batch: Optional[Tensor] = None):\n        r\"\"\"Runs the forward pass of the module.\"\"\"\n        pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos\n        (N, D), K = pos.size(), self.kernel_size\n\n        edge_index = knn_graph(pos, K * self.dilation, batch, loop=True,\n                               flow='target_to_source',\n                               num_workers=self.num_workers)\n\n        if self.dilation > 1:\n            edge_index = edge_index[:, ::self.dilation]\n\n        row, col = edge_index[0], edge_index[1]\n\n        pos = pos[col] - pos[row]\n\n        x_star = self.mlp1(pos)\n        if x is not None:\n            x = x.unsqueeze(-1) if x.dim() == 1 else x\n            x = x[col].view(N, K, self.in_channels)\n            x_star = torch.cat([x_star, x], dim=-1)\n        x_star = x_star.transpose(1, 2).contiguous()\n\n        transform_matrix = self.mlp2(pos.view(N, K * D))\n\n        x_transformed = torch.matmul(x_star, transform_matrix)\n\n        out = self.conv(x_transformed)\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/data_parallel.py",
    "content": "import logging\nimport warnings\nfrom itertools import chain\n\nimport torch\n\nfrom torch_geometric.data import Batch\nfrom torch_geometric.utils import cumsum\n\n\nclass DataParallel(torch.nn.DataParallel):\n    r\"\"\"Implements data parallelism at the module level.\n\n    This container parallelizes the application of the given :attr:`module` by\n    splitting a list of :class:`torch_geometric.data.Data` objects and copying\n    them as :class:`torch_geometric.data.Batch` objects to each device.\n    In the forward pass, the module is replicated on each device, and each\n    replica handles a portion of the input.\n    During the backwards pass, gradients from each replica are summed into the\n    original module.\n\n    The batch size should be larger than the number of GPUs used.\n\n    The parallelized :attr:`module` must have its parameters and buffers on\n    :obj:`device_ids[0]`.\n\n    .. note::\n\n        You need to use the :class:`torch_geometric.loader.DataListLoader` for\n        this module.\n\n    .. warning::\n\n        It is recommended to use\n        :class:`torch.nn.parallel.DistributedDataParallel` instead of\n        :class:`DataParallel` for multi-GPU training.\n        :class:`DataParallel` is usually much slower than\n        :class:`~torch.nn.parallel.DistributedDataParallel` even on a single\n        machine.\n        Take a look `here <https://github.com/pyg-team/pytorch_geometric/blob/\n        master/examples/multi_gpu/distributed_batching.py>`_ for an example on\n        how to use :pyg:`PyG` in combination with\n        :class:`~torch.nn.parallel.DistributedDataParallel`.\n\n    Args:\n        module (Module): Module to be parallelized.\n        device_ids (list of int or torch.device): CUDA devices.\n            (default: all devices)\n        output_device (int or torch.device): Device location of output.\n            (default: :obj:`device_ids[0]`)\n        follow_batch (list or tuple, optional): Creates assignment batch\n            vectors for each key in the list. (default: :obj:`None`)\n        exclude_keys (list or tuple, optional): Will exclude each key in the\n            list. (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, module, device_ids=None, output_device=None,\n                 follow_batch=None, exclude_keys=None):\n        super().__init__(module, device_ids, output_device)\n\n        warnings.warn(\n            \"'DataParallel' is usually much slower than \"\n            \"'DistributedDataParallel' even on a single machine. \"\n            \"Please consider switching to 'DistributedDataParallel' \"\n            \"for multi-GPU training.\", stacklevel=2)\n\n        self.src_device = torch.device(f'cuda:{self.device_ids[0]}')\n        self.follow_batch = follow_batch or []\n        self.exclude_keys = exclude_keys or []\n\n    def forward(self, data_list):\n        \"\"\"\"\"\"  # noqa: D419\n        if len(data_list) == 0:\n            logging.warning('DataParallel received an empty data list, which '\n                            'may result in unexpected behavior.')\n            return None\n\n        if not self.device_ids or len(self.device_ids) == 1:  # Fallback\n            data = Batch.from_data_list(\n                data_list, follow_batch=self.follow_batch,\n                exclude_keys=self.exclude_keys).to(self.src_device)\n            return self.module(data)\n\n        for t in chain(self.module.parameters(), self.module.buffers()):\n            if t.device != self.src_device:\n                raise RuntimeError(\n                    f\"Module must have its parameters and buffers on device \"\n                    f\"'{self.src_device}' but found one of them on device \"\n                    f\"'{t.device}'\")\n\n        inputs = self.scatter(data_list, self.device_ids)\n        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])\n        outputs = self.parallel_apply(replicas, inputs, None)\n        return self.gather(outputs, self.output_device)\n\n    def scatter(self, data_list, device_ids):\n        num_devices = min(len(device_ids), len(data_list))\n\n        count = torch.tensor([data.num_nodes for data in data_list])\n        ptr = cumsum(count)\n        device_id = num_devices * ptr.to(torch.float) / ptr[-1].item()\n        device_id = (device_id[:-1] + device_id[1:]) / 2.0\n        device_id = device_id.to(torch.long)  # round.\n        split = cumsum(device_id.bincount())\n        split = torch.unique(split, sorted=True)\n        split = split.tolist()\n\n        return [\n            Batch.from_data_list(data_list[split[i]:split[i + 1]],\n                                 follow_batch=self.follow_batch,\n                                 exclude_keys=self.exclude_keys).to(\n                                     torch.device(f'cuda:{device_ids[i]}'))\n            for i in range(len(split) - 1)\n        ]\n"
  },
  {
    "path": "torch_geometric/nn/dense/__init__.py",
    "content": "r\"\"\"Dense neural network module package.\n\nThis package provides modules applicable for operating on dense tensor\nrepresentations.\n\"\"\"\n\nfrom .linear import Linear, HeteroLinear, HeteroDictLinear\nfrom .dense_gat_conv import DenseGATConv\nfrom .dense_sage_conv import DenseSAGEConv\nfrom .dense_gcn_conv import DenseGCNConv\nfrom .dense_graph_conv import DenseGraphConv\nfrom .dense_gin_conv import DenseGINConv\nfrom .diff_pool import dense_diff_pool\nfrom .mincut_pool import dense_mincut_pool\nfrom .dmon_pool import DMoNPooling\n\n__all__ = [\n    'Linear',\n    'HeteroLinear',\n    'HeteroDictLinear',\n    'DenseGCNConv',\n    'DenseGINConv',\n    'DenseGraphConv',\n    'DenseSAGEConv',\n    'DenseGATConv',\n    'dense_diff_pool',\n    'dense_mincut_pool',\n    'DMoNPooling',\n]\n\nlin_classes = __all__[:3]\nconv_classes = __all__[3:8]\npool_classes = __all__[8:]\n"
  },
  {
    "path": "torch_geometric/nn/dense/dense_gat_conv.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import glorot, zeros\n\n\nclass DenseGATConv(torch.nn.Module):\n    r\"\"\"See :class:`torch_geometric.nn.conv.GATConv`.\"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        heads: int = 1,\n        concat: bool = True,\n        negative_slope: float = 0.2,\n        dropout: float = 0.0,\n        bias: bool = True,\n    ):\n        # TODO Add support for edge features.\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.concat = concat\n        self.negative_slope = negative_slope\n        self.dropout = dropout\n\n        self.lin = Linear(in_channels, heads * out_channels, bias=False,\n                          weight_initializer='glorot')\n\n        # The learnable parameters to compute attention coefficients:\n        self.att_src = Parameter(torch.empty(1, 1, heads, out_channels))\n        self.att_dst = Parameter(torch.empty(1, 1, heads, out_channels))\n\n        if bias and concat:\n            self.bias = Parameter(torch.empty(heads * out_channels))\n        elif bias and not concat:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.lin.reset_parameters()\n        glorot(self.att_src)\n        glorot(self.att_dst)\n        zeros(self.bias)\n\n    def forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None,\n                add_loop: bool = True):\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Node feature tensor\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n                batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n                each graph, and feature dimension :math:`F`.\n            adj (torch.Tensor): Adjacency tensor\n                :math:`\\mathbf{A} \\in \\mathbb{R}^{B \\times N \\times N}`.\n                The adjacency tensor is broadcastable in the batch dimension,\n                resulting in a shared adjacency matrix for the complete batch.\n            mask (torch.Tensor, optional): Mask matrix\n                :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n                the valid nodes for each graph. (default: :obj:`None`)\n            add_loop (bool, optional): If set to :obj:`False`, the layer will\n                not automatically add self-loops to the adjacency matrices.\n                (default: :obj:`True`)\n        \"\"\"\n        x = x.unsqueeze(0) if x.dim() == 2 else x  # [B, N, F]\n        adj = adj.unsqueeze(0) if adj.dim() == 2 else adj  # [B, N, N]\n\n        H, C = self.heads, self.out_channels\n        B, N, _ = x.size()\n\n        if add_loop:\n            adj = adj.clone()\n            idx = torch.arange(N, dtype=torch.long, device=adj.device)\n            adj[:, idx, idx] = 1.0\n\n        x = self.lin(x).view(B, N, H, C)  # [B, N, H, C]\n\n        alpha_src = torch.sum(x * self.att_src, dim=-1)  # [B, N, H]\n        alpha_dst = torch.sum(x * self.att_dst, dim=-1)  # [B, N, H]\n\n        alpha = alpha_src.unsqueeze(1) + alpha_dst.unsqueeze(2)  # [B, N, N, H]\n\n        # Weighted and masked softmax:\n        alpha = F.leaky_relu(alpha, self.negative_slope)\n        alpha = alpha.masked_fill(adj.unsqueeze(-1) == 0, float('-inf'))\n        alpha = alpha.softmax(dim=2)\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n\n        out = torch.matmul(alpha.movedim(3, 1), x.movedim(2, 1))\n        out = out.movedim(1, 2)  # [B,N,H,C]\n\n        if self.concat:\n            out = out.reshape(B, N, H * C)\n        else:\n            out = out.mean(dim=2)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        if mask is not None:\n            out = out * mask.view(-1, N, 1).to(x.dtype)\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads})')\n"
  },
  {
    "path": "torch_geometric/nn/dense/dense_gcn_conv.py",
    "content": "import torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.typing import OptTensor\n\n\nclass DenseGCNConv(torch.nn.Module):\n    r\"\"\"See :class:`torch_geometric.nn.conv.GCNConv`.\"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        improved: bool = False,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.improved = improved\n\n        self.lin = Linear(in_channels, out_channels, bias=False,\n                          weight_initializer='glorot')\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.lin.reset_parameters()\n        zeros(self.bias)\n\n    def forward(self, x: Tensor, adj: Tensor, mask: OptTensor = None,\n                add_loop: bool = True) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Node feature tensor\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n                batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n                each graph, and feature dimension :math:`F`.\n            adj (torch.Tensor): Adjacency tensor\n                :math:`\\mathbf{A} \\in \\mathbb{R}^{B \\times N \\times N}`.\n                The adjacency tensor is broadcastable in the batch dimension,\n                resulting in a shared adjacency matrix for the complete batch.\n            mask (torch.Tensor, optional): Mask matrix\n                :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n                the valid nodes for each graph. (default: :obj:`None`)\n            add_loop (bool, optional): If set to :obj:`False`, the layer will\n                not automatically add self-loops to the adjacency matrices.\n                (default: :obj:`True`)\n        \"\"\"\n        x = x.unsqueeze(0) if x.dim() == 2 else x\n        adj = adj.unsqueeze(0) if adj.dim() == 2 else adj\n        B, N, _ = adj.size()\n\n        if add_loop:\n            adj = adj.clone()\n            idx = torch.arange(N, dtype=torch.long, device=adj.device)\n            adj[:, idx, idx] = 1 if not self.improved else 2\n\n        out = self.lin(x)\n        deg_inv_sqrt = adj.sum(dim=-1).clamp(min=1).pow(-0.5)\n\n        adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)\n        out = torch.matmul(adj, out)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        if mask is not None:\n            out = out * mask.view(B, N, 1).to(x.dtype)\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/dense/dense_gin_conv.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Module\n\nfrom torch_geometric.nn.inits import reset\n\n\nclass DenseGINConv(torch.nn.Module):\n    r\"\"\"See :class:`torch_geometric.nn.conv.GINConv`.\"\"\"\n    def __init__(\n        self,\n        nn: Module,\n        eps: float = 0.0,\n        train_eps: bool = False,\n    ):\n        super().__init__()\n\n        self.nn = nn\n        self.initial_eps = eps\n        if train_eps:\n            self.eps = torch.nn.Parameter(torch.empty(1))\n        else:\n            self.register_buffer('eps', torch.empty(1))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        reset(self.nn)\n        self.eps.data.fill_(self.initial_eps)\n\n    def forward(self, x: Tensor, adj: Tensor, mask: Optional[Tensor] = None,\n                add_loop: bool = True) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Node feature tensor\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n                batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n                each graph, and feature dimension :math:`F`.\n            adj (torch.Tensor): Adjacency tensor\n                :math:`\\mathbf{A} \\in \\mathbb{R}^{B \\times N \\times N}`.\n                The adjacency tensor is broadcastable in the batch dimension,\n                resulting in a shared adjacency matrix for the complete batch.\n            mask (torch.Tensor, optional): Mask matrix\n                :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n                the valid nodes for each graph. (default: :obj:`None`)\n            add_loop (bool, optional): If set to :obj:`False`, the layer will\n                not automatically add self-loops to the adjacency matrices.\n                (default: :obj:`True`)\n        \"\"\"\n        x = x.unsqueeze(0) if x.dim() == 2 else x\n        adj = adj.unsqueeze(0) if adj.dim() == 2 else adj\n        B, N, _ = adj.size()\n\n        out = torch.matmul(adj, x)\n        if add_loop:\n            out = (1 + self.eps) * x + out\n\n        out = self.nn(out)\n\n        if mask is not None:\n            out = out * mask.view(B, N, 1).to(x.dtype)\n\n        return out\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(nn={self.nn})'\n"
  },
  {
    "path": "torch_geometric/nn/dense/dense_graph_conv.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Linear\n\n\nclass DenseGraphConv(torch.nn.Module):\n    r\"\"\"See :class:`torch_geometric.nn.conv.GraphConv`.\"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        aggr: str = 'add',\n        bias: bool = True,\n    ):\n        assert aggr in ['add', 'mean', 'max']\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.aggr = aggr\n\n        self.lin_rel = Linear(in_channels, out_channels, bias=bias)\n        self.lin_root = Linear(in_channels, out_channels, bias=False)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.lin_rel.reset_parameters()\n        self.lin_root.reset_parameters()\n\n    def forward(self, x: Tensor, adj: Tensor,\n                mask: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Node feature tensor\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n                batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n                each graph, and feature dimension :math:`F`.\n            adj (torch.Tensor): Adjacency tensor\n                :math:`\\mathbf{A} \\in \\mathbb{R}^{B \\times N \\times N}`.\n                The adjacency tensor is broadcastable in the batch dimension,\n                resulting in a shared adjacency matrix for the complete batch.\n            mask (torch.Tensor, optional): Mask matrix\n                :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n                the valid nodes for each graph. (default: :obj:`None`)\n        \"\"\"\n        x = x.unsqueeze(0) if x.dim() == 2 else x\n        adj = adj.unsqueeze(0) if adj.dim() == 2 else adj\n        B, N, C = x.size()\n\n        if self.aggr == 'add':\n            out = torch.matmul(adj, x)\n        elif self.aggr == 'mean':\n            out = torch.matmul(adj, x)\n            out = out / adj.sum(dim=-1, keepdim=True).clamp_(min=1)\n        elif self.aggr == 'max':\n            out = x.unsqueeze(-2).repeat(1, 1, N, 1)\n            adj = adj.unsqueeze(-1).expand(B, N, N, C)\n            out[adj == 0] = float('-inf')\n            out = out.max(dim=-3)[0]\n            out[out == float('-inf')] = 0.\n        else:\n            raise NotImplementedError\n\n        out = self.lin_rel(out)\n        out = out + self.lin_root(x)\n\n        if mask is not None:\n            out = out * mask.view(-1, N, 1).to(x.dtype)\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/dense/dense_sage_conv.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Linear\n\nfrom torch_geometric.typing import OptTensor\n\n\nclass DenseSAGEConv(torch.nn.Module):\n    r\"\"\"See :class:`torch_geometric.nn.conv.SAGEConv`.\n\n    .. note::\n\n        :class:`~torch_geometric.nn.dense.DenseSAGEConv` expects to work on\n        binary adjacency matrices.\n        If you want to make use of weighted dense adjacency matrices, please\n        use :class:`torch_geometric.nn.dense.DenseGraphConv` instead.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        normalize: bool = False,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.normalize = normalize\n\n        self.lin_rel = Linear(in_channels, out_channels, bias=False)\n        self.lin_root = Linear(in_channels, out_channels, bias=bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.lin_rel.reset_parameters()\n        self.lin_root.reset_parameters()\n\n    def forward(self, x: Tensor, adj: Tensor,\n                mask: OptTensor = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Node feature tensor\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n                batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n                each graph, and feature dimension :math:`F`.\n            adj (torch.Tensor): Adjacency tensor\n                :math:`\\mathbf{A} \\in \\mathbb{R}^{B \\times N \\times N}`.\n                The adjacency tensor is broadcastable in the batch dimension,\n                resulting in a shared adjacency matrix for the complete batch.\n            mask (torch.Tensor, optional): Mask matrix\n                :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n                the valid nodes for each graph. (default: :obj:`None`)\n        \"\"\"\n        x = x.unsqueeze(0) if x.dim() == 2 else x\n        adj = adj.unsqueeze(0) if adj.dim() == 2 else adj\n        B, N, _ = adj.size()\n\n        out = torch.matmul(adj, x)\n        out = out / adj.sum(dim=-1, keepdim=True).clamp(min=1)\n        out = self.lin_rel(out) + self.lin_root(x)\n\n        if self.normalize:\n            out = F.normalize(out, p=2.0, dim=-1)\n\n        if mask is not None:\n            out = out * mask.view(B, N, 1).to(x.dtype)\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/dense/diff_pool.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\n\ndef dense_diff_pool(\n    x: Tensor,\n    adj: Tensor,\n    s: Tensor,\n    mask: Optional[Tensor] = None,\n    normalize: bool = True,\n) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n    r\"\"\"The differentiable pooling operator from the `\"Hierarchical Graph\n    Representation Learning with Differentiable Pooling\"\n    <https://arxiv.org/abs/1806.08804>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} &= {\\mathrm{softmax}(\\mathbf{S})}^{\\top} \\cdot\n        \\mathbf{X}\n\n        \\mathbf{A}^{\\prime} &= {\\mathrm{softmax}(\\mathbf{S})}^{\\top} \\cdot\n        \\mathbf{A} \\cdot \\mathrm{softmax}(\\mathbf{S})\n\n    based on dense learned assignments :math:`\\mathbf{S} \\in \\mathbb{R}^{B\n    \\times N \\times C}`.\n    Returns the pooled node feature matrix, the coarsened adjacency matrix and\n    two auxiliary objectives: (1) The link prediction loss\n\n    .. math::\n        \\mathcal{L}_{LP} = {\\| \\mathbf{A} -\n        \\mathrm{softmax}(\\mathbf{S}) {\\mathrm{softmax}(\\mathbf{S})}^{\\top}\n        \\|}_F,\n\n    and (2) the entropy regularization\n\n    .. math::\n        \\mathcal{L}_E = \\frac{1}{N} \\sum_{n=1}^N H(\\mathbf{S}_n).\n\n    Args:\n        x (torch.Tensor): Node feature tensor\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n            batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n            each graph, and feature dimension :math:`F`.\n        adj (torch.Tensor): Adjacency tensor\n            :math:`\\mathbf{A} \\in \\mathbb{R}^{B \\times N \\times N}`.\n        s (torch.Tensor): Assignment tensor\n            :math:`\\mathbf{S} \\in \\mathbb{R}^{B \\times N \\times C}`\n            with number of clusters :math:`C`.\n            The softmax does not have to be applied before-hand, since it is\n            executed within this method.\n        mask (torch.Tensor, optional): Mask matrix\n            :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n            the valid nodes for each graph. (default: :obj:`None`)\n        normalize (bool, optional): If set to :obj:`False`, the link\n            prediction loss is not divided by :obj:`adj.numel()`.\n            (default: :obj:`True`)\n\n    :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`,\n        :class:`torch.Tensor`, :class:`torch.Tensor`)\n    \"\"\"\n    x = x.unsqueeze(0) if x.dim() == 2 else x\n    adj = adj.unsqueeze(0) if adj.dim() == 2 else adj\n    s = s.unsqueeze(0) if s.dim() == 2 else s\n\n    batch_size, num_nodes, _ = x.size()\n\n    s = torch.softmax(s, dim=-1)\n\n    if mask is not None:\n        mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)\n        x, s = x * mask, s * mask\n\n    out = torch.matmul(s.transpose(1, 2), x)\n    out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)\n\n    link_loss = adj - torch.matmul(s, s.transpose(1, 2))\n    link_loss = torch.norm(link_loss, p=2)\n    if normalize is True:\n        link_loss = link_loss / adj.numel()\n\n    ent_loss = (-s * torch.log(s + 1e-15)).sum(dim=-1).mean()\n\n    return out, out_adj, link_loss, ent_loss\n"
  },
  {
    "path": "torch_geometric/nn/dense/dmon_pool.py",
    "content": "from typing import List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.dense.mincut_pool import _rank3_trace\n\nEPS = 1e-15\n\n\nclass DMoNPooling(torch.nn.Module):\n    r\"\"\"The spectral modularity pooling operator from the `\"Graph Clustering\n    with Graph Neural Networks\" <https://arxiv.org/abs/2006.16904>`_ paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} &= {\\mathrm{softmax}(\\mathbf{S})}^{\\top} \\cdot\n        \\mathbf{X}\n\n        \\mathbf{A}^{\\prime} &= {\\mathrm{softmax}(\\mathbf{S})}^{\\top} \\cdot\n        \\mathbf{A} \\cdot \\mathrm{softmax}(\\mathbf{S})\n\n    based on dense learned assignments :math:`\\mathbf{S} \\in \\mathbb{R}^{B\n    \\times N \\times C}`.\n    Returns the learned cluster assignment matrix, the pooled node feature\n    matrix, the coarsened symmetrically normalized adjacency matrix, and three\n    auxiliary objectives: (1) The spectral loss\n\n    .. math::\n        \\mathcal{L}_s = - \\frac{1}{2m}\n        \\cdot{\\mathrm{Tr}(\\mathbf{S}^{\\top} \\mathbf{B} \\mathbf{S})}\n\n    where :math:`\\mathbf{B}` is the modularity matrix, (2) the orthogonality\n    loss\n\n    .. math::\n        \\mathcal{L}_o = {\\left\\| \\frac{\\mathbf{S}^{\\top} \\mathbf{S}}\n        {{\\|\\mathbf{S}^{\\top} \\mathbf{S}\\|}_F} -\\frac{\\mathbf{I}_C}{\\sqrt{C}}\n        \\right\\|}_F\n\n    where :math:`C` is the number of clusters, and (3) the cluster loss\n\n    .. math::\n        \\mathcal{L}_c = \\frac{\\sqrt{C}}{n}\n        {\\left\\|\\sum_i\\mathbf{C_i}^{\\top}\\right\\|}_F - 1.\n\n    .. note::\n\n        For an example of using :class:`DMoNPooling`, see\n        `examples/proteins_dmon_pool.py\n        <https://github.com/pyg-team/pytorch_geometric/blob\n        /master/examples/proteins_dmon_pool.py>`_.\n\n    Args:\n        channels (int or List[int]): Size of each input sample. If given as a\n            list, will construct an MLP based on the given feature sizes.\n        k (int): The number of clusters.\n        dropout (float, optional): Dropout probability. (default: :obj:`0.0`)\n    \"\"\"\n    def __init__(self, channels: Union[int, List[int]], k: int,\n                 dropout: float = 0.0):\n        super().__init__()\n\n        if isinstance(channels, int):\n            channels = [channels]\n\n        from torch_geometric.nn.models.mlp import MLP\n        self.mlp = MLP(channels + [k], act=None, norm=None)\n\n        self.dropout = dropout\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.mlp.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        adj: Tensor,\n        mask: Optional[Tensor] = None,\n    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Node feature tensor\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n                batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n                each graph, and feature dimension :math:`F`.\n                Note that the cluster assignment matrix\n                :math:`\\mathbf{S} \\in \\mathbb{R}^{B \\times N \\times C}` is\n                being created within this method.\n            adj (torch.Tensor): Adjacency tensor\n                :math:`\\mathbf{A} \\in \\mathbb{R}^{B \\times N \\times N}`.\n            mask (torch.Tensor, optional): Mask matrix\n                :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n                the valid nodes for each graph. (default: :obj:`None`)\n\n        :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`,\n            :class:`torch.Tensor`, :class:`torch.Tensor`,\n            :class:`torch.Tensor`, :class:`torch.Tensor`)\n        \"\"\"\n        x = x.unsqueeze(0) if x.dim() == 2 else x\n        adj = adj.unsqueeze(0) if adj.dim() == 2 else adj\n\n        s = self.mlp(x)\n        s = F.dropout(s, self.dropout, training=self.training)\n        s = torch.softmax(s, dim=-1)\n\n        (batch_size, num_nodes, _), C = x.size(), s.size(-1)\n\n        if mask is None:\n            mask = torch.ones(batch_size, num_nodes, dtype=torch.bool,\n                              device=x.device)\n\n        mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)\n        x, s = x * mask, s * mask\n\n        out = F.selu(torch.matmul(s.transpose(1, 2), x))\n        out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)\n\n        # Spectral loss:\n        degrees = torch.einsum('ijk->ij', adj)  # B X N\n        degrees = degrees.unsqueeze(-1) * mask  # B x N x 1\n        degrees_t = degrees.transpose(1, 2)  # B x 1 x N\n\n        m = torch.einsum('ijk->i', degrees) / 2  # B\n        m_expand = m.view(-1, 1, 1).expand(-1, C, C)  # B x C x C\n\n        ca = torch.matmul(s.transpose(1, 2), degrees)  # B x C x 1\n        cb = torch.matmul(degrees_t, s)  # B x 1 x C\n\n        normalizer = torch.matmul(ca, cb) / 2 / m_expand\n        decompose = out_adj - normalizer\n        spectral_loss = -_rank3_trace(decompose) / 2 / m\n        spectral_loss = spectral_loss.mean()\n\n        # Orthogonality regularization:\n        ss = torch.matmul(s.transpose(1, 2), s)\n        i_s = torch.eye(C).type_as(ss)\n        ortho_loss = torch.norm(\n            ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -\n            i_s / torch.norm(i_s), dim=(-1, -2))\n        ortho_loss = ortho_loss.mean()\n\n        # Cluster loss:\n        cluster_size = torch.einsum('ijk->ik', s)  # B x C\n        cluster_loss = torch.norm(input=cluster_size, dim=1)\n        cluster_loss = cluster_loss / mask.sum(dim=1) * torch.norm(i_s) - 1\n        cluster_loss = cluster_loss.mean()\n\n        # Fix and normalize coarsened adjacency matrix:\n        ind = torch.arange(C, device=out_adj.device)\n        out_adj[:, ind, ind] = 0\n        d = torch.einsum('ijk->ij', out_adj)\n        d = torch.sqrt(d)[:, None] + EPS\n        out_adj = (out_adj / d) / d.transpose(1, 2)\n\n        return s, out, out_adj, spectral_loss, ortho_loss, cluster_loss\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.mlp.in_channels}, '\n                f'num_clusters={self.mlp.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/dense/linear.py",
    "content": "import math\nimport os\nimport time\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn.parameter import Parameter\n\nimport torch_geometric.backend\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling\nfrom torch_geometric.index import index2ptr\nfrom torch_geometric.nn import inits\nfrom torch_geometric.typing import pyg_lib\nfrom torch_geometric.utils import index_sort\n\n\ndef is_uninitialized_parameter(x: Any) -> bool:\n    if not hasattr(torch.nn.parameter, 'UninitializedParameter'):\n        return False\n    return isinstance(x, torch.nn.parameter.UninitializedParameter)\n\n\ndef reset_weight_(weight: Tensor, in_channels: int,\n                  initializer: Optional[str] = None) -> Tensor:\n    if in_channels <= 0:\n        pass\n    elif initializer == 'glorot':\n        inits.glorot(weight)\n    elif initializer == 'uniform':\n        bound = 1.0 / math.sqrt(in_channels)\n        torch.nn.init.uniform_(weight.data, -bound, bound)\n    elif initializer == 'kaiming_uniform':\n        inits.kaiming_uniform(weight, fan=in_channels, a=math.sqrt(5))\n    elif initializer is None:\n        inits.kaiming_uniform(weight, fan=in_channels, a=math.sqrt(5))\n    else:\n        raise RuntimeError(f\"Weight initializer '{initializer}' not supported\")\n\n    return weight\n\n\ndef reset_bias_(bias: Optional[Tensor], in_channels: int,\n                initializer: Optional[str] = None) -> Optional[Tensor]:\n    if bias is None or in_channels <= 0:\n        pass\n    elif initializer == 'zeros':\n        inits.zeros(bias)\n    elif initializer is None:\n        inits.uniform(in_channels, bias)\n    else:\n        raise RuntimeError(f\"Bias initializer '{initializer}' not supported\")\n\n    return bias\n\n\nclass Linear(torch.nn.Module):\n    r\"\"\"Applies a linear transformation to the incoming data.\n\n    .. math::\n        \\mathbf{x}^{\\prime} = \\mathbf{x} \\mathbf{W}^{\\top} + \\mathbf{b}\n\n    In contrast to :class:`torch.nn.Linear`, it supports lazy initialization\n    and customizable weight and bias initialization.\n\n    Args:\n        in_channels (int): Size of each input sample. Will be initialized\n            lazily in case it is given as :obj:`-1`.\n        out_channels (int): Size of each output sample.\n        bias (bool, optional): If set to :obj:`False`, the layer will not learn\n            an additive bias. (default: :obj:`True`)\n        weight_initializer (str, optional): The initializer for the weight\n            matrix (:obj:`\"glorot\"`, :obj:`\"uniform\"`, :obj:`\"kaiming_uniform\"`\n            or :obj:`None`).\n            If set to :obj:`None`, will match default weight initialization of\n            :class:`torch.nn.Linear`. (default: :obj:`None`)\n        bias_initializer (str, optional): The initializer for the bias vector\n            (:obj:`\"zeros\"` or :obj:`None`).\n            If set to :obj:`None`, will match default bias initialization of\n            :class:`torch.nn.Linear`. (default: :obj:`None`)\n\n    Shapes:\n        - **input:** features :math:`(*, F_{in})`\n        - **output:** features :math:`(*, F_{out})`\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        bias: bool = True,\n        weight_initializer: Optional[str] = None,\n        bias_initializer: Optional[str] = None,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.weight_initializer = weight_initializer\n        self.bias_initializer = bias_initializer\n\n        if in_channels > 0:\n            self.weight = Parameter(torch.empty(out_channels, in_channels))\n        else:\n            self.weight = torch.nn.parameter.UninitializedParameter()\n            self._hook = self.register_forward_pre_hook(\n                self.initialize_parameters)\n\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        reset_weight_(self.weight, self.in_channels, self.weight_initializer)\n        reset_bias_(self.bias, self.in_channels, self.bias_initializer)\n\n    def forward(self, x: Tensor) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The input features.\n        \"\"\"\n        return F.linear(x, self.weight, self.bias)\n\n    @torch.no_grad()\n    def initialize_parameters(self, module, input):\n        if is_uninitialized_parameter(self.weight):\n            self.in_channels = input[0].size(-1)\n            self.weight.materialize((self.out_channels, self.in_channels))\n            self.reset_parameters()\n        self._hook.remove()\n        delattr(self, '_hook')\n\n    def _save_to_state_dict(self, destination, prefix, keep_vars):\n        if (is_uninitialized_parameter(self.weight)\n                or torch.onnx.is_in_onnx_export() or keep_vars):\n            destination[prefix + 'weight'] = self.weight\n        else:\n            destination[prefix + 'weight'] = self.weight.detach()\n        if self.bias is not None:\n            if torch.onnx.is_in_onnx_export() or keep_vars:\n                destination[prefix + 'bias'] = self.bias\n            else:\n                destination[prefix + 'bias'] = self.bias.detach()\n\n    def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):\n        weight = state_dict.get(prefix + 'weight', None)\n\n        if weight is not None and is_uninitialized_parameter(weight):\n            self.in_channels = -1\n            self.weight = torch.nn.parameter.UninitializedParameter()\n            if not hasattr(self, '_hook'):\n                self._hook = self.register_forward_pre_hook(\n                    self.initialize_parameters)\n\n        elif weight is not None and is_uninitialized_parameter(self.weight):\n            self.in_channels = weight.size(-1)\n            self.weight.materialize((self.out_channels, self.in_channels))\n            if hasattr(self, '_hook'):\n                self._hook.remove()\n                delattr(self, '_hook')\n\n        super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, bias={self.bias is not None})')\n\n\nclass HeteroLinear(torch.nn.Module):\n    r\"\"\"Applies separate linear transformations to the incoming data according\n    to types.\n\n    For type :math:`\\kappa`, it computes\n\n    .. math::\n        \\mathbf{x}^{\\prime}_{\\kappa} = \\mathbf{x}_{\\kappa}\n        \\mathbf{W}^{\\top}_{\\kappa} + \\mathbf{b}_{\\kappa}.\n\n    It supports lazy initialization and customizable weight and bias\n    initialization.\n\n    Args:\n        in_channels (int): Size of each input sample. Will be initialized\n            lazily in case it is given as :obj:`-1`.\n        out_channels (int): Size of each output sample.\n        num_types (int): The number of types.\n        is_sorted (bool, optional): If set to :obj:`True`, assumes that\n            :obj:`type_vec` is sorted. This avoids internal re-sorting of the\n            data and can improve runtime and memory efficiency.\n            (default: :obj:`False`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.Linear`.\n\n    Shapes:\n        - **input:**\n          features :math:`(*, F_{in})`,\n          type vector :math:`(*)`\n        - **output:** features :math:`(*, F_{out})`\n    \"\"\"\n    _timing_cache: Dict[int, Tuple[float, float]]\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        num_types: int,\n        is_sorted: bool = False,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_types = num_types\n        self.is_sorted = is_sorted\n        self.kwargs = kwargs\n\n        if self.in_channels == -1:\n            self.weight = torch.nn.parameter.UninitializedParameter()\n            self._hook = self.register_forward_pre_hook(\n                self.initialize_parameters)\n        else:\n            self.weight = torch.nn.Parameter(\n                torch.empty(num_types, in_channels, out_channels))\n\n        if kwargs.get('bias', True):\n            self.bias = Parameter(torch.empty(num_types, out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        # Timing cache for benchmarking naive vs. segment matmul usage:\n        self._timing_cache: Dict[int, Tuple[float, float]] = {}\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        reset_weight_(self.weight, self.in_channels,\n                      self.kwargs.get('weight_initializer', None))\n        reset_bias_(self.bias, self.in_channels,\n                    self.kwargs.get('bias_initializer', None))\n\n    def forward_naive(self, x: Tensor, type_ptr: Tensor) -> Tensor:\n        out = x.new_empty(x.size(0), self.out_channels)\n        for i, (start, end) in enumerate(zip(type_ptr[:-1], type_ptr[1:])):\n            out[start:end] = x[start:end] @ self.weight[i]\n        return out\n\n    def forward_segmm(self, x: Tensor, type_ptr: Tensor) -> Tensor:\n        return pyg_lib.ops.segment_matmul(x, type_ptr, self.weight)\n\n    @torch.no_grad()\n    def _update_timing_cache(\n        self,\n        x: Tensor,\n        type_ptr: Tensor,\n        key: int,\n    ) -> None:\n\n        MEASURE_ITER = 1 if 'PYTEST_CURRENT_TEST' in os.environ else 3\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        t = time.perf_counter()\n        for _ in range(MEASURE_ITER):\n            _ = self.forward_segmm(x, type_ptr)\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        time_segmm = time.perf_counter() - t\n\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        t = time.perf_counter()\n        for _ in range(MEASURE_ITER):\n            _ = self.forward_naive(x, type_ptr)\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        time_naive = time.perf_counter() - t\n\n        self._timing_cache[key] = (time_segmm, time_naive)\n\n    def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:\n        r\"\"\"The forward pass.\n\n        Args:\n            x (torch.Tensor): The input features.\n            type_vec (torch.Tensor): A vector that maps each entry to a type.\n        \"\"\"\n        perm: Optional[Tensor] = None\n        if not self.is_sorted and (type_vec[1:] < type_vec[:-1]).any():\n            type_vec, perm = index_sort(type_vec, self.num_types)\n            x = x[perm]\n\n        type_ptr = index2ptr(type_vec, self.num_types)\n\n        if torch_geometric.backend.use_segment_matmul is None:\n            use_segment_matmul = False\n            if (torch_geometric.typing.WITH_SEGMM and not is_compiling()\n                    and not torch.jit.is_scripting()):\n\n                # Use \"magnitude\" of number of rows as timing key:\n                key = math.floor(math.log10(x.size(0)))\n                if key not in self._timing_cache:\n                    self._update_timing_cache(x, type_ptr, key)\n                time_segmm, time_naive = self._timing_cache[key]\n                use_segment_matmul = time_segmm < time_naive\n        else:\n            use_segment_matmul = torch_geometric.backend.use_segment_matmul\n\n        if (torch_geometric.typing.WITH_SEGMM and not is_compiling()\n                and use_segment_matmul):\n            out = self.forward_segmm(x, type_ptr)\n        else:\n            out = self.forward_naive(x, type_ptr)\n\n        if self.bias is not None:\n            out += self.bias[type_vec]\n\n        if perm is not None:  # Restore original order (if necessary).\n            out_unsorted = torch.empty_like(out)\n            out_unsorted[perm] = out\n            out = out_unsorted\n\n        return out\n\n    @torch.no_grad()\n    def initialize_parameters(self, module, input):\n        if is_uninitialized_parameter(self.weight):\n            self.in_channels = input[0].size(-1)\n            self.weight.materialize(\n                (self.num_types, self.in_channels, self.out_channels))\n            self.reset_parameters()\n        self._hook.remove()\n        delattr(self, '_hook')\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, num_types={self.num_types}, '\n                f'bias={self.kwargs.get(\"bias\", True)})')\n\n\nclass HeteroDictLinear(torch.nn.Module):\n    r\"\"\"Applies separate linear transformations to the incoming data\n    dictionary.\n\n    For key :math:`\\kappa`, it computes\n\n    .. math::\n        \\mathbf{x}^{\\prime}_{\\kappa} = \\mathbf{x}_{\\kappa}\n        \\mathbf{W}^{\\top}_{\\kappa} + \\mathbf{b}_{\\kappa}.\n\n    It supports lazy initialization and customizable weight and bias\n    initialization.\n\n    Args:\n        in_channels (int or Dict[Any, int]): Size of each input sample. If\n            passed an integer, :obj:`types` will be a mandatory argument.\n            initialized lazily in case it is given as :obj:`-1`.\n        out_channels (int): Size of each output sample.\n        types (List[Any], optional): The keys of the input dictionary.\n            (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.Linear`.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: Union[int, Dict[Any, int]],\n        out_channels: int,\n        types: Optional[Any] = None,\n        **kwargs,\n    ):\n        super().__init__()\n\n        if isinstance(in_channels, dict):\n            self.types = list(in_channels.keys())\n\n            if any([i == -1 for i in in_channels.values()]):\n                self._hook = self.register_forward_pre_hook(\n                    self.initialize_parameters)\n\n            if types is not None and set(self.types) != set(types):\n                raise ValueError(\"The provided 'types' do not match with the \"\n                                 \"keys in the 'in_channels' dictionary\")\n\n        else:\n            if types is None:\n                raise ValueError(\"Please provide a list of 'types' if passing \"\n                                 \"'in_channels' as an integer\")\n\n            if in_channels == -1:\n                self._hook = self.register_forward_pre_hook(\n                    self.initialize_parameters)\n\n            self.types = types\n            in_channels = {node_type: in_channels for node_type in types}\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kwargs = kwargs\n\n        self.lins = torch.nn.ModuleDict({\n            key:\n            Linear(channels, self.out_channels, **kwargs)\n            for key, channels in self.in_channels.items()\n        })\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for lin in self.lins.values():\n            lin.reset_parameters()\n\n    def forward(\n        self,\n        x_dict: Dict[str, Tensor],\n    ) -> Dict[str, Tensor]:\n        r\"\"\"Forward pass.\n\n        Args:\n            x_dict (Dict[Any, torch.Tensor]): A dictionary holding input\n                features for each individual type.\n        \"\"\"\n        out_dict = {}\n\n        # Only apply fused kernel for more than 10 types, otherwise use\n        # sequential computation (which is generally faster for these cases).\n        use_segment_matmul = torch_geometric.backend.use_segment_matmul\n        if use_segment_matmul is None:\n            use_segment_matmul = len(x_dict) >= 10\n\n        if (use_segment_matmul and torch_geometric.typing.WITH_GMM\n                and not is_compiling() and not torch.jit.is_scripting()):\n            xs, weights, biases = [], [], []\n            for key, lin in self.lins.items():\n                if key in x_dict:\n                    xs.append(x_dict[key])\n                    weights.append(lin.weight.t())\n                    biases.append(lin.bias)\n            biases = None if biases[0] is None else biases\n            outs = pyg_lib.ops.grouped_matmul(xs, weights, biases)\n            for key, out in zip(x_dict.keys(), outs):\n                if key in x_dict:\n                    out_dict[key] = out\n        else:\n            for key, lin in self.lins.items():\n                if key in x_dict:\n                    out_dict[key] = lin(x_dict[key])\n\n        return out_dict\n\n    @torch.no_grad()\n    def initialize_parameters(self, module, input):\n        for key, x in input[0].items():\n            lin = self.lins[key]\n            if is_uninitialized_parameter(lin.weight):\n                self.lins[key].initialize_parameters(None, x)\n                self.lins[key].reset_parameters()\n        self._hook.remove()\n        self.in_channels = {key: x.size(-1) for key, x in input[0].items()}\n        delattr(self, '_hook')\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, bias={self.kwargs.get(\"bias\", True)})')\n"
  },
  {
    "path": "torch_geometric/nn/dense/mincut_pool.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\n\ndef dense_mincut_pool(\n    x: Tensor,\n    adj: Tensor,\n    s: Tensor,\n    mask: Optional[Tensor] = None,\n    temp: float = 1.0,\n) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n    r\"\"\"The MinCut pooling operator from the `\"Spectral Clustering in Graph\n    Neural Networks for Graph Pooling\" <https://arxiv.org/abs/1907.00481>`_\n    paper.\n\n    .. math::\n        \\mathbf{X}^{\\prime} &= {\\mathrm{softmax}(\\mathbf{S})}^{\\top} \\cdot\n        \\mathbf{X}\n\n        \\mathbf{A}^{\\prime} &= {\\mathrm{softmax}(\\mathbf{S})}^{\\top} \\cdot\n        \\mathbf{A} \\cdot \\mathrm{softmax}(\\mathbf{S})\n\n    based on dense learned assignments :math:`\\mathbf{S} \\in \\mathbb{R}^{B\n    \\times N \\times C}`.\n    Returns the pooled node feature matrix, the coarsened and symmetrically\n    normalized adjacency matrix and two auxiliary objectives: (1) The MinCut\n    loss\n\n    .. math::\n        \\mathcal{L}_c = - \\frac{\\mathrm{Tr}(\\mathbf{S}^{\\top} \\mathbf{A}\n        \\mathbf{S})} {\\mathrm{Tr}(\\mathbf{S}^{\\top} \\mathbf{D}\n        \\mathbf{S})}\n\n    where :math:`\\mathbf{D}` is the degree matrix, and (2) the orthogonality\n    loss\n\n    .. math::\n        \\mathcal{L}_o = {\\left\\| \\frac{\\mathbf{S}^{\\top} \\mathbf{S}}\n        {{\\|\\mathbf{S}^{\\top} \\mathbf{S}\\|}_F} -\\frac{\\mathbf{I}_C}{\\sqrt{C}}\n        \\right\\|}_F.\n\n    Args:\n        x (torch.Tensor): Node feature tensor\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`, with\n            batch-size :math:`B`, (maximum) number of nodes :math:`N` for\n            each graph, and feature dimension :math:`F`.\n        adj (torch.Tensor): Adjacency tensor\n            :math:`\\mathbf{A} \\in \\mathbb{R}^{B \\times N \\times N}`.\n        s (torch.Tensor): Assignment tensor\n            :math:`\\mathbf{S} \\in \\mathbb{R}^{B \\times N \\times C}`\n            with number of clusters :math:`C`.\n            The softmax does not have to be applied before-hand, since it is\n            executed within this method.\n        mask (torch.Tensor, optional): Mask matrix\n            :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}` indicating\n            the valid nodes for each graph. (default: :obj:`None`)\n        temp (float, optional): Temperature parameter for softmax function.\n            (default: :obj:`1.0`)\n\n    :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`,\n        :class:`torch.Tensor`, :class:`torch.Tensor`)\n    \"\"\"\n    x = x.unsqueeze(0) if x.dim() == 2 else x\n    adj = adj.unsqueeze(0) if adj.dim() == 2 else adj\n    s = s.unsqueeze(0) if s.dim() == 2 else s\n\n    (batch_size, num_nodes, _), k = x.size(), s.size(-1)\n\n    s = torch.softmax(s / temp if temp != 1.0 else s, dim=-1)\n\n    if mask is not None:\n        mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)\n        x, s = x * mask, s * mask\n\n    out = torch.matmul(s.transpose(1, 2), x)\n    out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)\n\n    # MinCut regularization.\n    mincut_num = _rank3_trace(out_adj)\n    d_flat = torch.einsum('ijk->ij', adj)\n    d = _rank3_diag(d_flat)\n    mincut_den = _rank3_trace(\n        torch.matmul(torch.matmul(s.transpose(1, 2), d), s))\n    mincut_loss = -(mincut_num / mincut_den)\n    mincut_loss = torch.mean(mincut_loss)\n\n    # Orthogonality regularization.\n    ss = torch.matmul(s.transpose(1, 2), s)\n    i_s = torch.eye(k).type_as(ss)\n    ortho_loss = torch.norm(\n        ss / torch.norm(ss, dim=(-1, -2), keepdim=True) -\n        i_s / torch.norm(i_s), dim=(-1, -2))\n    ortho_loss = torch.mean(ortho_loss)\n\n    EPS = 1e-15\n\n    # Fix and normalize coarsened adjacency matrix.\n    ind = torch.arange(k, device=out_adj.device)\n    out_adj[:, ind, ind] = 0\n    d = torch.einsum('ijk->ij', out_adj)\n    d = torch.sqrt(d)[:, None] + EPS\n    out_adj = (out_adj / d) / d.transpose(1, 2)\n\n    return out, out_adj, mincut_loss, ortho_loss\n\n\ndef _rank3_trace(x: Tensor) -> Tensor:\n    return torch.einsum('ijj->i', x)\n\n\ndef _rank3_diag(x: Tensor) -> Tensor:\n    eye = torch.eye(x.size(1)).type_as(x)\n    out = eye * x.unsqueeze(2).expand(x.size(0), x.size(1), x.size(1))\n\n    return out\n"
  },
  {
    "path": "torch_geometric/nn/encoding.py",
    "content": "import math\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\n__all__ = classes = [\n    'PositionalEncoding',\n    'TemporalEncoding',\n]\n\n\nclass PositionalEncoding(torch.nn.Module):\n    r\"\"\"The positional encoding scheme from the `\"Attention Is All You Need\"\n    <https://arxiv.org/abs/1706.03762>`_ paper.\n\n    .. math::\n\n        PE(x)_{2 \\cdot i} &= \\sin(x / 10000^{2 \\cdot i / d})\n\n        PE(x)_{2 \\cdot i + 1} &= \\cos(x / 10000^{2 \\cdot i / d})\n\n    where :math:`x` is the position and :math:`i` is the dimension.\n\n    Args:\n        out_channels (int): Size :math:`d` of each output sample.\n        base_freq (float, optional): The base frequency of sinusoidal\n            functions. (default: :obj:`1e-4`)\n        granularity (float, optional): The granularity of the positions. If\n            set to smaller value, the encoder will capture more fine-grained\n            changes in positions. (default: :obj:`1.0`)\n        device (torch.device, optional): The device of the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        out_channels: int,\n        base_freq: float = 1e-4,\n        granularity: float = 1.0,\n        device: Optional[torch.device] = None,\n    ):\n        super().__init__()\n\n        if out_channels % 2 != 0:\n            raise ValueError(f\"Cannot use sinusoidal positional encoding with \"\n                             f\"odd 'out_channels' (got {out_channels}).\")\n\n        self.out_channels = out_channels\n        self.base_freq = base_freq\n        self.granularity = granularity\n\n        frequency = torch.logspace(0, 1, out_channels // 2, base_freq,\n                                   device=device)\n        self.register_buffer('frequency', frequency)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        pass\n\n    def forward(self, x: Tensor) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        x = x / self.granularity if self.granularity != 1.0 else x\n        out = x.view(-1, 1) * self.frequency.view(1, -1)\n        return torch.cat([torch.sin(out), torch.cos(out)], dim=-1)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.out_channels})'\n\n\nclass TemporalEncoding(torch.nn.Module):\n    r\"\"\"The time-encoding function from the `\"Do We Really Need Complicated\n    Model Architectures for Temporal Networks?\"\n    <https://openreview.net/forum?id=ayPPc0SyLv1>`_ paper.\n\n    It first maps each entry to a vector with exponentially decreasing values,\n    and then uses the cosine function to project all values to range\n    :math:`[-1, 1]`.\n\n    .. math::\n        y_{i} = \\cos \\left(x \\cdot \\sqrt{d}^{-(i - 1)/\\sqrt{d}} \\right)\n\n    where :math:`d` defines the output feature dimension, and\n    :math:`1 \\leq i \\leq d`.\n\n    Args:\n        out_channels (int): Size :math:`d` of each output sample.\n        device (torch.device, optional): The device of the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, out_channels: int,\n                 device: Optional[torch.device] = None):\n        super().__init__()\n        self.out_channels = out_channels\n\n        sqrt = math.sqrt(out_channels)\n        weight = 1.0 / sqrt**torch.linspace(0, sqrt, out_channels,\n                                            device=device).view(1, -1)\n        self.register_buffer('weight', weight)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        pass\n\n    def forward(self, x: Tensor) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        return torch.cos(x.view(-1, 1) @ self.weight)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.out_channels})'\n"
  },
  {
    "path": "torch_geometric/nn/functional/__init__.py",
    "content": "r\"\"\"Functional operator package.\"\"\"\n\nfrom .bro import bro\nfrom .gini import gini\n\n__all__ = classes = [\n    'bro',\n    'gini',\n]\n"
  },
  {
    "path": "torch_geometric/nn/functional/bro.py",
    "content": "from typing import Union\n\nimport torch\n\n\ndef bro(\n    x: torch.Tensor,\n    batch: torch.Tensor,\n    p: Union[int, str] = 2,\n) -> torch.Tensor:\n    r\"\"\"The Batch Representation Orthogonality penalty from the `\"Improving\n    Molecular Graph Neural Network Explainability with Orthonormalization\n    and Induced Sparsity\" <https://arxiv.org/abs/2105.04854>`_ paper.\n\n    Computes a regularization for each graph representation in a mini-batch\n    according to\n\n    .. math::\n        \\mathcal{L}_{\\textrm{BRO}}^\\mathrm{graph} =\n          || \\mathbf{HH}^T - \\mathbf{I}||_p\n\n    and returns an average over all graphs in the batch.\n\n    Args:\n        x (torch.Tensor): The node feature matrix.\n        batch (torch.Tensor): The batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n            each node to a specific example.\n        p (int or str, optional): The norm order to use. (default: :obj:`2`)\n    \"\"\"\n    _, counts = torch.unique(batch, return_counts=True)\n    diags = torch.stack([\n        torch.diag(x) for x in torch.nn.utils.rnn.pad_sequence(\n            sequences=torch.ones_like(batch).split_with_sizes(counts.tolist()),\n            padding_value=0.,\n            batch_first=True,\n        )\n    ])\n    x = x.split_with_sizes(split_sizes=counts.tolist())\n    x = torch.nn.utils.rnn.pad_sequence(\n        sequences=x,\n        padding_value=0.,\n        batch_first=True,\n    )\n    return torch.sum(torch.norm(x @ x.transpose(1, 2) - diags, p=p,\n                                dim=(1, 2))) / counts.shape[0]\n"
  },
  {
    "path": "torch_geometric/nn/functional/gini.py",
    "content": "import torch\n\n\ndef gini(w: torch.Tensor) -> torch.Tensor:\n    r\"\"\"The Gini coefficient from the `\"Improving Molecular Graph Neural\n    Network Explainability with Orthonormalization and Induced Sparsity\"\n    <https://arxiv.org/abs/2105.04854>`_ paper.\n\n    Computes a regularization penalty :math:`\\in [0, 1]` for each row of a\n    matrix according to\n\n    .. math::\n        \\mathcal{L}_\\textrm{Gini}^i = \\sum_j^n \\sum_{j'}^n \\frac{|w_{ij}\n         - w_{ij'}|}{2 (n^2 - n)\\bar{w_i}}\n\n    and returns an average over all rows.\n\n    Args:\n        w (torch.Tensor): A two-dimensional tensor.\n    \"\"\"\n    s = 0\n    for row in w:\n        t = row.repeat(row.size(0), 1)\n        u = (t - t.T).abs().sum() / (2 * (row.size(-1)**2 - row.size(-1)) *\n                                     row.abs().mean() + torch.finfo().eps)\n        s += u\n    s /= w.shape[0]\n    return s\n"
  },
  {
    "path": "torch_geometric/nn/fx.py",
    "content": "import copy\nimport warnings\nfrom typing import Any, Callable, Dict, List, Optional, Type, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Module, ModuleDict, ModuleList, Sequential\n\ntry:\n    from torch.fx import Graph, GraphModule, Node\nexcept (ImportError, ModuleNotFoundError, AttributeError):\n    GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node'\n\n\nclass Transformer:\n    r\"\"\"A :class:`Transformer` executes an FX graph node-by-node, applies\n    transformations to each node, and produces a new :class:`torch.nn.Module`.\n    It exposes a :func:`transform` method that returns the transformed\n    :class:`~torch.nn.Module`.\n    :class:`Transformer` works entirely symbolically.\n\n    Methods in the :class:`Transformer` class can be overridden to customize\n    the behavior of transformation.\n\n    .. code-block:: none\n\n        transform()\n            +-- Iterate over each node in the graph\n                +-- placeholder()\n                +-- get_attr()\n                +-- call_function()\n                +-- call_method()\n                +-- call_module()\n                +-- call_message_passing_module()\n                +-- call_global_pooling_module()\n                +-- output()\n            +-- Erase unused nodes in the graph\n            +-- Iterate over each children module\n                +-- init_submodule()\n\n    In contrast to the :class:`torch.fx.Transformer` class, the\n    :class:`Transformer` exposes additional functionality:\n\n    #. It subdivides :func:`call_module` into nodes that call a regular\n       :class:`torch.nn.Module` (:func:`call_module`), a\n       :class:`MessagePassing` module (:func:`call_message_passing_module`),\n       or a :class:`GlobalPooling` module (:func:`call_global_pooling_module`).\n\n    #. It allows to customize or initialize new children modules via\n       :func:`init_submodule`\n\n    #. It allows to infer whether a node returns node-level or edge-level\n       information via :meth:`is_edge_level`.\n\n    Args:\n        module (torch.nn.Module): The module to be transformed.\n        input_map (Dict[str, str], optional): A dictionary holding information\n            about the type of input arguments of :obj:`module.forward`.\n            For example, in case :obj:`arg` is a node-level argument, then\n            :obj:`input_map['arg'] = 'node'`, and\n            :obj:`input_map['arg'] = 'edge'` otherwise.\n            In case :obj:`input_map` is not further specified, will try to\n            automatically determine the correct type of input arguments.\n            (default: :obj:`None`)\n        debug (bool, optional): If set to :obj:`True`, will perform\n            transformation in debug mode. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        module: Module,\n        input_map: Optional[Dict[str, str]] = None,\n        debug: bool = False,\n    ):\n        self.module = module\n        self.gm = symbolic_trace(module)\n        self.input_map = input_map\n        self.debug = debug\n\n    # Methods to override #####################################################\n\n    def placeholder(self, node: Node, target: Any, name: str):\n        pass\n\n    def get_attr(self, node: Node, target: Any, name: str):\n        pass\n\n    def call_message_passing_module(self, node: Node, target: Any, name: str):\n        pass\n\n    def call_global_pooling_module(self, node: Node, target: Any, name: str):\n        pass\n\n    def call_module(self, node: Node, target: Any, name: str):\n        pass\n\n    def call_method(self, node: Node, target: Any, name: str):\n        pass\n\n    def call_function(self, node: Node, target: Any, name: str):\n        pass\n\n    def output(self, node: Node, target: Any, name: str):\n        pass\n\n    def init_submodule(self, module: Module, target: str) -> Module:\n        return module\n\n    # Internal functionality ##################################################\n\n    @property\n    def graph(self) -> Graph:\n        return self.gm.graph\n\n    def transform(self) -> GraphModule:\n        r\"\"\"Transforms :obj:`self.module` and returns a transformed\n        :class:`torch.fx.GraphModule`.\n        \"\"\"\n        if self.debug:\n            self.graph.print_tabular()\n            print()\n            code = self.graph.python_code('self')\n            print(code.src if hasattr(code, 'src') else code)\n\n        # We create a private dictionary `self._state` which holds information\n        # about whether a node returns node-level or edge-level information:\n        # `self._state[node.name] in { 'node', 'edge' }`\n        self._state = copy.copy(self.input_map or {})\n\n        # We iterate over each node and determine its output level\n        # (node-level, edge-level) by filling `self._state`:\n        for node in list(self.graph.nodes):\n            if node.op == 'call_function' and 'training' in node.kwargs:\n                warnings.warn(\n                    f\"Found function '{node.name}' with keyword \"\n                    f\"argument 'training'. During FX tracing, this \"\n                    f\"will likely be baked in as a constant value. \"\n                    f\"Consider replacing this function by a module \"\n                    f\"to properly encapsulate its training flag.\",\n                    stacklevel=2)\n\n            if node.op == 'placeholder':\n                if node.name not in self._state:\n                    if 'edge' in node.name or 'adj' in node.name:\n                        self._state[node.name] = 'edge'\n                    else:\n                        self._state[node.name] = 'node'\n            elif is_message_passing_op(self.module, node.op, node.target):\n                self._state[node.name] = 'node'\n            elif is_global_pooling_op(self.module, node.op, node.target):\n                self._state[node.name] = 'graph'\n            elif node.op in ['call_module', 'call_method', 'call_function']:\n                if self.has_edge_level_arg(node):\n                    self._state[node.name] = 'edge'\n                elif self.has_node_level_arg(node):\n                    self._state[node.name] = 'node'\n                else:\n                    self._state[node.name] = 'graph'\n\n        # We iterate over each node and may transform it:\n        for node in list(self.graph.nodes):\n            # Call the corresponding `Transformer` method for each `node.op`,\n            # e.g.: `call_module(...)`, `call_function(...)`, ...\n            op = node.op\n            if is_message_passing_op(self.module, op, node.target):\n                op = 'call_message_passing_module'\n            elif is_global_pooling_op(self.module, op, node.target):\n                op = 'call_global_pooling_module'\n            getattr(self, op)(node, node.target, node.name)\n\n        # Remove all unused nodes in the computation graph, i.e., all nodes\n        # which have been replaced by node type-wise or edge type-wise variants\n        # but which are still present in the computation graph.\n        # We do this by iterating over the computation graph in reversed order,\n        # and try to remove every node. This does only succeed in case there\n        # are no users of that node left in the computation graph.\n        for node in reversed(list(self.graph.nodes)):\n            try:\n                if node.op not in ['placeholder', 'output']:\n                    self.graph.erase_node(node)\n            except RuntimeError:\n                pass\n\n        for target, submodule in dict(self.module._modules).items():\n            self.gm._modules[target] = self._init_submodule(submodule, target)\n\n        del self._state\n\n        if self.debug:\n            self.gm.graph.print_tabular()\n            print()\n            code = self.graph.python_code('self')\n            print(code.src if hasattr(code, 'src') else code)\n\n        self.gm.graph.lint()\n        self.gm.recompile()\n\n        return self.gm\n\n    def _init_submodule(self, module: Module, target: str) -> Module:\n        if isinstance(module, ModuleList) or isinstance(module, Sequential):\n            return ModuleList([\n                self._init_submodule(submodule, f'{target}.{i}')\n                for i, submodule in enumerate(module)\n            ])\n        elif isinstance(module, ModuleDict):\n            return ModuleDict({\n                key:\n                self._init_submodule(submodule, f'{target}.{key}')\n                for key, submodule in module.items()\n            })\n        else:\n            return self.init_submodule(module, target)\n\n    def _is_level(self, node: Node, name: str) -> bool:\n        return self._state[node.name] == name\n\n    def _has_level_arg(self, node: Node, name: str) -> bool:\n        def _recurse(value: Any) -> bool:\n            if isinstance(value, Node):\n                return getattr(self, f'is_{name}_level')(value)\n            elif isinstance(value, dict):\n                return any([_recurse(v) for v in value.values()])\n            elif isinstance(value, (list, tuple)):\n                return any([_recurse(v) for v in value])\n            else:\n                return False\n\n        return (any([_recurse(value) for value in node.args])\n                or any([_recurse(value) for value in node.kwargs.values()]))\n\n    def is_node_level(self, node: Node) -> bool:\n        return self._is_level(node, name='node')\n\n    def is_edge_level(self, node: Node) -> bool:\n        return self._is_level(node, name='edge')\n\n    def is_graph_level(self, node: Node) -> bool:\n        return self._is_level(node, name='graph')\n\n    def has_node_level_arg(self, node: Node) -> bool:\n        return self._has_level_arg(node, name='node')\n\n    def has_edge_level_arg(self, node: Node) -> bool:\n        return self._has_level_arg(node, name='edge')\n\n    def has_graph_level_arg(self, node: Node) -> bool:\n        return self._has_level_arg(node, name='graph')\n\n    def find_by_name(self, name: str) -> Optional[Node]:\n        for node in self.graph.nodes:\n            if node.name == name:\n                return node\n        return None\n\n    def find_by_target(self, target: Any) -> Optional[Node]:\n        for node in self.graph.nodes:\n            if node.target == target:\n                return node\n        return None\n\n    def replace_all_uses_with(self, to_replace: Node, replace_with: Node):\n        def maybe_replace_node(n: Node) -> Node:\n            return replace_with if n == to_replace else n\n\n        node = replace_with.next\n        while node.op != 'root':\n            node.args = torch.fx.map_arg(node.args, maybe_replace_node)\n            node.kwargs = torch.fx.map_arg(node.kwargs, maybe_replace_node)\n            node = node.next\n\n\ndef symbolic_trace(\n        module: Module,\n        concrete_args: Optional[Dict[str, Any]] = None) -> GraphModule:\n\n    # This is to support compatibility with pytorch version 1.9 and lower\n    try:\n        import torch.fx._symbolic_trace as st\n    except (ImportError, ModuleNotFoundError):\n        import torch.fx.symbolic_trace as st\n\n    from torch_geometric.nn import Aggregation\n\n    class Tracer(torch.fx.Tracer):\n        def is_leaf_module(self, module: Module, *args, **kwargs) -> bool:\n            # TODO We currently only trace top-level modules.\n            return not isinstance(module, torch.nn.Sequential)\n\n        # Note: This is a hack around the fact that `Aggregation.__call__`\n        # is not patched by the base implementation of `trace`.\n        # see https://github.com/pyg-team/pytorch_geometric/pull/5021 for\n        # details on the rationale\n        # TODO: Revisit https://github.com/pyg-team/pytorch_geometric/pull/5021\n        @st.compatibility(is_backward_compatible=True)\n        def trace(self, root: Union[torch.nn.Module, Callable[..., Any]],\n                  concrete_args: Optional[Dict[str, Any]] = None) -> Graph:\n\n            if isinstance(root, torch.nn.Module):\n                self.root = root\n                fn = type(root).forward\n                self.submodule_paths = {\n                    mod: name\n                    for name, mod in root.named_modules()\n                }\n            else:\n                self.root = torch.nn.Module()\n                fn = root\n\n            tracer_cls: Optional[Type['Tracer']] = getattr(\n                self, '__class__', None)\n            self.graph = Graph(tracer_cls=tracer_cls)\n\n            self.tensor_attrs: Dict[Union[Tensor, st.ScriptObject], str] = {}\n\n            def collect_tensor_attrs(m: torch.nn.Module,\n                                     prefix_atoms: List[str]):\n                for k, v in m.__dict__.items():\n                    if isinstance(v, (Tensor, st.ScriptObject)):\n                        self.tensor_attrs[v] = '.'.join(prefix_atoms + [k])\n                for k, v in m.named_children():\n                    collect_tensor_attrs(v, prefix_atoms + [k])\n\n            collect_tensor_attrs(self.root, [])\n\n            assert isinstance(fn, st.FunctionType)\n\n            fn_globals = fn.__globals__  # run before it gets patched\n            fn, args = self.create_args_for_root(\n                fn, isinstance(root, torch.nn.Module), concrete_args)\n\n            parameter_proxy_cache: Dict[str, st.Proxy] = {\n            }  # Reduce number of get_attr calls\n\n            @st.functools.wraps(st._orig_module_getattr)\n            def module_getattr_wrapper(mod, attr):\n                attr_val = st._orig_module_getattr(mod, attr)\n                # Support for PyTorch > 1.12, see:\n                # https://github.com/pytorch/pytorch/pull/84011\n                if hasattr(self, 'getattr'):\n                    return self.getattr(attr, attr_val, parameter_proxy_cache)\n                return self._module_getattr(attr, attr_val,\n                                            parameter_proxy_cache)\n\n            @st.functools.wraps(st._orig_module_call)\n            def module_call_wrapper(mod, *args, **kwargs):\n                def forward(*args, **kwargs):\n                    return st._orig_module_call(mod, *args, **kwargs)\n\n                st._autowrap_check(\n                    patcher,\n                    getattr(getattr(mod, \"forward\", mod), \"__globals__\", {}),\n                    self._autowrap_function_ids)\n                return self.call_module(mod, forward, args, kwargs)\n\n            with st._Patcher() as patcher:\n                # allow duplicate patches to support the case of nested calls\n                patcher.patch_method(torch.nn.Module, \"__getattr__\",\n                                     module_getattr_wrapper, deduplicate=False)\n                patcher.patch_method(torch.nn.Module, \"__call__\",\n                                     module_call_wrapper, deduplicate=False)\n                patcher.patch_method(Aggregation, \"__call__\",\n                                     module_call_wrapper, deduplicate=False)\n                st._patch_wrapped_functions(patcher)\n                st._autowrap_check(patcher, fn_globals,\n                                   self._autowrap_function_ids)\n                for module in self._autowrap_search:\n                    st._autowrap_check(patcher, module.__dict__,\n                                       self._autowrap_function_ids)\n                self.create_node(\n                    'output', 'output', (self.create_arg(fn(*args)), ), {},\n                    type_expr=fn.__annotations__.get('return', None))\n\n            self.submodule_paths = None\n\n            return self.graph\n\n    return GraphModule(module, Tracer().trace(module, concrete_args))\n\n\ndef get_submodule(module: Module, target: str) -> Module:\n    out = module\n    for attr in target.split('.'):\n        out = getattr(out, attr)\n    return out\n\n\ndef is_message_passing_op(module: Module, op: str, target: str) -> bool:\n    from torch_geometric.nn import MessagePassing\n    if op == 'call_module':\n        return isinstance(get_submodule(module, target), MessagePassing)\n    return False\n\n\ndef is_global_pooling_op(module: Module, op: str, target: str) -> bool:\n    from torch_geometric.nn import Aggregation\n    if op == 'call_module':\n        return isinstance(get_submodule(module, target), Aggregation)\n    return False\n"
  },
  {
    "path": "torch_geometric/nn/glob.py",
    "content": "from torch_geometric.deprecation import deprecated\nfrom torch_geometric.nn import (\n    global_add_pool,\n    global_max_pool,\n    global_mean_pool,\n)\nfrom torch_geometric.nn.aggr import AttentionalAggregation, SortAggregation\n\n\n@deprecated(\n    details=\"use 'nn.aggr.AttentionalAggregation' instead\",\n    func_name='nn.glob.GlobalAttention',\n)\nclass GlobalAttention(AttentionalAggregation):\n    def __call__(self, x, batch=None, size=None):\n        return super().__call__(x, batch, dim_size=size)\n\n\n@deprecated(\n    details=\"use 'nn.aggr.SortAggr' instead\",\n    func_name='nn.glob.global_sort_pool',\n)\ndef global_sort_pool(x, index, k):\n    module = SortAggregation(k=k)\n    return module(x, index=index)\n\n\ndeprecated(\n    details=\"use 'nn.pool.global_add_pool' instead\",\n    func_name='nn.glob.global_add_pool',\n)(global_add_pool)\n\ndeprecated(\n    details=\"use 'nn.pool.global_max_pool' instead\",\n    func_name='nn.glob.global_max_pool',\n)(global_max_pool)\n\ndeprecated(\n    details=\"use 'nn.pool.global_mean_pool' instead\",\n    func_name='nn.glob.global_mean_pool',\n)(global_mean_pool)\n"
  },
  {
    "path": "torch_geometric/nn/inits.py",
    "content": "import math\nfrom typing import Any\n\nimport torch\nfrom torch import Tensor\n\n\ndef uniform(size: int, value: Any):\n    if isinstance(value, Tensor):\n        bound = 1.0 / math.sqrt(size)\n        value.data.uniform_(-bound, bound)\n    else:\n        for v in value.parameters() if hasattr(value, 'parameters') else []:\n            uniform(size, v)\n        for v in value.buffers() if hasattr(value, 'buffers') else []:\n            uniform(size, v)\n\n\ndef kaiming_uniform(value: Any, fan: int, a: float):\n    if isinstance(value, Tensor):\n        bound = math.sqrt(6 / ((1 + a**2) * fan))\n        value.data.uniform_(-bound, bound)\n    else:\n        for v in value.parameters() if hasattr(value, 'parameters') else []:\n            kaiming_uniform(v, fan, a)\n        for v in value.buffers() if hasattr(value, 'buffers') else []:\n            kaiming_uniform(v, fan, a)\n\n\ndef glorot(value: Any):\n    if isinstance(value, Tensor):\n        stdv = math.sqrt(6.0 / (value.size(-2) + value.size(-1)))\n        value.data.uniform_(-stdv, stdv)\n    else:\n        for v in value.parameters() if hasattr(value, 'parameters') else []:\n            glorot(v)\n        for v in value.buffers() if hasattr(value, 'buffers') else []:\n            glorot(v)\n\n\ndef glorot_orthogonal(tensor, scale):\n    if tensor is not None:\n        torch.nn.init.orthogonal_(tensor.data)\n        scale /= ((tensor.size(-2) + tensor.size(-1)) * tensor.var())\n        tensor.data *= scale.sqrt()\n\n\ndef constant(value: Any, fill_value: float):\n    if isinstance(value, Tensor):\n        value.data.fill_(fill_value)\n    else:\n        for v in value.parameters() if hasattr(value, 'parameters') else []:\n            constant(v, fill_value)\n        for v in value.buffers() if hasattr(value, 'buffers') else []:\n            constant(v, fill_value)\n\n\ndef zeros(value: Any):\n    constant(value, 0.)\n\n\ndef ones(tensor: Any):\n    constant(tensor, 1.)\n\n\ndef normal(value: Any, mean: float, std: float):\n    if isinstance(value, Tensor):\n        value.data.normal_(mean, std)\n    else:\n        for v in value.parameters() if hasattr(value, 'parameters') else []:\n            normal(v, mean, std)\n        for v in value.buffers() if hasattr(value, 'buffers') else []:\n            normal(v, mean, std)\n\n\ndef reset(value: Any):\n    if hasattr(value, 'reset_parameters'):\n        value.reset_parameters()\n    else:\n        for child in value.children() if hasattr(value, 'children') else []:\n            reset(child)\n"
  },
  {
    "path": "torch_geometric/nn/kge/__init__.py",
    "content": "r\"\"\"Knowledge Graph Embedding (KGE) package.\"\"\"\n\nfrom .base import KGEModel\nfrom .transe import TransE\nfrom .complex import ComplEx\nfrom .distmult import DistMult\nfrom .rotate import RotatE\n\n__all__ = classes = [\n    'KGEModel',\n    'TransE',\n    'ComplEx',\n    'DistMult',\n    'RotatE',\n]\n"
  },
  {
    "path": "torch_geometric/nn/kge/base.py",
    "content": "from typing import Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Embedding\nfrom tqdm import tqdm\n\nfrom torch_geometric.nn.kge.loader import KGTripletLoader\n\n\nclass KGEModel(torch.nn.Module):\n    r\"\"\"An abstract base class for implementing custom KGE models.\n\n    Args:\n        num_nodes (int): The number of nodes/entities in the graph.\n        num_relations (int): The number of relations in the graph.\n        hidden_channels (int): The hidden embedding size.\n        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the\n            embedding matrices will be sparse. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        num_nodes: int,\n        num_relations: int,\n        hidden_channels: int,\n        sparse: bool = False,\n    ):\n        super().__init__()\n\n        self.num_nodes = num_nodes\n        self.num_relations = num_relations\n        self.hidden_channels = hidden_channels\n\n        self.node_emb = Embedding(num_nodes, hidden_channels, sparse=sparse)\n        self.rel_emb = Embedding(num_relations, hidden_channels, sparse=sparse)\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.node_emb.reset_parameters()\n        self.rel_emb.reset_parameters()\n\n    def forward(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tensor:\n        r\"\"\"Returns the score for the given triplet.\n\n        Args:\n            head_index (torch.Tensor): The head indices.\n            rel_type (torch.Tensor): The relation type.\n            tail_index (torch.Tensor): The tail indices.\n        \"\"\"\n        raise NotImplementedError\n\n    def loss(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tensor:\n        r\"\"\"Returns the loss value for the given triplet.\n\n        Args:\n            head_index (torch.Tensor): The head indices.\n            rel_type (torch.Tensor): The relation type.\n            tail_index (torch.Tensor): The tail indices.\n        \"\"\"\n        raise NotImplementedError\n\n    def loader(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n        **kwargs,\n    ) -> Tensor:\n        r\"\"\"Returns a mini-batch loader that samples a subset of triplets.\n\n        Args:\n            head_index (torch.Tensor): The head indices.\n            rel_type (torch.Tensor): The relation type.\n            tail_index (torch.Tensor): The tail indices.\n            **kwargs (optional): Additional arguments of\n                :class:`torch.utils.data.DataLoader`, such as\n                :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last`\n                or :obj:`num_workers`.\n        \"\"\"\n        return KGTripletLoader(head_index, rel_type, tail_index, **kwargs)\n\n    @torch.no_grad()\n    def test(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n        batch_size: int,\n        k: int = 10,\n        log: bool = True,\n    ) -> Tuple[float, float, float]:\n        r\"\"\"Evaluates the model quality by computing Mean Rank, MRR and\n        Hits@:math:`k` across all possible tail entities.\n\n        Args:\n            head_index (torch.Tensor): The head indices.\n            rel_type (torch.Tensor): The relation type.\n            tail_index (torch.Tensor): The tail indices.\n            batch_size (int): The batch size to use for evaluating.\n            k (int, optional): The :math:`k` in Hits @ :math:`k`.\n                (default: :obj:`10`)\n            log (bool, optional): If set to :obj:`False`, will not print a\n                progress bar to the console. (default: :obj:`True`)\n        \"\"\"\n        arange = range(head_index.numel())\n        arange = tqdm(arange) if log else arange\n\n        mean_ranks, reciprocal_ranks, hits_at_k = [], [], []\n        for i in arange:\n            h, r, t = head_index[i], rel_type[i], tail_index[i]\n\n            scores = []\n            tail_indices = torch.arange(self.num_nodes, device=t.device)\n            for ts in tail_indices.split(batch_size):\n                scores.append(self(h.expand_as(ts), r.expand_as(ts), ts))\n            rank = int((torch.cat(scores).argsort(\n                descending=True) == t).nonzero().view(-1))\n            mean_ranks.append(rank)\n            reciprocal_ranks.append(1 / (rank + 1))\n            hits_at_k.append(rank < k)\n\n        mean_rank = float(torch.tensor(mean_ranks, dtype=torch.float).mean())\n        mrr = float(torch.tensor(reciprocal_ranks, dtype=torch.float).mean())\n        hits_at_k = int(torch.tensor(hits_at_k).sum()) / len(hits_at_k)\n\n        return mean_rank, mrr, hits_at_k\n\n    @torch.no_grad()\n    def random_sample(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tuple[Tensor, Tensor, Tensor]:\n        r\"\"\"Randomly samples negative triplets by either replacing the head or\n        the tail (but not both).\n\n        Args:\n            head_index (torch.Tensor): The head indices.\n            rel_type (torch.Tensor): The relation type.\n            tail_index (torch.Tensor): The tail indices.\n        \"\"\"\n        # Random sample either `head_index` or `tail_index` (but not both):\n        num_negatives = head_index.numel() // 2\n        rnd_index = torch.randint(self.num_nodes, head_index.size(),\n                                  device=head_index.device)\n\n        head_index = head_index.clone()\n        head_index[:num_negatives] = rnd_index[:num_negatives]\n        tail_index = tail_index.clone()\n        tail_index[num_negatives:] = rnd_index[num_negatives:]\n\n        return head_index, rel_type, tail_index\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.num_nodes}, '\n                f'num_relations={self.num_relations}, '\n                f'hidden_channels={self.hidden_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/kge/complex.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Embedding\n\nfrom torch_geometric.nn.kge import KGEModel\n\n\nclass ComplEx(KGEModel):\n    r\"\"\"The ComplEx model from the `\"Complex Embeddings for Simple Link\n    Prediction\" <https://arxiv.org/abs/1606.06357>`_ paper.\n\n    :class:`ComplEx` models relations as complex-valued bilinear mappings\n    between head and tail entities using the Hermetian dot product.\n    The entities and relations are embedded in different dimensional spaces,\n    resulting in the scoring function:\n\n    .. math::\n        d(h, r, t) = Re(< \\mathbf{e}_h,  \\mathbf{e}_r, \\mathbf{e}_t>)\n\n    .. note::\n\n        For an example of using the :class:`ComplEx` model, see\n        `examples/kge_fb15k_237.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        kge_fb15k_237.py>`_.\n\n    Args:\n        num_nodes (int): The number of nodes/entities in the graph.\n        num_relations (int): The number of relations in the graph.\n        hidden_channels (int): The hidden embedding size.\n        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to\n            the embedding matrices will be sparse. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        num_nodes: int,\n        num_relations: int,\n        hidden_channels: int,\n        sparse: bool = False,\n    ):\n        super().__init__(num_nodes, num_relations, hidden_channels, sparse)\n\n        self.node_emb_im = Embedding(num_nodes, hidden_channels, sparse=sparse)\n        self.rel_emb_im = Embedding(num_relations, hidden_channels,\n                                    sparse=sparse)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.node_emb.weight)\n        torch.nn.init.xavier_uniform_(self.node_emb_im.weight)\n        torch.nn.init.xavier_uniform_(self.rel_emb.weight)\n        torch.nn.init.xavier_uniform_(self.rel_emb_im.weight)\n\n    def forward(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tensor:\n\n        head_re = self.node_emb(head_index)\n        head_im = self.node_emb_im(head_index)\n        rel_re = self.rel_emb(rel_type)\n        rel_im = self.rel_emb_im(rel_type)\n        tail_re = self.node_emb(tail_index)\n        tail_im = self.node_emb_im(tail_index)\n\n        return (triple_dot(head_re, rel_re, tail_re) +\n                triple_dot(head_im, rel_re, tail_im) +\n                triple_dot(head_re, rel_im, tail_im) -\n                triple_dot(head_im, rel_im, tail_re))\n\n    def loss(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tensor:\n\n        pos_score = self(head_index, rel_type, tail_index)\n        neg_score = self(*self.random_sample(head_index, rel_type, tail_index))\n        scores = torch.cat([pos_score, neg_score], dim=0)\n\n        pos_target = torch.ones_like(pos_score)\n        neg_target = torch.zeros_like(neg_score)\n        target = torch.cat([pos_target, neg_target], dim=0)\n\n        return F.binary_cross_entropy_with_logits(scores, target)\n\n\ndef triple_dot(x: Tensor, y: Tensor, z: Tensor) -> Tensor:\n    return (x * y * z).sum(dim=-1)\n"
  },
  {
    "path": "torch_geometric/nn/kge/distmult.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.kge import KGEModel\n\n\nclass DistMult(KGEModel):\n    r\"\"\"The DistMult model from the `\"Embedding Entities and Relations for\n    Learning and Inference in Knowledge Bases\"\n    <https://arxiv.org/abs/1412.6575>`_ paper.\n\n    :class:`DistMult` models relations as diagonal matrices, which simplifies\n    the bi-linear interaction between the head and tail entities to the score\n    function:\n\n    .. math::\n        d(h, r, t) = < \\mathbf{e}_h,  \\mathbf{e}_r, \\mathbf{e}_t >\n\n    .. note::\n\n        For an example of using the :class:`DistMult` model, see\n        `examples/kge_fb15k_237.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        kge_fb15k_237.py>`_.\n\n    Args:\n        num_nodes (int): The number of nodes/entities in the graph.\n        num_relations (int): The number of relations in the graph.\n        hidden_channels (int): The hidden embedding size.\n        margin (float, optional): The margin of the ranking loss.\n            (default: :obj:`1.0`)\n        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to\n            the embedding matrices will be sparse. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        num_nodes: int,\n        num_relations: int,\n        hidden_channels: int,\n        margin: float = 1.0,\n        sparse: bool = False,\n    ):\n        super().__init__(num_nodes, num_relations, hidden_channels, sparse)\n\n        self.margin = margin\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.node_emb.weight)\n        torch.nn.init.xavier_uniform_(self.rel_emb.weight)\n\n    def forward(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tensor:\n\n        head = self.node_emb(head_index)\n        rel = self.rel_emb(rel_type)\n        tail = self.node_emb(tail_index)\n\n        return (head * rel * tail).sum(dim=-1)\n\n    def loss(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tensor:\n\n        pos_score = self(head_index, rel_type, tail_index)\n        neg_score = self(*self.random_sample(head_index, rel_type, tail_index))\n\n        return F.margin_ranking_loss(\n            pos_score,\n            neg_score,\n            target=torch.ones_like(pos_score),\n            margin=self.margin,\n        )\n"
  },
  {
    "path": "torch_geometric/nn/kge/loader.py",
    "content": "from typing import List, Tuple\n\nimport torch\nfrom torch import Tensor\n\n\nclass KGTripletLoader(torch.utils.data.DataLoader):\n    def __init__(self, head_index: Tensor, rel_type: Tensor,\n                 tail_index: Tensor, **kwargs):\n        self.head_index = head_index\n        self.rel_type = rel_type\n        self.tail_index = tail_index\n\n        super().__init__(range(head_index.numel()), collate_fn=self.sample,\n                         **kwargs)\n\n    def sample(self, index: List[int]) -> Tuple[Tensor, Tensor, Tensor]:\n        index = torch.tensor(index, device=self.head_index.device)\n\n        head_index = self.head_index[index]\n        rel_type = self.rel_type[index]\n        tail_index = self.tail_index[index]\n\n        return head_index, rel_type, tail_index\n"
  },
  {
    "path": "torch_geometric/nn/kge/rotate.py",
    "content": "import math\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Embedding\n\nfrom torch_geometric.nn.kge import KGEModel\n\n\nclass RotatE(KGEModel):\n    r\"\"\"The RotatE model from the `\"RotatE: Knowledge Graph Embedding by\n    Relational Rotation in Complex Space\" <https://arxiv.org/abs/\n    1902.10197>`_ paper.\n\n    :class:`RotatE` models relations as a rotation in complex space\n    from head to tail such that\n\n    .. math::\n        \\mathbf{e}_t = \\mathbf{e}_h \\circ \\mathbf{e}_r,\n\n    resulting in the scoring function\n\n    .. math::\n        d(h, r, t) = - {\\| \\mathbf{e}_h \\circ \\mathbf{e}_r - \\mathbf{e}_t \\|}_p\n\n    .. note::\n\n        For an example of using the :class:`RotatE` model, see\n        `examples/kge_fb15k_237.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        kge_fb15k_237.py>`_.\n\n    Args:\n        num_nodes (int): The number of nodes/entities in the graph.\n        num_relations (int): The number of relations in the graph.\n        hidden_channels (int): The hidden embedding size.\n        margin (float, optional): The margin of the ranking loss.\n        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to\n            the embedding matrices will be sparse. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        num_nodes: int,\n        num_relations: int,\n        hidden_channels: int,\n        margin: float = 1.0,\n        sparse: bool = False,\n    ):\n        super().__init__(num_nodes, num_relations, hidden_channels, sparse)\n\n        self.margin = margin\n        self.node_emb_im = Embedding(num_nodes, hidden_channels, sparse=sparse)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.node_emb.weight)\n        torch.nn.init.xavier_uniform_(self.node_emb_im.weight)\n        torch.nn.init.uniform_(self.rel_emb.weight, 0, 2 * math.pi)\n\n    def forward(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tensor:\n\n        head_re = self.node_emb(head_index)\n        head_im = self.node_emb_im(head_index)\n        tail_re = self.node_emb(tail_index)\n        tail_im = self.node_emb_im(tail_index)\n\n        rel_theta = self.rel_emb(rel_type)\n        rel_re, rel_im = torch.cos(rel_theta), torch.sin(rel_theta)\n\n        re_score = (rel_re * head_re - rel_im * head_im) - tail_re\n        im_score = (rel_re * head_im + rel_im * head_re) - tail_im\n        complex_score = torch.stack([re_score, im_score], dim=2)\n        score = torch.linalg.vector_norm(complex_score, dim=(1, 2))\n\n        return self.margin - score\n\n    def loss(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tensor:\n\n        pos_score = self(head_index, rel_type, tail_index)\n        neg_score = self(*self.random_sample(head_index, rel_type, tail_index))\n        scores = torch.cat([pos_score, neg_score], dim=0)\n\n        pos_target = torch.ones_like(pos_score)\n        neg_target = torch.zeros_like(neg_score)\n        target = torch.cat([pos_target, neg_target], dim=0)\n\n        return F.binary_cross_entropy_with_logits(scores, target)\n"
  },
  {
    "path": "torch_geometric/nn/kge/transe.py",
    "content": "import math\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.kge import KGEModel\n\n\nclass TransE(KGEModel):\n    r\"\"\"The TransE model from the `\"Translating Embeddings for Modeling\n    Multi-Relational Data\" <https://proceedings.neurips.cc/paper/2013/file/\n    1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf>`_ paper.\n\n    :class:`TransE` models relations as a translation from head to tail\n    entities such that\n\n    .. math::\n        \\mathbf{e}_h + \\mathbf{e}_r \\approx \\mathbf{e}_t,\n\n    resulting in the scoring function:\n\n    .. math::\n        d(h, r, t) = - {\\| \\mathbf{e}_h + \\mathbf{e}_r - \\mathbf{e}_t \\|}_p\n\n    .. note::\n\n        For an example of using the :class:`TransE` model, see\n        `examples/kge_fb15k_237.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        kge_fb15k_237.py>`_.\n\n    Args:\n        num_nodes (int): The number of nodes/entities in the graph.\n        num_relations (int): The number of relations in the graph.\n        hidden_channels (int): The hidden embedding size.\n        margin (int, optional): The margin of the ranking loss.\n            (default: :obj:`1.0`)\n        p_norm (int, optional): The order embedding and distance normalization.\n            (default: :obj:`1.0`)\n        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the\n            embedding matrices will be sparse. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        num_nodes: int,\n        num_relations: int,\n        hidden_channels: int,\n        margin: float = 1.0,\n        p_norm: float = 1.0,\n        sparse: bool = False,\n    ):\n        super().__init__(num_nodes, num_relations, hidden_channels, sparse)\n\n        self.p_norm = p_norm\n        self.margin = margin\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        bound = 6. / math.sqrt(self.hidden_channels)\n        torch.nn.init.uniform_(self.node_emb.weight, -bound, bound)\n        torch.nn.init.uniform_(self.rel_emb.weight, -bound, bound)\n        F.normalize(self.rel_emb.weight.data, p=self.p_norm, dim=-1,\n                    out=self.rel_emb.weight.data)\n\n    def forward(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tensor:\n\n        head = self.node_emb(head_index)\n        rel = self.rel_emb(rel_type)\n        tail = self.node_emb(tail_index)\n\n        head = F.normalize(head, p=self.p_norm, dim=-1)\n        tail = F.normalize(tail, p=self.p_norm, dim=-1)\n\n        # Calculate *negative* TransE norm:\n        return -((head + rel) - tail).norm(p=self.p_norm, dim=-1)\n\n    def loss(\n        self,\n        head_index: Tensor,\n        rel_type: Tensor,\n        tail_index: Tensor,\n    ) -> Tensor:\n\n        pos_score = self(head_index, rel_type, tail_index)\n        neg_score = self(*self.random_sample(head_index, rel_type, tail_index))\n\n        return F.margin_ranking_loss(\n            pos_score,\n            neg_score,\n            target=torch.ones_like(pos_score),\n            margin=self.margin,\n        )\n"
  },
  {
    "path": "torch_geometric/nn/lr_scheduler.py",
    "content": "# See HuggingFace `transformers/optimization.py`.\nimport functools\nimport math\n\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import LambdaLR\n\n\nclass ConstantWithWarmupLR(LambdaLR):\n    r\"\"\"Creates a LR scheduler with a constant learning rate preceded by a\n    warmup period during which the learning rate increases linearly between\n    :obj:`0` and the initial LR set in the optimizer.\n\n    Args:\n        optimizer (Optimizer): The optimizer to be scheduled.\n        num_warmup_steps (int): The number of steps for the warmup phase.\n        last_epoch (int, optional): The index of the last epoch when resuming\n            training. (default: :obj:`-1`)\n    \"\"\"\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        num_warmup_steps: int,\n        last_epoch: int = -1,\n    ):\n        lr_lambda = functools.partial(\n            self._lr_lambda,\n            num_warmup_steps=num_warmup_steps,\n        )\n        super().__init__(optimizer, lr_lambda, last_epoch)\n\n    @staticmethod\n    def _lr_lambda(\n        current_step: int,\n        num_warmup_steps: int,\n    ) -> float:\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1.0, num_warmup_steps))\n        return 1.0\n\n\nclass LinearWithWarmupLR(LambdaLR):\n    r\"\"\"Creates a LR scheduler with a learning rate that decreases linearly\n    from the initial LR set in the optimizer to :obj:`0`, after a warmup period\n    during which it increases linearly from :obj:`0` to the initial LR set in\n    the optimizer.\n\n    Args:\n        optimizer (Optimizer): The optimizer to be scheduled.\n        num_warmup_steps (int): The number of steps for the warmup phase.\n        num_training_steps (int): The total number of training steps.\n        last_epoch (int, optional): The index of the last epoch when resuming\n            training. (default: :obj:`-1`)\n    \"\"\"\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        num_warmup_steps: int,\n        num_training_steps: int,\n        last_epoch: int = -1,\n    ):\n        lr_lambda = functools.partial(\n            self._lr_lambda,\n            num_warmup_steps=num_warmup_steps,\n            num_training_steps=num_training_steps,\n        )\n        super().__init__(optimizer, lr_lambda, last_epoch)\n\n    @staticmethod\n    def _lr_lambda(\n        current_step: int,\n        num_warmup_steps: int,\n        num_training_steps: int,\n    ) -> float:\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        return max(\n            0.0,\n            float(num_training_steps - current_step) /\n            float(max(1, num_training_steps - num_warmup_steps)),\n        )\n\n\nclass CosineWithWarmupLR(LambdaLR):\n    r\"\"\"Creates a LR scheduler with a learning rate that decreases following\n    the values of the cosine function between the initial LR set in the\n    optimizer to :obj:`0`, after a warmup period during which it increases\n    linearly between :obj:`0` and the initial LR set in the optimizer.\n\n    Args:\n        optimizer (Optimizer): The optimizer to be scheduled.\n        num_warmup_steps (int): The number of steps for the warmup phase.\n        num_training_steps (int): The total number of training steps.\n        num_cycles (float, optional): The number of waves in the cosine\n            schedule (the default decreases LR from the max value to :obj:`0`\n            following a half-cosine). (default: :obj:`0.5`)\n        last_epoch (int, optional): The index of the last epoch when resuming\n            training. (default: :obj:`-1`)\n    \"\"\"\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        num_warmup_steps: int,\n        num_training_steps: int,\n        num_cycles: float = 0.5,\n        last_epoch: int = -1,\n    ):\n        lr_lambda = functools.partial(\n            self._lr_lambda,\n            num_warmup_steps=num_warmup_steps,\n            num_training_steps=num_training_steps,\n            num_cycles=num_cycles,\n        )\n        super().__init__(optimizer, lr_lambda, last_epoch)\n\n    @staticmethod\n    def _lr_lambda(\n        current_step: int,\n        num_warmup_steps: int,\n        num_training_steps: int,\n        num_cycles: float,\n    ):\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        progress = float(current_step - num_warmup_steps) / float(\n            max(1, num_training_steps - num_warmup_steps))\n        return max(\n            0.0,\n            0.5 *\n            (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),\n        )\n\n\nclass CosineWithWarmupRestartsLR(LambdaLR):\n    r\"\"\"Creates a LR scheduler with a learning rate that decreases following\n    the values of the cosine function between the initial LR set in the\n    optimizer to :obj:`0`, with several hard restarts, after a warmup period\n    during which it increases linearly between :obj:`0` and the initial LR set\n    in the optimizer.\n\n    Args:\n        optimizer (Optimizer): The optimizer to be scheduled.\n        num_warmup_steps (int): The number of steps for the warmup phase.\n        num_training_steps (int): The total number of training steps.\n        num_cycles (int, optional): The number of hard restarts to use.\n            (default: :obj:`3`)\n        last_epoch (int, optional): The index of the last epoch when resuming\n            training. (default: :obj:`-1`)\n    \"\"\"\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        num_warmup_steps: int,\n        num_training_steps: int,\n        num_cycles: int = 3,\n        last_epoch: int = -1,\n    ):\n        lr_lambda = functools.partial(\n            self._lr_lambda,\n            num_warmup_steps=num_warmup_steps,\n            num_training_steps=num_training_steps,\n            num_cycles=num_cycles,\n        )\n        super().__init__(optimizer, lr_lambda, last_epoch)\n\n    @staticmethod\n    def _lr_lambda(\n        current_step: int,\n        num_warmup_steps: int,\n        num_training_steps: int,\n        num_cycles: int,\n    ) -> float:\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        progress = float(current_step - num_warmup_steps) / float(\n            max(1, num_training_steps - num_warmup_steps))\n        if progress >= 1.0:\n            return 0.0\n        return max(\n            0.0,\n            0.5 * (1.0 + math.cos(math.pi *\n                                  ((float(num_cycles) * progress) % 1.0))),\n        )\n\n\nclass PolynomialWithWarmupLR(LambdaLR):\n    r\"\"\"Creates a LR scheduler with a learning rate that decreases as a\n    polynomial decay from the initial LR set in the optimizer to end LR defined\n    by `lr_end`, after a warmup period during which it increases linearly from\n    :obj:`0` to the initial LR set in the optimizer.\n\n    Args:\n        optimizer (Optimizer): The optimizer to be scheduled.\n        num_warmup_steps (int): The number of steps for the warmup phase.\n        num_training_steps (int): The total number of training steps.\n        lr_end (float, optional): The end learning rate. (default: :obj:`1e-7`)\n        power (float, optional): The power factor of the polynomial decay.\n            (default: :obj:`1.0`)\n        last_epoch (int, optional): The index of the last epoch when resuming\n            training. (default: :obj:`-1`)\n    \"\"\"\n    def __init__(\n        self,\n        optimizer: Optimizer,\n        num_warmup_steps: int,\n        num_training_steps: int,\n        lr_end: float = 1e-7,\n        power: float = 1.0,\n        last_epoch: int = -1,\n    ):\n        lr_init = optimizer.defaults[\"lr\"]\n        if not (lr_init > lr_end):\n            raise ValueError(f\"`lr_end` ({lr_end}) must be smaller than the \"\n                             f\"initial lr ({lr_init})\")\n\n        lr_lambda = functools.partial(\n            self._lr_lambda,\n            num_warmup_steps=num_warmup_steps,\n            num_training_steps=num_training_steps,\n            lr_init=lr_init,\n            lr_end=lr_end,\n            power=power,\n        )\n        super().__init__(optimizer, lr_lambda, last_epoch)\n\n    @staticmethod\n    def _lr_lambda(\n        current_step: int,\n        num_warmup_steps: int,\n        num_training_steps: int,\n        lr_init: float,\n        lr_end: float,\n        power: float,\n    ) -> float:\n        if current_step < num_warmup_steps:\n            return float(current_step) / float(max(1, num_warmup_steps))\n        elif current_step > num_training_steps:\n            return lr_end / lr_init  # As `LambdaLR` multiplies by `lr_init`.\n        else:\n            lr_range = lr_init - lr_end\n            decay_steps = num_training_steps - num_warmup_steps\n            pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps\n            decay = lr_range * pct_remaining**power + lr_end\n            return decay / lr_init  # As `LambdaLR` multiplies by `lr_init`.\n"
  },
  {
    "path": "torch_geometric/nn/model_hub.py",
    "content": "import os.path as osp\nfrom pathlib import Path\nfrom typing import Any, Dict, Optional, Union\n\nimport torch\n\nfrom torch_geometric.io import fs\n\ntry:\n    from huggingface_hub import ModelHubMixin, hf_hub_download\nexcept ImportError:\n    ModelHubMixin = object\n    hf_hub_download = None\n\nCONFIG_NAME = 'config.json'\nMODEL_HUB_ORGANIZATION = \"pytorch_geometric\"\nMODEL_WEIGHTS_NAME = 'model.pth'\nTAGS = ['graph-machine-learning']\n\n\nclass PyGModelHubMixin(ModelHubMixin):\n    r\"\"\"A mixin for saving and loading models to the\n    `Huggingface Model Hub <https://huggingface.co/docs/hub/index>`_.\n\n    .. code-block:: python\n\n       from torch_geometric.datasets import Planetoid\n       from torch_geometric.nn import Node2Vec\n       from torch_geometric.nn.model_hub import PyGModelHubMixin\n\n       # Define your class with the mixin:\n       class N2V(Node2Vec, PyGModelHubMixin):\n           def __init__(self,model_name, dataset_name, model_kwargs):\n               Node2Vec.__init__(self,**model_kwargs)\n               PyGModelHubMixin.__init__(self, model_name,\n                   dataset_name, model_kwargs)\n\n       # Instantiate your model:\n       n2v = N2V(model_name='node2vec',\n           dataset_name='Cora', model_kwargs=dict(\n           edge_index=data.edge_index, embedding_dim=128,\n           walk_length=20, context_size=10, walks_per_node=10,\n           num_negative_samples=1, p=1, q=1, sparse=True))\n\n       # Train the model:\n       ...\n\n       # Push to the HuggingFace hub:\n       repo_id = ...  # your repo id\n       n2v.save_pretrained(\n           local_file_path,\n           push_to_hub=True,\n           repo_id=repo_id,\n        )\n\n       # Load the model for inference:\n       # The required arguments are the repo id/local folder, and any model\n       # initialisation arguments that are not native python types (e.g\n       # Node2Vec requires the edge_index argument which is not stored in the\n       # model hub).\n       model = N2V.from_pretrained(\n           repo_id,\n           model_name='node2vec',\n           dataset_name='Cora',\n           edge_index=data.edge_index,\n       )\n\n    Args:\n        model_name (str): Name of the model.\n        dataset_name (str): Name of the dataset the model was trained against.\n        model_kwargs (Dict[str, Any]): The arguments to initialise the model.\n    \"\"\"\n    def __init__(self, model_name: str, dataset_name: str, model_kwargs: Dict):\n        ModelHubMixin.__init__(self)\n\n        # Huggingface Hub API only accepts saving the config as a dict.\n        # If the model is instantiated with non-native python types\n        # such as torch Tensors (node2vec being an example), we have to remove\n        # these as they are not json serialisable\n        self.model_config = {\n            k: v\n            for k, v in model_kwargs.items() if type(v) in [str, int, float]\n        }\n        self.model_name = model_name\n        self.dataset_name = dataset_name\n\n    def construct_model_card(self, model_name: str, dataset_name: str) -> Any:\n        from huggingface_hub import ModelCard, ModelCardData\n        card_data = ModelCardData(\n            language='en',\n            license='mit',\n            library_name=MODEL_HUB_ORGANIZATION,\n            tags=TAGS,\n            datasets=dataset_name,\n            model_name=model_name,\n        )\n        card = ModelCard.from_template(card_data)\n        return card\n\n    def _save_pretrained(self, save_directory: Union[Path, str]):\n        path = osp.join(save_directory, MODEL_WEIGHTS_NAME)\n        model_to_save = self.module if hasattr(self, 'module') else self\n        torch.save(model_to_save.state_dict(), path)\n\n    def save_pretrained(self, save_directory: Union[str, Path],\n                        push_to_hub: bool = False,\n                        repo_id: Optional[str] = None, **kwargs):\n        r\"\"\"Save a trained model to a local directory or to the HuggingFace\n        model hub.\n\n        Args:\n            save_directory (str): The directory where weights are saved.\n            push_to_hub (bool, optional): If :obj:`True`, push the model to the\n                HuggingFace model hub. (default: :obj:`False`)\n            repo_id (str, optional): The repository name in the hub.\n                If not provided will default to the name of\n                :obj:`save_directory` in your namespace. (default: :obj:`None`)\n            **kwargs: Additional keyword arguments passed to\n                :meth:`huggingface_hub.ModelHubMixin.save_pretrained`.\n        \"\"\"\n        config = self.model_config\n        # due to way huggingface hub handles the loading/saving of models,\n        # the model config can end up in one of the items in the kwargs\n        # this has to be removed to prevent a duplication of arguments to\n        # ModelHubMixin.save_pretrained\n        kwargs.pop('config', None)\n\n        super().save_pretrained(\n            save_directory=save_directory,\n            config=config,\n            push_to_hub=push_to_hub,\n            repo_id=repo_id,\n            **kwargs,\n        )\n        model_card = self.construct_model_card(self.model_name,\n                                               self.dataset_name)\n        if push_to_hub:\n            model_card.push_to_hub(repo_id)\n\n    @classmethod\n    def _from_pretrained(\n        cls,\n        model_id,\n        revision,\n        cache_dir,\n        force_download,\n        local_files_only,\n        token,\n        proxies=None,\n        resume_download=False,\n        dataset_name='',\n        model_name='',\n        map_location='cpu',\n        strict=False,\n        **model_kwargs,\n    ):\n        map_location = torch.device(map_location)\n\n        if osp.isdir(model_id):\n            model_file = osp.join(model_id, MODEL_WEIGHTS_NAME)\n        else:\n            model_file = hf_hub_download(\n                repo_id=model_id,\n                filename=MODEL_WEIGHTS_NAME,\n                revision=revision,\n                cache_dir=cache_dir,\n                force_download=force_download,\n                token=token,\n                local_files_only=local_files_only,\n            )\n\n        config = model_kwargs.pop('config', None)\n        if config is not None:\n            model_kwargs = {**model_kwargs, **config}\n\n        model = cls(dataset_name, model_name, model_kwargs)\n\n        state_dict = fs.torch_load(model_file, map_location=map_location)\n        model.load_state_dict(state_dict, strict=strict)\n        model.eval()\n\n        return model\n\n    @classmethod\n    def from_pretrained(\n        cls,\n        pretrained_model_name_or_path: str,\n        force_download: bool = False,\n        token: Optional[Union[str, bool]] = None,\n        cache_dir: Optional[str] = None,\n        local_files_only: bool = False,\n        **model_kwargs,\n    ) -> Any:\n        r\"\"\"Downloads and instantiates a model from the HuggingFace hub.\n\n        Args:\n            pretrained_model_name_or_path (str): Can be either:\n\n                - The :obj:`model_id` of a pretrained model hosted inside the\n                  HuggingFace hub.\n\n                - You can add a :obj:`revision` by appending :obj:`@` at the\n                  end of :obj:`model_id` to load a specific model version.\n\n                - A path to a directory containing the saved model weights.\n\n                - :obj:`None` if you are both providing the configuration\n                  :obj:`config` and state dictionary :obj:`state_dict`.\n\n            force_download (bool, optional): Whether to force the\n                (re-)download of the model weights and configuration files,\n                overriding the cached versions if they exist.\n                (default: :obj:`False`)\n            token (str or bool, optional): The token to use as HTTP bearer\n                authorization for remote files. If set to :obj:`True`, will use\n                the token generated when running :obj:`transformers-cli login`\n                (stored in :obj:`~/.huggingface`). It is **required** if you\n                want to use a private model. (default: :obj:`None`)\n            cache_dir (str, optional): The path to a directory in which a\n                downloaded model configuration should be cached if the\n                standard cache should not be used. (default: :obj:`None`)\n            local_files_only (bool, optional): Whether to only look at local\n                files, *i.e.* do not try to download the model.\n                (default: :obj:`False`)\n            **model_kwargs: Additional keyword arguments passed to the\n                model during initialization.\n        \"\"\"\n        return super().from_pretrained(\n            pretrained_model_name_or_path,\n            force_download=force_download,\n            use_auth_token=token,\n            cache_dir=cache_dir,\n            local_files_only=local_files_only,\n            **model_kwargs,\n        )\n"
  },
  {
    "path": "torch_geometric/nn/models/__init__.py",
    "content": "r\"\"\"Model package.\"\"\"\n\nfrom .mlp import MLP\nfrom .basic_gnn import GCN, GraphSAGE, GIN, GAT, PNA, EdgeCNN\nfrom .jumping_knowledge import JumpingKnowledge, HeteroJumpingKnowledge\nfrom .meta import MetaLayer\nfrom .node2vec import Node2Vec\nfrom .deep_graph_infomax import DeepGraphInfomax\nfrom .autoencoder import InnerProductDecoder, GAE, VGAE, ARGA, ARGVA\nfrom .signed_gcn import SignedGCN\nfrom .re_net import RENet\nfrom .graph_unet import GraphUNet\nfrom .schnet import SchNet\nfrom .dimenet import DimeNet, DimeNetPlusPlus\nfrom .gpse import GPSE, GPSENodeEncoder\nfrom .captum import to_captum_model\nfrom .metapath2vec import MetaPath2Vec\nfrom .deepgcn import DeepGCNLayer\nfrom .tgn import TGNMemory\nfrom .label_prop import LabelPropagation\nfrom .correct_and_smooth import CorrectAndSmooth\nfrom .attentive_fp import AttentiveFP\nfrom .rect import RECT_L\nfrom .linkx import LINKX\nfrom .lightgcn import LightGCN\nfrom .mask_label import MaskLabel\nfrom .rev_gnn import GroupAddRev\nfrom .gnnff import GNNFF\nfrom .pmlp import PMLP\nfrom .neural_fingerprint import NeuralFingerprint\nfrom .visnet import ViSNet\nfrom .lpformer import LPFormer\nfrom .sgformer import SGFormer\n\nfrom .polynormer import Polynormer\n# Deprecated:\nfrom torch_geometric.explain.algorithm.captum import (to_captum_input,\n                                                      captum_output_to_dicts)\nfrom .attract_repel import ARLinkPredictor\n\n__all__ = classes = [\n    'MLP',\n    'GCN',\n    'GraphSAGE',\n    'GIN',\n    'GAT',\n    'PNA',\n    'EdgeCNN',\n    'JumpingKnowledge',\n    'HeteroJumpingKnowledge',\n    'MetaLayer',\n    'Node2Vec',\n    'DeepGraphInfomax',\n    'InnerProductDecoder',\n    'GAE',\n    'VGAE',\n    'ARGA',\n    'ARGVA',\n    'SignedGCN',\n    'RENet',\n    'GraphUNet',\n    'SchNet',\n    'DimeNet',\n    'DimeNetPlusPlus',\n    'GPSE',\n    'GPSENodeEncoder',\n    'to_captum_model',\n    'to_captum_input',\n    'captum_output_to_dicts',\n    'MetaPath2Vec',\n    'DeepGCNLayer',\n    'TGNMemory',\n    'LabelPropagation',\n    'CorrectAndSmooth',\n    'AttentiveFP',\n    'RECT_L',\n    'LINKX',\n    'LightGCN',\n    'MaskLabel',\n    'GroupAddRev',\n    'GNNFF',\n    'PMLP',\n    'NeuralFingerprint',\n    'ViSNet',\n    'LPFormer',\n    'SGFormer',\n    'Polynormer',\n    'ARLinkPredictor',\n]\n"
  },
  {
    "path": "torch_geometric/nn/models/attentive_fp.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import GRUCell, Linear, Parameter\n\nfrom torch_geometric.nn import GATConv, MessagePassing, global_add_pool\nfrom torch_geometric.nn.inits import glorot, zeros\nfrom torch_geometric.typing import Adj, OptTensor\nfrom torch_geometric.utils import softmax\n\n\nclass GATEConv(MessagePassing):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        edge_dim: int,\n        dropout: float = 0.0,\n    ):\n        super().__init__(aggr='add', node_dim=0)\n\n        self.dropout = dropout\n\n        self.att_l = Parameter(torch.empty(1, out_channels))\n        self.att_r = Parameter(torch.empty(1, in_channels))\n\n        self.lin1 = Linear(in_channels + edge_dim, out_channels, False)\n        self.lin2 = Linear(out_channels, out_channels, False)\n\n        self.bias = Parameter(torch.empty(out_channels))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot(self.att_l)\n        glorot(self.att_r)\n        glorot(self.lin1.weight)\n        glorot(self.lin2.weight)\n        zeros(self.bias)\n\n    def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor:\n        # edge_updater_type: (x: Tensor, edge_attr: Tensor)\n        alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr)\n\n        # propagate_type: (x: Tensor, alpha: Tensor)\n        out = self.propagate(edge_index, x=x, alpha=alpha)\n        out = out + self.bias\n        return out\n\n    def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor,\n                    index: Tensor, ptr: OptTensor,\n                    size_i: Optional[int]) -> Tensor:\n        x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1)))\n        alpha_j = (x_j @ self.att_l.t()).squeeze(-1)\n        alpha_i = (x_i @ self.att_r.t()).squeeze(-1)\n        alpha = alpha_j + alpha_i\n        alpha = F.leaky_relu_(alpha)\n        alpha = softmax(alpha, index, ptr, size_i)\n        alpha = F.dropout(alpha, p=self.dropout, training=self.training)\n        return alpha\n\n    def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:\n        return self.lin2(x_j) * alpha.unsqueeze(-1)\n\n\nclass AttentiveFP(torch.nn.Module):\n    r\"\"\"The Attentive FP model for molecular representation learning from the\n    `\"Pushing the Boundaries of Molecular Representation for Drug Discovery\n    with the Graph Attention Mechanism\"\n    <https://pubs.acs.org/doi/10.1021/acs.jmedchem.9b00959>`_ paper, based on\n    graph attention mechanisms.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        hidden_channels (int): Hidden node feature dimensionality.\n        out_channels (int): Size of each output sample.\n        edge_dim (int): Edge feature dimensionality.\n        num_layers (int): Number of GNN layers.\n        num_timesteps (int): Number of iterative refinement steps for global\n            readout.\n        dropout (float, optional): Dropout probability. (default: :obj:`0.0`)\n\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        edge_dim: int,\n        num_layers: int,\n        num_timesteps: int,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.hidden_channels = hidden_channels\n        self.out_channels = out_channels\n        self.edge_dim = edge_dim\n        self.num_layers = num_layers\n        self.num_timesteps = num_timesteps\n        self.dropout = dropout\n\n        self.lin1 = Linear(in_channels, hidden_channels)\n\n        self.gate_conv = GATEConv(hidden_channels, hidden_channels, edge_dim,\n                                  dropout)\n        self.gru = GRUCell(hidden_channels, hidden_channels)\n\n        self.atom_convs = torch.nn.ModuleList()\n        self.atom_grus = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            conv = GATConv(hidden_channels, hidden_channels, dropout=dropout,\n                           add_self_loops=False, negative_slope=0.01)\n            self.atom_convs.append(conv)\n            self.atom_grus.append(GRUCell(hidden_channels, hidden_channels))\n\n        self.mol_conv = GATConv(hidden_channels, hidden_channels,\n                                dropout=dropout, add_self_loops=False,\n                                negative_slope=0.01)\n        self.mol_conv.explain = False  # Cannot explain global pooling.\n        self.mol_gru = GRUCell(hidden_channels, hidden_channels)\n\n        self.lin2 = Linear(hidden_channels, out_channels)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.lin1.reset_parameters()\n        self.gate_conv.reset_parameters()\n        self.gru.reset_parameters()\n        for conv, gru in zip(self.atom_convs, self.atom_grus):\n            conv.reset_parameters()\n            gru.reset_parameters()\n        self.mol_conv.reset_parameters()\n        self.mol_gru.reset_parameters()\n        self.lin2.reset_parameters()\n\n    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor,\n                batch: Tensor) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        # Atom Embedding:\n        x = F.leaky_relu_(self.lin1(x))\n\n        h = F.elu_(self.gate_conv(x, edge_index, edge_attr))\n        h = F.dropout(h, p=self.dropout, training=self.training)\n        x = self.gru(h, x).relu_()\n\n        for conv, gru in zip(self.atom_convs, self.atom_grus):\n            h = conv(x, edge_index)\n            h = F.elu(h)\n            h = F.dropout(h, p=self.dropout, training=self.training)\n            x = gru(h, x).relu()\n\n        # Molecule Embedding:\n        row = torch.arange(batch.size(0), device=batch.device)\n        edge_index = torch.stack([row, batch], dim=0)\n\n        out = global_add_pool(x, batch).relu_()\n        for _ in range(self.num_timesteps):\n            h = F.elu_(self.mol_conv((x, out), edge_index))\n            h = F.dropout(h, p=self.dropout, training=self.training)\n            out = self.mol_gru(h, out).relu_()\n\n        # Predictor:\n        out = F.dropout(out, p=self.dropout, training=self.training)\n        return self.lin2(out)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'in_channels={self.in_channels}, '\n                f'hidden_channels={self.hidden_channels}, '\n                f'out_channels={self.out_channels}, '\n                f'edge_dim={self.edge_dim}, '\n                f'num_layers={self.num_layers}, '\n                f'num_timesteps={self.num_timesteps}'\n                f')')\n"
  },
  {
    "path": "torch_geometric/nn/models/attract_repel.py",
    "content": "import torch\nimport torch.nn.functional as F\n\n\nclass ARLinkPredictor(torch.nn.Module):\n    r\"\"\"Link predictor using Attract-Repel embeddings from the paper\n    `\"Pseudo-Euclidean Attract-Repel Embeddings for Undirected Graphs\"\n    <https://arxiv.org/abs/2106.09671>`_.\n\n    This model splits node embeddings into: attract and\n    repel.\n    The edge prediction score is computed as the dot product of attract\n    components minus the dot product of repel components.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        hidden_channels (int): Size of hidden embeddings.\n        out_channels (int, optional): Size of output embeddings.\n            If set to :obj:`None`, will default to :obj:`hidden_channels`.\n            (default: :obj:`None`)\n        num_layers (int): Number of message passing layers.\n            (default: :obj:`2`)\n        dropout (float): Dropout probability. (default: :obj:`0.0`)\n        attract_ratio (float): Ratio to use for attract component.\n            Must be between 0 and 1. (default: :obj:`0.5`)\n    \"\"\"\n    def __init__(self, in_channels, hidden_channels, out_channels=None,\n                 num_layers=2, dropout=0.0, attract_ratio=0.5):\n        super().__init__()\n\n        if out_channels is None:\n            out_channels = hidden_channels\n\n        self.in_channels = in_channels\n        self.hidden_channels = hidden_channels\n        self.out_channels = out_channels\n        self.num_layers = num_layers\n        self.dropout = dropout\n\n        if not 0 <= attract_ratio <= 1:\n            raise ValueError(\n                f\"attract_ratio must be between 0 and 1, got {attract_ratio}\")\n\n        self.attract_ratio = attract_ratio\n        self.attract_dim = int(out_channels * attract_ratio)\n        self.repel_dim = out_channels - self.attract_dim\n\n        # Create model layers\n        self.lins = torch.nn.ModuleList()\n        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))\n\n        for _ in range(num_layers - 2):\n            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))\n\n        # Final layer splits into attract and repel components\n        self.lin_attract = torch.nn.Linear(hidden_channels, self.attract_dim)\n        self.lin_repel = torch.nn.Linear(hidden_channels, self.repel_dim)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        \"\"\"Reset all learnable parameters.\"\"\"\n        for lin in self.lins:\n            lin.reset_parameters()\n        self.lin_attract.reset_parameters()\n        self.lin_repel.reset_parameters()\n\n    def encode(self, x, *args, **kwargs):\n        \"\"\"Encode node features into attract-repel embeddings.\n\n        Args:\n            x (torch.Tensor): Node feature matrix of shape\n                :obj:`[num_nodes, in_channels]`.\n            *args: Variable length argument list\n            **kwargs: Arbitrary keyword arguments\n\n        \"\"\"\n        for lin in self.lins:\n            x = lin(x)\n            x = F.relu(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n\n        # Split into attract and repel components\n        attract_x = self.lin_attract(x)\n        repel_x = self.lin_repel(x)\n\n        return attract_x, repel_x\n\n    def decode(self, attract_z, repel_z, edge_index):\n        \"\"\"Decode edge scores from attract-repel embeddings.\n\n        Args:\n            attract_z (torch.Tensor): Attract embeddings of shape\n                :obj:`[num_nodes, attract_dim]`.\n            repel_z (torch.Tensor): Repel embeddings of shape\n                :obj:`[num_nodes, repel_dim]`.\n            edge_index (torch.Tensor): Edge indices of shape\n                :obj:`[2, num_edges]`.\n\n        Returns:\n            torch.Tensor: Edge prediction scores.\n        \"\"\"\n        # Get node embeddings for edges\n        row, col = edge_index\n        attract_z_row = attract_z[row]\n        attract_z_col = attract_z[col]\n        repel_z_row = repel_z[row]\n        repel_z_col = repel_z[col]\n\n        # Compute attract-repel scores\n        attract_score = torch.sum(attract_z_row * attract_z_col, dim=1)\n        repel_score = torch.sum(repel_z_row * repel_z_col, dim=1)\n\n        return attract_score - repel_score\n\n    def forward(self, x, edge_index):\n        \"\"\"Forward pass for link prediction.\n\n        Args:\n            x (torch.Tensor): Node feature matrix.\n            edge_index (torch.Tensor): Edge indices to predict.\n\n        Returns:\n            torch.Tensor: Predicted edge scores.\n        \"\"\"\n        # Encode nodes into attract-repel embeddings\n        attract_z, repel_z = self.encode(x)\n\n        # Decode target edges\n        return torch.sigmoid(self.decode(attract_z, repel_z, edge_index))\n\n    def calculate_r_fraction(self, attract_z, repel_z):\n        \"\"\"Calculate the R-fraction (proportion of energy in repel space).\n\n        Args:\n            attract_z (torch.Tensor): Attract embeddings.\n            repel_z (torch.Tensor): Repel embeddings.\n\n        Returns:\n            float: R-fraction value.\n        \"\"\"\n        attract_norm_squared = torch.sum(attract_z**2)\n        repel_norm_squared = torch.sum(repel_z**2)\n\n        r_fraction = repel_norm_squared / (attract_norm_squared +\n                                           repel_norm_squared + 1e-10)\n\n        return r_fraction.item()\n"
  },
  {
    "path": "torch_geometric/nn/models/autoencoder.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Module\n\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.utils import negative_sampling\n\nEPS = 1e-15\nMAX_LOGSTD = 10\n\n\nclass InnerProductDecoder(torch.nn.Module):\n    r\"\"\"The inner product decoder from the `\"Variational Graph Auto-Encoders\"\n    <https://arxiv.org/abs/1611.07308>`_ paper.\n\n    .. math::\n        \\sigma(\\mathbf{Z}\\mathbf{Z}^{\\top})\n\n    where :math:`\\mathbf{Z} \\in \\mathbb{R}^{N \\times d}` denotes the latent\n    space produced by the encoder.\n    \"\"\"\n    def forward(\n        self,\n        z: Tensor,\n        edge_index: Tensor,\n        sigmoid: bool = True,\n    ) -> Tensor:\n        r\"\"\"Decodes the latent variables :obj:`z` into edge probabilities for\n        the given node-pairs :obj:`edge_index`.\n\n        Args:\n            z (torch.Tensor): The latent space :math:`\\mathbf{Z}`.\n            edge_index (torch.Tensor): The edge indices.\n            sigmoid (bool, optional): If set to :obj:`False`, does not apply\n                the logistic sigmoid function to the output.\n                (default: :obj:`True`)\n        \"\"\"\n        value = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=1)\n        return torch.sigmoid(value) if sigmoid else value\n\n    def forward_all(self, z: Tensor, sigmoid: bool = True) -> Tensor:\n        r\"\"\"Decodes the latent variables :obj:`z` into a probabilistic dense\n        adjacency matrix.\n\n        Args:\n            z (torch.Tensor): The latent space :math:`\\mathbf{Z}`.\n            sigmoid (bool, optional): If set to :obj:`False`, does not apply\n                the logistic sigmoid function to the output.\n                (default: :obj:`True`)\n        \"\"\"\n        adj = torch.matmul(z, z.t())\n        return torch.sigmoid(adj) if sigmoid else adj\n\n\nclass GAE(torch.nn.Module):\n    r\"\"\"The Graph Auto-Encoder model from the\n    `\"Variational Graph Auto-Encoders\" <https://arxiv.org/abs/1611.07308>`_\n    paper based on user-defined encoder and decoder models.\n\n    Args:\n        encoder (torch.nn.Module): The encoder module.\n        decoder (torch.nn.Module, optional): The decoder module. If set to\n            :obj:`None`, will default to the\n            :class:`torch_geometric.nn.models.InnerProductDecoder`.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, encoder: Module, decoder: Optional[Module] = None):\n        super().__init__()\n        self.encoder = encoder\n        self.decoder = InnerProductDecoder() if decoder is None else decoder\n        GAE.reset_parameters(self)\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        reset(self.encoder)\n        reset(self.decoder)\n\n    def forward(self, *args, **kwargs) -> Tensor:  # pragma: no cover\n        r\"\"\"Alias for :meth:`encode`.\"\"\"\n        return self.encoder(*args, **kwargs)\n\n    def encode(self, *args, **kwargs) -> Tensor:\n        r\"\"\"Runs the encoder and computes node-wise latent variables.\"\"\"\n        return self.encoder(*args, **kwargs)\n\n    def decode(self, *args, **kwargs) -> Tensor:\n        r\"\"\"Runs the decoder and computes edge probabilities.\"\"\"\n        return self.decoder(*args, **kwargs)\n\n    def recon_loss(self, z: Tensor, pos_edge_index: Tensor,\n                   neg_edge_index: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Given latent variables :obj:`z`, computes the binary cross\n        entropy loss for positive edges :obj:`pos_edge_index` and negative\n        sampled edges.\n\n        Args:\n            z (torch.Tensor): The latent space :math:`\\mathbf{Z}`.\n            pos_edge_index (torch.Tensor): The positive edges to train against.\n            neg_edge_index (torch.Tensor, optional): The negative edges to\n                train against. If not given, uses negative sampling to\n                calculate negative edges. (default: :obj:`None`)\n        \"\"\"\n        pos_loss = -torch.log(\n            self.decoder(z, pos_edge_index, sigmoid=True) + EPS).mean()\n\n        if neg_edge_index is None:\n            neg_edge_index = negative_sampling(pos_edge_index, z.size(0))\n        neg_loss = -torch.log(1 -\n                              self.decoder(z, neg_edge_index, sigmoid=True) +\n                              EPS).mean()\n\n        return pos_loss + neg_loss\n\n    def test(self, z: Tensor, pos_edge_index: Tensor,\n             neg_edge_index: Tensor) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Given latent variables :obj:`z`, positive edges\n        :obj:`pos_edge_index` and negative edges :obj:`neg_edge_index`,\n        computes area under the ROC curve (AUC) and average precision (AP)\n        scores.\n\n        Args:\n            z (torch.Tensor): The latent space :math:`\\mathbf{Z}`.\n            pos_edge_index (torch.Tensor): The positive edges to evaluate\n                against.\n            neg_edge_index (torch.Tensor): The negative edges to evaluate\n                against.\n        \"\"\"\n        from sklearn.metrics import average_precision_score, roc_auc_score\n\n        pos_y = z.new_ones(pos_edge_index.size(1))\n        neg_y = z.new_zeros(neg_edge_index.size(1))\n        y = torch.cat([pos_y, neg_y], dim=0)\n\n        pos_pred = self.decoder(z, pos_edge_index, sigmoid=True)\n        neg_pred = self.decoder(z, neg_edge_index, sigmoid=True)\n        pred = torch.cat([pos_pred, neg_pred], dim=0)\n\n        y, pred = y.detach().cpu().numpy(), pred.detach().cpu().numpy()\n\n        return roc_auc_score(y, pred), average_precision_score(y, pred)\n\n\nclass VGAE(GAE):\n    r\"\"\"The Variational Graph Auto-Encoder model from the\n    `\"Variational Graph Auto-Encoders\" <https://arxiv.org/abs/1611.07308>`_\n    paper.\n\n    Args:\n        encoder (torch.nn.Module): The encoder module to compute :math:`\\mu`\n            and :math:`\\log\\sigma^2`.\n        decoder (torch.nn.Module, optional): The decoder module. If set to\n            :obj:`None`, will default to the\n            :class:`torch_geometric.nn.models.InnerProductDecoder`.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, encoder: Module, decoder: Optional[Module] = None):\n        super().__init__(encoder, decoder)\n\n    def reparametrize(self, mu: Tensor, logstd: Tensor) -> Tensor:\n        if self.training:\n            return mu + torch.randn_like(logstd) * torch.exp(logstd)\n        else:\n            return mu\n\n    def encode(self, *args, **kwargs) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        self.__mu__, self.__logstd__ = self.encoder(*args, **kwargs)\n        self.__logstd__ = self.__logstd__.clamp(max=MAX_LOGSTD)\n        z = self.reparametrize(self.__mu__, self.__logstd__)\n        return z\n\n    def kl_loss(self, mu: Optional[Tensor] = None,\n                logstd: Optional[Tensor] = None) -> Tensor:\n        r\"\"\"Computes the KL loss, either for the passed arguments :obj:`mu`\n        and :obj:`logstd`, or based on latent variables from last encoding.\n\n        Args:\n            mu (torch.Tensor, optional): The latent space for :math:`\\mu`. If\n                set to :obj:`None`, uses the last computation of :math:`\\mu`.\n                (default: :obj:`None`)\n            logstd (torch.Tensor, optional): The latent space for\n                :math:`\\log\\sigma`.  If set to :obj:`None`, uses the last\n                computation of :math:`\\log\\sigma^2`. (default: :obj:`None`)\n        \"\"\"\n        mu = self.__mu__ if mu is None else mu\n        logstd = self.__logstd__ if logstd is None else logstd.clamp(\n            max=MAX_LOGSTD)\n        return -0.5 * torch.mean(\n            torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1))\n\n\nclass ARGA(GAE):\n    r\"\"\"The Adversarially Regularized Graph Auto-Encoder model from the\n    `\"Adversarially Regularized Graph Autoencoder for Graph Embedding\"\n    <https://arxiv.org/abs/1802.04407>`_ paper.\n\n    Args:\n        encoder (torch.nn.Module): The encoder module.\n        discriminator (torch.nn.Module): The discriminator module.\n        decoder (torch.nn.Module, optional): The decoder module. If set to\n            :obj:`None`, will default to the\n            :class:`torch_geometric.nn.models.InnerProductDecoder`.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        encoder: Module,\n        discriminator: Module,\n        decoder: Optional[Module] = None,\n    ):\n        super().__init__(encoder, decoder)\n        self.discriminator = discriminator\n        reset(self.discriminator)\n\n    def reset_parameters(self):\n        super().reset_parameters()\n        reset(self.discriminator)\n\n    def reg_loss(self, z: Tensor) -> Tensor:\n        r\"\"\"Computes the regularization loss of the encoder.\n\n        Args:\n            z (torch.Tensor): The latent space :math:`\\mathbf{Z}`.\n        \"\"\"\n        real = torch.sigmoid(self.discriminator(z))\n        real_loss = -torch.log(real + EPS).mean()\n        return real_loss\n\n    def discriminator_loss(self, z: Tensor) -> Tensor:\n        r\"\"\"Computes the loss of the discriminator.\n\n        Args:\n            z (torch.Tensor): The latent space :math:`\\mathbf{Z}`.\n        \"\"\"\n        real = torch.sigmoid(self.discriminator(torch.randn_like(z)))\n        fake = torch.sigmoid(self.discriminator(z.detach()))\n        real_loss = -torch.log(real + EPS).mean()\n        fake_loss = -torch.log(1 - fake + EPS).mean()\n        return real_loss + fake_loss\n\n\nclass ARGVA(ARGA):\n    r\"\"\"The Adversarially Regularized Variational Graph Auto-Encoder model from\n    the `\"Adversarially Regularized Graph Autoencoder for Graph Embedding\"\n    <https://arxiv.org/abs/1802.04407>`_ paper.\n\n    Args:\n        encoder (torch.nn.Module): The encoder module to compute :math:`\\mu`\n            and :math:`\\log\\sigma^2`.\n        discriminator (torch.nn.Module): The discriminator module.\n        decoder (torch.nn.Module, optional): The decoder module. If set to\n            :obj:`None`, will default to the\n            :class:`torch_geometric.nn.models.InnerProductDecoder`.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        encoder: Module,\n        discriminator: Module,\n        decoder: Optional[Module] = None,\n    ):\n        super().__init__(encoder, discriminator, decoder)\n        self.VGAE = VGAE(encoder, decoder)\n\n    @property\n    def __mu__(self) -> Tensor:\n        return self.VGAE.__mu__\n\n    @property\n    def __logstd__(self) -> Tensor:\n        return self.VGAE.__logstd__\n\n    def reparametrize(self, mu: Tensor, logstd: Tensor) -> Tensor:\n        return self.VGAE.reparametrize(mu, logstd)\n\n    def encode(self, *args, **kwargs) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        return self.VGAE.encode(*args, **kwargs)\n\n    def kl_loss(\n        self,\n        mu: Optional[Tensor] = None,\n        logstd: Optional[Tensor] = None,\n    ) -> Tensor:\n        return self.VGAE.kl_loss(mu, logstd)\n"
  },
  {
    "path": "torch_geometric/nn/models/basic_gnn.py",
    "content": "import copy\nimport inspect\nfrom typing import Any, Callable, Dict, Final, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Linear, ModuleList\nfrom tqdm import tqdm\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.loader import CachedLoader, NeighborLoader\nfrom torch_geometric.nn.conv import (\n    EdgeConv,\n    GATConv,\n    GATv2Conv,\n    GCNConv,\n    GINConv,\n    MessagePassing,\n    PNAConv,\n    SAGEConv,\n)\nfrom torch_geometric.nn.models import MLP\nfrom torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge\nfrom torch_geometric.nn.resolver import (\n    activation_resolver,\n    normalization_resolver,\n)\nfrom torch_geometric.typing import Adj, OptTensor\nfrom torch_geometric.utils._trim_to_layer import TrimToLayer\n\n\nclass BasicGNN(torch.nn.Module):\n    r\"\"\"An abstract class for implementing basic GNN models.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        hidden_channels (int): Size of each hidden sample.\n        num_layers (int): Number of message passing layers.\n        out_channels (int, optional): If not set to :obj:`None`, will apply a\n            final linear transformation to convert hidden node embeddings to\n            output size :obj:`out_channels`. (default: :obj:`None`)\n        dropout (float, optional): Dropout probability. (default: :obj:`0.`)\n        act (str or Callable, optional): The non-linear activation function to\n            use. (default: :obj:`\"relu\"`)\n        act_first (bool, optional): If set to :obj:`True`, activation is\n            applied before normalization. (default: :obj:`False`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        norm (str or Callable, optional): The normalization function to\n            use. (default: :obj:`None`)\n        norm_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective normalization function defined by :obj:`norm`.\n            (default: :obj:`None`)\n        jk (str, optional): The Jumping Knowledge mode. If specified, the model\n            will additionally apply a final linear transformation to transform\n            node embeddings to the expected output feature dimensionality.\n            (:obj:`None`, :obj:`\"last\"`, :obj:`\"cat\"`, :obj:`\"max\"`,\n            :obj:`\"lstm\"`). (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of the underlying\n            :class:`torch_geometric.nn.conv.MessagePassing` layers.\n    \"\"\"\n    supports_edge_weight: Final[bool]\n    supports_edge_attr: Final[bool]\n    supports_norm_batch: Final[bool]\n\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        num_layers: int,\n        out_channels: Optional[int] = None,\n        dropout: float = 0.0,\n        act: Union[str, Callable, None] = \"relu\",\n        act_first: bool = False,\n        act_kwargs: Optional[Dict[str, Any]] = None,\n        norm: Union[str, Callable, None] = None,\n        norm_kwargs: Optional[Dict[str, Any]] = None,\n        jk: Optional[str] = None,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.hidden_channels = hidden_channels\n        self.num_layers = num_layers\n\n        self.dropout = torch.nn.Dropout(p=dropout)\n        self.act = activation_resolver(act, **(act_kwargs or {}))\n        self.jk_mode = jk\n        self.act_first = act_first\n        self.norm = norm if isinstance(norm, str) else None\n        self.norm_kwargs = norm_kwargs\n\n        if out_channels is not None:\n            self.out_channels = out_channels\n        else:\n            self.out_channels = hidden_channels\n\n        self.convs = ModuleList()\n        if num_layers > 1:\n            self.convs.append(\n                self.init_conv(in_channels, hidden_channels, **kwargs))\n            if isinstance(in_channels, (tuple, list)):\n                in_channels = (hidden_channels, hidden_channels)\n            else:\n                in_channels = hidden_channels\n        for _ in range(num_layers - 2):\n            self.convs.append(\n                self.init_conv(in_channels, hidden_channels, **kwargs))\n            if isinstance(in_channels, (tuple, list)):\n                in_channels = (hidden_channels, hidden_channels)\n            else:\n                in_channels = hidden_channels\n        if out_channels is not None and jk is None:\n            self._is_conv_to_out = True\n            self.convs.append(\n                self.init_conv(in_channels, out_channels, **kwargs))\n        else:\n            self.convs.append(\n                self.init_conv(in_channels, hidden_channels, **kwargs))\n\n        self.norms = ModuleList()\n        norm_layer = normalization_resolver(\n            norm,\n            hidden_channels,\n            **(norm_kwargs or {}),\n        )\n        if norm_layer is None:\n            norm_layer = torch.nn.Identity()\n\n        self.supports_norm_batch = False\n        if hasattr(norm_layer, 'forward'):\n            norm_params = inspect.signature(norm_layer.forward).parameters\n            self.supports_norm_batch = 'batch' in norm_params\n\n        for _ in range(num_layers - 1):\n            self.norms.append(copy.deepcopy(norm_layer))\n\n        if jk is not None:\n            self.norms.append(copy.deepcopy(norm_layer))\n        else:\n            self.norms.append(torch.nn.Identity())\n\n        if jk is not None and jk != 'last':\n            self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)\n\n        if jk is not None:\n            if jk == 'cat':\n                in_channels = num_layers * hidden_channels\n            else:\n                in_channels = hidden_channels\n            self.lin = Linear(in_channels, self.out_channels)\n\n        # We define `trim_to_layer` functionality as a module such that we can\n        # still use `to_hetero` on-top.\n        self._trim = TrimToLayer()\n\n    def init_conv(self, in_channels: Union[int, Tuple[int, int]],\n                  out_channels: int, **kwargs) -> MessagePassing:\n        raise NotImplementedError\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for conv in self.convs:\n            conv.reset_parameters()\n        for norm in self.norms:\n            if hasattr(norm, 'reset_parameters'):\n                norm.reset_parameters()\n        if hasattr(self, 'jk'):\n            self.jk.reset_parameters()\n        if hasattr(self, 'lin'):\n            self.lin.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n        edge_attr: OptTensor = None,\n        batch: OptTensor = None,\n        batch_size: Optional[int] = None,\n        num_sampled_nodes_per_hop: Optional[List[int]] = None,\n        num_sampled_edges_per_hop: Optional[List[int]] = None,\n    ) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The input node features.\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            edge_weight (torch.Tensor, optional): The edge weights (if\n                supported by the underlying GNN layer). (default: :obj:`None`)\n            edge_attr (torch.Tensor, optional): The edge features (if supported\n                by the underlying GNN layer). (default: :obj:`None`)\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example.\n                Only needs to be passed in case the underlying normalization\n                layers require the :obj:`batch` information.\n                (default: :obj:`None`)\n            batch_size (int, optional): The number of examples :math:`B`.\n                Automatically calculated if not given.\n                Only needs to be passed in case the underlying normalization\n                layers require the :obj:`batch` information.\n                (default: :obj:`None`)\n            num_sampled_nodes_per_hop (List[int], optional): The number of\n                sampled nodes per hop.\n                Useful in :class:`~torch_geometric.loader.NeighborLoader`\n                scenarios to only operate on minimal-sized representations.\n                (default: :obj:`None`)\n            num_sampled_edges_per_hop (List[int], optional): The number of\n                sampled edges per hop.\n                Useful in :class:`~torch_geometric.loader.NeighborLoader`\n                scenarios to only operate on minimal-sized representations.\n                (default: :obj:`None`)\n        \"\"\"\n        if (num_sampled_nodes_per_hop is not None\n                and isinstance(edge_weight, Tensor)\n                and isinstance(edge_attr, Tensor)):\n            raise NotImplementedError(\"'trim_to_layer' functionality does not \"\n                                      \"yet support trimming of both \"\n                                      \"'edge_weight' and 'edge_attr'\")\n\n        xs: List[Tensor] = []\n        assert len(self.convs) == len(self.norms)\n        for i, (conv, norm) in enumerate(zip(self.convs, self.norms)):\n            if (not torch.jit.is_scripting()\n                    and num_sampled_nodes_per_hop is not None):\n                x, edge_index, value = self._trim(\n                    i,\n                    num_sampled_nodes_per_hop,\n                    num_sampled_edges_per_hop,\n                    x,\n                    edge_index,\n                    edge_weight if edge_weight is not None else edge_attr,\n                )\n                if edge_weight is not None:\n                    edge_weight = value\n                else:\n                    edge_attr = value\n\n            # Tracing the module is not allowed with *args and **kwargs :(\n            # As such, we rely on a static solution to pass optional edge\n            # weights and edge attributes to the module.\n            if self.supports_edge_weight and self.supports_edge_attr:\n                x = conv(x, edge_index, edge_weight=edge_weight,\n                         edge_attr=edge_attr)\n            elif self.supports_edge_weight:\n                x = conv(x, edge_index, edge_weight=edge_weight)\n            elif self.supports_edge_attr:\n                x = conv(x, edge_index, edge_attr=edge_attr)\n            else:\n                x = conv(x, edge_index)\n\n            if i < self.num_layers - 1 or self.jk_mode is not None:\n                if self.act is not None and self.act_first:\n                    x = self.act(x)\n                if self.supports_norm_batch:\n                    x = norm(x, batch, batch_size)\n                else:\n                    x = norm(x)\n                if self.act is not None and not self.act_first:\n                    x = self.act(x)\n                x = self.dropout(x)\n                if hasattr(self, 'jk'):\n                    xs.append(x)\n\n        x = self.jk(xs) if hasattr(self, 'jk') else x\n        x = self.lin(x) if hasattr(self, 'lin') else x\n\n        return x\n\n    @torch.no_grad()\n    def inference_per_layer(\n        self,\n        layer: int,\n        x: Tensor,\n        edge_index: Adj,\n        batch_size: int,\n    ) -> Tensor:\n\n        x = self.convs[layer](x, edge_index)[:batch_size]\n\n        if layer == self.num_layers - 1 and self.jk_mode is None:\n            return x\n\n        if self.act is not None and self.act_first:\n            x = self.act(x)\n        if self.norms is not None:\n            x = self.norms[layer](x)\n        if self.act is not None and not self.act_first:\n            x = self.act(x)\n        if layer == self.num_layers - 1 and hasattr(self, 'lin'):\n            x = self.lin(x)\n\n        return x\n\n    @torch.no_grad()\n    def inference(\n        self,\n        loader: NeighborLoader,\n        device: Optional[Union[str, torch.device]] = None,\n        embedding_device: Union[str, torch.device] = 'cpu',\n        progress_bar: bool = False,\n        cache: bool = False,\n    ) -> Tensor:\n        r\"\"\"Performs layer-wise inference on large-graphs using a\n        :class:`~torch_geometric.loader.NeighborLoader`, where\n        :class:`~torch_geometric.loader.NeighborLoader` should sample the\n        full neighborhood for only one layer.\n        This is an efficient way to compute the output embeddings for all\n        nodes in the graph.\n        Only applicable in case :obj:`jk=None` or `jk='last'`.\n\n        Args:\n            loader (torch_geometric.loader.NeighborLoader): A neighbor loader\n                object that generates full 1-hop subgraphs, *i.e.*,\n                :obj:`loader.num_neighbors = [-1]`.\n            device (torch.device, optional): The device to run the GNN on.\n                (default: :obj:`None`)\n            embedding_device (torch.device, optional): The device to store\n                intermediate embeddings on. If intermediate embeddings fit on\n                GPU, this option helps to avoid unnecessary device transfers.\n                (default: :obj:`\"cpu\"`)\n            progress_bar (bool, optional): If set to :obj:`True`, will print a\n                progress bar during computation. (default: :obj:`False`)\n            cache (bool, optional): If set to :obj:`True`, caches intermediate\n                sampler outputs for usage in later epochs.\n                This will avoid repeated sampling to accelerate inference.\n                (default: :obj:`False`)\n        \"\"\"\n        assert self.jk_mode is None or self.jk_mode == 'last'\n        assert isinstance(loader, NeighborLoader)\n        assert len(loader.dataset) == loader.data.num_nodes\n        assert len(loader.node_sampler.num_neighbors) == 1\n        assert not self.training\n        # assert not loader.shuffle  # TODO (matthias) does not work :(\n        if progress_bar:\n            pbar = tqdm(total=len(self.convs) * len(loader))\n            pbar.set_description('Inference')\n\n        x_all = loader.data.x.to(embedding_device)\n\n        if cache:\n\n            # Only cache necessary attributes:\n            def transform(data: Data) -> Data:\n                kwargs = dict(n_id=data.n_id, batch_size=data.batch_size)\n                if hasattr(data, 'adj_t'):\n                    kwargs['adj_t'] = data.adj_t\n                else:\n                    kwargs['edge_index'] = data.edge_index\n\n                return Data.from_dict(kwargs)\n\n            loader = CachedLoader(loader, device=device, transform=transform)\n\n        for i in range(self.num_layers):\n            xs: List[Tensor] = []\n            for batch in loader:\n                x = x_all[batch.n_id].to(device)\n                batch_size = batch.batch_size\n                if hasattr(batch, 'adj_t'):\n                    edge_index = batch.adj_t.to(device)\n                else:\n                    edge_index = batch.edge_index.to(device)\n\n                x = self.inference_per_layer(i, x, edge_index, batch_size)\n                xs.append(x.to(embedding_device))\n\n                if progress_bar:\n                    pbar.update(1)\n\n            x_all = torch.cat(xs, dim=0)\n\n        if progress_bar:\n            pbar.close()\n\n        return x_all\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, num_layers={self.num_layers})')\n\n\nclass GCN(BasicGNN):\n    r\"\"\"The Graph Neural Network from the `\"Semi-supervised\n    Classification with Graph Convolutional Networks\"\n    <https://arxiv.org/abs/1609.02907>`_ paper, using the\n    :class:`~torch_geometric.nn.conv.GCNConv` operator for message passing.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        hidden_channels (int): Size of each hidden sample.\n        num_layers (int): Number of message passing layers.\n        out_channels (int, optional): If not set to :obj:`None`, will apply a\n            final linear transformation to convert hidden node embeddings to\n            output size :obj:`out_channels`. (default: :obj:`None`)\n        dropout (float, optional): Dropout probability. (default: :obj:`0.`)\n        act (str or Callable, optional): The non-linear activation function to\n            use. (default: :obj:`\"relu\"`)\n        act_first (bool, optional): If set to :obj:`True`, activation is\n            applied before normalization. (default: :obj:`False`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        norm (str or Callable, optional): The normalization function to\n            use. (default: :obj:`None`)\n        norm_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective normalization function defined by :obj:`norm`.\n            (default: :obj:`None`)\n        jk (str, optional): The Jumping Knowledge mode. If specified, the model\n            will additionally apply a final linear transformation to transform\n            node embeddings to the expected output feature dimensionality,\n            while default will not.\n            (:obj:`None`, :obj:`\"last\"`, :obj:`\"cat\"`, :obj:`\"max\"`,\n            :obj:`\"lstm\"`). (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.GCNConv`.\n    \"\"\"\n    supports_edge_weight: Final[bool] = True\n    supports_edge_attr: Final[bool] = False\n    supports_norm_batch: Final[bool]\n\n    def init_conv(self, in_channels: int, out_channels: int,\n                  **kwargs) -> MessagePassing:\n        return GCNConv(in_channels, out_channels, **kwargs)\n\n\nclass GraphSAGE(BasicGNN):\n    r\"\"\"The Graph Neural Network from the `\"Inductive Representation Learning\n    on Large Graphs\" <https://arxiv.org/abs/1706.02216>`_ paper, using the\n    :class:`~torch_geometric.nn.SAGEConv` operator for message passing.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        hidden_channels (int): Size of each hidden sample.\n        num_layers (int): Number of message passing layers.\n        out_channels (int, optional): If not set to :obj:`None`, will apply a\n            final linear transformation to convert hidden node embeddings to\n            output size :obj:`out_channels`. (default: :obj:`None`)\n        dropout (float, optional): Dropout probability. (default: :obj:`0.`)\n        act (str or Callable, optional): The non-linear activation function to\n            use. (default: :obj:`\"relu\"`)\n        act_first (bool, optional): If set to :obj:`True`, activation is\n            applied before normalization. (default: :obj:`False`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        norm (str or Callable, optional): The normalization function to\n            use. (default: :obj:`None`)\n        norm_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective normalization function defined by :obj:`norm`.\n            (default: :obj:`None`)\n        jk (str, optional): The Jumping Knowledge mode. If specified, the model\n            will additionally apply a final linear transformation to transform\n            node embeddings to the expected output feature dimensionality.\n            (:obj:`None`, :obj:`\"last\"`, :obj:`\"cat\"`, :obj:`\"max\"`,\n            :obj:`\"lstm\"`). (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.SAGEConv`.\n    \"\"\"\n    supports_edge_weight: Final[bool] = False\n    supports_edge_attr: Final[bool] = False\n    supports_norm_batch: Final[bool]\n\n    def init_conv(self, in_channels: Union[int, Tuple[int, int]],\n                  out_channels: int, **kwargs) -> MessagePassing:\n        return SAGEConv(in_channels, out_channels, **kwargs)\n\n\nclass GIN(BasicGNN):\n    r\"\"\"The Graph Neural Network from the `\"How Powerful are Graph Neural\n    Networks?\" <https://arxiv.org/abs/1810.00826>`_ paper, using the\n    :class:`~torch_geometric.nn.GINConv` operator for message passing.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        hidden_channels (int): Size of each hidden sample.\n        num_layers (int): Number of message passing layers.\n        out_channels (int, optional): If not set to :obj:`None`, will apply a\n            final linear transformation to convert hidden node embeddings to\n            output size :obj:`out_channels`. (default: :obj:`None`)\n        dropout (float, optional): Dropout probability. (default: :obj:`0.`)\n        act (str or Callable, optional): The non-linear activation function to\n            use. (default: :obj:`\"relu\"`)\n        act_first (bool, optional): If set to :obj:`True`, activation is\n            applied before normalization. (default: :obj:`False`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        norm (str or Callable, optional): The normalization function to\n            use. (default: :obj:`None`)\n        norm_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective normalization function defined by :obj:`norm`.\n            (default: :obj:`None`)\n        jk (str, optional): The Jumping Knowledge mode. If specified, the model\n            will additionally apply a final linear transformation to transform\n            node embeddings to the expected output feature dimensionality.\n            (:obj:`None`, :obj:`\"last\"`, :obj:`\"cat\"`, :obj:`\"max\"`,\n            :obj:`\"lstm\"`). (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.GINConv`.\n    \"\"\"\n    supports_edge_weight: Final[bool] = False\n    supports_edge_attr: Final[bool] = False\n    supports_norm_batch: Final[bool]\n\n    def init_conv(self, in_channels: int, out_channels: int,\n                  **kwargs) -> MessagePassing:\n        mlp = MLP(\n            [in_channels, out_channels, out_channels],\n            act=self.act,\n            act_first=self.act_first,\n            norm=self.norm,\n            norm_kwargs=self.norm_kwargs,\n        )\n        return GINConv(mlp, **kwargs)\n\n\nclass GAT(BasicGNN):\n    r\"\"\"The Graph Neural Network from `\"Graph Attention Networks\"\n    <https://arxiv.org/abs/1710.10903>`_ or `\"How Attentive are Graph Attention\n    Networks?\" <https://arxiv.org/abs/2105.14491>`_ papers, using the\n    :class:`~torch_geometric.nn.GATConv` or\n    :class:`~torch_geometric.nn.GATv2Conv` operator for message passing,\n    respectively.\n\n    Args:\n        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to\n            derive the size from the first input(s) to the forward method.\n            A tuple corresponds to the sizes of source and target\n            dimensionalities.\n        hidden_channels (int): Size of each hidden sample.\n        num_layers (int): Number of message passing layers.\n        out_channels (int, optional): If not set to :obj:`None`, will apply a\n            final linear transformation to convert hidden node embeddings to\n            output size :obj:`out_channels`. (default: :obj:`None`)\n        v2 (bool, optional): If set to :obj:`True`, will make use of\n            :class:`~torch_geometric.nn.conv.GATv2Conv` rather than\n            :class:`~torch_geometric.nn.conv.GATConv`. (default: :obj:`False`)\n        dropout (float, optional): Dropout probability. (default: :obj:`0.`)\n        act (str or Callable, optional): The non-linear activation function to\n            use. (default: :obj:`\"relu\"`)\n        act_first (bool, optional): If set to :obj:`True`, activation is\n            applied before normalization. (default: :obj:`False`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        norm (str or Callable, optional): The normalization function to\n            use. (default: :obj:`None`)\n        norm_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective normalization function defined by :obj:`norm`.\n            (default: :obj:`None`)\n        jk (str, optional): The Jumping Knowledge mode. If specified, the model\n            will additionally apply a final linear transformation to transform\n            node embeddings to the expected output feature dimensionality.\n            (:obj:`None`, :obj:`\"last\"`, :obj:`\"cat\"`, :obj:`\"max\"`,\n            :obj:`\"lstm\"`). (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.GATConv` or\n            :class:`torch_geometric.nn.conv.GATv2Conv`.\n    \"\"\"\n    supports_edge_weight: Final[bool] = False\n    supports_edge_attr: Final[bool] = True\n    supports_norm_batch: Final[bool]\n\n    def init_conv(self, in_channels: Union[int, Tuple[int, int]],\n                  out_channels: int, **kwargs) -> MessagePassing:\n\n        v2 = kwargs.pop('v2', False)\n        heads = kwargs.pop('heads', 1)\n        concat = kwargs.pop('concat', True)\n\n        # Do not use concatenation in case the layer `GATConv` layer maps to\n        # the desired output channels (out_channels != None and jk != None):\n        if getattr(self, '_is_conv_to_out', False):\n            concat = False\n\n        if concat and out_channels % heads != 0:\n            raise ValueError(f\"Ensure that the number of output channels of \"\n                             f\"'GATConv' (got '{out_channels}') is divisible \"\n                             f\"by the number of heads (got '{heads}')\")\n\n        if concat:\n            out_channels = out_channels // heads\n\n        Conv = GATConv if not v2 else GATv2Conv\n        return Conv(in_channels, out_channels, heads=heads, concat=concat,\n                    dropout=self.dropout.p, **kwargs)\n\n\nclass PNA(BasicGNN):\n    r\"\"\"The Graph Neural Network from the `\"Principal Neighbourhood Aggregation\n    for Graph Nets\" <https://arxiv.org/abs/2004.05718>`_ paper, using the\n    :class:`~torch_geometric.nn.conv.PNAConv` operator for message passing.\n\n    Args:\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        hidden_channels (int): Size of each hidden sample.\n        num_layers (int): Number of message passing layers.\n        out_channels (int, optional): If not set to :obj:`None`, will apply a\n            final linear transformation to convert hidden node embeddings to\n            output size :obj:`out_channels`. (default: :obj:`None`)\n        dropout (float, optional): Dropout probability. (default: :obj:`0.`)\n        act (str or Callable, optional): The non-linear activation function to\n            use. (default: :obj:`\"relu\"`)\n        act_first (bool, optional): If set to :obj:`True`, activation is\n            applied before normalization. (default: :obj:`False`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        norm (str or Callable, optional): The normalization function to\n            use. (default: :obj:`None`)\n        norm_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective normalization function defined by :obj:`norm`.\n            (default: :obj:`None`)\n        jk (str, optional): The Jumping Knowledge mode. If specified, the model\n            will additionally apply a final linear transformation to transform\n            node embeddings to the expected output feature dimensionality.\n            (:obj:`None`, :obj:`\"last\"`, :obj:`\"cat\"`, :obj:`\"max\"`,\n            :obj:`\"lstm\"`). (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.PNAConv`.\n    \"\"\"\n    supports_edge_weight: Final[bool] = False\n    supports_edge_attr: Final[bool] = True\n    supports_norm_batch: Final[bool]\n\n    def init_conv(self, in_channels: int, out_channels: int,\n                  **kwargs) -> MessagePassing:\n        return PNAConv(in_channels, out_channels, **kwargs)\n\n\nclass EdgeCNN(BasicGNN):\n    r\"\"\"The Graph Neural Network from the `\"Dynamic Graph CNN for Learning on\n    Point Clouds\" <https://arxiv.org/abs/1801.07829>`_ paper, using the\n    :class:`~torch_geometric.nn.conv.EdgeConv` operator for message passing.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        hidden_channels (int): Size of each hidden sample.\n        num_layers (int): Number of message passing layers.\n        out_channels (int, optional): If not set to :obj:`None`, will apply a\n            final linear transformation to convert hidden node embeddings to\n            output size :obj:`out_channels`. (default: :obj:`None`)\n        dropout (float, optional): Dropout probability. (default: :obj:`0.`)\n        act (str or Callable, optional): The non-linear activation function to\n            use. (default: :obj:`\"relu\"`)\n        act_first (bool, optional): If set to :obj:`True`, activation is\n            applied before normalization. (default: :obj:`False`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        norm (str or Callable, optional): The normalization function to\n            use. (default: :obj:`None`)\n        norm_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective normalization function defined by :obj:`norm`.\n            (default: :obj:`None`)\n        jk (str, optional): The Jumping Knowledge mode. If specified, the model\n            will additionally apply a final linear transformation to transform\n            node embeddings to the expected output feature dimensionality.\n            (:obj:`None`, :obj:`\"last\"`, :obj:`\"cat\"`, :obj:`\"max\"`,\n            :obj:`\"lstm\"`). (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.EdgeConv`.\n    \"\"\"\n    supports_edge_weight: Final[bool] = False\n    supports_edge_attr: Final[bool] = False\n    supports_norm_batch: Final[bool]\n\n    def init_conv(self, in_channels: int, out_channels: int,\n                  **kwargs) -> MessagePassing:\n        mlp = MLP(\n            [2 * in_channels, out_channels, out_channels],\n            act=self.act,\n            act_first=self.act_first,\n            norm=self.norm,\n            norm_kwargs=self.norm_kwargs,\n        )\n        return EdgeConv(mlp, **kwargs)\n\n\n__all__ = [\n    'GCN',\n    'GraphSAGE',\n    'GIN',\n    'GAT',\n    'PNA',\n    'EdgeCNN',\n]\n"
  },
  {
    "path": "torch_geometric/nn/models/captum.py",
    "content": "from typing import Optional, Union\n\nimport torch\n\nfrom torch_geometric.explain.algorithm.captum import (\n    CaptumHeteroModel,\n    CaptumModel,\n    MaskLevelType,\n)\nfrom torch_geometric.typing import Metadata\n\n\ndef to_captum_model(\n    model: torch.nn.Module,\n    mask_type: Union[str, MaskLevelType] = MaskLevelType.edge,\n    output_idx: Optional[int] = None,\n    metadata: Optional[Metadata] = None,\n) -> Union[CaptumModel, CaptumHeteroModel]:\n    r\"\"\"Converts a model to a model that can be used for\n    `Captum <https://captum.ai/>`_ attribution methods.\n\n    Sample code for homogeneous graphs:\n\n    .. code-block:: python\n\n        from captum.attr import IntegratedGradients\n\n        from torch_geometric.data import Data\n        from torch_geometric.nn import GCN\n        from torch_geometric.nn import to_captum_model, to_captum_input\n\n        data = Data(x=(...), edge_index(...))\n        model = GCN(...)\n        ...  # Train the model.\n\n        # Explain predictions for node `10`:\n        mask_type=\"edge\"\n        output_idx = 10\n        captum_model = to_captum_model(model, mask_type, output_idx)\n        inputs, additional_forward_args = to_captum_input(data.x,\n                                            data.edge_index,mask_type)\n\n        ig = IntegratedGradients(captum_model)\n        ig_attr = ig.attribute(inputs = inputs,\n                               target=int(y[output_idx]),\n                               additional_forward_args=additional_forward_args,\n                               internal_batch_size=1)\n\n\n    Sample code for heterogeneous graphs:\n\n    .. code-block:: python\n\n        from captum.attr import IntegratedGradients\n\n        from torch_geometric.data import HeteroData\n        from torch_geometric.nn import HeteroConv\n        from torch_geometric.nn import (captum_output_to_dicts,\n                                        to_captum_model, to_captum_input)\n\n        data = HeteroData(...)\n        model = HeteroConv(...)\n        ...  # Train the model.\n\n        # Explain predictions for node `10`:\n        mask_type=\"edge\"\n        metadata = data.metadata\n        output_idx = 10\n        captum_model = to_captum_model(model, mask_type, output_idx, metadata)\n        inputs, additional_forward_args = to_captum_input(data.x_dict,\n                                            data.edge_index_dict, mask_type)\n\n        ig = IntegratedGradients(captum_model)\n        ig_attr = ig.attribute(inputs=inputs,\n                               target=int(y[output_idx]),\n                               additional_forward_args=additional_forward_args,\n                               internal_batch_size=1)\n        edge_attr_dict = captum_output_to_dicts(ig_attr, mask_type, metadata)\n\n\n    .. note::\n        For an example of using a :captum:`Captum` attribution method within\n        :pyg:`PyG`, see `examples/explain/captum_explainer.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        explain/captum_explainer.py>`_.\n\n    Args:\n        model (torch.nn.Module): The model to be explained.\n        mask_type (str, optional): Denotes the type of mask to be created with\n            a :captum:`Captum` explainer. Valid inputs are :obj:`\"edge\"`,\n            :obj:`\"node\"`, and :obj:`\"node_and_edge\"`. (default: :obj:`\"edge\"`)\n        output_idx (int, optional): Index of the output element (node or link\n            index) to be explained. With :obj:`output_idx` set, the forward\n            function will return the output of the model for the element at\n            the index specified. (default: :obj:`None`)\n        metadata (Metadata, optional): The metadata of the heterogeneous graph.\n            Only required if explaining a\n            :class:`~torch_geometric.data.HeteroData` object.\n            (default: :obj:`None`)\n    \"\"\"\n    if metadata is None:\n        return CaptumModel(model, mask_type, output_idx)\n    else:\n        return CaptumHeteroModel(model, mask_type, output_idx, metadata)\n"
  },
  {
    "path": "torch_geometric/nn/models/correct_and_smooth.py",
    "content": "import torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.models import LabelPropagation\nfrom torch_geometric.typing import Adj, OptTensor\nfrom torch_geometric.utils import one_hot\n\n\nclass CorrectAndSmooth(torch.nn.Module):\n    r\"\"\"The correct and smooth (C&S) post-processing model from the\n    `\"Combining Label Propagation And Simple Models Out-performs Graph Neural\n    Networks\"\n    <https://arxiv.org/abs/2010.13993>`_ paper, where soft predictions\n    :math:`\\mathbf{Z}` (obtained from a simple base predictor) are\n    first corrected based on ground-truth training\n    label information :math:`\\mathbf{Y}` and residual propagation.\n\n    .. math::\n        \\mathbf{e}^{(0)}_i &= \\begin{cases}\n            \\mathbf{y}_i - \\mathbf{z}_i, & \\text{if }i\n            \\text{ is training node,}\\\\\n            \\mathbf{0}, & \\text{else}\n        \\end{cases}\n\n    .. math::\n        \\mathbf{E}^{(\\ell)} &= \\alpha_1 \\mathbf{D}^{-1/2}\\mathbf{A}\n        \\mathbf{D}^{-1/2} \\mathbf{E}^{(\\ell - 1)} +\n        (1 - \\alpha_1) \\mathbf{E}^{(\\ell - 1)}\n\n        \\mathbf{\\hat{Z}} &= \\mathbf{Z} + \\gamma \\cdot \\mathbf{E}^{(L_1)},\n\n    where :math:`\\gamma` denotes the scaling factor (either fixed or\n    automatically determined), and then smoothed over the graph via label\n    propagation\n\n    .. math::\n        \\mathbf{\\hat{z}}^{(0)}_i &= \\begin{cases}\n            \\mathbf{y}_i, & \\text{if }i\\text{ is training node,}\\\\\n            \\mathbf{\\hat{z}}_i, & \\text{else}\n        \\end{cases}\n\n    .. math::\n        \\mathbf{\\hat{Z}}^{(\\ell)} = \\alpha_2 \\mathbf{D}^{-1/2}\\mathbf{A}\n        \\mathbf{D}^{-1/2} \\mathbf{\\hat{Z}}^{(\\ell - 1)} +\n        (1 - \\alpha_2) \\mathbf{\\hat{Z}}^{(\\ell - 1)}\n\n    to obtain the final prediction :math:`\\mathbf{\\hat{Z}}^{(L_2)}`.\n\n    .. note::\n\n        For an example of using the C&S model, see\n        `examples/correct_and_smooth.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        correct_and_smooth.py>`_.\n\n    Args:\n        num_correction_layers (int): The number of propagations :math:`L_1`.\n        correction_alpha (float): The :math:`\\alpha_1` coefficient.\n        num_smoothing_layers (int): The number of propagations :math:`L_2`.\n        smoothing_alpha (float): The :math:`\\alpha_2` coefficient.\n        autoscale (bool, optional): If set to :obj:`True`, will automatically\n            determine the scaling factor :math:`\\gamma`. (default: :obj:`True`)\n        scale (float, optional): The scaling factor :math:`\\gamma`, in case\n            :obj:`autoscale = False`. (default: :obj:`1.0`)\n    \"\"\"\n    def __init__(self, num_correction_layers: int, correction_alpha: float,\n                 num_smoothing_layers: int, smoothing_alpha: float,\n                 autoscale: bool = True, scale: float = 1.0):\n        super().__init__()\n        self.autoscale = autoscale\n        self.scale = scale\n\n        self.prop1 = LabelPropagation(num_correction_layers, correction_alpha)\n        self.prop2 = LabelPropagation(num_smoothing_layers, smoothing_alpha)\n\n    def forward(self, y_soft: Tensor, *args) -> Tensor:  # pragma: no cover\n        r\"\"\"Applies both :meth:`correct` and :meth:`smooth`.\"\"\"\n        y_soft = self.correct(y_soft, *args)\n        return self.smooth(y_soft, *args)\n\n    def correct(self, y_soft: Tensor, y_true: Tensor, mask: Tensor,\n                edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            y_soft (torch.Tensor): The soft predictions :math:`\\mathbf{Z}`\n                obtained from a simple base predictor.\n            y_true (torch.Tensor): The ground-truth label information\n                :math:`\\mathbf{Y}` of training nodes.\n            mask (torch.Tensor): A mask or index tensor denoting which nodes\n                were used for training.\n            edge_index (torch.Tensor or SparseTensor): The edge connectivity.\n            edge_weight (torch.Tensor, optional): The edge weights.\n                (default: :obj:`None`)\n        \"\"\"\n        numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)\n        assert y_true.size(0) == numel\n\n        if y_true.dtype == torch.long and y_true.size(0) == y_true.numel():\n            y_true = one_hot(y_true.view(-1), num_classes=y_soft.size(-1),\n                             dtype=y_soft.dtype)\n\n        error = torch.zeros_like(y_soft)\n        error[mask] = y_true - y_soft[mask]\n\n        if self.autoscale:\n            smoothed_error = self.prop1(error, edge_index,\n                                        edge_weight=edge_weight,\n                                        post_step=lambda x: x.clamp_(-1., 1.))\n\n            sigma = error[mask].abs().sum() / numel\n            scale = sigma / smoothed_error.abs().sum(dim=1, keepdim=True)\n            scale[scale.isinf() | (scale > 1000)] = 1.0\n            return y_soft + scale * smoothed_error\n        else:\n\n            def fix_input(x):\n                x[mask] = error[mask]\n                return x\n\n            smoothed_error = self.prop1(error, edge_index,\n                                        edge_weight=edge_weight,\n                                        post_step=fix_input)\n            return y_soft + self.scale * smoothed_error\n\n    def smooth(self, y_soft: Tensor, y_true: Tensor, mask: Tensor,\n               edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            y_soft (torch.Tensor): The corrected predictions :math:`\\mathbf{Z}`\n                obtained from :meth:`correct`.\n            y_true (torch.Tensor): The ground-truth label information\n                :math:`\\mathbf{Y}` of training nodes.\n            mask (torch.Tensor): A mask or index tensor denoting which nodes\n                were used for training.\n            edge_index (torch.Tensor or SparseTensor): The edge connectivity.\n            edge_weight (torch.Tensor, optional): The edge weights.\n                (default: :obj:`None`)\n        \"\"\"\n        numel = int(mask.sum()) if mask.dtype == torch.bool else mask.size(0)\n        assert y_true.size(0) == numel\n\n        if y_true.dtype == torch.long and y_true.size(0) == y_true.numel():\n            y_true = one_hot(y_true.view(-1), num_classes=y_soft.size(-1),\n                             dtype=y_soft.dtype)\n\n        y_soft = y_soft.clone()\n        y_soft[mask] = y_true\n\n        return self.prop2(y_soft, edge_index, edge_weight=edge_weight)\n\n    def __repr__(self):\n        L1, alpha1 = self.prop1.num_layers, self.prop1.alpha\n        L2, alpha2 = self.prop2.num_layers, self.prop2.alpha\n        return (f'{self.__class__.__name__}(\\n'\n                f'  correct: num_layers={L1}, alpha={alpha1}\\n'\n                f'  smooth:  num_layers={L2}, alpha={alpha2}\\n'\n                f'  autoscale={self.autoscale}, scale={self.scale}\\n'\n                ')')\n"
  },
  {
    "path": "torch_geometric/nn/models/deep_graph_infomax.py",
    "content": "import copy\nfrom typing import Callable, Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Module, Parameter\n\nfrom torch_geometric.nn.inits import reset, uniform\n\nEPS = 1e-15\n\n\nclass DeepGraphInfomax(torch.nn.Module):\n    r\"\"\"The Deep Graph Infomax model from the\n    `\"Deep Graph Infomax\" <https://arxiv.org/abs/1809.10341>`_\n    paper based on user-defined encoder and summary model :math:`\\mathcal{E}`\n    and :math:`\\mathcal{R}` respectively, and a corruption function\n    :math:`\\mathcal{C}`.\n\n    Args:\n        hidden_channels (int): The latent space dimensionality.\n        encoder (torch.nn.Module): The encoder module :math:`\\mathcal{E}`.\n        summary (callable): The readout function :math:`\\mathcal{R}`.\n        corruption (callable): The corruption function :math:`\\mathcal{C}`.\n    \"\"\"\n    def __init__(\n        self,\n        hidden_channels: int,\n        encoder: Module,\n        summary: Callable,\n        corruption: Callable,\n    ):\n        super().__init__()\n        self.hidden_channels = hidden_channels\n        self.encoder = encoder\n        self.summary = summary\n        self.corruption = corruption\n\n        self.weight = Parameter(torch.empty(hidden_channels, hidden_channels))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        reset(self.encoder)\n        reset(self.summary)\n        uniform(self.hidden_channels, self.weight)\n\n    def forward(self, *args, **kwargs) -> Tuple[Tensor, Tensor, Tensor]:\n        \"\"\"Returns the latent space for the input arguments, their\n        corruptions and their summary representation.\n        \"\"\"\n        pos_z = self.encoder(*args, **kwargs)\n\n        cor = self.corruption(*args, **kwargs)\n        cor = cor if isinstance(cor, tuple) else (cor, )\n        cor_args = cor[:len(args)]\n        cor_kwargs = copy.copy(kwargs)\n        for key, value in zip(kwargs.keys(), cor[len(args):]):\n            cor_kwargs[key] = value\n\n        neg_z = self.encoder(*cor_args, **cor_kwargs)\n\n        summary = self.summary(pos_z, *args, **kwargs)\n\n        return pos_z, neg_z, summary\n\n    def discriminate(self, z: Tensor, summary: Tensor,\n                     sigmoid: bool = True) -> Tensor:\n        r\"\"\"Given the patch-summary pair :obj:`z` and :obj:`summary`, computes\n        the probability scores assigned to this patch-summary pair.\n\n        Args:\n            z (torch.Tensor): The latent space.\n            summary (torch.Tensor): The summary vector.\n            sigmoid (bool, optional): If set to :obj:`False`, does not apply\n                the logistic sigmoid function to the output.\n                (default: :obj:`True`)\n        \"\"\"\n        summary = summary.t() if summary.dim() > 1 else summary\n        value = torch.matmul(z, torch.matmul(self.weight, summary))\n        return torch.sigmoid(value) if sigmoid else value\n\n    def loss(self, pos_z: Tensor, neg_z: Tensor, summary: Tensor) -> Tensor:\n        r\"\"\"Computes the mutual information maximization objective.\"\"\"\n        pos_loss = -torch.log(\n            self.discriminate(pos_z, summary, sigmoid=True) + EPS).mean()\n        neg_loss = -torch.log(1 -\n                              self.discriminate(neg_z, summary, sigmoid=True) +\n                              EPS).mean()\n\n        return pos_loss + neg_loss\n\n    def test(\n        self,\n        train_z: Tensor,\n        train_y: Tensor,\n        test_z: Tensor,\n        test_y: Tensor,\n        solver: str = 'lbfgs',\n        *args,\n        **kwargs,\n    ) -> float:\n        r\"\"\"Evaluates latent space quality via a logistic regression downstream\n        task.\n        \"\"\"\n        from sklearn.linear_model import LogisticRegression\n\n        clf = LogisticRegression(*args, solver=solver,\n                                 **kwargs).fit(train_z.detach().cpu().numpy(),\n                                               train_y.detach().cpu().numpy())\n        return clf.score(test_z.detach().cpu().numpy(),\n                         test_y.detach().cpu().numpy())\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.hidden_channels})'\n"
  },
  {
    "path": "torch_geometric/nn/models/deepgcn.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Module\nfrom torch.utils.checkpoint import checkpoint\n\n\nclass DeepGCNLayer(torch.nn.Module):\n    r\"\"\"The skip connection operations from the\n    `\"DeepGCNs: Can GCNs Go as Deep as CNNs?\"\n    <https://arxiv.org/abs/1904.03751>`_ and `\"All You Need to Train Deeper\n    GCNs\" <https://arxiv.org/abs/2006.07739>`_ papers.\n    The implemented skip connections includes the pre-activation residual\n    connection (:obj:`\"res+\"`), the residual connection (:obj:`\"res\"`),\n    the dense connection (:obj:`\"dense\"`) and no connections (:obj:`\"plain\"`).\n\n    * **Res+** (:obj:`\"res+\"`):\n\n    .. math::\n        \\text{Normalization}\\to\\text{Activation}\\to\\text{Dropout}\\to\n        \\text{GraphConv}\\to\\text{Res}\n\n    * **Res** (:obj:`\"res\"`) / **Dense** (:obj:`\"dense\"`) / **Plain**\n      (:obj:`\"plain\"`):\n\n    .. math::\n        \\text{GraphConv}\\to\\text{Normalization}\\to\\text{Activation}\\to\n        \\text{Res/Dense/Plain}\\to\\text{Dropout}\n\n    .. note::\n\n        For an example of using :obj:`GENConv`, see\n        `examples/ogbn_proteins_deepgcn.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        ogbn_proteins_deepgcn.py>`_.\n\n    Args:\n        conv (torch.nn.Module, optional): the GCN operator.\n            (default: :obj:`None`)\n        norm (torch.nn.Module): the normalization layer. (default: :obj:`None`)\n        act (torch.nn.Module): the activation layer. (default: :obj:`None`)\n        block (str, optional): The skip connection operation to use\n            (:obj:`\"res+\"`, :obj:`\"res\"`, :obj:`\"dense\"` or :obj:`\"plain\"`).\n            (default: :obj:`\"res+\"`)\n        dropout (float, optional): Whether to apply or dropout.\n            (default: :obj:`0.`)\n        ckpt_grad (bool, optional): If set to :obj:`True`, will checkpoint this\n            part of the model. Checkpointing works by trading compute for\n            memory, since intermediate activations do not need to be kept in\n            memory. Set this to :obj:`True` in case you encounter out-of-memory\n            errors while going deep. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        conv: Optional[Module] = None,\n        norm: Optional[Module] = None,\n        act: Optional[Module] = None,\n        block: str = 'res+',\n        dropout: float = 0.,\n        ckpt_grad: bool = False,\n    ):\n        super().__init__()\n\n        self.conv = conv\n        self.norm = norm\n        self.act = act\n        self.block = block.lower()\n        assert self.block in ['res+', 'res', 'dense', 'plain']\n        self.dropout = dropout\n        self.ckpt_grad = ckpt_grad\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.conv.reset_parameters()\n        self.norm.reset_parameters()\n\n    def forward(self, *args, **kwargs) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        args = list(args)\n        x = args.pop(0)\n\n        if self.block == 'res+':\n            h = x\n            if self.norm is not None:\n                h = self.norm(h)\n            if self.act is not None:\n                h = self.act(h)\n            h = F.dropout(h, p=self.dropout, training=self.training)\n            if self.conv is not None and self.ckpt_grad and h.requires_grad:\n                h = checkpoint(self.conv, h, *args, use_reentrant=True,\n                               **kwargs)\n            else:\n                h = self.conv(h, *args, **kwargs)\n\n            return x + h\n\n        else:\n            if self.conv is not None and self.ckpt_grad and x.requires_grad:\n                h = checkpoint(self.conv, x, *args, use_reentrant=True,\n                               **kwargs)\n            else:\n                h = self.conv(x, *args, **kwargs)\n            if self.norm is not None:\n                h = self.norm(h)\n            if self.act is not None:\n                h = self.act(h)\n\n            if self.block == 'res':\n                h = x + h\n            elif self.block == 'dense':\n                h = torch.cat([x, h], dim=-1)\n            elif self.block == 'plain':\n                pass\n\n            return F.dropout(h, p=self.dropout, training=self.training)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(block={self.block})'\n"
  },
  {
    "path": "torch_geometric/nn/models/dimenet.py",
    "content": "import os\nimport os.path as osp\nfrom functools import partial\nfrom math import pi as PI\nfrom math import sqrt\nfrom typing import Callable, Dict, Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Embedding, Linear\n\nfrom torch_geometric.data import Dataset, download_url\nfrom torch_geometric.nn import radius_graph\nfrom torch_geometric.nn.inits import glorot_orthogonal\nfrom torch_geometric.nn.resolver import activation_resolver\nfrom torch_geometric.typing import OptTensor, SparseTensor\nfrom torch_geometric.utils import scatter\n\nqm9_target_dict: Dict[int, str] = {\n    0: 'mu',\n    1: 'alpha',\n    2: 'homo',\n    3: 'lumo',\n    5: 'r2',\n    6: 'zpve',\n    7: 'U0',\n    8: 'U',\n    9: 'H',\n    10: 'G',\n    11: 'Cv',\n}\n\n\nclass Envelope(torch.nn.Module):\n    def __init__(self, exponent: int):\n        super().__init__()\n        self.p = exponent + 1\n        self.a = -(self.p + 1) * (self.p + 2) / 2\n        self.b = self.p * (self.p + 2)\n        self.c = -self.p * (self.p + 1) / 2\n\n    def forward(self, x: Tensor) -> Tensor:\n        p, a, b, c = self.p, self.a, self.b, self.c\n        x_pow_p0 = x.pow(p - 1)\n        x_pow_p1 = x_pow_p0 * x\n        x_pow_p2 = x_pow_p1 * x\n        return (1.0 / x + a * x_pow_p0 + b * x_pow_p1 +\n                c * x_pow_p2) * (x < 1.0).to(x.dtype)\n\n\nclass BesselBasisLayer(torch.nn.Module):\n    def __init__(self, num_radial: int, cutoff: float = 5.0,\n                 envelope_exponent: int = 5):\n        super().__init__()\n        self.cutoff = cutoff\n        self.envelope = Envelope(envelope_exponent)\n\n        self.freq = torch.nn.Parameter(torch.empty(num_radial))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        with torch.no_grad():\n            torch.arange(1, self.freq.numel() + 1, out=self.freq).mul_(PI)\n        self.freq.requires_grad_()\n\n    def forward(self, dist: Tensor) -> Tensor:\n        dist = dist.unsqueeze(-1) / self.cutoff\n        return self.envelope(dist) * (self.freq * dist).sin()\n\n\nclass SphericalBasisLayer(torch.nn.Module):\n    def __init__(\n        self,\n        num_spherical: int,\n        num_radial: int,\n        cutoff: float = 5.0,\n        envelope_exponent: int = 5,\n    ):\n        super().__init__()\n        import sympy as sym\n\n        from torch_geometric.nn.models.dimenet_utils import (\n            bessel_basis,\n            real_sph_harm,\n        )\n\n        assert num_radial <= 64\n        self.num_spherical = num_spherical\n        self.num_radial = num_radial\n        self.cutoff = cutoff\n        self.envelope = Envelope(envelope_exponent)\n\n        bessel_forms = bessel_basis(num_spherical, num_radial)\n        sph_harm_forms = real_sph_harm(num_spherical)\n        self.sph_funcs = []\n        self.bessel_funcs = []\n\n        x, theta = sym.symbols('x theta')\n        modules = {'sin': torch.sin, 'cos': torch.cos}\n        for i in range(num_spherical):\n            if i == 0:\n                sph1 = sym.lambdify([theta], sph_harm_forms[i][0], modules)(0)\n                self.sph_funcs.append(partial(self._sph_to_tensor, sph1))\n            else:\n                sph = sym.lambdify([theta], sph_harm_forms[i][0], modules)\n                self.sph_funcs.append(sph)\n            for j in range(num_radial):\n                bessel = sym.lambdify([x], bessel_forms[i][j], modules)\n                self.bessel_funcs.append(bessel)\n\n    @staticmethod\n    def _sph_to_tensor(sph, x: Tensor) -> Tensor:\n        return torch.zeros_like(x) + sph\n\n    def forward(self, dist: Tensor, angle: Tensor, idx_kj: Tensor) -> Tensor:\n        dist = dist / self.cutoff\n        rbf = torch.stack([f(dist) for f in self.bessel_funcs], dim=1)\n        rbf = self.envelope(dist).unsqueeze(-1) * rbf\n\n        cbf = torch.stack([f(angle) for f in self.sph_funcs], dim=1)\n\n        n, k = self.num_spherical, self.num_radial\n        out = (rbf[idx_kj].view(-1, n, k) * cbf.view(-1, n, 1)).view(-1, n * k)\n        return out\n\n\nclass EmbeddingBlock(torch.nn.Module):\n    def __init__(self, num_radial: int, hidden_channels: int, act: Callable):\n        super().__init__()\n        self.act = act\n\n        self.emb = Embedding(95, hidden_channels)\n        self.lin_rbf = Linear(num_radial, hidden_channels)\n        self.lin = Linear(3 * hidden_channels, hidden_channels)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        self.emb.weight.data.uniform_(-sqrt(3), sqrt(3))\n        self.lin_rbf.reset_parameters()\n        self.lin.reset_parameters()\n\n    def forward(self, x: Tensor, rbf: Tensor, i: Tensor, j: Tensor) -> Tensor:\n        x = self.emb(x)\n        rbf = self.act(self.lin_rbf(rbf))\n        return self.act(self.lin(torch.cat([x[i], x[j], rbf], dim=-1)))\n\n\nclass ResidualLayer(torch.nn.Module):\n    def __init__(self, hidden_channels: int, act: Callable):\n        super().__init__()\n        self.act = act\n        self.lin1 = Linear(hidden_channels, hidden_channels)\n        self.lin2 = Linear(hidden_channels, hidden_channels)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot_orthogonal(self.lin1.weight, scale=2.0)\n        self.lin1.bias.data.fill_(0)\n        glorot_orthogonal(self.lin2.weight, scale=2.0)\n        self.lin2.bias.data.fill_(0)\n\n    def forward(self, x: Tensor) -> Tensor:\n        return x + self.act(self.lin2(self.act(self.lin1(x))))\n\n\nclass InteractionBlock(torch.nn.Module):\n    def __init__(\n        self,\n        hidden_channels: int,\n        num_bilinear: int,\n        num_spherical: int,\n        num_radial: int,\n        num_before_skip: int,\n        num_after_skip: int,\n        act: Callable,\n    ):\n        super().__init__()\n        self.act = act\n\n        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)\n        self.lin_sbf = Linear(num_spherical * num_radial, num_bilinear,\n                              bias=False)\n\n        # Dense transformations of input messages.\n        self.lin_kj = Linear(hidden_channels, hidden_channels)\n        self.lin_ji = Linear(hidden_channels, hidden_channels)\n\n        self.W = torch.nn.Parameter(\n            torch.empty(hidden_channels, num_bilinear, hidden_channels))\n\n        self.layers_before_skip = torch.nn.ModuleList([\n            ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)\n        ])\n        self.lin = Linear(hidden_channels, hidden_channels)\n        self.layers_after_skip = torch.nn.ModuleList([\n            ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)\n        ])\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)\n        glorot_orthogonal(self.lin_sbf.weight, scale=2.0)\n        glorot_orthogonal(self.lin_kj.weight, scale=2.0)\n        self.lin_kj.bias.data.fill_(0)\n        glorot_orthogonal(self.lin_ji.weight, scale=2.0)\n        self.lin_ji.bias.data.fill_(0)\n        self.W.data.normal_(mean=0, std=2 / self.W.size(0))\n        for res_layer in self.layers_before_skip:\n            res_layer.reset_parameters()\n        glorot_orthogonal(self.lin.weight, scale=2.0)\n        self.lin.bias.data.fill_(0)\n        for res_layer in self.layers_after_skip:\n            res_layer.reset_parameters()\n\n    def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,\n                idx_ji: Tensor) -> Tensor:\n        rbf = self.lin_rbf(rbf)\n        sbf = self.lin_sbf(sbf)\n\n        x_ji = self.act(self.lin_ji(x))\n        x_kj = self.act(self.lin_kj(x))\n        x_kj = x_kj * rbf\n        x_kj = torch.einsum('wj,wl,ijl->wi', sbf, x_kj[idx_kj], self.W)\n        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0), reduce='sum')\n\n        h = x_ji + x_kj\n        for layer in self.layers_before_skip:\n            h = layer(h)\n        h = self.act(self.lin(h)) + x\n        for layer in self.layers_after_skip:\n            h = layer(h)\n\n        return h\n\n\nclass InteractionPPBlock(torch.nn.Module):\n    def __init__(\n        self,\n        hidden_channels: int,\n        int_emb_size: int,\n        basis_emb_size: int,\n        num_spherical: int,\n        num_radial: int,\n        num_before_skip: int,\n        num_after_skip: int,\n        act: Callable,\n    ):\n        super().__init__()\n        self.act = act\n\n        # Transformation of Bessel and spherical basis representations:\n        self.lin_rbf1 = Linear(num_radial, basis_emb_size, bias=False)\n        self.lin_rbf2 = Linear(basis_emb_size, hidden_channels, bias=False)\n\n        self.lin_sbf1 = Linear(num_spherical * num_radial, basis_emb_size,\n                               bias=False)\n        self.lin_sbf2 = Linear(basis_emb_size, int_emb_size, bias=False)\n\n        # Hidden transformation of input message:\n        self.lin_kj = Linear(hidden_channels, hidden_channels)\n        self.lin_ji = Linear(hidden_channels, hidden_channels)\n\n        # Embedding projections for interaction triplets:\n        self.lin_down = Linear(hidden_channels, int_emb_size, bias=False)\n        self.lin_up = Linear(int_emb_size, hidden_channels, bias=False)\n\n        # Residual layers before and after skip connection:\n        self.layers_before_skip = torch.nn.ModuleList([\n            ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)\n        ])\n        self.lin = Linear(hidden_channels, hidden_channels)\n        self.layers_after_skip = torch.nn.ModuleList([\n            ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)\n        ])\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot_orthogonal(self.lin_rbf1.weight, scale=2.0)\n        glorot_orthogonal(self.lin_rbf2.weight, scale=2.0)\n        glorot_orthogonal(self.lin_sbf1.weight, scale=2.0)\n        glorot_orthogonal(self.lin_sbf2.weight, scale=2.0)\n\n        glorot_orthogonal(self.lin_kj.weight, scale=2.0)\n        self.lin_kj.bias.data.fill_(0)\n        glorot_orthogonal(self.lin_ji.weight, scale=2.0)\n        self.lin_ji.bias.data.fill_(0)\n\n        glorot_orthogonal(self.lin_down.weight, scale=2.0)\n        glorot_orthogonal(self.lin_up.weight, scale=2.0)\n\n        for res_layer in self.layers_before_skip:\n            res_layer.reset_parameters()\n        glorot_orthogonal(self.lin.weight, scale=2.0)\n        self.lin.bias.data.fill_(0)\n        for res_layer in self.layers_after_skip:\n            res_layer.reset_parameters()\n\n    def forward(self, x: Tensor, rbf: Tensor, sbf: Tensor, idx_kj: Tensor,\n                idx_ji: Tensor) -> Tensor:\n        # Initial transformation:\n        x_ji = self.act(self.lin_ji(x))\n        x_kj = self.act(self.lin_kj(x))\n\n        # Transformation via Bessel basis:\n        rbf = self.lin_rbf1(rbf)\n        rbf = self.lin_rbf2(rbf)\n        x_kj = x_kj * rbf\n\n        # Down project embedding and generating triple-interactions:\n        x_kj = self.act(self.lin_down(x_kj))\n\n        # Transform via 2D spherical basis:\n        sbf = self.lin_sbf1(sbf)\n        sbf = self.lin_sbf2(sbf)\n        x_kj = x_kj[idx_kj] * sbf\n\n        # Aggregate interactions and up-project embeddings:\n        x_kj = scatter(x_kj, idx_ji, dim=0, dim_size=x.size(0), reduce='sum')\n        x_kj = self.act(self.lin_up(x_kj))\n\n        h = x_ji + x_kj\n        for layer in self.layers_before_skip:\n            h = layer(h)\n        h = self.act(self.lin(h)) + x\n        for layer in self.layers_after_skip:\n            h = layer(h)\n\n        return h\n\n\nclass OutputBlock(torch.nn.Module):\n    def __init__(\n        self,\n        num_radial: int,\n        hidden_channels: int,\n        out_channels: int,\n        num_layers: int,\n        act: Callable,\n        output_initializer: str = 'zeros',\n    ):\n        assert output_initializer in {'zeros', 'glorot_orthogonal'}\n\n        super().__init__()\n\n        self.act = act\n        self.output_initializer = output_initializer\n\n        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)\n        self.lins = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            self.lins.append(Linear(hidden_channels, hidden_channels))\n        self.lin = Linear(hidden_channels, out_channels, bias=False)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)\n        for lin in self.lins:\n            glorot_orthogonal(lin.weight, scale=2.0)\n            lin.bias.data.fill_(0)\n        if self.output_initializer == 'zeros':\n            self.lin.weight.data.fill_(0)\n        elif self.output_initializer == 'glorot_orthogonal':\n            glorot_orthogonal(self.lin.weight, scale=2.0)\n\n    def forward(self, x: Tensor, rbf: Tensor, i: Tensor,\n                num_nodes: Optional[int] = None) -> Tensor:\n        x = self.lin_rbf(rbf) * x\n        x = scatter(x, i, dim=0, dim_size=num_nodes, reduce='sum')\n        for lin in self.lins:\n            x = self.act(lin(x))\n        return self.lin(x)\n\n\nclass OutputPPBlock(torch.nn.Module):\n    def __init__(\n        self,\n        num_radial: int,\n        hidden_channels: int,\n        out_emb_channels: int,\n        out_channels: int,\n        num_layers: int,\n        act: Callable,\n        output_initializer: str = 'zeros',\n    ):\n        assert output_initializer in {'zeros', 'glorot_orthogonal'}\n\n        super().__init__()\n\n        self.act = act\n        self.output_initializer = output_initializer\n\n        self.lin_rbf = Linear(num_radial, hidden_channels, bias=False)\n\n        # The up-projection layer:\n        self.lin_up = Linear(hidden_channels, out_emb_channels, bias=False)\n        self.lins = torch.nn.ModuleList()\n        for _ in range(num_layers):\n            self.lins.append(Linear(out_emb_channels, out_emb_channels))\n        self.lin = Linear(out_emb_channels, out_channels, bias=False)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        glorot_orthogonal(self.lin_rbf.weight, scale=2.0)\n        glorot_orthogonal(self.lin_up.weight, scale=2.0)\n        for lin in self.lins:\n            glorot_orthogonal(lin.weight, scale=2.0)\n            lin.bias.data.fill_(0)\n        if self.output_initializer == 'zeros':\n            self.lin.weight.data.fill_(0)\n        elif self.output_initializer == 'glorot_orthogonal':\n            glorot_orthogonal(self.lin.weight, scale=2.0)\n\n    def forward(self, x: Tensor, rbf: Tensor, i: Tensor,\n                num_nodes: Optional[int] = None) -> Tensor:\n        x = self.lin_rbf(rbf) * x\n        x = scatter(x, i, dim=0, dim_size=num_nodes, reduce='sum')\n        x = self.lin_up(x)\n        for lin in self.lins:\n            x = self.act(lin(x))\n        return self.lin(x)\n\n\ndef triplets(\n    edge_index: Tensor,\n    num_nodes: int,\n) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:\n    row, col = edge_index  # j->i\n\n    value = torch.arange(row.size(0), device=row.device)\n    adj_t = SparseTensor(row=col, col=row, value=value,\n                         sparse_sizes=(num_nodes, num_nodes))\n    adj_t_row = adj_t[row]\n    num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)\n\n    # Node indices (k->j->i) for triplets.\n    idx_i = col.repeat_interleave(num_triplets)\n    idx_j = row.repeat_interleave(num_triplets)\n    idx_k = adj_t_row.storage.col()\n    mask = idx_i != idx_k  # Remove i == k triplets.\n    idx_i, idx_j, idx_k = idx_i[mask], idx_j[mask], idx_k[mask]\n\n    # Edge indices (k-j, j->i) for triplets.\n    idx_kj = adj_t_row.storage.value()[mask]\n    idx_ji = adj_t_row.storage.row()[mask]\n\n    return col, row, idx_i, idx_j, idx_k, idx_kj, idx_ji\n\n\nclass DimeNet(torch.nn.Module):\n    r\"\"\"The directional message passing neural network (DimeNet) from the\n    `\"Directional Message Passing for Molecular Graphs\"\n    <https://arxiv.org/abs/2003.03123>`_ paper.\n    DimeNet transforms messages based on the angle between them in a\n    rotation-equivariant fashion.\n\n    .. note::\n\n        For an example of using a pretrained DimeNet variant, see\n        `examples/qm9_pretrained_dimenet.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        qm9_pretrained_dimenet.py>`_.\n\n    Args:\n        hidden_channels (int): Hidden embedding size.\n        out_channels (int): Size of each output sample.\n        num_blocks (int): Number of building blocks.\n        num_bilinear (int): Size of the bilinear layer tensor.\n        num_spherical (int): Number of spherical harmonics.\n        num_radial (int): Number of radial basis functions.\n        cutoff (float, optional): Cutoff distance for interatomic\n            interactions. (default: :obj:`5.0`)\n        max_num_neighbors (int, optional): The maximum number of neighbors to\n            collect for each node within the :attr:`cutoff` distance.\n            (default: :obj:`32`)\n        envelope_exponent (int, optional): Shape of the smooth cutoff.\n            (default: :obj:`5`)\n        num_before_skip (int, optional): Number of residual layers in the\n            interaction blocks before the skip connection. (default: :obj:`1`)\n        num_after_skip (int, optional): Number of residual layers in the\n            interaction blocks after the skip connection. (default: :obj:`2`)\n        num_output_layers (int, optional): Number of linear layers for the\n            output blocks. (default: :obj:`3`)\n        act (str or Callable, optional): The activation function.\n            (default: :obj:`\"swish\"`)\n        output_initializer (str, optional): The initialization method for the\n            output layer (:obj:`\"zeros\"`, :obj:`\"glorot_orthogonal\"`).\n            (default: :obj:`\"zeros\"`)\n    \"\"\"\n\n    url = ('https://github.com/klicperajo/dimenet/raw/master/pretrained/'\n           'dimenet')\n\n    def __init__(\n        self,\n        hidden_channels: int,\n        out_channels: int,\n        num_blocks: int,\n        num_bilinear: int,\n        num_spherical: int,\n        num_radial: int,\n        cutoff: float = 5.0,\n        max_num_neighbors: int = 32,\n        envelope_exponent: int = 5,\n        num_before_skip: int = 1,\n        num_after_skip: int = 2,\n        num_output_layers: int = 3,\n        act: Union[str, Callable] = 'swish',\n        output_initializer: str = 'zeros',\n    ):\n        super().__init__()\n\n        if num_spherical < 2:\n            raise ValueError(\"'num_spherical' should be greater than 1\")\n\n        act = activation_resolver(act)\n\n        self.cutoff = cutoff\n        self.max_num_neighbors = max_num_neighbors\n        self.num_blocks = num_blocks\n\n        self.rbf = BesselBasisLayer(num_radial, cutoff, envelope_exponent)\n        self.sbf = SphericalBasisLayer(num_spherical, num_radial, cutoff,\n                                       envelope_exponent)\n\n        self.emb = EmbeddingBlock(num_radial, hidden_channels, act)\n\n        self.output_blocks = torch.nn.ModuleList([\n            OutputBlock(\n                num_radial,\n                hidden_channels,\n                out_channels,\n                num_output_layers,\n                act,\n                output_initializer,\n            ) for _ in range(num_blocks + 1)\n        ])\n\n        self.interaction_blocks = torch.nn.ModuleList([\n            InteractionBlock(\n                hidden_channels,\n                num_bilinear,\n                num_spherical,\n                num_radial,\n                num_before_skip,\n                num_after_skip,\n                act,\n            ) for _ in range(num_blocks)\n        ])\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.rbf.reset_parameters()\n        self.emb.reset_parameters()\n        for out in self.output_blocks:\n            out.reset_parameters()\n        for interaction in self.interaction_blocks:\n            interaction.reset_parameters()\n\n    @classmethod\n    def from_qm9_pretrained(\n        cls,\n        root: str,\n        dataset: Dataset,\n        target: int,\n    ) -> Tuple['DimeNet', Dataset, Dataset, Dataset]:  # pragma: no cover\n        r\"\"\"Returns a pre-trained :class:`DimeNet` model on the\n        :class:`~torch_geometric.datasets.QM9` dataset, trained on the\n        specified target :obj:`target`.\n        \"\"\"\n        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n        import tensorflow as tf\n\n        assert target >= 0 and target <= 12 and not target == 4\n\n        root = osp.expanduser(osp.normpath(root))\n        path = osp.join(root, 'pretrained_dimenet', qm9_target_dict[target])\n\n        os.makedirs(path, exist_ok=True)\n        url = f'{cls.url}/{qm9_target_dict[target]}'\n\n        if not osp.exists(osp.join(path, 'checkpoint')):\n            download_url(f'{url}/checkpoint', path)\n            download_url(f'{url}/ckpt.data-00000-of-00002', path)\n            download_url(f'{url}/ckpt.data-00001-of-00002', path)\n            download_url(f'{url}/ckpt.index', path)\n\n        path = osp.join(path, 'ckpt')\n        reader = tf.train.load_checkpoint(path)\n\n        model = cls(\n            hidden_channels=128,\n            out_channels=1,\n            num_blocks=6,\n            num_bilinear=8,\n            num_spherical=7,\n            num_radial=6,\n            cutoff=5.0,\n            envelope_exponent=5,\n            num_before_skip=1,\n            num_after_skip=2,\n            num_output_layers=3,\n        )\n\n        def copy_(src, name, transpose=False):\n            init = reader.get_tensor(f'{name}/.ATTRIBUTES/VARIABLE_VALUE')\n            init = torch.from_numpy(init)\n            if name[-6:] == 'kernel':\n                init = init.t()\n            src.data.copy_(init)\n\n        copy_(model.rbf.freq, 'rbf_layer/frequencies')\n        copy_(model.emb.emb.weight, 'emb_block/embeddings')\n        copy_(model.emb.lin_rbf.weight, 'emb_block/dense_rbf/kernel')\n        copy_(model.emb.lin_rbf.bias, 'emb_block/dense_rbf/bias')\n        copy_(model.emb.lin.weight, 'emb_block/dense/kernel')\n        copy_(model.emb.lin.bias, 'emb_block/dense/bias')\n\n        for i, block in enumerate(model.output_blocks):\n            copy_(block.lin_rbf.weight, f'output_blocks/{i}/dense_rbf/kernel')\n            for j, lin in enumerate(block.lins):\n                copy_(lin.weight, f'output_blocks/{i}/dense_layers/{j}/kernel')\n                copy_(lin.bias, f'output_blocks/{i}/dense_layers/{j}/bias')\n            copy_(block.lin.weight, f'output_blocks/{i}/dense_final/kernel')\n\n        for i, block in enumerate(model.interaction_blocks):\n            copy_(block.lin_rbf.weight, f'int_blocks/{i}/dense_rbf/kernel')\n            copy_(block.lin_sbf.weight, f'int_blocks/{i}/dense_sbf/kernel')\n            copy_(block.lin_kj.weight, f'int_blocks/{i}/dense_kj/kernel')\n            copy_(block.lin_kj.bias, f'int_blocks/{i}/dense_kj/bias')\n            copy_(block.lin_ji.weight, f'int_blocks/{i}/dense_ji/kernel')\n            copy_(block.lin_ji.bias, f'int_blocks/{i}/dense_ji/bias')\n            copy_(block.W, f'int_blocks/{i}/bilinear')\n            for j, layer in enumerate(block.layers_before_skip):\n                copy_(layer.lin1.weight,\n                      f'int_blocks/{i}/layers_before_skip/{j}/dense_1/kernel')\n                copy_(layer.lin1.bias,\n                      f'int_blocks/{i}/layers_before_skip/{j}/dense_1/bias')\n                copy_(layer.lin2.weight,\n                      f'int_blocks/{i}/layers_before_skip/{j}/dense_2/kernel')\n                copy_(layer.lin2.bias,\n                      f'int_blocks/{i}/layers_before_skip/{j}/dense_2/bias')\n            copy_(block.lin.weight, f'int_blocks/{i}/final_before_skip/kernel')\n            copy_(block.lin.bias, f'int_blocks/{i}/final_before_skip/bias')\n            for j, layer in enumerate(block.layers_after_skip):\n                copy_(layer.lin1.weight,\n                      f'int_blocks/{i}/layers_after_skip/{j}/dense_1/kernel')\n                copy_(layer.lin1.bias,\n                      f'int_blocks/{i}/layers_after_skip/{j}/dense_1/bias')\n                copy_(layer.lin2.weight,\n                      f'int_blocks/{i}/layers_after_skip/{j}/dense_2/kernel')\n                copy_(layer.lin2.bias,\n                      f'int_blocks/{i}/layers_after_skip/{j}/dense_2/bias')\n\n        # Use the same random seed as the official DimeNet` implementation.\n        random_state = np.random.RandomState(seed=42)\n        perm = torch.from_numpy(random_state.permutation(np.arange(130831)))\n        perm = perm.long()\n        train_idx = perm[:110000]\n        val_idx = perm[110000:120000]\n        test_idx = perm[120000:]\n\n        return model, (dataset[train_idx], dataset[val_idx], dataset[test_idx])\n\n    def forward(\n        self,\n        z: Tensor,\n        pos: Tensor,\n        batch: OptTensor = None,\n    ) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            z (torch.Tensor): Atomic number of each atom with shape\n                :obj:`[num_atoms]`.\n            pos (torch.Tensor): Coordinates of each atom with shape\n                :obj:`[num_atoms, 3]`.\n            batch (torch.Tensor, optional): Batch indices assigning each atom\n                to a separate molecule with shape :obj:`[num_atoms]`.\n                (default: :obj:`None`)\n        \"\"\"\n        edge_index = radius_graph(pos, r=self.cutoff, batch=batch,\n                                  max_num_neighbors=self.max_num_neighbors)\n\n        i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(\n            edge_index, num_nodes=z.size(0))\n\n        # Calculate distances.\n        dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()\n\n        # Calculate angles.\n        if isinstance(self, DimeNetPlusPlus):\n            pos_jk, pos_ij = pos[idx_j] - pos[idx_k], pos[idx_i] - pos[idx_j]\n            a = (pos_ij * pos_jk).sum(dim=-1)\n            b = torch.cross(pos_ij, pos_jk, dim=1).norm(dim=-1)\n        elif isinstance(self, DimeNet):\n            pos_ji, pos_ki = pos[idx_j] - pos[idx_i], pos[idx_k] - pos[idx_i]\n            a = (pos_ji * pos_ki).sum(dim=-1)\n            b = torch.cross(pos_ji, pos_ki, dim=1).norm(dim=-1)\n        angle = torch.atan2(b, a)\n\n        rbf = self.rbf(dist)\n        sbf = self.sbf(dist, angle, idx_kj)\n\n        # Embedding block.\n        x = self.emb(z, rbf, i, j)\n        P = self.output_blocks[0](x, rbf, i, num_nodes=pos.size(0))\n\n        # Interaction blocks.\n        for interaction_block, output_block in zip(self.interaction_blocks,\n                                                   self.output_blocks[1:]):\n            x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)\n            P = P + output_block(x, rbf, i, num_nodes=pos.size(0))\n\n        if batch is None:\n            return P.sum(dim=0)\n        else:\n            return scatter(P, batch, dim=0, reduce='sum')\n\n\nclass DimeNetPlusPlus(DimeNet):\n    r\"\"\"The DimeNet++ from the `\"Fast and Uncertainty-Aware\n    Directional Message Passing for Non-Equilibrium Molecules\"\n    <https://arxiv.org/abs/2011.14115>`_ paper.\n\n    :class:`DimeNetPlusPlus` is an upgrade to the :class:`DimeNet` model with\n    8x faster and 10% more accurate than :class:`DimeNet`.\n\n    Args:\n        hidden_channels (int): Hidden embedding size.\n        out_channels (int): Size of each output sample.\n        num_blocks (int): Number of building blocks.\n        int_emb_size (int): Size of embedding in the interaction block.\n        basis_emb_size (int): Size of basis embedding in the interaction block.\n        out_emb_channels (int): Size of embedding in the output block.\n        num_spherical (int): Number of spherical harmonics.\n        num_radial (int): Number of radial basis functions.\n        cutoff: (float, optional): Cutoff distance for interatomic\n            interactions. (default: :obj:`5.0`)\n        max_num_neighbors (int, optional): The maximum number of neighbors to\n            collect for each node within the :attr:`cutoff` distance.\n            (default: :obj:`32`)\n        envelope_exponent (int, optional): Shape of the smooth cutoff.\n            (default: :obj:`5`)\n        num_before_skip: (int, optional): Number of residual layers in the\n            interaction blocks before the skip connection. (default: :obj:`1`)\n        num_after_skip: (int, optional): Number of residual layers in the\n            interaction blocks after the skip connection. (default: :obj:`2`)\n        num_output_layers: (int, optional): Number of linear layers for the\n            output blocks. (default: :obj:`3`)\n        act: (str or Callable, optional): The activation function.\n            (default: :obj:`\"swish\"`)\n        output_initializer (str, optional): The initialization method for the\n            output layer (:obj:`\"zeros\"`, :obj:`\"glorot_orthogonal\"`).\n            (default: :obj:`\"zeros\"`)\n    \"\"\"\n\n    url = ('https://raw.githubusercontent.com/gasteigerjo/dimenet/'\n           'master/pretrained/dimenet_pp')\n\n    def __init__(\n        self,\n        hidden_channels: int,\n        out_channels: int,\n        num_blocks: int,\n        int_emb_size: int,\n        basis_emb_size: int,\n        out_emb_channels: int,\n        num_spherical: int,\n        num_radial: int,\n        cutoff: float = 5.0,\n        max_num_neighbors: int = 32,\n        envelope_exponent: int = 5,\n        num_before_skip: int = 1,\n        num_after_skip: int = 2,\n        num_output_layers: int = 3,\n        act: Union[str, Callable] = 'swish',\n        output_initializer: str = 'zeros',\n    ):\n        act = activation_resolver(act)\n\n        super().__init__(\n            hidden_channels=hidden_channels,\n            out_channels=out_channels,\n            num_blocks=num_blocks,\n            num_bilinear=1,\n            num_spherical=num_spherical,\n            num_radial=num_radial,\n            cutoff=cutoff,\n            max_num_neighbors=max_num_neighbors,\n            envelope_exponent=envelope_exponent,\n            num_before_skip=num_before_skip,\n            num_after_skip=num_after_skip,\n            num_output_layers=num_output_layers,\n            act=act,\n            output_initializer=output_initializer,\n        )\n\n        # We are re-using the RBF, SBF and embedding layers of `DimeNet` and\n        # redefine output_block and interaction_block in DimeNet++.\n        # Hence, it is to be noted that in the above initialization, the\n        # variable `num_bilinear` does not have any purpose as it is used\n        # solely in the `OutputBlock` of DimeNet:\n        self.output_blocks = torch.nn.ModuleList([\n            OutputPPBlock(\n                num_radial,\n                hidden_channels,\n                out_emb_channels,\n                out_channels,\n                num_output_layers,\n                act,\n                output_initializer,\n            ) for _ in range(num_blocks + 1)\n        ])\n\n        self.interaction_blocks = torch.nn.ModuleList([\n            InteractionPPBlock(\n                hidden_channels,\n                int_emb_size,\n                basis_emb_size,\n                num_spherical,\n                num_radial,\n                num_before_skip,\n                num_after_skip,\n                act,\n            ) for _ in range(num_blocks)\n        ])\n\n        self.reset_parameters()\n\n    @classmethod\n    def from_qm9_pretrained(\n        cls,\n        root: str,\n        dataset: Dataset,\n        target: int,\n    ) -> Tuple['DimeNetPlusPlus', Dataset, Dataset,\n               Dataset]:  # pragma: no cover\n        r\"\"\"Returns a pre-trained :class:`DimeNetPlusPlus` model on the\n        :class:`~torch_geometric.datasets.QM9` dataset, trained on the\n        specified target :obj:`target`.\n        \"\"\"\n        os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n        import tensorflow as tf\n\n        assert target >= 0 and target <= 12 and not target == 4\n\n        root = osp.expanduser(osp.normpath(root))\n        path = osp.join(root, 'pretrained_dimenet_pp', qm9_target_dict[target])\n\n        os.makedirs(path, exist_ok=True)\n        url = f'{cls.url}/{qm9_target_dict[target]}'\n\n        if not osp.exists(osp.join(path, 'checkpoint')):\n            download_url(f'{url}/checkpoint', path)\n            download_url(f'{url}/ckpt.data-00000-of-00002', path)\n            download_url(f'{url}/ckpt.data-00001-of-00002', path)\n            download_url(f'{url}/ckpt.index', path)\n\n        path = osp.join(path, 'ckpt')\n        reader = tf.train.load_checkpoint(path)\n\n        # Configuration from DimeNet++:\n        # https://github.com/gasteigerjo/dimenet/blob/master/config_pp.yaml\n        model = cls(\n            hidden_channels=128,\n            out_channels=1,\n            num_blocks=4,\n            int_emb_size=64,\n            basis_emb_size=8,\n            out_emb_channels=256,\n            num_spherical=7,\n            num_radial=6,\n            cutoff=5.0,\n            max_num_neighbors=32,\n            envelope_exponent=5,\n            num_before_skip=1,\n            num_after_skip=2,\n            num_output_layers=3,\n        )\n\n        def copy_(src, name, transpose=False):\n            init = reader.get_tensor(f'{name}/.ATTRIBUTES/VARIABLE_VALUE')\n            init = torch.from_numpy(init)\n            if name[-6:] == 'kernel':\n                init = init.t()\n            src.data.copy_(init)\n\n        copy_(model.rbf.freq, 'rbf_layer/frequencies')\n        copy_(model.emb.emb.weight, 'emb_block/embeddings')\n        copy_(model.emb.lin_rbf.weight, 'emb_block/dense_rbf/kernel')\n        copy_(model.emb.lin_rbf.bias, 'emb_block/dense_rbf/bias')\n        copy_(model.emb.lin.weight, 'emb_block/dense/kernel')\n        copy_(model.emb.lin.bias, 'emb_block/dense/bias')\n\n        for i, block in enumerate(model.output_blocks):\n            copy_(block.lin_rbf.weight, f'output_blocks/{i}/dense_rbf/kernel')\n            copy_(block.lin_up.weight,\n                  f'output_blocks/{i}/up_projection/kernel')\n            for j, lin in enumerate(block.lins):\n                copy_(lin.weight, f'output_blocks/{i}/dense_layers/{j}/kernel')\n                copy_(lin.bias, f'output_blocks/{i}/dense_layers/{j}/bias')\n            copy_(block.lin.weight, f'output_blocks/{i}/dense_final/kernel')\n\n        for i, block in enumerate(model.interaction_blocks):\n            copy_(block.lin_rbf1.weight, f'int_blocks/{i}/dense_rbf1/kernel')\n            copy_(block.lin_rbf2.weight, f'int_blocks/{i}/dense_rbf2/kernel')\n            copy_(block.lin_sbf1.weight, f'int_blocks/{i}/dense_sbf1/kernel')\n            copy_(block.lin_sbf2.weight, f'int_blocks/{i}/dense_sbf2/kernel')\n\n            copy_(block.lin_ji.weight, f'int_blocks/{i}/dense_ji/kernel')\n            copy_(block.lin_ji.bias, f'int_blocks/{i}/dense_ji/bias')\n            copy_(block.lin_kj.weight, f'int_blocks/{i}/dense_kj/kernel')\n            copy_(block.lin_kj.bias, f'int_blocks/{i}/dense_kj/bias')\n\n            copy_(block.lin_down.weight,\n                  f'int_blocks/{i}/down_projection/kernel')\n            copy_(block.lin_up.weight, f'int_blocks/{i}/up_projection/kernel')\n\n            for j, layer in enumerate(block.layers_before_skip):\n                copy_(layer.lin1.weight,\n                      f'int_blocks/{i}/layers_before_skip/{j}/dense_1/kernel')\n                copy_(layer.lin1.bias,\n                      f'int_blocks/{i}/layers_before_skip/{j}/dense_1/bias')\n                copy_(layer.lin2.weight,\n                      f'int_blocks/{i}/layers_before_skip/{j}/dense_2/kernel')\n                copy_(layer.lin2.bias,\n                      f'int_blocks/{i}/layers_before_skip/{j}/dense_2/bias')\n\n            copy_(block.lin.weight, f'int_blocks/{i}/final_before_skip/kernel')\n            copy_(block.lin.bias, f'int_blocks/{i}/final_before_skip/bias')\n\n            for j, layer in enumerate(block.layers_after_skip):\n                copy_(layer.lin1.weight,\n                      f'int_blocks/{i}/layers_after_skip/{j}/dense_1/kernel')\n                copy_(layer.lin1.bias,\n                      f'int_blocks/{i}/layers_after_skip/{j}/dense_1/bias')\n                copy_(layer.lin2.weight,\n                      f'int_blocks/{i}/layers_after_skip/{j}/dense_2/kernel')\n                copy_(layer.lin2.bias,\n                      f'int_blocks/{i}/layers_after_skip/{j}/dense_2/bias')\n\n        random_state = np.random.RandomState(seed=42)\n        perm = torch.from_numpy(random_state.permutation(np.arange(130831)))\n        perm = perm.long()\n        train_idx = perm[:110000]\n        val_idx = perm[110000:120000]\n        test_idx = perm[120000:]\n\n        return model, (dataset[train_idx], dataset[val_idx], dataset[test_idx])\n"
  },
  {
    "path": "torch_geometric/nn/models/dimenet_utils.py",
    "content": "# Shameless steal from: https://github.com/klicperajo/dimenet\n\nimport math\n\nimport numpy as np\nimport sympy as sym\nfrom scipy import special as sp\nfrom scipy.optimize import brentq\n\n\ndef Jn(r, n):\n    return np.sqrt(np.pi / (2 * r)) * sp.jv(n + 0.5, r)\n\n\ndef Jn_zeros(n, k):\n    zerosj = np.zeros((n, k), dtype='float32')\n    zerosj[0] = np.arange(1, k + 1) * np.pi\n    points = np.arange(1, k + n) * np.pi\n    racines = np.zeros(k + n - 1, dtype='float32')\n    for i in range(1, n):\n        for j in range(k + n - 1 - i):\n            foo = brentq(Jn, points[j], points[j + 1], (i, ))\n            racines[j] = foo\n        points = racines\n        zerosj[i][:k] = racines[:k]\n\n    return zerosj\n\n\ndef spherical_bessel_formulas(n):\n    x = sym.symbols('x')\n\n    f = [sym.sin(x) / x]\n    a = sym.sin(x) / x\n    for i in range(1, n):\n        b = sym.diff(a, x) / x\n        f += [sym.simplify(b * (-x)**i)]\n        a = sym.simplify(b)\n    return f\n\n\ndef bessel_basis(n, k):\n    zeros = Jn_zeros(n, k)\n    normalizer = []\n    for order in range(n):\n        normalizer_tmp = []\n        for i in range(k):\n            normalizer_tmp += [0.5 * Jn(zeros[order, i], order + 1)**2]\n        normalizer_tmp = 1 / np.array(normalizer_tmp)**0.5\n        normalizer += [normalizer_tmp]\n\n    f = spherical_bessel_formulas(n)\n    x = sym.symbols('x')\n    bess_basis = []\n    for order in range(n):\n        bess_basis_tmp = []\n        for i in range(k):\n            bess_basis_tmp += [\n                sym.simplify(normalizer[order][i] *\n                             f[order].subs(x, zeros[order, i] * x))\n            ]\n        bess_basis += [bess_basis_tmp]\n    return bess_basis\n\n\ndef sph_harm_prefactor(k, m):\n    return ((2 * k + 1) * math.factorial(k - abs(m)) /\n            (4 * np.pi * math.factorial(k + abs(m))))**0.5\n\n\ndef associated_legendre_polynomials(k, zero_m_only=True):\n    r\"\"\"Helper function to calculate Y_l^m.\"\"\"\n    z = sym.symbols('z')\n    P_l_m = [[0] * (j + 1) for j in range(k)]\n\n    P_l_m[0][0] = 1\n    if k > 0:\n        P_l_m[1][0] = z\n\n        for j in range(2, k):\n            # Use the property of Eq (7) in\n            # https://mathworld.wolfram.com/AssociatedLegendrePolynomial.html:\n            P_l_m[j][0] = sym.simplify(((2 * j - 1) * z * P_l_m[j - 1][0] -\n                                        (j - 1) * P_l_m[j - 2][0]) / j)\n        if not zero_m_only:\n            for i in range(1, k):\n                P_l_m[i][i] = sym.simplify(\n                    (1 - 2 * i) * P_l_m[i - 1][i - 1] * (1 - z**2)**0.5)\n                if i + 1 < k:\n                    # Use the property of Eq (11) in\n                    # https://mathworld.wolfram.com/AssociatedLegendrePolynomial.html:\n                    P_l_m[i + 1][i] = sym.simplify(\n                        (2 * i + 1) * z * P_l_m[i][i])\n                for j in range(i + 2, k):\n                    # Use the property of Eq (7) in\n                    # https://mathworld.wolfram.com/AssociatedLegendrePolynomial.html:\n                    P_l_m[j][i] = sym.simplify(\n                        ((2 * j - 1) * z * P_l_m[j - 1][i] -\n                         (i + j - 1) * P_l_m[j - 2][i]) / (j - i))\n\n    return P_l_m\n\n\ndef real_sph_harm(k, zero_m_only=True, spherical_coordinates=True):\n    if not zero_m_only:\n        S_m = [0]\n        C_m = [1]\n        for i in range(1, k):\n            x = sym.symbols('x')\n            y = sym.symbols('y')\n            S_m += [x * S_m[i - 1] + y * C_m[i - 1]]\n            C_m += [x * C_m[i - 1] - y * S_m[i - 1]]\n\n    P_l_m = associated_legendre_polynomials(k, zero_m_only)\n    if spherical_coordinates:\n        theta = sym.symbols('theta')\n        z = sym.symbols('z')\n        for i in range(len(P_l_m)):\n            for j in range(len(P_l_m[i])):\n                if not isinstance(P_l_m[i][j], int):\n                    P_l_m[i][j] = P_l_m[i][j].subs(z, sym.cos(theta))\n        if not zero_m_only:\n            phi = sym.symbols('phi')\n            for i in range(len(S_m)):\n                S_m[i] = S_m[i].subs(x,\n                                     sym.sin(theta) * sym.cos(phi)).subs(\n                                         y,\n                                         sym.sin(theta) * sym.sin(phi))\n            for i in range(len(C_m)):\n                C_m[i] = C_m[i].subs(x,\n                                     sym.sin(theta) * sym.cos(phi)).subs(\n                                         y,\n                                         sym.sin(theta) * sym.sin(phi))\n\n    Y_func_l_m = [['0'] * (2 * j + 1) for j in range(k)]\n    for i in range(k):\n        Y_func_l_m[i][0] = sym.simplify(sph_harm_prefactor(i, 0) * P_l_m[i][0])\n\n    if not zero_m_only:\n        for i in range(1, k):\n            for j in range(1, i + 1):\n                Y_func_l_m[i][j] = sym.simplify(\n                    2**0.5 * sph_harm_prefactor(i, j) * C_m[j] * P_l_m[i][j])\n        for i in range(1, k):\n            for j in range(1, i + 1):\n                Y_func_l_m[i][-j] = sym.simplify(\n                    2**0.5 * sph_harm_prefactor(i, -j) * S_m[j] * P_l_m[i][j])\n\n    return Y_func_l_m\n"
  },
  {
    "path": "torch_geometric/nn/models/gnnff.py",
    "content": "import torch\nfrom torch import Tensor\nfrom torch.nn import BatchNorm1d, Embedding, Linear, ModuleList, Sequential\n\nfrom torch_geometric.nn import radius_graph\nfrom torch_geometric.nn.inits import reset\nfrom torch_geometric.nn.models.dimenet import triplets\nfrom torch_geometric.nn.models.schnet import ShiftedSoftplus\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import scatter\n\n\nclass GaussianFilter(torch.nn.Module):\n    def __init__(self, start=0.0, stop=5.0, num_gaussians=50):\n        super().__init__()\n        offset = torch.linspace(start, stop, num_gaussians)\n        self.coeff = -0.5 / (float(offset[1]) - float(offset[0]))**2\n        self.register_buffer('offset', offset)\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n\n    def forward(self, dist: Tensor) -> Tensor:\n        dist = dist.view(-1, 1) - self.offset.view(1, -1)\n        return torch.exp(self.coeff * dist.pow(2))\n\n\nclass NodeBlock(torch.nn.Module):\n    def __init__(self, hidden_node_channels: int, hidden_edge_channels: int):\n        super().__init__()\n        self.lin_c1 = Linear(hidden_node_channels + hidden_edge_channels,\n                             2 * hidden_node_channels)\n\n        # BN was added based on previous studies.\n        # ref: https://github.com/txie-93/cgcnn/blob/master/cgcnn/model.py\n        self.bn_c1 = BatchNorm1d(2 * hidden_node_channels)\n        self.bn = BatchNorm1d(hidden_node_channels)\n\n    def reset_parameters(self):\n        self.lin_c1.reset_parameters()\n        self.bn_c1.reset_parameters()\n        self.bn.reset_parameters()\n\n    def forward(self, node_emb: Tensor, edge_emb: Tensor, i: Tensor) -> Tensor:\n        c1 = torch.cat([node_emb[i], edge_emb], dim=1)\n        c1 = self.bn_c1(self.lin_c1(c1))\n        c1_filter, c1_core = c1.chunk(2, dim=1)\n        c1_filter = c1_filter.sigmoid()\n        c1_core = c1_core.tanh()\n        c1_emb = scatter(c1_filter * c1_core, i, dim=0,\n                         dim_size=node_emb.size(0), reduce='sum')\n        c1_emb = self.bn(c1_emb)\n\n        return (node_emb + c1_emb).tanh()\n\n\nclass EdgeBlock(torch.nn.Module):\n    def __init__(self, hidden_node_channels: int, hidden_edge_channels: int):\n        super().__init__()\n        self.lin_c2 = Linear(hidden_node_channels, 2 * hidden_edge_channels)\n        self.lin_c3 = Linear(\n            3 * hidden_node_channels + 2 * hidden_edge_channels,\n            2 * hidden_edge_channels,\n        )\n\n        # BN was added based on previous studies.\n        # ref: https://github.com/txie-93/cgcnn/blob/master/cgcnn/model.py\n        self.bn_c2 = BatchNorm1d(2 * hidden_edge_channels)\n        self.bn_c3 = BatchNorm1d(2 * hidden_edge_channels)\n        self.bn_c2_2 = BatchNorm1d(hidden_edge_channels)\n        self.bn_c3_2 = BatchNorm1d(hidden_edge_channels)\n\n    def reset_parameters(self):\n        self.lin_c2.reset_parameters()\n        self.lin_c3.reset_parameters()\n        self.bn_c2.reset_parameters()\n        self.bn_c3.reset_parameters()\n        self.bn_c2_2.reset_parameters()\n        self.bn_c3_2.reset_parameters()\n\n    def forward(\n        self,\n        node_emb: Tensor,\n        edge_emb: Tensor,\n        i: Tensor,\n        j: Tensor,\n        idx_i: Tensor,\n        idx_j: Tensor,\n        idx_k: Tensor,\n        idx_ji: Tensor,\n        idx_kj: Tensor,\n    ) -> Tensor:\n        c2 = node_emb[i] * node_emb[j]\n        c2 = self.bn_c2(self.lin_c2(c2))\n        c2_filter, c2_core = c2.chunk(2, dim=1)\n        c2_filter = c2_filter.sigmoid()\n        c2_core = c2_core.tanh()\n        c2_emb = self.bn_c2_2(c2_filter * c2_core)\n\n        c3 = torch.cat([\n            node_emb[idx_i],\n            node_emb[idx_j],\n            node_emb[idx_k],\n            edge_emb[idx_ji],\n            edge_emb[idx_kj],\n        ], dim=1)\n        c3 = self.bn_c3(self.lin_c3(c3))\n        c3_filter, c3_core = c3.chunk(2, dim=1)\n        c3_filter = c3_filter.sigmoid()\n        c3_core = c3_core.tanh()\n        c3_emb = scatter(c3_filter * c3_core, idx_ji, dim=0,\n                         dim_size=edge_emb.size(0), reduce='sum')\n        c3_emb = self.bn_c3_2(c3_emb)\n\n        return (edge_emb + c2_emb + c3_emb).tanh()\n\n\nclass GNNFF(torch.nn.Module):\n    r\"\"\"The Graph Neural Network Force Field (GNNFF) from the\n    `\"Accurate and scalable graph neural network force field and molecular\n    dynamics with direct force architecture\"\n    <https://www.nature.com/articles/s41524-021-00543-3>`_ paper.\n    :class:`GNNFF` directly predicts atomic forces from automatically\n    extracted features of the local atomic environment that are\n    translationally-invariant, but rotationally-covariant to the coordinate of\n    the atoms.\n\n    Args:\n        hidden_node_channels (int): Hidden node embedding size.\n        hidden_edge_channels (int): Hidden edge embedding size.\n        num_layers (int): Number of message passing blocks.\n        cutoff (float, optional): Cutoff distance for interatomic\n            interactions. (default: :obj:`5.0`)\n        max_num_neighbors (int, optional): The maximum number of neighbors to\n            collect for each node within the :attr:`cutoff` distance.\n            (default: :obj:`32`)\n    \"\"\"\n    def __init__(\n        self,\n        hidden_node_channels: int,\n        hidden_edge_channels: int,\n        num_layers: int,\n        cutoff: float = 5.0,\n        max_num_neighbors: int = 32,\n    ):\n        super().__init__()\n\n        self.cutoff = cutoff\n        self.max_num_neighbors = max_num_neighbors\n\n        self.node_emb = Sequential(\n            Embedding(95, hidden_node_channels),\n            ShiftedSoftplus(),\n            Linear(hidden_node_channels, hidden_node_channels),\n            ShiftedSoftplus(),\n            Linear(hidden_node_channels, hidden_node_channels),\n        )\n        self.edge_emb = GaussianFilter(0.0, 5.0, hidden_edge_channels)\n\n        self.node_blocks = ModuleList([\n            NodeBlock(hidden_node_channels, hidden_edge_channels)\n            for _ in range(num_layers)\n        ])\n        self.edge_blocks = ModuleList([\n            EdgeBlock(hidden_node_channels, hidden_edge_channels)\n            for _ in range(num_layers)\n        ])\n\n        self.force_predictor = Sequential(\n            Linear(hidden_edge_channels, hidden_edge_channels),\n            ShiftedSoftplus(),\n            Linear(hidden_edge_channels, hidden_edge_channels),\n            ShiftedSoftplus(),\n            Linear(hidden_edge_channels, 1),\n        )\n\n    def reset_parameters(self):\n        reset(self.node_emb)\n        self.edge_emb.reset_parameters()\n        for node_block in self.node_blocks:\n            node_block.reset_parameters()\n        for edge_block in self.edge_blocks:\n            edge_block.reset_parameters()\n        reset(self.force_predictor)\n\n    def forward(self, z: Tensor, pos: Tensor,\n                batch: OptTensor = None) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        edge_index = radius_graph(pos, r=self.cutoff, batch=batch,\n                                  max_num_neighbors=self.max_num_neighbors)\n\n        i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = triplets(\n            edge_index, num_nodes=z.size(0))\n\n        # Calculate distances and unit vector:\n        dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()\n        unit_vec = (pos[i] - pos[j]) / dist.view(-1, 1)\n\n        # Embedding blocks:\n        node_emb = self.node_emb(z)\n        edge_emb = self.edge_emb(dist)\n\n        # Message passing blocks:\n        for node_block, edge_block in zip(self.node_blocks, self.edge_blocks):\n            node_emb = node_block(node_emb, edge_emb, i)\n            edge_emb = edge_block(node_emb, edge_emb, i, j, idx_i, idx_j,\n                                  idx_k, idx_ji, idx_kj)\n\n        # Force prediction block:\n        force = self.force_predictor(edge_emb) * unit_vec\n\n        return scatter(force, i, dim=0, reduce='sum')\n"
  },
  {
    "path": "torch_geometric/nn/models/gpse.py",
    "content": "import logging\nimport os\nimport os.path as osp\nimport time\nfrom collections import OrderedDict\nfrom typing import List, Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.nn import Module\nfrom tqdm import trange\n\nimport torch_geometric.transforms as T\nfrom torch_geometric.data import Data, Dataset, download_url\nfrom torch_geometric.loader import DataLoader, NeighborLoader\nfrom torch_geometric.nn import (\n    ResGatedGraphConv,\n    global_add_pool,\n    global_max_pool,\n    global_mean_pool,\n)\nfrom torch_geometric.nn.resolver import activation_resolver\nfrom torch_geometric.utils import to_dense_batch\n\n\nclass Linear(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        bias: bool,\n    ) -> None:\n        super().__init__()\n        self.model = torch.nn.Linear(in_channels, out_channels, bias=bias)\n\n    def forward(self, batch):\n        if isinstance(batch, torch.Tensor):\n            batch = self.model(batch)\n        else:\n            batch.x = self.model(batch.x)\n        return batch\n\n\nclass ResGatedGCNConv(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        bias: bool,\n        **kwargs,\n    ) -> None:\n        super().__init__()\n        self.model = ResGatedGraphConv(\n            in_channels,\n            out_channels,\n            bias=bias,\n            **kwargs,\n        )\n\n    def forward(self, batch):\n        batch.x = self.model(batch.x, batch.edge_index)\n        return batch\n\n\nclass GeneralLayer(torch.nn.Module):\n    def __init__(\n        self,\n        name: str,\n        in_channels: int,\n        out_channels: int,\n        has_batch_norm: bool,\n        has_l2_norm: bool,\n        dropout: float,\n        act: Optional[str],\n        **kwargs,\n    ):\n        super().__init__()\n        self.has_l2_norm = has_l2_norm\n\n        layer_dict = {\n            'linear': Linear,\n            'resgatedgcnconv': ResGatedGCNConv,\n        }\n        self.layer = layer_dict[name](\n            in_channels,\n            out_channels,\n            bias=not has_batch_norm,\n            **kwargs,\n        )\n        post_layers = []\n        if has_batch_norm:\n            post_layers.append(\n                torch.nn.BatchNorm1d(out_channels, eps=1e-5, momentum=0.1))\n        if dropout > 0:\n            post_layers.append(torch.nn.Dropout(p=dropout, inplace=False))\n        if act is not None:\n            post_layers.append(activation_resolver(act))\n        self.post_layer = nn.Sequential(*post_layers)\n\n    def forward(self, batch):\n        batch = self.layer(batch)\n        if isinstance(batch, torch.Tensor):\n            batch = self.post_layer(batch)\n            if self.has_l2_norm:\n                batch = F.normalize(batch, p=2, dim=1)\n        else:\n            batch.x = self.post_layer(batch.x)\n            if self.has_l2_norm:\n                batch.x = F.normalize(batch.x, p=2, dim=1)\n        return batch\n\n\nclass GeneralMultiLayer(torch.nn.Module):\n    def __init__(\n        self,\n        name: str,\n        in_channels: int,\n        out_channels: int,\n        hidden_channels: Optional[int],\n        num_layers: int,\n        has_batch_norm: bool,\n        has_l2_norm: bool,\n        dropout: float,\n        act: str,\n        final_act: bool,\n        **kwargs,\n    ) -> None:\n        super().__init__()\n        hidden_channels = hidden_channels or out_channels\n\n        for i in range(num_layers):\n            d_in = in_channels if i == 0 else hidden_channels\n            d_out = out_channels if i == num_layers - 1 else hidden_channels\n            layer = GeneralLayer(\n                name=name,\n                in_channels=d_in,\n                out_channels=d_out,\n                has_batch_norm=has_batch_norm,\n                has_l2_norm=has_l2_norm,\n                dropout=dropout,\n                act=None if i == num_layers - 1 and not final_act else act,\n                **kwargs,\n            )\n            self.add_module(f'Layer_{i}', layer)\n\n    def forward(self, batch):\n        for layer in self.children():\n            batch = layer(batch)\n        return batch\n\n\nclass BatchNorm1dNode(torch.nn.Module):\n    def __init__(self, channels: int) -> None:\n        super().__init__()\n        self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1)\n\n    def forward(self, batch):\n        batch.x = self.bn(batch.x)\n        return batch\n\n\nclass BatchNorm1dEdge(torch.nn.Module):\n    def __init__(self, channels: int) -> None:\n        super().__init__()\n        self.bn = torch.nn.BatchNorm1d(channels, eps=1e-5, momentum=0.1)\n\n    def forward(self, batch):\n        batch.edge_attr = self.bn(batch.edge_attr)\n        return batch\n\n\nclass MLP(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        hidden_channels: Optional[int],\n        num_layers: int,\n        has_batch_norm: bool = True,\n        has_l2_norm: bool = True,\n        dropout: float = 0.2,\n        act: str = 'relu',\n        **kwargs,\n    ):\n        super().__init__()\n        hidden_channels = hidden_channels or in_channels\n\n        layers = []\n        if num_layers > 1:\n            layer = GeneralMultiLayer(\n                'linear',\n                in_channels,\n                hidden_channels,\n                hidden_channels,\n                num_layers - 1,\n                has_batch_norm,\n                has_l2_norm,\n                dropout,\n                act,\n                final_act=True,\n                **kwargs,\n            )\n            layers.append(layer)\n        layers.append(Linear(hidden_channels, out_channels, bias=True))\n        self.model = nn.Sequential(*layers)\n\n    def forward(self, batch):\n        if isinstance(batch, torch.Tensor):\n            batch = self.model(batch)\n        else:\n            batch.x = self.model(batch.x)\n        return batch\n\n\nclass GNNStackStage(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        num_layers: int,\n        layer_type: str,\n        stage_type: str = 'skipsum',\n        final_l2_norm: bool = True,\n        has_batch_norm: bool = True,\n        has_l2_norm: bool = True,\n        dropout: float = 0.2,\n        act: Optional[str] = 'relu',\n    ):\n        super().__init__()\n        self.num_layers = num_layers\n        self.stage_type = stage_type\n        self.final_l2_norm = final_l2_norm\n\n        for i in range(num_layers):\n            if stage_type == 'skipconcat':\n                if i == 0:\n                    d_in = in_channels\n                else:\n                    d_in = in_channels + i * out_channels\n            else:\n                d_in = in_channels if i == 0 else out_channels\n            layer = GeneralLayer(layer_type, d_in, out_channels,\n                                 has_batch_norm, has_l2_norm, dropout, act)\n            self.add_module(f'layer{i}', layer)\n\n    def forward(self, batch):\n        for i, layer in enumerate(self.children()):\n            x = batch.x\n            batch = layer(batch)\n            if self.stage_type == 'skipsum':\n                batch.x = x + batch.x\n            elif self.stage_type == 'skipconcat' and i < self.num_layers - 1:\n                batch.x = torch.cat([x, batch.x], dim=1)\n\n        if self.final_l2_norm:\n            batch.x = F.normalize(batch.x, p=2, dim=-1)\n\n        return batch\n\n\nclass GNNInductiveHybridMultiHead(torch.nn.Module):\n    r\"\"\"GNN prediction head for inductive node and graph prediction tasks using\n    individual MLP for each task.\n\n    Args:\n        dim_in (int): Input dimension.\n        dim_out (int): Output dimension. Not used, as the dimension is\n            determined by :obj:`num_node_targets` and :obj:`num_graph_targets`\n            instead.\n        num_node_targets (int): Number of individual PSEs used as node-level\n            targets in pretraining :class:`GPSE`.\n        num_graph_targets (int): Number of graph-level targets used in\n            pretraining :class:`GPSE`.\n        layers_post_mp (int): Number of MLP layers after GNN message-passing.\n        virtual_node (bool, optional): Whether a virtual node is added to\n            graphs in :class:`GPSE` computation. (default: :obj:`True`)\n        multi_head_dim_inner (int, optional): Width of MLPs for PSE target\n            prediction heads. (default: :obj:`32`)\n        graph_pooling (str, optional): Type of graph pooling applied before\n            post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`.\n            (default: :obj:`add`)\n        has_bn (bool, optional): Whether to apply batch normalization to layer\n            outputs. (default: :obj:`True`)\n        has_l2norm (bool, optional): Whether to apply L2 normalization to the\n            layer outputs. (default: :obj:`True`)\n        dropout (float, optional): Dropout ratio at layer output.\n            (default: :obj:`0.2`)\n        act (str, optional): Activation to apply to layer outputs if\n            :obj:`has_act` is :obj:`True`. (default: :obj:`relu`)\n    \"\"\"\n    def __init__(\n        self,\n        dim_in: int,\n        dim_out: int,\n        num_node_targets: int,\n        num_graph_targets: int,\n        layers_post_mp: int,\n        virtual_node: bool = True,\n        multi_head_dim_inner: int = 32,\n        graph_pooling: str = 'add',\n        has_bn: bool = True,\n        has_l2norm: bool = True,\n        dropout: float = 0.2,\n        act: str = 'relu',\n    ):\n        super().__init__()\n        pool_dict = {\n            'add': global_add_pool,\n            'max': global_max_pool,\n            'mean': global_mean_pool\n        }\n        self.node_target_dim = num_node_targets\n        self.graph_target_dim = num_graph_targets\n        self.virtual_node = virtual_node\n        num_layers = layers_post_mp\n\n        self.node_post_mps = nn.ModuleList([\n            MLP(dim_in, 1, multi_head_dim_inner, num_layers, has_bn,\n                has_l2norm, dropout, act) for _ in range(self.node_target_dim)\n        ])\n\n        self.graph_pooling = pool_dict[graph_pooling]\n\n        self.graph_post_mp = MLP(dim_in, self.graph_target_dim, dim_in,\n                                 num_layers, has_bn, has_l2norm, dropout, act)\n\n    def _pad_and_stack(self, x1: torch.Tensor, x2: torch.Tensor, pad1: int,\n                       pad2: int):\n        padded_x1 = nn.functional.pad(x1, (0, pad2))\n        padded_x2 = nn.functional.pad(x2, (pad1, 0))\n        return torch.vstack([padded_x1, padded_x2])\n\n    def _apply_index(self, batch, virtual_node: bool, pad_node: int,\n                     pad_graph: int):\n        graph_pred, graph_true = batch.graph_feature, batch.y_graph\n        node_pred, node_true = batch.node_feature, batch.y\n        if virtual_node:\n            # Remove virtual node\n            idx = torch.concat([\n                torch.where(batch.batch == i)[0][:-1]\n                for i in range(batch.batch.max().item() + 1)\n            ])\n            node_pred, node_true = node_pred[idx], node_true[idx]\n\n        # Stack node predictions on top of graph predictions and pad with zeros\n        pred = self._pad_and_stack(node_pred, graph_pred, pad_node, pad_graph)\n        true = self._pad_and_stack(node_true, graph_true, pad_node, pad_graph)\n\n        return pred, true\n\n    def forward(self, batch):\n        batch.node_feature = torch.hstack(\n            [m(batch.x) for m in self.node_post_mps])\n        graph_emb = self.graph_pooling(batch.x, batch.batch)\n        batch.graph_feature = self.graph_post_mp(graph_emb)\n        return self._apply_index(batch, self.virtual_node,\n                                 self.node_target_dim, self.graph_target_dim)\n\n\nclass IdentityHead(torch.nn.Module):\n    def forward(self, batch):\n        return batch.x, batch.y\n\n\nclass GPSE(torch.nn.Module):\n    r\"\"\"The Graph Positional and Structural Encoder (GPSE) model from the\n    `\"Graph Positional and Structural Encoder\"\n    <https://arxiv.org/abs/2307.07107>`_ paper.\n\n    The GPSE model consists of a (1) deep GNN that consists of stacked\n    message passing layers, and a (2) prediction head to predict pre-computed\n    positional and structural encodings (PSE).\n    When used on downstream datasets, these prediction heads are removed and\n    the final fully-connected layer outputs are used as learned PSE embeddings.\n\n    GPSE also provides a static method :meth:`from_pretrained` to load\n    pre-trained GPSE models trained on a variety of molecular datasets.\n\n    .. code-block:: python\n\n        from torch_geometric.nn import GPSE, GPSENodeEncoder\n        from torch_geometric.transforms import AddGPSE\n        from torch_geometric.nn.models.gpse import precompute_GPSE\n\n        gpse_model = GPSE.from_pretrained('molpcba')\n\n        # Option 1: Precompute GPSE encodings in-place for a given dataset\n        dataset = ZINC(path, subset=True, split='train')\n        precompute_gpse(gpse_model, dataset)\n\n        # Option 2: Use the GPSE model with AddGPSE as a pre_transform to save\n        # the encodings\n        dataset = ZINC(path, subset=True, split='train',\n                       pre_transform=AddGPSE(gpse_model, vn=True,\n                       rand_type='NormalSE'))\n\n    Both approaches append the generated encodings to the :obj:`pestat_GPSE`\n    attribute of :class:`~torch_geometric.data.Data` objects. To use the GPSE\n    encodings for a downstream task, one may need to add these encodings to the\n    :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects. To\n    do so, one can use the :class:`GPSENodeEncoder` provided to map these\n    encodings to a desired dimension before appending them to :obj:`x`.\n\n    Let's say we have a graph dataset with 64 original node features, and we\n    have generated  GPSE encodings of dimension 32, i.e.\n    :obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an\n    inner dimension of 128. To do so, we can map the 32-dimensional GPSE\n    encodings to a higher dimension of 64, and then append them to the :obj:`x`\n    attribute of the :class:`~torch_geometric.data.Data` objects to obtain a\n    128-dimensional node feature representation.\n    :class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and\n    concatenation to :obj:`x`, the outputs of which can be used as input to a\n    GNN:\n\n    .. code-block:: python\n\n        encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,\n                                  expand_x=False)\n        gnn = GNN(...)\n\n        for batch in loader:\n            x = encoder(batch.x, batch.pestat_GPSE)\n            out = gnn(x, batch.edge_index)\n\n\n    Args:\n        dim_in (int, optional): Input dimension. (default: :obj:`20`)\n        dim_out (int, optional): Output dimension. (default: :obj:`51`)\n        dim_inner (int, optional): Width of the encoder layers.\n            (default: :obj:`512`)\n        layer_type (str, optional): Type of graph convolutional layer for\n            message-passing. (default: :obj:`resgatedgcnconv`)\n        layers_pre_mp (int, optional): Number of MLP layers before\n            message-passing. (default: :obj:`1`)\n        layers_mp (int, optional): Number of layers for message-passing.\n            (default: :obj:`20`)\n        layers_post_mp (int, optional): Number of MLP layers after\n            message-passing. (default: :obj:`2`)\n        num_node_targets (int, optional): Number of individual PSEs used as\n            node-level targets in pretraining :class:`GPSE`.\n            (default: :obj:`51`)\n        num_graph_targets (int, optional): Number of graph-level targets used\n            in pretraining :class:`GPSE`. (default: :obj:`11`)\n        stage_type (str, optional): The type of staging to apply. Possible\n            values are: :obj:`skipsum`, :obj:`skipconcat`. Any other value will\n            default to no skip connections. (default: :obj:`skipsum`)\n        has_bn (bool, optional): Whether to apply batch normalization in the\n            layer. (default: :obj:`True`)\n        final_l2norm (bool, optional): Whether to apply L2 normalization to the\n            outputs. (default: :obj:`True`)\n        has_l2norm (bool, optional): Whether to apply L2 normalization after\n        the layer. (default: :obj:`True`)\n        dropout (float, optional): Dropout ratio at layer output.\n            (default: :obj:`0.2`)\n        has_act (bool, optional): Whether has activation after the layer.\n            (default: :obj:`True`)\n        final_act (bool, optional): Whether to apply activation after the layer\n            stack. (default: :obj:`True`)\n        act (str, optional): Activation to apply to layer output if\n            :obj:`has_act` is :obj:`True`. (default: :obj:`relu`)\n        virtual_node (bool, optional): Whether a virtual node is added to\n            graphs in :class:`GPSE` computation. (default: :obj:`True`)\n        multi_head_dim_inner (int, optional): Width of MLPs for PSE target\n            prediction heads. (default: :obj:`32`)\n        graph_pooling (str, optional): Type of graph pooling applied before\n            post_mp. Options are :obj:`add`, :obj:`max`, :obj:`mean`.\n            (default: :obj:`add`)\n        use_repr (bool, optional): Whether to use the hidden representation of\n            the final layer as :class:`GPSE` encodings. (default: :obj:`True`)\n        repr_type (str, optional): Type of representation to use. Options are\n            :obj:`no_post_mp`, :obj:`one_layer_before`.\n            (default: :obj:`no_post_mp`)\n        bernoulli_threshold (float, optional): Threshold for Bernoulli sampling\n        of virtual nodes. (default: :obj:`0.5`)\n    \"\"\"\n\n    url_dict = {\n        'molpcba':\n        'https://zenodo.org/record/8145095/files/'\n        'gpse_model_molpcba_1.0.pt',\n        'zinc':\n        'https://zenodo.org/record/8145095/files/gpse_model_zinc_1.0.pt',\n        'pcqm4mv2':\n        'https://zenodo.org/record/8145095/files/'\n        'gpse_model_pcqm4mv2_1.0.pt',\n        'geom':\n        'https://zenodo.org/record/8145095/files/gpse_model_geom_1.0.pt',\n        'chembl':\n        'https://zenodo.org/record/8145095/files/gpse_model_chembl_1.0.pt'\n    }\n\n    def __init__(\n        self,\n        dim_in: int = 20,\n        dim_out: int = 51,\n        dim_inner: int = 512,\n        layer_type: str = 'resgatedgcnconv',\n        layers_pre_mp: int = 1,\n        layers_mp: int = 20,\n        layers_post_mp: int = 2,\n        num_node_targets: int = 51,\n        num_graph_targets: int = 11,\n        stage_type: str = 'skipsum',\n        has_bn: bool = True,\n        head_bn: bool = False,\n        final_l2norm: bool = True,\n        has_l2norm: bool = True,\n        dropout: float = 0.2,\n        has_act: bool = True,\n        final_act: bool = True,\n        act: str = 'relu',\n        virtual_node: bool = True,\n        multi_head_dim_inner: int = 32,\n        graph_pooling: str = 'add',\n        use_repr: bool = True,\n        repr_type: str = 'no_post_mp',\n        bernoulli_threshold: float = 0.5,\n    ):\n        super().__init__()\n\n        self.use_repr = use_repr\n        self.repr_type = repr_type\n        self.bernoulli_threshold = bernoulli_threshold\n\n        if layers_pre_mp > 0:\n            self.pre_mp = GeneralMultiLayer(\n                name='linear',\n                in_channels=dim_in,\n                out_channels=dim_inner,\n                hidden_channels=dim_inner,\n                num_layers=layers_pre_mp,\n                has_batch_norm=has_bn,\n                has_l2_norm=has_l2norm,\n                dropout=dropout,\n                act=act,\n                final_act=final_act,\n            )\n            dim_in = dim_inner\n        if layers_mp > 0:\n            self.mp = GNNStackStage(\n                in_channels=dim_in,\n                out_channels=dim_inner,\n                num_layers=layers_mp,\n                layer_type=layer_type,\n                stage_type=stage_type,\n                final_l2_norm=final_l2norm,\n                has_batch_norm=has_bn,\n                has_l2_norm=has_l2norm,\n                dropout=dropout,\n                act=act if has_act else None,\n            )\n\n        self.post_mp = GNNInductiveHybridMultiHead(\n            dim_inner,\n            dim_out,\n            num_node_targets,\n            num_graph_targets,\n            layers_post_mp,\n            virtual_node,\n            multi_head_dim_inner,\n            graph_pooling,\n            head_bn,\n            has_l2norm,\n            dropout,\n            act,\n        )\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        pass\n\n    @classmethod\n    def from_pretrained(cls, name: str, root: str = 'GPSE_pretrained'):\n        r\"\"\"Returns a pretrained :class:`GPSE` model on a dataset.\n\n        Args:\n            name (str): The name of the dataset (:obj:`\"molpcba\"`,\n                :obj:`\"zinc\"`, :obj:`\"pcqm4mv2\"`, :obj:`\"geom\"`,\n                :obj:`\"chembl\"`).\n            root (str, optional): The root directory to save the pre-trained\n                model. (default: :obj:`\"GPSE_pretrained\"`)\n        \"\"\"\n        root = osp.expanduser(osp.normpath(root))\n        os.makedirs(root, exist_ok=True)\n        path = download_url(cls.url_dict[name], root)\n\n        model = GPSE()  # All pretrained models use the default arguments\n        model_state = torch.load(path, map_location='cpu')['model_state']\n        model_state_new = OrderedDict([(k.split('.', 1)[1], v)\n                                       for k, v in model_state.items()])\n        model.load_state_dict(model_state_new)\n\n        # Set the final linear layer to identity if we use hidden reprs\n        if model.use_repr:\n            if model.repr_type == 'one_layer_before':\n                model.post_mp.layer_post_mp.model[-1] = torch.nn.Identity()\n            elif model.repr_type == 'no_post_mp':\n                model.post_mp = IdentityHead()\n            else:\n                raise ValueError(f\"Unknown type '{model.repr_type}'\")\n\n        model.eval()\n        return model\n\n    def forward(self, batch):\n        batch = batch.clone()\n        for module in self.children():\n            batch = module(batch)\n        return batch\n\n\nclass GPSENodeEncoder(torch.nn.Module):\n    r\"\"\"A helper linear/MLP encoder that takes the :class:`GPSE` encodings\n    (based on the `\"Graph Positional and Structural Encoder\"\n    <https://arxiv.org/abs/2307.07107>`_ paper) precomputed as\n    :obj:`batch.pestat_GPSE` in the input graphs, maps them to a desired\n    dimension defined by :obj:`dim_pe_out` and appends them to node features.\n\n    Let's say we have a graph dataset with 64 original node features, and we\n    have generated GPSE encodings of dimension 32, i.e.\n    :obj:`data.pestat_GPSE` = 32. Additionally, we want to use a GNN with an\n    inner dimension of 128. To do so, we can map the 32-dimensional GPSE\n    encodings to a higher dimension of 64, and then append them to the\n    :obj:`x` attribute of the :class:`~torch_geometric.data.Data` objects to\n    obtain a 128-dimensional node feature representation.\n    :class:`~torch_geometric.nn.GPSENodeEncoder` handles both this mapping and\n    concatenation to :obj:`x`, the outputs of which can be used as input to a\n    GNN:\n\n    .. code-block:: python\n\n        encoder = GPSENodeEncoder(dim_emb=128, dim_pe_in=32, dim_pe_out=64,\n                                  expand_x=False)\n        gnn = GNN(...)\n\n        for batch in loader:\n            x = encoder(batch.x, batch.pestat_GPSE)\n            batch = gnn(x, batch.edge_index)\n\n    Args:\n        dim_emb (int): Size of final node embedding.\n        dim_pe_in (int): Original dimension of :obj:`batch.pestat_GPSE`.\n        dim_pe_out (int): Desired dimension of :class:`GPSE` after the encoder.\n        dim_in (int, optional): Original dimension of input node features,\n            required only if :obj:`expand_x` is set to :obj:`True`.\n            (default: :obj:`None`)\n        expand_x (bool, optional): Expand node features :obj:`x` from\n            :obj:`dim_in` to (:obj:`dim_emb` - :obj:`dim_pe_out`)\n        norm_type (str, optional): Type of normalization to apply.\n            (default: :obj:`batchnorm`)\n        model_type (str, optional): Type of encoder, either :obj:`mlp` or\n            :obj:`linear`. (default: :obj:`mlp`)\n        n_layers (int, optional): Number of MLP layers if :obj:`model_type` is\n            :obj:`mlp`. (default: :obj:`2`)\n        dropout_be (float, optional): Dropout ratio of inputs to encoder, i.e.\n            before encoding. (default: :obj:`0.5`)\n        dropout_ae (float, optional): Dropout ratio of outputs, i.e. after\n            encoding. (default: :obj:`0.2`)\n    \"\"\"\n    def __init__(self, dim_emb: int, dim_pe_in: int, dim_pe_out: int,\n                 dim_in: int = None, expand_x=False, norm_type='batchnorm',\n                 model_type='mlp', n_layers=2, dropout_be=0.5, dropout_ae=0.2):\n        super().__init__()\n\n        assert dim_emb > dim_pe_out, ('Desired GPSE dimension (dim_pe_out) '\n                                      'must be smaller than the final node '\n                                      'embedding dimension (dim_emb).')\n\n        if expand_x:\n            self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe_out)\n        self.expand_x = expand_x\n\n        self.raw_norm = None\n        if norm_type == 'batchnorm':\n            self.raw_norm = nn.BatchNorm1d(dim_pe_in)\n\n        self.dropout_be = nn.Dropout(p=dropout_be)\n        self.dropout_ae = nn.Dropout(p=dropout_ae)\n\n        activation = nn.ReLU  # register.act_dict[cfg.gnn.act]\n        if model_type == 'mlp':\n            layers = []\n            if n_layers == 1:\n                layers.append(torch.nn.Linear(dim_pe_in, dim_pe_out))\n                layers.append(activation())\n            else:\n                layers.append(torch.nn.Linear(dim_pe_in, 2 * dim_pe_out))\n                layers.append(activation())\n                for _ in range(n_layers - 2):\n                    layers.append(\n                        torch.nn.Linear(2 * dim_pe_out, 2 * dim_pe_out))\n                    layers.append(activation())\n                layers.append(torch.nn.Linear(2 * dim_pe_out, dim_pe_out))\n                layers.append(activation())\n            self.pe_encoder = nn.Sequential(*layers)\n        elif model_type == 'linear':\n            self.pe_encoder = nn.Linear(dim_pe_in, dim_pe_out)\n        else:\n            raise ValueError(f\"{self.__class__.__name__}: Does not support \"\n                             f\"'{model_type}' encoder model.\")\n\n    def forward(self, x, pos_enc):\n        pos_enc = self.dropout_be(pos_enc)\n        pos_enc = self.raw_norm(pos_enc) if self.raw_norm else pos_enc\n        pos_enc = self.pe_encoder(pos_enc)  # (Num nodes) x dim_pe\n        pos_enc = self.dropout_ae(pos_enc)\n\n        # Expand node features if needed\n        h = self.linear_x(x) if self.expand_x else x\n\n        # Concatenate final PEs to input embedding\n        return torch.cat((h, pos_enc), 1)\n\n\n@torch.no_grad()\ndef gpse_process(\n    model: Module,\n    data: Data,\n    rand_type: str,\n    use_vn: bool = True,\n    bernoulli_thresh: float = 0.5,\n    neighbor_loader: bool = False,\n    num_neighbors: Optional[List[int]] = None,\n    fillval: int = 5,\n    layers_mp: int = None,\n    **kwargs,\n) -> torch.Tensor:\n    r\"\"\"Processes the data using the :class:`GPSE` model to generate and append\n    GPSE encodings. Identical to :obj:`gpse_process_batch`, but operates on a\n    single :class:`~torch_geometric.data.Dataset` object.\n\n    Unlike transform-based GPSE processing (i.e.\n    :class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument\n    does not append virtual nodes if set to :obj:`True`, and instead assumes\n    the input graphs to :obj:`gpse_process` already have virtual nodes. Under\n    normal circumstances, one does not need to call this function; running\n    :obj:`precompute_GPSE` on your whole dataset is advised instead.\n\n    Args:\n        model (Module): The :class:`GPSE` model.\n        data (torch_geometric.data.Data): A :class:`~torch_geometric.data.Data`\n            object.\n        rand_type (str, optional): Type of random features to use. Options are\n            :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.\n            (default: :obj:`NormalSE`)\n        use_vn (bool, optional): Whether the input graphs have virtual nodes.\n            (default: :obj:`True`)\n        bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of\n            virtual nodes. (default: :obj:`0.5`)\n        neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`.\n            (default: :obj:`False`)\n        num_neighbors (List[int], optional): Number of neighbors to consider\n            for each message-passing layer. (default: :obj:`[30, 20, 10]`)\n        fillval (int, optional): Value to fill for missing\n            :obj:`num_neighbors`. (default: :obj:`5`)\n        layers_mp (int, optional): Number of message-passing layers.\n            (default: :obj:`None`)\n        **kwargs (optional): Additional arguments for :obj:`NeighborLoader`.\n\n    Returns:\n        torch.Tensor: A tensor corresponding to the original\n        :class:`~torch_geometric.data.Data` object, with :class:`GPSE`\n        encodings appended as :obj:`out.pestat_GPSE` attribute.\n    \"\"\"\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    # Generate random features for the encoder\n    n = data.num_nodes\n    dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1]\n\n    # Prepare input distributions for GPSE\n    if rand_type == 'NormalSE':\n        rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in))\n    elif rand_type == 'UniformSE':\n        rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))\n    elif rand_type == 'BernoulliSE':\n        rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))\n        rand = (rand < bernoulli_thresh)\n    else:\n        raise ValueError(f'Unknown {rand_type=!r}')\n    data.x = torch.from_numpy(rand.astype('float32'))\n\n    if use_vn:\n        data.x[-1] = 0\n\n    model, data = model.to(device), data.to(device)\n    # Generate encodings using the pretrained encoder\n    if neighbor_loader:\n        if layers_mp is None:\n            raise ValueError('Please provide the number of message-passing '\n                             'layers as \"layers_mp\".')\n\n        num_neighbors = num_neighbors or [30, 20, 10]\n        diff = layers_mp - len(num_neighbors)\n        if fillval > 0 and diff > 0:\n            num_neighbors += [fillval] * diff\n\n        loader = NeighborLoader(data, num_neighbors=num_neighbors,\n                                shuffle=False, pin_memory=True, **kwargs)\n        out_list = []\n        pbar = trange(data.num_nodes, position=2)\n        for batch in loader:\n            out, _ = model(batch.to(device))\n            out = out[:batch.batch_size].to(\"cpu\", non_blocking=True)\n            out_list.append(out)\n            pbar.update(batch.batch_size)\n        out = torch.vstack(out_list)\n    else:\n        out, _ = model(data)\n        out = out.to(\"cpu\")\n\n    return out\n\n\n@torch.no_grad()\ndef gpse_process_batch(\n    model: GPSE,\n    batch,\n    rand_type: str,\n    use_vn: bool = True,\n    bernoulli_thresh: float = 0.5,\n    neighbor_loader: bool = False,\n    num_neighbors: Optional[List[int]] = None,\n    fillval: int = 5,\n    layers_mp: int = None,\n    **kwargs,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n    r\"\"\"Process a batch of data using the :class:`GPSE` model to generate and\n    append :class:`GPSE` encodings. Identical to `gpse_process`, but operates\n    on a batch of :class:`~torch_geometric.data.Data` objects.\n\n    Unlike transform-based GPSE processing (i.e.\n    :class:`~torch_geometric.transforms.AddGPSE`), the :obj:`use_vn` argument\n    does not append virtual nodes if set to :obj:`True`, and instead assumes\n    the input graphs to :obj:`gpse_process` already have virtual nodes. This is\n    because the virtual nodes are already added to graphs before the call to\n    :obj:`gpse_process_batch` in :obj:`precompute_GPSE` for better efficiency.\n    Under normal circumstances, one does not need to call this function;\n    running :obj:`precompute_GPSE` on your whole dataset is advised instead.\n\n    Args:\n        model (GPSE): The :class:`GPSE` model.\n        batch: A batch of PyG Data objects.\n        rand_type (str, optional): Type of random features to use. Options are\n            :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.\n            (default: :obj:`NormalSE`)\n        use_vn (bool, optional): Whether the input graphs have virtual nodes.\n            (default: :obj:`True`)\n        bernoulli_thresh (float, optional): Threshold for Bernoulli sampling of\n            virtual nodes. (default: :obj:`0.5`)\n        neighbor_loader (bool, optional): Whether to use :obj:`NeighborLoader`.\n            (default: :obj:`False`)\n        num_neighbors (List[int], optional): Number of neighbors to consider\n            for each message-passing layer. (default: :obj:`[30, 20, 10]`)\n        fillval (int, optional): Value to fill for missing\n            :obj:`num_neighbors`. (default: :obj:`5`)\n        layers_mp (int, optional): Number of message-passing layers.\n            (default: :obj:`None`)\n        **kwargs: Additional keyword arguments for :obj:`NeighborLoader`.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding\n            to the stacked :class:`GPSE` encodings and the pointers indicating\n            individual graphs.\n    \"\"\"\n    n = batch.num_nodes\n    dim_in = model.state_dict()[list(model.state_dict())[0]].shape[1]\n\n    # Prepare input distributions for GPSE\n    if rand_type == 'NormalSE':\n        rand = np.random.normal(loc=0, scale=1.0, size=(n, dim_in))\n    elif rand_type == 'UniformSE':\n        rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))\n    elif rand_type == 'BernoulliSE':\n        rand = np.random.uniform(low=0.0, high=1.0, size=(n, dim_in))\n        rand = (rand < bernoulli_thresh)\n    else:\n        raise ValueError(f'Unknown {rand_type=!r}')\n    batch.x = torch.from_numpy(rand.astype('float32'))\n\n    if use_vn:\n        # HACK: We need to reset virtual node features to zeros to match the\n        # pretraining setting (virtual node applied after random node features\n        # are set, and the default node features for the virtual node are all\n        # zeros). Can potentially test if initializing virtual node features to\n        # random features is better than setting them to zeros.\n        for i in batch.ptr[1:]:\n            batch.x[i - 1] = 0\n\n    # Generate encodings using the pretrained encoder\n    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n    model = model.to(device)\n    if neighbor_loader:\n        if layers_mp is None:\n            raise ValueError('Please provide the number of message-passing '\n                             'layers as \"layers_mp\".')\n\n        num_neighbors = num_neighbors or [30, 20, 10]\n        diff = layers_mp - len(num_neighbors)\n        if fillval > 0 and diff > 0:\n            num_neighbors += [fillval] * diff\n\n        loader = NeighborLoader(batch, num_neighbors=num_neighbors,\n                                shuffle=False, pin_memory=True, **kwargs)\n        out_list = []\n        pbar = trange(batch.num_nodes, position=2)\n        for batch in loader:\n            out, _ = model(batch.to(device))\n            out = out[:batch.batch_size].to('cpu', non_blocking=True)\n            out_list.append(out)\n            pbar.update(batch.batch_size)\n        out = torch.vstack(out_list)\n    else:\n        out, _ = model(batch.to(device))\n        out = out.to('cpu')\n\n    return out, batch.ptr\n\n\n@torch.no_grad()\ndef precompute_GPSE(model: GPSE, dataset: Dataset, use_vn: bool = True,\n                    rand_type: str = 'NormalSE', **kwargs):\n    r\"\"\"Precomputes :class:`GPSE` encodings in-place for a given dataset using\n    a :class:`GPSE` model.\n\n    Args:\n        model (GPSE): The :class:`GPSE` model.\n        dataset (Dataset): A PyG Dataset.\n        use_vn (bool, optional): Whether to append virtual nodes to graphs in\n            :class:`GPSE` computation. Should match the setting used when\n            pre-training the :class:`GPSE` model. (default :obj:`True`)\n        rand_type (str, optional): The type of randomization to use.\n            (default :obj:`NormalSE`)\n        **kwargs (optional): Additional arguments for\n            :class:`~torch_geometric.data.DataLoader`.\n    \"\"\"\n    # Temporarily replace the transformation\n    orig_dataset_transform = dataset.transform\n    dataset.transform = None\n    if use_vn:\n        dataset.transform = T.VirtualNode()\n\n    # Remove split indices, to be recovered at the end of the precomputation\n    tmp_store = {}\n    for name in [\n            'train_mask', 'val_mask', 'test_mask', 'train_graph_index',\n            'val_graph_index', 'test_graph_index', 'train_edge_index',\n            'val_edge_index', 'test_edge_index'\n    ]:\n        if (name in dataset.data) and (dataset.slices is None\n                                       or name in dataset.slices):\n            tmp_store_data = dataset.data.pop(name)\n            tmp_store_slices = dataset.slices.pop(name) \\\n                if dataset.slices else None\n            tmp_store[name] = (tmp_store_data, tmp_store_slices)\n\n    loader = DataLoader(dataset, shuffle=False, pin_memory=True, **kwargs)\n\n    # Batched GPSE precomputation loop\n    data_list = []\n    curr_idx = 0\n    pbar = trange(len(dataset), desc='Pre-computing GPSE')\n    tic = time.perf_counter()\n    for batch in loader:\n        batch_out, batch_ptr = gpse_process_batch(model, batch, rand_type,\n                                                  **kwargs)\n\n        batch_out = batch_out.to('cpu', non_blocking=True)\n        # Need to wait for batch_ptr to finish transferring so that start and\n        # end indices are ready to use\n        batch_ptr = batch_ptr.to('cpu', non_blocking=False)\n\n        for start, end in zip(batch_ptr[:-1], batch_ptr[1:]):\n            data = dataset.get(curr_idx)\n            if use_vn:\n                end = end - 1\n            data.pestat_GPSE = batch_out[start:end]\n            data_list.append(data)\n            curr_idx += 1\n\n        pbar.update(len(batch_ptr) - 1)\n    pbar.close()\n\n    # Collate dataset and reset indices and data list\n    dataset.transform = orig_dataset_transform\n    dataset._indices = None\n    dataset._data_list = data_list\n    dataset.data, dataset.slices = dataset.collate(data_list)\n\n    # Recover split indices\n    for name, (tmp_store_data, tmp_store_slices) in tmp_store.items():\n        dataset.data[name] = tmp_store_data\n        if tmp_store_slices is not None:\n            dataset.slices[name] = tmp_store_slices\n    dataset._data_list = None\n\n    timestr = time.strftime('%H:%M:%S', time.gmtime(time.perf_counter() - tic))\n    logging.info(f'Finished GPSE pre-computation, took {timestr}')\n\n    # Release resource and recover original configs\n    del model\n    torch.cuda.empty_cache()\n\n\ndef cosim_col_sep(pred: torch.Tensor, true: torch.Tensor,\n                  batch_idx: torch.Tensor) -> torch.Tensor:\n    r\"\"\"Calculates the average cosine similarity between predicted and true\n    features on a batch of graphs.\n\n    Args:\n        pred (torch.Tensor): Predicted outputs.\n        true (torch.Tensor): Value of ground truths.\n        batch_idx (torch.Tensor): Batch indices to separate the graphs.\n\n    Returns:\n        torch.Tensor: Average cosine similarity per graph in batch.\n\n    Raises:\n        ValueError: If batch_index is not specified.\n    \"\"\"\n    if batch_idx is None:\n        raise ValueError(\"mae_cosim_col_sep requires batch index as \"\n                         \"input to distinguish different graphs.\")\n    batch_idx = batch_idx + 1 if batch_idx.min() == -1 else batch_idx\n    pred_dense = to_dense_batch(pred, batch_idx)[0]\n    true_dense = to_dense_batch(true, batch_idx)[0]\n    mask = (true_dense == 0).all(1)  # exclude trivial features from loss\n    loss = 1 - F.cosine_similarity(pred_dense, true_dense, dim=1)[~mask].mean()\n    return loss\n\n\ndef gpse_loss(pred: torch.Tensor, true: torch.Tensor,\n              batch_idx: torch.Tensor = None) \\\n        -> Tuple[torch.Tensor, torch.Tensor]:\n    r\"\"\"Calculates :class:`GPSE` loss as the sum of MAE loss and cosine\n    similarity loss over a batch of graphs.\n\n    Args:\n        pred (torch.Tensor): Predicted outputs.\n        true (torch.Tensor): Value of ground truths.\n        batch_idx (torch.Tensor): Batch indices to separate the graphs.\n\n    Returns:\n        Tuple[torch.Tensor, torch.Tensor]: A two-tuple of tensors corresponding\n        to the :class:`GPSE` loss and the predicted node-and-graph level\n        outputs.\n\n    \"\"\"\n    if batch_idx is None:\n        raise ValueError(\"mae_cosim_col_sep requires batch index as \"\n                         \"input to distinguish different graphs.\")\n    mae_loss = F.l1_loss(pred, true)\n    cosim_loss = cosim_col_sep(pred, true, batch_idx)\n    loss = mae_loss + cosim_loss\n    return loss, pred\n\n\ndef process_batch_idx(batch_idx, true, use_vn=True):\n    r\"\"\"Processes batch indices to adjust for the removal of virtual nodes, and\n    pads batch index for hybrid tasks.\n\n    Args:\n        batch_idx: Batch indices to separate the graphs.\n        true: Value of ground truths.\n        use_vn: If input graphs have virtual nodes that need to be removed.\n\n    Returns:\n        torch.Tensor: Batch indices that separate the graphs.\n    \"\"\"\n    if batch_idx is None:\n        return\n    if use_vn:  # remove virtual node\n        batch_idx = torch.concat([\n            batch_idx[batch_idx == i][:-1]\n            for i in range(batch_idx.max().item() + 1)\n        ])\n    # Pad batch index for hybrid tasks (set batch index for graph heads to -1)\n    if (pad := true.shape[0] - batch_idx.shape[0]) > 0:\n        pad_idx = -torch.ones(pad, dtype=torch.long, device=batch_idx.device)\n        batch_idx = torch.hstack([batch_idx, pad_idx])\n    return batch_idx\n"
  },
  {
    "path": "torch_geometric/nn/models/graph_mixer.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import LayerNorm, Linear\n\nfrom torch_geometric.nn import TemporalEncoding\nfrom torch_geometric.utils import scatter, to_dense_batch\n\n\nclass NodeEncoder(torch.nn.Module):\n    r\"\"\"The node encoder module from the `\"Do We Really Need Complicated\n    Model Architectures for Temporal Networks?\"\n    <https://openreview.net/forum?id=ayPPc0SyLv1>`_ paper.\n    :class:`NodeEncoder` captures the 1-hop temporal neighborhood information\n    via mean pooling.\n\n    .. math::\n        \\mathbf{x}_v^{\\prime}(t_0) = \\mathbf{x}_v + \\textrm{mean} \\left\\{\n        \\mathbf{x}_w : w \\in \\mathcal{N}(v, t_0 - T, t_0) \\right\\}\n\n    Args:\n        time_window (int): The temporal window size :math:`T` to define the\n            1-hop temporal neighborhood.\n    \"\"\"\n    def __init__(self, time_window: int):\n        super().__init__()\n        self.time_window = time_window\n\n    def reset_parameters(self):\n        pass\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        edge_time: Tensor,\n        seed_time: Tensor,\n    ) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The input node features.\n            edge_index (torch.Tensor): The edge indices.\n            edge_time (torch.Tensor): The timestamp attached to every edge.\n            seed_time (torch.Tensor): The seed time :math:`t_0` for every\n                destination node.\n        \"\"\"\n        mask = ((edge_time <= seed_time[edge_index[1]]) &\n                (edge_time > seed_time[edge_index[1]] - self.time_window))\n\n        src, dst = edge_index[:, mask]\n        mean = scatter(x[src], dst, dim=0, dim_size=x.size(0), reduce='mean')\n        return x + mean\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(time_window={self.time_window})'\n\n\nclass _MLPMixer(torch.nn.Module):\n    r\"\"\"The MLP-Mixer module.\n\n    Args:\n        num_tokens (int): Number of tokens/patches in each sample.\n        in_channels (int): Input channels.\n        out_channels (int): Output channels.\n        dropout (float, optional): Dropout probability. (default: :obj:`0.0`)\n    \"\"\"\n    def __init__(\n        self,\n        num_tokens: int,\n        in_channels: int,\n        out_channels: int,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n\n        self.dropout = dropout\n\n        self.token_norm = LayerNorm(in_channels)\n        self.token_lin1 = Linear(num_tokens, num_tokens // 2)\n        self.token_lin2 = Linear(num_tokens // 2, num_tokens)\n\n        self.channel_norm = LayerNorm(in_channels)\n        self.channel_lin1 = Linear(in_channels, 4 * in_channels)\n        self.channel_lin2 = Linear(4 * in_channels, in_channels)\n\n        self.head_norm = LayerNorm(in_channels)\n        self.head_lin = Linear(in_channels, out_channels)\n\n    def reset_parameters(self):\n        self.token_norm.reset_parameters()\n        self.token_lin1.reset_parameters()\n        self.token_lin2.reset_parameters()\n        self.channel_norm.reset_parameters()\n        self.channel_lin1.reset_parameters()\n        self.channel_lin2.reset_parameters()\n        self.head_norm.reset_parameters()\n        self.head_lin.reset_parameters()\n\n    def forward(self, x: Tensor) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): Tensor of size\n                :obj:`[*, num_tokens, in_channels]`.\n\n        Returns:\n            Tensor of size :obj:`[*, out_channels]`.\n        \"\"\"\n        # Token mixing:\n        h = self.token_norm(x).mT\n        h = self.token_lin1(h)\n        h = F.gelu(h)\n        h = F.dropout(h, p=self.dropout, training=self.training)\n        h = self.token_lin2(h)\n        h = F.dropout(h, p=self.dropout, training=self.training)\n        h_token = h.mT + x\n\n        # Channel mixing:\n        h = self.channel_norm(h_token)\n        h = self.channel_lin1(h)\n        h = F.gelu(h)\n        h = F.dropout(h, p=self.dropout, training=self.training)\n        h = self.channel_lin2(h)\n        h = F.dropout(h, p=self.dropout, training=self.training)\n        h_channel = h + h_token\n\n        # Head:\n        out = self.head_norm(h_channel)\n        out = out.mean(dim=1)\n        out = self.head_lin(out)\n        return out\n\n\ndef get_latest_k_edge_attr(\n    k: int,\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    edge_time: Tensor,\n    num_nodes: int,\n    is_sorted: bool = False,\n) -> Tensor:\n    r\"\"\"Returns the latest :obj:`k` incoming edge attributes by\n    :obj:`edge_time` for each node.\n    The shape of the output tensor is :obj:`[num_nodes, k, edge_attr_dim]`.\n    Nodes with fewer than :obj:`k` incoming edges are zero-padded.\n    \"\"\"\n    _, col = edge_index\n\n    if not is_sorted:\n        perm = np.lexsort([\n            -edge_time.detach().cpu().numpy(),\n            col.detach().cpu().numpy(),\n        ])\n        perm = torch.from_numpy(perm).to(edge_index.device)\n        col = col[perm]\n        edge_attr = edge_attr[perm]\n\n    return to_dense_batch(\n        edge_attr,\n        col,\n        max_num_nodes=k,\n        batch_size=num_nodes,\n    )[0]\n\n\nclass LinkEncoder(torch.nn.Module):\n    r\"\"\"The link encoder module from the `\"Do We Really Need Complicated\n    Model Architectures for Temporal Networks?\"\n    <https://openreview.net/forum?id=ayPPc0SyLv1>`_ paper.\n    It is composed of two components: (1) :class:`TemporalEncoding` maps each\n    edge timestamp to a :obj:`time_channels`-dimensional vector; (2) an MLP\n    that groups and maps the :math:`k`-latest encoded timestamps and edge\n    features to a :obj:`out_channels`-dimensional representation.\n\n    Args:\n        k (int): The number of most recent temporal links to use.\n        in_channels (int): The edge feature dimensionality.\n        hidden_channels (int): Size of each hidden sample.\n        time_channels (int): Size of encoded timestamp.\n        out_channels (int): Size of each output sample.\n        is_sorted (bool, optional): If set to :obj:`True`, assumes that\n            :obj:`edge_index` is sorted by column and the\n            rows are sorted according to :obj:`edge_time`\n            within individual neighborhoods. This avoids internal\n            re-sorting of the data and can improve runtime and memory\n            efficiency. (default: :obj:`False`)\n        dropout (float, optional): Dropout probability of the MLP layer.\n            (default: :obj:`0.0`)\n    \"\"\"\n    def __init__(\n        self,\n        k: int,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        time_channels: int,\n        is_sorted: bool = False,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n\n        self.k = k\n        self.in_channels = in_channels\n        self.hidden_channels = hidden_channels\n        self.out_channels = out_channels\n        self.time_channels = time_channels\n        self.is_sorted = is_sorted\n        self.dropout = dropout\n\n        self.temporal_encoder = TemporalEncoding(time_channels)\n        self.temporal_head = Linear(time_channels + in_channels,\n                                    hidden_channels)\n\n        self.mlp_mixer = _MLPMixer(  # MLP that summarizes temporal embeddings:\n            num_tokens=k,\n            in_channels=hidden_channels,\n            out_channels=out_channels,\n            dropout=dropout,\n        )\n\n    def reset_parameters(self):\n        self.temporal_encoder.reset_parameters()\n        self.temporal_head.reset_parameters()\n        self.mlp_mixer.reset_parameters()\n\n    def forward(\n        self,\n        edge_index: Tensor,\n        edge_attr: Tensor,\n        edge_time: Tensor,\n        seed_time: Tensor,\n    ) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            edge_index (torch.Tensor): The edge indices.\n            edge_attr (torch.Tensor): The edge features of shape\n                :obj:`[num_edges, in_channels]`.\n            edge_time (torch.Tensor): The time tensor of shape\n                :obj:`[num_edges]`. This can be in the order of millions.\n            seed_time (torch.Tensor): The seed time :math:`t_0` for every\n                destination node.\n\n        Returns:\n            A node embedding tensor of shape :obj:`[num_nodes, out_channels]`.\n        \"\"\"\n        mask = edge_time <= seed_time[edge_index[1]]\n\n        edge_index = edge_index[:, mask]\n        edge_attr = edge_attr[mask]\n        edge_time = edge_time[mask]\n\n        time_enc = self.temporal_encoder(seed_time[edge_index[1]] - edge_time)\n        edge_attr = torch.cat([time_enc, edge_attr], dim=-1)\n        edge_attr = self.temporal_head(edge_attr)\n\n        edge_attr = get_latest_k_edge_attr(\n            k=self.k,\n            edge_index=edge_index,\n            edge_attr=edge_attr,\n            edge_time=edge_time,\n            num_nodes=seed_time.size(0),\n            is_sorted=self.is_sorted,\n        )\n\n        return self.mlp_mixer(edge_attr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(k={self.k}, '\n                f'in_channels={self.in_channels}, '\n                f'hidden_channels={self.hidden_channels}, '\n                f'out_channels={self.out_channels}, '\n                f'time_channels={self.time_channels}, '\n                f'dropout={self.dropout})')\n"
  },
  {
    "path": "torch_geometric/nn/models/graph_unet.py",
    "content": "from typing import Callable, List, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn import GCNConv, TopKPooling\nfrom torch_geometric.nn.resolver import activation_resolver\nfrom torch_geometric.typing import OptTensor, PairTensor\nfrom torch_geometric.utils import (\n    add_self_loops,\n    remove_self_loops,\n    to_torch_csr_tensor,\n)\nfrom torch_geometric.utils.repeat import repeat\n\n\nclass GraphUNet(torch.nn.Module):\n    r\"\"\"The Graph U-Net model from the `\"Graph U-Nets\"\n    <https://arxiv.org/abs/1905.05178>`_ paper which implements a U-Net like\n    architecture with graph pooling and unpooling operations.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        hidden_channels (int): Size of each hidden sample.\n        out_channels (int): Size of each output sample.\n        depth (int): The depth of the U-Net architecture.\n        pool_ratios (float or [float], optional): Graph pooling ratio for each\n            depth. (default: :obj:`0.5`)\n        sum_res (bool, optional): If set to :obj:`False`, will use\n            concatenation for integration of skip connections instead\n            summation. (default: :obj:`True`)\n        act (torch.nn.functional, optional): The nonlinearity to use.\n            (default: :obj:`torch.nn.functional.relu`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        depth: int,\n        pool_ratios: Union[float, List[float]] = 0.5,\n        sum_res: bool = True,\n        act: Union[str, Callable] = 'relu',\n    ):\n        super().__init__()\n        assert depth >= 1\n        self.in_channels = in_channels\n        self.hidden_channels = hidden_channels\n        self.out_channels = out_channels\n        self.depth = depth\n        self.pool_ratios = repeat(pool_ratios, depth)\n        self.act = activation_resolver(act)\n        self.sum_res = sum_res\n\n        channels = hidden_channels\n\n        self.down_convs = torch.nn.ModuleList()\n        self.pools = torch.nn.ModuleList()\n        self.down_convs.append(GCNConv(in_channels, channels, improved=True))\n        for i in range(depth):\n            self.pools.append(TopKPooling(channels, self.pool_ratios[i]))\n            self.down_convs.append(GCNConv(channels, channels, improved=True))\n\n        in_channels = channels if sum_res else 2 * channels\n\n        self.up_convs = torch.nn.ModuleList()\n        for _ in range(depth - 1):\n            self.up_convs.append(GCNConv(in_channels, channels, improved=True))\n        self.up_convs.append(GCNConv(in_channels, out_channels, improved=True))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for conv in self.down_convs:\n            conv.reset_parameters()\n        for pool in self.pools:\n            pool.reset_parameters()\n        for conv in self.up_convs:\n            conv.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: OptTensor = None,\n        edge_weight: Tensor = None,\n    ) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        if batch is None:\n            batch = edge_index.new_zeros(x.size(0))\n\n        if edge_weight is None:\n            edge_weight = x.new_ones(edge_index.size(1))\n        assert edge_weight.dim() == 1\n        assert edge_weight.size(0) == edge_index.size(1)\n\n        x = self.down_convs[0](x, edge_index, edge_weight)\n        x = self.act(x)\n\n        xs = [x]\n        edge_indices = [edge_index]\n        edge_weights = [edge_weight]\n        perms = []\n\n        for i in range(1, self.depth + 1):\n            edge_index, edge_weight = self.augment_adj(edge_index, edge_weight,\n                                                       x.size(0))\n            x, edge_index, edge_weight, batch, perm, _ = self.pools[i - 1](\n                x, edge_index, edge_weight, batch)\n\n            x = self.down_convs[i](x, edge_index, edge_weight)\n            x = self.act(x)\n\n            if i < self.depth:\n                xs += [x]\n                edge_indices += [edge_index]\n                edge_weights += [edge_weight]\n            perms += [perm]\n\n        for i in range(self.depth):\n            j = self.depth - 1 - i\n\n            res = xs[j]\n            edge_index = edge_indices[j]\n            edge_weight = edge_weights[j]\n            perm = perms[j]\n\n            up = torch.zeros_like(res)\n            up[perm] = x\n            x = res + up if self.sum_res else torch.cat((res, up), dim=-1)\n\n            x = self.up_convs[i](x, edge_index, edge_weight)\n            x = self.act(x) if i < self.depth - 1 else x\n\n        return x\n\n    def augment_adj(self, edge_index: Tensor, edge_weight: Tensor,\n                    num_nodes: int) -> PairTensor:\n        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)\n        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,\n                                                 num_nodes=num_nodes)\n        adj = to_torch_csr_tensor(edge_index, edge_weight,\n                                  size=(num_nodes, num_nodes))\n        adj = (adj @ adj).to_sparse_coo()\n        edge_index, edge_weight = adj.indices(), adj.values()\n        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)\n        return edge_index, edge_weight\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.hidden_channels}, {self.out_channels}, '\n                f'depth={self.depth}, pool_ratios={self.pool_ratios})')\n"
  },
  {
    "path": "torch_geometric/nn/models/jumping_knowledge.py",
    "content": "from typing import Dict, List, Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import LSTM, Linear\n\n\nclass JumpingKnowledge(torch.nn.Module):\n    r\"\"\"The Jumping Knowledge layer aggregation module from the\n    `\"Representation Learning on Graphs with Jumping Knowledge Networks\"\n    <https://arxiv.org/abs/1806.03536>`_ paper.\n\n    Jumping knowledge is performed based on either **concatenation**\n    (:obj:`\"cat\"`)\n\n    .. math::\n\n        \\mathbf{x}_v^{(1)} \\, \\Vert \\, \\ldots \\, \\Vert \\, \\mathbf{x}_v^{(T)},\n\n    **max pooling** (:obj:`\"max\"`)\n\n    .. math::\n\n        \\max \\left( \\mathbf{x}_v^{(1)}, \\ldots, \\mathbf{x}_v^{(T)} \\right),\n\n    or **weighted summation**\n\n    .. math::\n\n        \\sum_{t=1}^T \\alpha_v^{(t)} \\mathbf{x}_v^{(t)}\n\n    with attention scores :math:`\\alpha_v^{(t)}` obtained from a bi-directional\n    LSTM (:obj:`\"lstm\"`).\n\n    Args:\n        mode (str): The aggregation scheme to use\n            (:obj:`\"cat\"`, :obj:`\"max\"` or :obj:`\"lstm\"`).\n        channels (int, optional): The number of channels per representation.\n            Needs to be only set for LSTM-style aggregation.\n            (default: :obj:`None`)\n        num_layers (int, optional): The number of layers to aggregate. Needs to\n            be only set for LSTM-style aggregation. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        mode: str,\n        channels: Optional[int] = None,\n        num_layers: Optional[int] = None,\n    ) -> None:\n        super().__init__()\n        self.mode = mode.lower()\n        assert self.mode in ['cat', 'max', 'lstm']\n\n        if mode == 'lstm':\n            assert channels is not None, 'channels cannot be None for lstm'\n            assert num_layers is not None, 'num_layers cannot be None for lstm'\n            self.lstm = LSTM(channels, (num_layers * channels) // 2,\n                             bidirectional=True, batch_first=True)\n            self.att = Linear(2 * ((num_layers * channels) // 2), 1)\n            self.channels = channels\n            self.num_layers = num_layers\n        else:\n            self.lstm = None\n            self.att = None\n            self.channels = None\n            self.num_layers = None\n\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        if self.lstm is not None:\n            self.lstm.reset_parameters()\n        if self.att is not None:\n            self.att.reset_parameters()\n\n    def forward(self, xs: List[Tensor]) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            xs (List[torch.Tensor]): List containing the layer-wise\n                representations.\n        \"\"\"\n        if self.mode == 'cat':\n            return torch.cat(xs, dim=-1)\n        elif self.mode == 'max':\n            return torch.stack(xs, dim=-1).max(dim=-1)[0]\n        else:  # self.mode == 'lstm'\n            assert self.lstm is not None and self.att is not None\n            x = torch.stack(xs, dim=1)  # [num_nodes, num_layers, num_channels]\n            alpha, _ = self.lstm(x)\n            alpha = self.att(alpha).squeeze(-1)  # [num_nodes, num_layers]\n            alpha = torch.softmax(alpha, dim=-1)\n            return (x * alpha.unsqueeze(-1)).sum(dim=1)\n\n    def __repr__(self) -> str:\n        if self.mode == 'lstm':\n            return (f'{self.__class__.__name__}({self.mode}, '\n                    f'channels={self.channels}, layers={self.num_layers})')\n        return f'{self.__class__.__name__}({self.mode})'\n\n\nclass HeteroJumpingKnowledge(torch.nn.Module):\n    r\"\"\"A heterogeneous version of the :class:`JumpingKnowledge` module.\n\n    Args:\n        types (List[str]): The keys of the input dictionary.\n        mode (str): The aggregation scheme to use\n            (:obj:`\"cat\"`, :obj:`\"max\"` or :obj:`\"lstm\"`).\n        channels (int, optional): The number of channels per representation.\n            Needs to be only set for LSTM-style aggregation.\n            (default: :obj:`None`)\n        num_layers (int, optional): The number of layers to aggregate. Needs to\n            be only set for LSTM-style aggregation. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        types: List[str],\n        mode: str,\n        channels: Optional[int] = None,\n        num_layers: Optional[int] = None,\n    ) -> None:\n        super().__init__()\n\n        self.mode = mode.lower()\n\n        self.jk_dict = torch.nn.ModuleDict({\n            key:\n            JumpingKnowledge(mode, channels, num_layers)\n            for key in types\n        })\n\n    def reset_parameters(self) -> None:\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for jk in self.jk_dict.values():\n            jk.reset_parameters()\n\n    def forward(self, xs_dict: Dict[str, List[Tensor]]) -> Dict[str, Tensor]:\n        r\"\"\"Forward pass.\n\n        Args:\n            xs_dict (Dict[str, List[torch.Tensor]]): A dictionary holding a\n                list of layer-wise representation for each type.\n        \"\"\"\n        return {key: jk(xs_dict[key]) for key, jk in self.jk_dict.items()}\n\n    def __repr__(self):\n        if self.mode == 'lstm':\n            jk = next(iter(self.jk_dict.values()))\n            return (f'{self.__class__.__name__}('\n                    f'num_types={len(self.jk_dict)}, '\n                    f'mode={self.mode}, channels={jk.channels}, '\n                    f'layers={jk.num_layers})')\n        return (f'{self.__class__.__name__}(num_types={len(self.jk_dict)}, '\n                f'mode={self.mode})')\n"
  },
  {
    "path": "torch_geometric/nn/models/label_prop.py",
    "content": "from typing import Callable, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.conv.gcn_conv import gcn_norm\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor\nfrom torch_geometric.utils import one_hot, spmm\n\n\nclass LabelPropagation(MessagePassing):\n    r\"\"\"The label propagation operator, firstly introduced in the\n    `\"Learning from Labeled and Unlabeled Data with Label Propagation\"\n    <http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>`_ paper.\n\n    .. math::\n        \\mathbf{Y}^{\\prime} = \\alpha \\cdot \\mathbf{D}^{-1/2} \\mathbf{A}\n        \\mathbf{D}^{-1/2} \\mathbf{Y} + (1 - \\alpha) \\mathbf{Y},\n\n    where unlabeled data is inferred by labeled data via propagation.\n    This concrete implementation here is derived from the `\"Combining Label\n    Propagation And Simple Models Out-performs Graph Neural Networks\"\n    <https://arxiv.org/abs/2010.13993>`_ paper.\n\n    .. note::\n\n        For an example of using the :class:`LabelPropagation`, see\n        `examples/label_prop.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        label_prop.py>`_.\n\n    Args:\n        num_layers (int): The number of propagations.\n        alpha (float): The :math:`\\alpha` coefficient.\n    \"\"\"\n    def __init__(self, num_layers: int, alpha: float):\n        super().__init__(aggr='add')\n        self.num_layers = num_layers\n        self.alpha = alpha\n\n    @torch.no_grad()\n    def forward(\n        self,\n        y: Tensor,\n        edge_index: Adj,\n        mask: OptTensor = None,\n        edge_weight: OptTensor = None,\n        post_step: Optional[Callable[[Tensor], Tensor]] = None,\n    ) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            y (torch.Tensor): The ground-truth label information\n                :math:`\\mathbf{Y}`.\n            edge_index (torch.Tensor or SparseTensor): The edge connectivity.\n            mask (torch.Tensor, optional): A mask or index tensor denoting\n                which nodes are used for label propagation.\n                (default: :obj:`None`)\n            edge_weight (torch.Tensor, optional): The edge weights.\n                (default: :obj:`None`)\n            post_step (callable, optional): A post step function specified\n                to apply after label propagation. If no post step function\n                is specified, the output will be clamped between 0 and 1.\n                (default: :obj:`None`)\n        \"\"\"\n        if y.dtype == torch.long and y.size(0) == y.numel():\n            y = one_hot(y.view(-1))\n\n        out = y\n        if mask is not None:\n            out = torch.zeros_like(y)\n            out[mask] = y[mask]\n\n        if isinstance(edge_index, SparseTensor) and not edge_index.has_value():\n            edge_index = gcn_norm(edge_index, add_self_loops=False)\n        elif isinstance(edge_index, Tensor) and edge_weight is None:\n            edge_index, edge_weight = gcn_norm(edge_index, num_nodes=y.size(0),\n                                               add_self_loops=False)\n\n        res = (1 - self.alpha) * out\n        for _ in range(self.num_layers):\n            # propagate_type: (x: Tensor, edge_weight: OptTensor)\n            out = self.propagate(edge_index, x=out, edge_weight=edge_weight)\n            out.mul_(self.alpha).add_(res)\n            if post_step is not None:\n                out = post_step(out)\n            else:\n                out.clamp_(0., 1.)\n\n        return out\n\n    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j\n\n    def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:\n        return spmm(adj_t, x, reduce=self.aggr)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(num_layers={self.num_layers}, '\n                f'alpha={self.alpha})')\n"
  },
  {
    "path": "torch_geometric/nn/models/lightgcn.py",
    "content": "from typing import Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Embedding, ModuleList\nfrom torch.nn.modules.loss import _Loss\n\nfrom torch_geometric.nn.conv import LGConv\nfrom torch_geometric.typing import Adj, OptTensor\nfrom torch_geometric.utils import is_sparse, to_edge_index\n\n\nclass LightGCN(torch.nn.Module):\n    r\"\"\"The LightGCN model from the `\"LightGCN: Simplifying and Powering\n    Graph Convolution Network for Recommendation\"\n    <https://arxiv.org/abs/2002.02126>`_ paper.\n\n    :class:`~torch_geometric.nn.models.LightGCN` learns embeddings by linearly\n    propagating them on the underlying graph, and uses the weighted sum of the\n    embeddings learned at all layers as the final embedding\n\n    .. math::\n        \\textbf{x}_i = \\sum_{l=0}^{L} \\alpha_l \\textbf{x}^{(l)}_i,\n\n    where each layer's embedding is computed as\n\n    .. math::\n        \\mathbf{x}^{(l+1)}_i = \\sum_{j \\in \\mathcal{N}(i)}\n        \\frac{1}{\\sqrt{\\deg(i)\\deg(j)}}\\mathbf{x}^{(l)}_j.\n\n    Two prediction heads and training objectives are provided:\n    **link prediction** (via\n    :meth:`~torch_geometric.nn.models.LightGCN.link_pred_loss` and\n    :meth:`~torch_geometric.nn.models.LightGCN.predict_link`) and\n    **recommendation** (via\n    :meth:`~torch_geometric.nn.models.LightGCN.recommendation_loss` and\n    :meth:`~torch_geometric.nn.models.LightGCN.recommend`).\n\n    .. note::\n\n        Embeddings are propagated according to the graph connectivity specified\n        by :obj:`edge_index` while rankings or link probabilities are computed\n        according to the edges specified by :obj:`edge_label_index`.\n\n    .. note::\n\n        For an example of using :class:`LightGCN`, see `examples/lightgcn.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        lightgcn.py>`_.\n\n    Args:\n        num_nodes (int): The number of nodes in the graph.\n        embedding_dim (int): The dimensionality of node embeddings.\n        num_layers (int): The number of\n            :class:`~torch_geometric.nn.conv.LGConv` layers.\n        alpha (float or torch.Tensor, optional): The scalar or vector\n            specifying the re-weighting coefficients for aggregating the final\n            embedding. If set to :obj:`None`, the uniform initialization of\n            :obj:`1 / (num_layers + 1)` is used. (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of the underlying\n            :class:`~torch_geometric.nn.conv.LGConv` layers.\n    \"\"\"\n    def __init__(\n        self,\n        num_nodes: int,\n        embedding_dim: int,\n        num_layers: int,\n        alpha: Optional[Union[float, Tensor]] = None,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.num_nodes = num_nodes\n        self.embedding_dim = embedding_dim\n        self.num_layers = num_layers\n\n        if alpha is None:\n            alpha = 1. / (num_layers + 1)\n\n        if isinstance(alpha, Tensor):\n            assert alpha.size(0) == num_layers + 1\n        else:\n            alpha = torch.tensor([alpha] * (num_layers + 1))\n        self.register_buffer('alpha', alpha)\n\n        self.embedding = Embedding(num_nodes, embedding_dim)\n        self.convs = ModuleList([LGConv(**kwargs) for _ in range(num_layers)])\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        torch.nn.init.xavier_uniform_(self.embedding.weight)\n        for conv in self.convs:\n            conv.reset_parameters()\n\n    def get_embedding(\n        self,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n    ) -> Tensor:\n        r\"\"\"Returns the embedding of nodes in the graph.\"\"\"\n        x = self.embedding.weight\n        out = x * self.alpha[0]\n\n        for i in range(self.num_layers):\n            x = self.convs[i](x, edge_index, edge_weight)\n            out = out + x * self.alpha[i + 1]\n\n        return out\n\n    def forward(\n        self,\n        edge_index: Adj,\n        edge_label_index: OptTensor = None,\n        edge_weight: OptTensor = None,\n    ) -> Tensor:\n        r\"\"\"Computes rankings for pairs of nodes.\n\n        Args:\n            edge_index (torch.Tensor or SparseTensor): Edge tensor specifying\n                the connectivity of the graph.\n            edge_label_index (torch.Tensor, optional): Edge tensor specifying\n                the node pairs for which to compute rankings or probabilities.\n                If :obj:`edge_label_index` is set to :obj:`None`, all edges in\n                :obj:`edge_index` will be used instead. (default: :obj:`None`)\n            edge_weight (torch.Tensor, optional): The weight of each edge in\n                :obj:`edge_index`. (default: :obj:`None`)\n        \"\"\"\n        if edge_label_index is None:\n            if is_sparse(edge_index):\n                edge_label_index, _ = to_edge_index(edge_index)\n            else:\n                edge_label_index = edge_index\n\n        out = self.get_embedding(edge_index, edge_weight)\n\n        out_src = out[edge_label_index[0]]\n        out_dst = out[edge_label_index[1]]\n\n        return (out_src * out_dst).sum(dim=-1)\n\n    def predict_link(\n        self,\n        edge_index: Adj,\n        edge_label_index: OptTensor = None,\n        edge_weight: OptTensor = None,\n        prob: bool = False,\n    ) -> Tensor:\n        r\"\"\"Predict links between nodes specified in :obj:`edge_label_index`.\n\n        Args:\n            edge_index (torch.Tensor or SparseTensor): Edge tensor specifying\n                the connectivity of the graph.\n            edge_label_index (torch.Tensor, optional): Edge tensor specifying\n                the node pairs for which to compute probabilities.\n                If :obj:`edge_label_index` is set to :obj:`None`, all edges in\n                :obj:`edge_index` will be used instead. (default: :obj:`None`)\n            edge_weight (torch.Tensor, optional): The weight of each edge in\n                :obj:`edge_index`. (default: :obj:`None`)\n            prob (bool, optional): Whether probabilities should be returned.\n                (default: :obj:`False`)\n        \"\"\"\n        pred = self(edge_index, edge_label_index, edge_weight).sigmoid()\n        return pred if prob else pred.round()\n\n    def recommend(\n        self,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n        src_index: OptTensor = None,\n        dst_index: OptTensor = None,\n        k: int = 1,\n        sorted: bool = True,\n    ) -> Tensor:\n        r\"\"\"Get top-:math:`k` recommendations for nodes in :obj:`src_index`.\n\n        Args:\n            edge_index (torch.Tensor or SparseTensor): Edge tensor specifying\n                the connectivity of the graph.\n            edge_weight (torch.Tensor, optional): The weight of each edge in\n                :obj:`edge_index`. (default: :obj:`None`)\n            src_index (torch.Tensor, optional): Node indices for which\n                recommendations should be generated.\n                If set to :obj:`None`, all nodes will be used.\n                (default: :obj:`None`)\n            dst_index (torch.Tensor, optional): Node indices which represent\n                the possible recommendation choices.\n                If set to :obj:`None`, all nodes will be used.\n                (default: :obj:`None`)\n            k (int, optional): Number of recommendations. (default: :obj:`1`)\n            sorted (bool, optional): Whether to sort the recommendations\n                by score. (default: :obj:`True`)\n        \"\"\"\n        out_src = out_dst = self.get_embedding(edge_index, edge_weight)\n\n        if src_index is not None:\n            out_src = out_src[src_index]\n\n        if dst_index is not None:\n            out_dst = out_dst[dst_index]\n\n        pred = out_src @ out_dst.t()\n        top_index = pred.topk(k, dim=-1, sorted=sorted).indices\n\n        if dst_index is not None:  # Map local top-indices to original indices.\n            top_index = dst_index[top_index.view(-1)].view(*top_index.size())\n\n        return top_index\n\n    def link_pred_loss(self, pred: Tensor, edge_label: Tensor,\n                       **kwargs) -> Tensor:\n        r\"\"\"Computes the model loss for a link prediction objective via the\n        :class:`torch.nn.BCEWithLogitsLoss`.\n\n        Args:\n            pred (torch.Tensor): The predictions.\n            edge_label (torch.Tensor): The ground-truth edge labels.\n            **kwargs (optional): Additional arguments of the underlying\n                :class:`torch.nn.BCEWithLogitsLoss` loss function.\n        \"\"\"\n        loss_fn = torch.nn.BCEWithLogitsLoss(**kwargs)\n        return loss_fn(pred, edge_label.to(pred.dtype))\n\n    def recommendation_loss(\n        self,\n        pos_edge_rank: Tensor,\n        neg_edge_rank: Tensor,\n        node_id: Optional[Tensor] = None,\n        lambda_reg: float = 1e-4,\n        **kwargs,\n    ) -> Tensor:\n        r\"\"\"Computes the model loss for a ranking objective via the Bayesian\n        Personalized Ranking (BPR) loss.\n\n        .. note::\n\n            The i-th entry in the :obj:`pos_edge_rank` vector and i-th entry\n            in the :obj:`neg_edge_rank` entry must correspond to ranks of\n            positive and negative edges of the same entity (*e.g.*, user).\n\n        Args:\n            pos_edge_rank (torch.Tensor): Positive edge rankings.\n            neg_edge_rank (torch.Tensor): Negative edge rankings.\n            node_id (torch.Tensor): The indices of the nodes involved for\n                deriving a prediction for both positive and negative edges.\n                If set to :obj:`None`, all nodes will be used.\n            lambda_reg (int, optional): The :math:`L_2` regularization strength\n                of the Bayesian Personalized Ranking (BPR) loss.\n                (default: :obj:`1e-4`)\n            **kwargs (optional): Additional arguments of the underlying\n                :class:`torch_geometric.nn.models.lightgcn.BPRLoss` loss\n                function.\n        \"\"\"\n        loss_fn = BPRLoss(lambda_reg, **kwargs)\n        emb = self.embedding.weight\n        emb = emb if node_id is None else emb[node_id]\n        return loss_fn(pos_edge_rank, neg_edge_rank, emb)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.num_nodes}, '\n                f'{self.embedding_dim}, num_layers={self.num_layers})')\n\n\nclass BPRLoss(_Loss):\n    r\"\"\"The Bayesian Personalized Ranking (BPR) loss.\n\n    The BPR loss is a pairwise loss that encourages the prediction of an\n    observed entry to be higher than its unobserved counterparts\n    (see `here <https://arxiv.org/abs/2002.02126>`__).\n\n    .. math::\n        L_{\\text{BPR}} = - \\sum_{u=1}^{M} \\sum_{i \\in \\mathcal{N}_u}\n        \\sum_{j \\not\\in \\mathcal{N}_u} \\ln \\sigma(\\hat{y}_{ui} - \\hat{y}_{uj})\n        + \\lambda \\vert\\vert \\textbf{x}^{(0)} \\vert\\vert^2\n\n    where :math:`\\lambda` controls the :math:`L_2` regularization strength.\n    We compute the mean BPR loss for simplicity.\n\n    Args:\n        lambda_reg (float, optional): The :math:`L_2` regularization strength\n            (default: 0).\n        **kwargs (optional): Additional arguments of the underlying\n            :class:`torch.nn.modules.loss._Loss` class.\n    \"\"\"\n    __constants__ = ['lambda_reg']\n    lambda_reg: float\n\n    def __init__(self, lambda_reg: float = 0, **kwargs):\n        super().__init__(None, None, \"sum\", **kwargs)\n        self.lambda_reg = lambda_reg\n\n    def forward(self, positives: Tensor, negatives: Tensor,\n                parameters: Tensor = None) -> Tensor:\n        r\"\"\"Compute the mean Bayesian Personalized Ranking (BPR) loss.\n\n        .. note::\n\n            The i-th entry in the :obj:`positives` vector and i-th entry\n            in the :obj:`negatives` entry should correspond to the same\n            entity (*.e.g*, user), as the BPR is a personalized ranking loss.\n\n        Args:\n            positives (Tensor): The vector of positive-pair rankings.\n            negatives (Tensor): The vector of negative-pair rankings.\n            parameters (Tensor, optional): The tensor of parameters which\n                should be used for :math:`L_2` regularization\n                (default: :obj:`None`).\n        \"\"\"\n        log_prob = F.logsigmoid(positives - negatives).mean()\n\n        regularization = 0\n        if self.lambda_reg != 0:\n            regularization = self.lambda_reg * parameters.norm(p=2).pow(2)\n            regularization = regularization / positives.size(0)\n\n        return -log_prob + regularization\n"
  },
  {
    "path": "torch_geometric/nn/models/linkx.py",
    "content": "import math\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import BatchNorm1d, Parameter\n\nfrom torch_geometric.nn import inits\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.models import MLP\nfrom torch_geometric.typing import Adj, OptTensor\nfrom torch_geometric.utils import spmm\n\n\nclass SparseLinear(MessagePassing):\n    def __init__(self, in_channels: int, out_channels: int, bias: bool = True):\n        super().__init__(aggr='add')\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n\n        self.weight = Parameter(torch.empty(in_channels, out_channels))\n        if bias:\n            self.bias = Parameter(torch.empty(out_channels))\n        else:\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        inits.kaiming_uniform(self.weight, fan=self.in_channels,\n                              a=math.sqrt(5))\n        inits.uniform(self.in_channels, self.bias)\n\n    def forward(\n        self,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n    ) -> Tensor:\n        # propagate_type: (weight: Tensor, edge_weight: OptTensor)\n        out = self.propagate(edge_index, weight=self.weight,\n                             edge_weight=edge_weight)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        return out\n\n    def message(self, weight_j: Tensor, edge_weight: OptTensor) -> Tensor:\n        if edge_weight is None:\n            return weight_j\n        else:\n            return edge_weight.view(-1, 1) * weight_j\n\n    def message_and_aggregate(self, adj_t: Adj, weight: Tensor) -> Tensor:\n        return spmm(adj_t, weight, reduce=self.aggr)\n\n\nclass LINKX(torch.nn.Module):\n    r\"\"\"The LINKX model from the `\"Large Scale Learning on Non-Homophilous\n    Graphs: New Benchmarks and Strong Simple Methods\"\n    <https://arxiv.org/abs/2110.14446>`_ paper.\n\n    .. math::\n        \\mathbf{H}_{\\mathbf{A}} &= \\textrm{MLP}_{\\mathbf{A}}(\\mathbf{A})\n\n        \\mathbf{H}_{\\mathbf{X}} &= \\textrm{MLP}_{\\mathbf{X}}(\\mathbf{X})\n\n        \\mathbf{Y} &= \\textrm{MLP}_{f} \\left( \\sigma \\left( \\mathbf{W}\n        [\\mathbf{H}_{\\mathbf{A}}, \\mathbf{H}_{\\mathbf{X}}] +\n        \\mathbf{H}_{\\mathbf{A}} + \\mathbf{H}_{\\mathbf{X}} \\right) \\right)\n\n    .. note::\n\n        For an example of using LINKX, see `examples/linkx.py <https://\n        github.com/pyg-team/pytorch_geometric/blob/master/examples/linkx.py>`_.\n\n    Args:\n        num_nodes (int): The number of nodes in the graph.\n        in_channels (int): Size of each input sample, or :obj:`-1` to derive\n            the size from the first input(s) to the forward method.\n        hidden_channels (int): Size of each hidden sample.\n        out_channels (int): Size of each output sample.\n        num_layers (int): Number of layers of :math:`\\textrm{MLP}_{f}`.\n        num_edge_layers (int, optional): Number of layers of\n            :math:`\\textrm{MLP}_{\\mathbf{A}}`. (default: :obj:`1`)\n        num_node_layers (int, optional): Number of layers of\n            :math:`\\textrm{MLP}_{\\mathbf{X}}`. (default: :obj:`1`)\n        dropout (float, optional): Dropout probability of each hidden\n            embedding. (default: :obj:`0.0`)\n    \"\"\"\n    def __init__(\n        self,\n        num_nodes: int,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        num_layers: int,\n        num_edge_layers: int = 1,\n        num_node_layers: int = 1,\n        dropout: float = 0.0,\n    ):\n        super().__init__()\n\n        self.num_nodes = num_nodes\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.num_edge_layers = num_edge_layers\n\n        self.edge_lin = SparseLinear(num_nodes, hidden_channels)\n\n        if self.num_edge_layers > 1:\n            self.edge_norm = BatchNorm1d(hidden_channels)\n            channels = [hidden_channels] * num_edge_layers\n            self.edge_mlp = MLP(channels, dropout=0., act_first=True)\n        else:\n            self.edge_norm = None\n            self.edge_mlp = None\n\n        channels = [in_channels] + [hidden_channels] * num_node_layers\n        self.node_mlp = MLP(channels, dropout=0., act_first=True)\n\n        self.cat_lin1 = torch.nn.Linear(hidden_channels, hidden_channels)\n        self.cat_lin2 = torch.nn.Linear(hidden_channels, hidden_channels)\n\n        channels = [hidden_channels] * num_layers + [out_channels]\n        self.final_mlp = MLP(channels, dropout=dropout, act_first=True)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.edge_lin.reset_parameters()\n        if self.edge_norm is not None:\n            self.edge_norm.reset_parameters()\n        if self.edge_mlp is not None:\n            self.edge_mlp.reset_parameters()\n        self.node_mlp.reset_parameters()\n        self.cat_lin1.reset_parameters()\n        self.cat_lin2.reset_parameters()\n        self.final_mlp.reset_parameters()\n\n    def forward(\n        self,\n        x: OptTensor,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n    ) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        out = self.edge_lin(edge_index, edge_weight)\n\n        if self.edge_norm is not None and self.edge_mlp is not None:\n            out = out.relu_()\n            out = self.edge_norm(out)\n            out = self.edge_mlp(out)\n\n        out = out + self.cat_lin1(out)\n\n        if x is not None:\n            x = self.node_mlp(x)\n            out = out + x\n            out = out + self.cat_lin2(x)\n\n        return self.final_mlp(out.relu_())\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(num_nodes={self.num_nodes}, '\n                f'in_channels={self.in_channels}, '\n                f'out_channels={self.out_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/models/lpformer.py",
    "content": "import math\nfrom typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom ...nn.conv import MessagePassing\nfrom ...nn.dense.linear import Linear\nfrom ...nn.inits import glorot, zeros\nfrom ...typing import Adj, OptTensor, Tuple\nfrom ...utils import get_ppr, is_sparse, scatter, softmax\nfrom .basic_gnn import GCN\n\n\nclass LPFormer(nn.Module):\n    r\"\"\"The LPFormer model from the\n    `\"LPFormer: An Adaptive Graph Transformer for Link Prediction\"\n    <https://arxiv.org/abs/2310.11009>`_ paper.\n\n    .. note::\n\n        For an example of using LPFormer, see\n        `examples/lpformer.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        lpformer.py>`_.\n\n    Args:\n        in_channels (int): Size of input dimension\n        hidden_channels (int): Size of hidden dimension\n        num_gnn_layers (int, optional): Number of GNN layers\n            (default: :obj:`2`)\n        gnn_dropout(float, optional): Dropout used for GNN\n            (default: :obj:`0.1`)\n        num_transformer_layers (int, optional): Number of Transformer layers\n            (default: :obj:`1`)\n        num_heads (int, optional): Number of heads to use in MHA\n            (default: :obj:`1`)\n        transformer_dropout (float, optional): Dropout used for Transformer\n            (default: :obj:`0.1`)\n        ppr_thresholds (list): PPR thresholds for different types of nodes.\n            Types include (in order) common neighbors, 1-Hop nodes\n            (that aren't CNs), and all other nodes.\n            (default: :obj:`[0, 1e-4, 1e-2]`)\n        gcn_cache (bool, optional): Whether to cache edge indices\n            during message passing. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        num_gnn_layers: int = 2,\n        gnn_dropout: float = 0.1,\n        num_transformer_layers: int = 1,\n        num_heads: int = 1,\n        transformer_dropout: float = 0.1,\n        ppr_thresholds: list = None,\n        gcn_cache=False,\n    ):\n        super().__init__()\n\n        # Default thresholds\n        if ppr_thresholds is None:\n            ppr_thresholds = [0, 1e-4, 1e-2]\n\n        if len(ppr_thresholds) == 3:\n            self.thresh_cn = ppr_thresholds[0]\n            self.thresh_1hop = ppr_thresholds[1]\n            self.thresh_non1hop = ppr_thresholds[2]\n        else:\n            raise ValueError(\n                \"Argument 'ppr_thresholds' must only be length 3!\")\n\n        self.in_dim = in_channels\n        self.hid_dim = hidden_channels\n        self.gnn_drop = gnn_dropout\n        self.trans_drop = transformer_dropout\n\n        self.gnn = GCN(in_channels, hidden_channels, num_gnn_layers,\n                       dropout=gnn_dropout, norm=\"layer_norm\",\n                       cached=gcn_cache)\n        self.gnn_norm = nn.LayerNorm(hidden_channels)\n\n        # Create Transformer Layers\n        self.att_layers = nn.ModuleList()\n        for il in range(num_transformer_layers):\n            if il == 0:\n                node_dim = None\n                self.out_dim = self.hid_dim * 2 if num_transformer_layers > 1 \\\n                    else self.hid_dim\n            elif il == self.num_layers - 1:\n                node_dim = self.hid_dim\n            else:\n                self.out_dim = node_dim = self.hid_dim\n\n            self.att_layers.append(\n                LPAttLayer(self.hid_dim, self.out_dim, node_dim, num_heads,\n                           self.trans_drop))\n\n        self.elementwise_lin = MLP(self.hid_dim, self.hid_dim, self.hid_dim)\n\n        # Relative Positional Encodings\n        self.ppr_encoder_cn = MLP(2, self.hid_dim, self.hid_dim)\n        self.ppr_encoder_onehop = MLP(2, self.hid_dim, self.hid_dim)\n        self.ppr_encoder_non1hop = MLP(2, self.hid_dim, self.hid_dim)\n\n        # thresh=1 implies ignoring some set of nodes\n        # Also allows us to be more efficient later\n        if self.thresh_non1hop == 1 and self.thresh_1hop == 1:\n            self.mask = \"cn\"\n        elif self.thresh_non1hop == 1 and self.thresh_1hop < 1:\n            self.mask = \"1-hop\"\n        else:\n            self.mask = \"all\"\n\n        # 4 is for counts of diff nodes\n        pairwise_dim = self.hid_dim * num_heads + 4\n        self.pairwise_lin = MLP(pairwise_dim, pairwise_dim, self.hid_dim)\n\n        self.score_func = MLP(self.hid_dim * 2, self.hid_dim * 2, 1, norm=None)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_dim}, '\n                f'{self.hid_dim}, num_gnn_layers={self.gnn.num_layers}, '\n                f'num_transformer_layers={len(self.att_layers)})')\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.gnn.reset_parameters()\n        self.gnn_norm.reset_parameters()\n        self.elementwise_lin.reset_parameters()\n        self.pairwise_lin.reset_parameters()\n        self.ppr_encoder_cn.reset_parameters()\n        self.ppr_encoder_onehop.reset_parameters()\n        self.ppr_encoder_non1hop.reset_parameters()\n        self.score_func.reset_parameters()\n        for i in range(len(self.att_layers)):\n            self.att_layers[i].reset_parameters()\n\n    def forward(\n        self,\n        batch: Tensor,\n        x: Tensor,\n        edge_index: Adj,\n        ppr_matrix: Tensor,\n    ) -> Tensor:\n        r\"\"\"Forward Pass of LPFormer.\n\n        Returns raw logits for each link\n\n        Args:\n            batch (Tensor): The batch vector.\n                Denotes which node pairs to predict.\n            x (Tensor): Input node features\n            edge_index (torch.Tensor, SparseTensor): The edge indices.\n                Either in COO or SparseTensor format\n            ppr_matrix (Tensor): PPR matrix\n        \"\"\"\n        batch = batch.to(x.device)\n\n        X_node = self.propagate(x, edge_index)\n        x_i, x_j = X_node[batch[0]], X_node[batch[1]]\n        elementwise_edge_feats = self.elementwise_lin(x_i * x_j)\n\n        # Ensure in sparse format\n        # Need as native torch.sparse for later computations\n        # (necessary operations are not supported by PyG SparseTensor)\n        if not edge_index.is_sparse:\n            num_nodes = ppr_matrix.size(1)\n            vals = torch.ones(len(edge_index[0]), device=edge_index.device)\n            edge_index = torch.sparse_coo_tensor(edge_index, vals,\n                                                 [num_nodes, num_nodes])\n        # Checks if SparseTensor, if so the convert\n        if is_sparse(edge_index) and not edge_index.is_sparse:\n            edge_index = edge_index.to_torch_sparse_coo_tensor()\n\n        # Ensure {0, 1}\n        edge_index = edge_index.coalesce().bool().int()\n\n        pairwise_feats = self.calc_pairwise(batch, X_node, edge_index,\n                                            ppr_matrix)\n        combined_feats = torch.cat((elementwise_edge_feats, pairwise_feats),\n                                   dim=-1)\n\n        logits = self.score_func(combined_feats)\n        return logits\n\n    def propagate(self, x: Tensor, adj: Adj) -> Tensor:\n        \"\"\"Propagate via GNN.\n\n        Args:\n            x (Tensor): Node features\n            adj (torch.Tensor, SparseTensor): Adjacency matrix\n        \"\"\"\n        x = F.dropout(x, p=self.gnn_drop, training=self.training)\n        X_node = self.gnn(x, adj)\n        X_node = self.gnn_norm(X_node)\n\n        return X_node\n\n    def calc_pairwise(self, batch: Tensor, X_node: Tensor, adj_mask: Tensor,\n                      ppr_matrix: Tensor) -> Tensor:\n        r\"\"\"Calculate the pairwise features for the node pairs.\n\n        Args:\n            batch (Tensor): The batch vector.\n                Denotes which node pairs to predict.\n            X_node (Tensor): Node representations\n            adj_mask (Tensor): Mask of adjacency matrix used for computing the\n                different node types.\n            ppr_matrix (Tensor): PPR matrix\n        \"\"\"\n        k_i, k_j = X_node[batch[0]], X_node[batch[1]]\n        pairwise_feats = torch.cat((k_i, k_j), dim=-1)\n\n        cn_info, onehop_info, non1hop_info = self.compute_node_mask(\n            batch, adj_mask, ppr_matrix)\n\n        all_mask = cn_info[0]\n        if onehop_info is not None:\n            all_mask = torch.cat((all_mask, onehop_info[0]), dim=-1)\n        if non1hop_info is not None:\n            all_mask = torch.cat((all_mask, non1hop_info[0]), dim=-1)\n\n        pes = self.get_pos_encodings(cn_info[1:], onehop_info[1:],\n                                     non1hop_info[1:])\n\n        for lay in range(len(self.att_layers)):\n            pairwise_feats = self.att_layers[lay](all_mask, pairwise_feats,\n                                                  X_node, pes)\n\n        num_cns, num_1hop, num_non1hop, num_neigh = self.get_structure_cnts(\n            batch, cn_info, onehop_info, non1hop_info)\n\n        pairwise_feats = torch.cat(\n            (pairwise_feats, num_cns, num_1hop, num_non1hop, num_neigh),\n            dim=-1)\n\n        pairwise_feats = self.pairwise_lin(pairwise_feats)\n        return pairwise_feats\n\n    def get_pos_encodings(\n            self, cn_ppr: Tuple[Tensor, Tensor],\n            onehop_ppr: Optional[Tuple[Tensor, Tensor]] = None,\n            non1hop_ppr: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor:\n        r\"\"\"Calculate the PPR-based relative positional encodings.\n\n        Due to thresholds, sometimes we don't have 1-hop or >1-hop nodes.\n        In those cases, the value of onehop_ppr and/or non1hop_ppr should\n        be `None`.\n\n        Args:\n            cn_ppr (tuple, optional): PPR scores of CNs.\n            onehop_ppr (tuple, optional): PPR scores of 1-Hop.\n                (default: :obj:`None`)\n            non1hop_ppr (tuple, optional): PPR scores of >1-Hop.\n                (default: :obj:`None`)\n        \"\"\"\n        cn_a = self.ppr_encoder_cn(torch.stack((cn_ppr[0], cn_ppr[1])).t())\n        cn_b = self.ppr_encoder_cn(torch.stack((cn_ppr[1], cn_ppr[0])).t())\n        cn_pe = cn_a + cn_b\n\n        if onehop_ppr is None:\n            return cn_pe\n\n        onehop_a = self.ppr_encoder_onehop(\n            torch.stack((onehop_ppr[0], onehop_ppr[1])).t())\n        onehop_b = self.ppr_encoder_onehop(\n            torch.stack((onehop_ppr[1], onehop_ppr[0])).t())\n        onehop_pe = onehop_a + onehop_b\n\n        if non1hop_ppr is None:\n            return torch.cat((cn_pe, onehop_pe), dim=0)\n\n        non1hop_a = self.ppr_encoder_non1hop(\n            torch.stack((non1hop_ppr[0], non1hop_ppr[1])).t())\n        non1hop_b = self.ppr_encoder_non1hop(\n            torch.stack((non1hop_ppr[1], non1hop_ppr[0])).t())\n        non1hop_pe = non1hop_a + non1hop_b\n\n        return torch.cat((cn_pe, onehop_pe, non1hop_pe), dim=0)\n\n    def compute_node_mask(\n            self, batch: Tensor, adj: Tensor, ppr_matrix: Tensor\n    ) -> Tuple[Tuple, Optional[Tuple], Optional[Tuple]]:\n        r\"\"\"Get mask based on type of node.\n\n        When mask_type is not \"cn\", also return the ppr vals for both\n        the source and target.\n\n        Args:\n            batch (Tensor): The batch vector.\n                Denotes which node pairs to predict.\n            adj (SparseTensor): Adjacency matrix\n            ppr_matrix (Tensor): PPR matrix\n        \"\"\"\n        src_adj = torch.index_select(adj, 0, batch[0])\n        tgt_adj = torch.index_select(adj, 0, batch[1])\n\n        if self.mask == \"cn\":\n            # 1 when CN, 0 otherwise\n            pair_adj = src_adj * tgt_adj\n        else:\n            # Equals: {0: \">1-Hop\", 1: \"1-Hop (Non-CN)\", 2: \"CN\"}\n            pair_adj = src_adj + tgt_adj\n\n        pair_ix, node_type, src_ppr, tgt_ppr = self.get_ppr_vals(\n            batch, pair_adj, ppr_matrix)\n\n        cn_filt_cond = (src_ppr >= self.thresh_cn) & (tgt_ppr\n                                                      >= self.thresh_cn)\n        onehop_filt_cond = (src_ppr >= self.thresh_1hop) & (\n            tgt_ppr >= self.thresh_1hop)\n\n        if self.mask != \"cn\":\n            filt_cond = torch.where(node_type == 1, onehop_filt_cond,\n                                    cn_filt_cond)\n        else:\n            filt_cond = torch.where(node_type == 0, onehop_filt_cond,\n                                    cn_filt_cond)\n\n        pair_ix, node_type = pair_ix[:, filt_cond], node_type[filt_cond]\n        src_ppr, tgt_ppr = src_ppr[filt_cond], tgt_ppr[filt_cond]\n\n        # >1-Hop mask is gotten separately\n        if self.mask == \"all\":\n            non1hop_ix, non1hop_sppr, non1hop_tppr = self.get_non_1hop_ppr(\n                batch, adj, ppr_matrix)\n\n        # Dropout\n        if self.training and self.trans_drop > 0:\n            pair_ix, src_ppr, tgt_ppr, node_type = self.drop_pairwise(\n                pair_ix, src_ppr, tgt_ppr, node_type)\n            if self.mask == \"all\":\n                non1hop_ix, non1hop_sppr, non1hop_tppr, _ = self.drop_pairwise(\n                    non1hop_ix, non1hop_sppr, non1hop_tppr)\n\n        # Separate out CN and 1-Hop\n        if self.mask != \"cn\":\n            cn_ind = node_type == 2\n            cn_ix = pair_ix[:, cn_ind]\n            cn_src_ppr = src_ppr[cn_ind]\n            cn_tgt_ppr = tgt_ppr[cn_ind]\n\n            one_hop_ind = node_type == 1\n            onehop_ix = pair_ix[:, one_hop_ind]\n            onehop_src_ppr = src_ppr[one_hop_ind]\n            onehop_tgt_ppr = tgt_ppr[one_hop_ind]\n\n        if self.mask == \"cn\":\n            return (pair_ix, src_ppr, tgt_ppr), None, None\n        elif self.mask == \"1-hop\":\n            return (cn_ix, cn_src_ppr, cn_tgt_ppr), (onehop_ix, onehop_src_ppr,\n                                                     onehop_tgt_ppr), None\n        else:\n            return (cn_ix, cn_src_ppr,\n                    cn_tgt_ppr), (onehop_ix, onehop_src_ppr,\n                                  onehop_tgt_ppr), (non1hop_ix, non1hop_sppr,\n                                                    non1hop_tppr)\n\n    def get_ppr_vals(\n            self, batch: Tensor, pair_diff_adj: Tensor,\n            ppr_matrix: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n        r\"\"\"Get the src and tgt ppr vals.\n\n        Returns the: link the node belongs to, type of node\n        (e.g., CN), PPR relative to src, PPR relative to tgt.\n\n        Args:\n            batch (Tensor): The batch vector.\n                Denotes which node pairs to predict.\n            pair_diff_adj (SparseTensor): Combination of rows in\n                adjacency for src and tgt nodes (e.g., X1 + X2)\n            ppr_matrix (Tensor): PPR matrix\n        \"\"\"\n        # Additional terms for also choosing scores when ppr=0\n        # Multiplication removes any values for nodes not in batch\n        # Addition then adds offset to ensure we select when ppr=0\n        # All selected scores are +1 higher than their true val\n        src_ppr_adj = torch.index_select(\n            ppr_matrix, 0, batch[0]) * pair_diff_adj + pair_diff_adj\n        tgt_ppr_adj = torch.index_select(\n            ppr_matrix, 0, batch[1]) * pair_diff_adj + pair_diff_adj\n\n        # Can now convert ppr scores to dense\n        ppr_ix = src_ppr_adj.coalesce().indices()\n        src_ppr = src_ppr_adj.coalesce().values()\n        tgt_ppr = tgt_ppr_adj.coalesce().values()\n\n        # TODO: Needed due to a bug in recent torch versions\n        # see here for more - https://github.com/pytorch/pytorch/issues/114529\n        # note that if one is 0 so is the other\n        zero_vals = (src_ppr != 0)\n        src_ppr = src_ppr[zero_vals]\n        tgt_ppr = tgt_ppr[tgt_ppr != 0]\n        ppr_ix = ppr_ix[:, zero_vals]\n\n        pair_diff_adj = pair_diff_adj.coalesce().values()\n        node_type = pair_diff_adj[src_ppr != 0]\n\n        # Remove additional +1 from each ppr val\n        src_ppr = (src_ppr - node_type) / node_type\n        tgt_ppr = (tgt_ppr - node_type) / node_type\n\n        return ppr_ix, node_type, src_ppr, tgt_ppr\n\n    def drop_pairwise(\n        self,\n        pair_ix: Tensor,\n        src_ppr: Optional[Tensor] = None,\n        tgt_ppr: Optional[Tensor] = None,\n        node_indicator: Optional[Tensor] = None,\n    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n        r\"\"\"Perform dropout on pairwise information\n        by randomly dropping a percentage of nodes.\n\n        Done before performing attention for efficiency\n\n        Args:\n            pair_ix (Tensor): Link node belongs to\n            src_ppr (Tensor, optional): PPR relative to src\n                (default: :obj:`None`)\n            tgt_ppr (Tensor, optional): PPR relative to tgt\n                (default: :obj:`None`)\n            node_indicator (Tensor, optional): Type of node (e.g., CN)\n                (default: :obj:`None`)\n        \"\"\"\n        num_indices = math.ceil(pair_ix.size(1) * (1 - self.trans_drop))\n        indices = torch.randperm(pair_ix.size(1))[:num_indices]\n        pair_ix = pair_ix[:, indices]\n\n        if src_ppr is not None:\n            src_ppr = src_ppr[indices]\n        if tgt_ppr is not None:\n            tgt_ppr = tgt_ppr[indices]\n        if node_indicator is not None:\n            node_indicator = node_indicator[indices]\n\n        return pair_ix, src_ppr, tgt_ppr, node_indicator\n\n    def get_structure_cnts(\n        self,\n        batch: Tensor,\n        cn_info: Tuple[Tensor, Tensor],\n        onehop_info: Tuple[Tensor, Tensor],\n        non1hop_info: Optional[Tuple[Tensor, Tensor]],\n    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n        \"\"\"Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold.\n\n        Also include total # of neighbors\n\n        Args:\n            batch (Tensor): The batch vector.\n                Denotes which node pairs to predict.\n            cn_info (tuple): Information of CN nodes\n                Contains (ID of node, src ppr, tgt ppr)\n            onehop_info (tuple): Information of 1-Hop nodes.\n                Contains (ID of node, src ppr, tgt ppr)\n            non1hop_info (tuple): Information of >1-Hop nodes.\n                Contains (ID of node, src ppr, tgt ppr)\n        \"\"\"\n        num_cns = self.get_num_ppr_thresh(batch, cn_info[0], cn_info[1],\n                                          cn_info[2], self.thresh_cn)\n        num_1hop = self.get_num_ppr_thresh(batch, onehop_info[0],\n                                           onehop_info[1], onehop_info[2],\n                                           self.thresh_1hop)\n\n        # TOTAL num of 1-hop neighbors union\n        num_ppr_ones = self.get_num_ppr_thresh(batch, onehop_info[0],\n                                               onehop_info[1], onehop_info[2],\n                                               thresh=0)\n        num_neighbors = num_cns + num_ppr_ones\n\n        # Process for >1-hop is different which is why we use get_count below\n        if non1hop_info is None:\n            return num_cns, num_1hop, 0, num_neighbors\n        else:\n            num_non1hop = self.get_count(non1hop_info[0], batch)\n            return num_cns, num_1hop, num_non1hop, num_neighbors\n\n    def get_num_ppr_thresh(self, batch: Tensor, node_mask: Tensor,\n                           src_ppr: Tensor, tgt_ppr: Tensor,\n                           thresh: float) -> Tensor:\n        \"\"\"Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`.\n\n        Args:\n            batch (Tensor): The batch vector.\n                Denotes which node pairs to predict.\n            node_mask (Tensor): IDs of nodes\n            src_ppr (Tensor): PPR relative to src node\n            tgt_ppr (Tensor): PPR relative to tgt node\n            thresh (float): PPR threshold for nodes (`eta`)\n        \"\"\"\n        weight = torch.ones(node_mask.size(1), device=node_mask.device)\n\n        ppr_above_thresh = (src_ppr >= thresh) & (tgt_ppr >= thresh)\n        num_ppr = scatter(ppr_above_thresh.float() * weight,\n                          node_mask[0].long(), dim=0, dim_size=batch.size(1),\n                          reduce=\"sum\")\n        num_ppr = num_ppr.unsqueeze(-1)\n\n        return num_ppr\n\n    def get_count(\n        self,\n        node_mask: Tensor,\n        batch: Tensor,\n    ) -> Tensor:\n        \"\"\"# of nodes for each sample in batch.\n\n        They node have already filtered by PPR beforehand\n\n        Args:\n            node_mask (Tensor): IDs of nodes\n            batch (Tensor): The batch vector.\n                Denotes which node pairs to predict.\n        \"\"\"\n        weight = torch.ones(node_mask.size(1), device=node_mask.device)\n        num_nodes = scatter(weight, node_mask[0].long(), dim=0,\n                            dim_size=batch.size(1), reduce=\"sum\")\n        num_nodes = num_nodes.unsqueeze(-1)\n\n        return num_nodes\n\n    def get_non_1hop_ppr(self, batch: Tensor, adj: Tensor,\n                         ppr_matrix: Tensor) -> Tensor:\n        r\"\"\"Get PPR scores for non-1hop nodes.\n\n        Args:\n            batch (Tensor): Links in batch\n            adj (Tensor): Adjacency matrix\n            ppr_matrix (Tensor): Sparse PPR matrix\n        \"\"\"\n        # NOTE: Use original adj (one pass in forward() removes links in batch)\n        # Done since removing them converts src/tgt nodes to >1-hop nodes.\n        # Therefore removing CN and 1-hop will also remove the batch links.\n\n        # During training we add back in the links in the batch\n        # (we're removed from adjacency before being passed to model)\n        # Done since otherwise they will be mistakenly seen as >1-Hop nodes\n        # Instead they're 1-Hop, and get ignored accordingly\n        # Ignored during eval since we know the links aren't in the adj\n        adj2 = adj\n        if self.training:\n            n = adj.size(0)\n            batch_flip = torch.cat(\n                (batch, torch.flip(batch, (0, )).to(batch.device)), dim=-1)\n            batch_ones = torch.ones_like(batch_flip[0], device=batch.device)\n            adj_edges = torch.sparse_coo_tensor(batch_flip, batch_ones, [n, n],\n                                                device=batch.device)\n            adj_edges = adj_edges\n            adj2 = (adj + adj_edges).coalesce().bool().int()\n\n        src_adj = torch.index_select(adj2, 0, batch[0])\n        tgt_adj = torch.index_select(adj2, 0, batch[1])\n\n        src_ppr = torch.index_select(ppr_matrix, 0, batch[0])\n        tgt_ppr = torch.index_select(ppr_matrix, 0, batch[1])\n\n        # Remove CN scores\n        src_ppr = src_ppr - src_ppr * (src_adj * tgt_adj)\n        tgt_ppr = tgt_ppr - tgt_ppr * (src_adj * tgt_adj)\n        # Also need to remove CN entries in Adj\n        # Otherwise they leak into next computation\n        src_adj = src_adj - src_adj * (src_adj * tgt_adj)\n        tgt_adj = tgt_adj - tgt_adj * (src_adj * tgt_adj)\n\n        # Remove 1-Hop scores\n        src_ppr = src_ppr - src_ppr * (src_adj + tgt_adj)\n        tgt_ppr = tgt_ppr - tgt_ppr * (src_adj + tgt_adj)\n\n        # Make sure we include both when we convert to dense so indices align\n        # Do so by adding 1 to each based on the other\n        src_ppr_add = src_ppr + torch.sign(tgt_ppr)\n        tgt_ppr_add = tgt_ppr + torch.sign(src_ppr)\n\n        src_ix = src_ppr_add.coalesce().indices()\n        src_vals = src_ppr_add.coalesce().values()\n        tgt_vals = tgt_ppr_add.coalesce().values()\n\n        # Now we can remove value which is just 1\n        # Technically creates -1 scores for ppr scores that were 0\n        # Doesn't matter as they'll be filtered out by condition later\n        src_vals = src_vals - 1\n        tgt_vals = tgt_vals - 1\n\n        ppr_condition = (src_vals >= self.thresh_non1hop) & (\n            tgt_vals >= self.thresh_non1hop)\n        src_ix, src_vals, tgt_vals = src_ix[:, ppr_condition], src_vals[\n            ppr_condition], tgt_vals[ppr_condition]\n\n        return src_ix, src_vals, tgt_vals\n\n    def calc_sparse_ppr(self, edge_index: Tensor, num_nodes: int,\n                        alpha: float = 0.15, eps: float = 5e-5) -> Tensor:\n        r\"\"\"Calculate the PPR of the graph in sparse format.\n\n        Args:\n            edge_index: The edge indices\n            num_nodes: Number of nodes\n            alpha (float, optional): The alpha value of the PageRank algorithm.\n                (default: :obj:`0.15`)\n            eps (float, optional): Threshold for stopping the PPR calculation\n                (default: :obj:`5e-5`)\n        \"\"\"\n        ei, ei_w = get_ppr(edge_index.cpu(), alpha=alpha, eps=eps,\n                           num_nodes=num_nodes)\n        ppr_matrix = torch.sparse_coo_tensor(ei, ei_w, [num_nodes, num_nodes])\n\n        return ppr_matrix\n\n\nclass LPAttLayer(MessagePassing):\n    r\"\"\"Attention Layer for pairwise interaction module.\n\n    Args:\n        in_channels (int): Size of input dimension\n        out_channels (int): Size of output dimension\n        node_dim (int): Dimension of nodes being aggregated\n        num_heads (int): Number of heads to use in MHA\n        dropout (float): Dropout on attention values\n        concat (bool, optional): Whether to concat attention\n            heads. Otherwise sum (default: :obj:`True`)\n    \"\"\"\n    _alpha: OptTensor\n\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        node_dim: int,\n        num_heads: int,\n        dropout: float,\n        concat: bool = True,\n        **kwargs,\n    ):\n        super().__init__(node_dim=0, flow=\"target_to_source\", **kwargs)\n\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = num_heads\n        self.concat = concat\n        self.dropout = dropout\n        self.negative_slope = 0.2  # LeakyRelu\n\n        out_dim = 2\n        if node_dim is None:\n            node_dim = in_channels * out_dim\n        else:\n            node_dim = node_dim * out_dim\n\n        self.lin_l = Linear(in_channels, self.heads * out_channels,\n                            weight_initializer='glorot')\n        self.lin_r = Linear(node_dim, self.heads * out_channels,\n                            weight_initializer='glorot')\n\n        att_out = out_channels\n        self.att = Parameter(Tensor(1, self.heads, att_out))\n\n        if concat:\n            self.bias = Parameter(Tensor(self.heads * out_channels))\n        else:\n            self.bias = Parameter(Tensor(out_channels))\n\n        self._alpha = None\n\n        self.dropout = dropout\n        self.post_att_norm = nn.LayerNorm(out_channels)\n\n        self.reset_parameters()\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads})')\n\n    def reset_parameters(self):\n        self.lin_l.reset_parameters()\n        self.lin_r.reset_parameters()\n        self.post_att_norm.reset_parameters()\n        glorot(self.att)\n        zeros(self.bias)\n\n    def forward(\n        self,\n        edge_index: Tensor,\n        edge_feats: Tensor,\n        node_feats: Tensor,\n        ppr_rpes: Tensor,\n    ) -> Tensor:\n        \"\"\"Runs the forward pass of the module.\n\n        Args:\n            edge_index (Tensor): The edge indices.\n            edge_feats (Tensor): Concatenated representations\n                of src and target nodes for each link\n            node_feats (Tensor): Representations for individual\n                nodes\n            ppr_rpes (Tensor): Relative PEs for each node\n        \"\"\"\n        out = self.propagate(edge_index, x=(edge_feats, node_feats),\n                             ppr_rpes=ppr_rpes, size=None)\n\n        alpha = self._alpha\n        assert alpha is not None\n        self._alpha = None\n\n        if self.concat:\n            out = out.view(-1, self.heads * self.out_channels)\n        else:\n            out = out.mean(dim=1)\n\n        if self.bias is not None:\n            out = out + self.bias\n\n        out = self.post_att_norm(out)\n        out = F.dropout(out, p=self.dropout, training=self.training)\n\n        return out\n\n    def message(self, x_i: Tensor, x_j: Tensor, ppr_rpes: Tensor,\n                index: Tensor, ptr: Tensor, size_i: Optional[int]) -> Tensor:\n        H, C = self.heads, self.out_channels\n\n        x_j = torch.cat((x_j, ppr_rpes), dim=-1)\n        x_j = self.lin_r(x_j).view(-1, H, C)\n\n        # e=(a, b) attending to v\n        e1, e2 = x_i.chunk(2, dim=-1)\n        e1 = self.lin_l(e1).view(-1, H, C)\n        e2 = self.lin_l(e2).view(-1, H, C)\n        x = x_j * (e1 + e2)\n\n        x = F.leaky_relu(x, self.negative_slope)\n        alpha = (x * self.att).sum(dim=-1)\n\n        alpha = softmax(alpha, index, ptr, size_i)\n        self._alpha = alpha\n\n        return x_j * alpha.unsqueeze(-1)\n\n\nclass MLP(nn.Module):\n    r\"\"\"L Layer MLP.\"\"\"\n    def __init__(self, in_channels: int, hid_channels: int, out_channels: int,\n                 num_layers: int = 2, drop: int = 0, norm: str = \"layer\"):\n        super().__init__()\n        self.dropout = drop\n\n        if norm == \"batch\":\n            self.norm = nn.BatchNorm1d(hid_channels)\n        elif norm == \"layer\":\n            self.norm = nn.LayerNorm(hid_channels)\n        else:\n            self.norm = None\n\n        self.linears = torch.nn.ModuleList()\n\n        if num_layers == 1:\n            self.linears.append(nn.Linear(in_channels, out_channels))\n        else:\n            self.linears.append(nn.Linear(in_channels, hid_channels))\n            for _ in range(num_layers - 2):\n                self.linears.append(nn.Linear(hid_channels, hid_channels))\n            self.linears.append(nn.Linear(hid_channels, out_channels))\n\n    def reset_parameters(self):\n        for lin in self.linears:\n            lin.reset_parameters()\n        if self.norm is not None:\n            self.norm.reset_parameters()\n\n    def forward(self, x: Tensor) -> Tensor:\n        for lin in self.linears[:-1]:\n            x = lin(x)\n            x = self.norm(x) if self.norm is not None else x\n            x = F.relu(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n\n        x = self.linears[-1](x)\n\n        return x.squeeze(-1)\n"
  },
  {
    "path": "torch_geometric/nn/models/mask_label.py",
    "content": "import torch\nfrom torch import Tensor\n\n\nclass MaskLabel(torch.nn.Module):\n    r\"\"\"The label embedding and masking layer from the `\"Masked Label\n    Prediction: Unified Message Passing Model for Semi-Supervised\n    Classification\" <https://arxiv.org/abs/2009.03509>`_ paper.\n\n    Here, node labels :obj:`y` are merged to the initial node features :obj:`x`\n    for a subset of their nodes according to :obj:`mask`.\n\n    .. note::\n\n        For an example of using :class:`MaskLabel`, see\n        `examples/unimp_arxiv.py <https://github.com/pyg-team/\n        pytorch_geometric/blob/master/examples/unimp_arxiv.py>`_.\n\n\n    Args:\n        num_classes (int): The number of classes.\n        out_channels (int): Size of each output sample.\n        method (str, optional): If set to :obj:`\"add\"`, label embeddings are\n            added to the input. If set to :obj:`\"concat\"`, label embeddings are\n            concatenated. In case :obj:`method=\"add\"`, then :obj:`out_channels`\n            needs to be identical to the input dimensionality of node features.\n            (default: :obj:`\"add\"`)\n    \"\"\"\n    def __init__(self, num_classes: int, out_channels: int,\n                 method: str = \"add\"):\n        super().__init__()\n\n        self.method = method\n        if method not in [\"add\", \"concat\"]:\n            raise ValueError(\n                f\"'method' must be either 'add' or 'concat' (got '{method}')\")\n\n        self.emb = torch.nn.Embedding(num_classes, out_channels)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.emb.reset_parameters()\n\n    def forward(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        if self.method == \"concat\":\n            out = x.new_zeros(y.size(0), self.emb.weight.size(-1))\n            out[mask] = self.emb(y[mask])\n            return torch.cat([x, out], dim=-1)\n        else:\n            x = torch.clone(x)\n            x[mask] += self.emb(y[mask])\n            return x\n\n    @staticmethod\n    def ratio_mask(mask: Tensor, ratio: float):\n        r\"\"\"Modifies :obj:`mask` by setting :obj:`ratio` of :obj:`True`\n        entries to :obj:`False`. Does not operate in-place.\n\n        Args:\n            mask (torch.Tensor): The mask to re-mask.\n            ratio (float): The ratio of entries to keep.\n        \"\"\"\n        n = int(mask.sum())\n        out = mask.clone()\n        out[mask] = torch.rand(n, device=mask.device) < ratio\n        return out\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/nn/models/meta.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\n\nclass MetaLayer(torch.nn.Module):\n    r\"\"\"A meta layer for building any kind of graph network, inspired by the\n    `\"Relational Inductive Biases, Deep Learning, and Graph Networks\"\n    <https://arxiv.org/abs/1806.01261>`_ paper.\n\n    A graph network takes a graph as input and returns an updated graph as\n    output (with same connectivity).\n    The input graph has node features :obj:`x`, edge features :obj:`edge_attr`\n    as well as graph-level features :obj:`u`.\n    The output graph has the same structure, but updated features.\n\n    Edge features, node features as well as global features are updated by\n    calling the modules :obj:`edge_model`, :obj:`node_model` and\n    :obj:`global_model`, respectively.\n\n    To allow for batch-wise graph processing, all callable functions take an\n    additional argument :obj:`batch`, which determines the assignment of\n    edges or nodes to their specific graphs.\n\n    Args:\n        edge_model (torch.nn.Module, optional): A callable which updates a\n            graph's edge features based on its source and target node features,\n            its current edge features and its global features.\n            (default: :obj:`None`)\n        node_model (torch.nn.Module, optional): A callable which updates a\n            graph's node features based on its current node features, its graph\n            connectivity, its edge features and its global features.\n            (default: :obj:`None`)\n        global_model (torch.nn.Module, optional): A callable which updates a\n            graph's global features based on its node features, its graph\n            connectivity, its edge features and its current global features.\n            (default: :obj:`None`)\n\n    .. code-block:: python\n\n        from torch.nn import Sequential as Seq, Linear as Lin, ReLU\n        from torch_geometric.utils import scatter\n        from torch_geometric.nn import MetaLayer\n\n        class EdgeModel(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.edge_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))\n\n            def forward(self, src, dst, edge_attr, u, batch):\n                # src, dst: [E, F_x], where E is the number of edges.\n                # edge_attr: [E, F_e]\n                # u: [B, F_u], where B is the number of graphs.\n                # batch: [E] with max entry B - 1.\n                out = torch.cat([src, dst, edge_attr, u[batch]], 1)\n                return self.edge_mlp(out)\n\n        class NodeModel(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.node_mlp_1 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))\n                self.node_mlp_2 = Seq(Lin(..., ...), ReLU(), Lin(..., ...))\n\n            def forward(self, x, edge_index, edge_attr, u, batch):\n                # x: [N, F_x], where N is the number of nodes.\n                # edge_index: [2, E] with max entry N - 1.\n                # edge_attr: [E, F_e]\n                # u: [B, F_u]\n                # batch: [N] with max entry B - 1.\n                row, col = edge_index\n                out = torch.cat([x[row], edge_attr], dim=1)\n                out = self.node_mlp_1(out)\n                out = scatter(out, col, dim=0, dim_size=x.size(0),\n                              reduce='mean')\n                out = torch.cat([x, out, u[batch]], dim=1)\n                return self.node_mlp_2(out)\n\n        class GlobalModel(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.global_mlp = Seq(Lin(..., ...), ReLU(), Lin(..., ...))\n\n            def forward(self, x, edge_index, edge_attr, u, batch):\n                # x: [N, F_x], where N is the number of nodes.\n                # edge_index: [2, E] with max entry N - 1.\n                # edge_attr: [E, F_e]\n                # u: [B, F_u]\n                # batch: [N] with max entry B - 1.\n                out = torch.cat([\n                    u,\n                    scatter(x, batch, dim=0, reduce='mean'),\n                ], dim=1)\n                return self.global_mlp(out)\n\n        op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel())\n        x, edge_attr, u = op(x, edge_index, edge_attr, u, batch)\n    \"\"\"\n    def __init__(\n        self,\n        edge_model: Optional[torch.nn.Module] = None,\n        node_model: Optional[torch.nn.Module] = None,\n        global_model: Optional[torch.nn.Module] = None,\n    ):\n        super().__init__()\n        self.edge_model = edge_model\n        self.node_model = node_model\n        self.global_model = global_model\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for item in [self.node_model, self.edge_model, self.global_model]:\n            if hasattr(item, 'reset_parameters'):\n                item.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        edge_attr: Optional[Tensor] = None,\n        u: Optional[Tensor] = None,\n        batch: Optional[Tensor] = None,\n    ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The node features.\n            edge_index (torch.Tensor): The edge indices.\n            edge_attr (torch.Tensor, optional): The edge features.\n                (default: :obj:`None`)\n            u (torch.Tensor, optional): The global graph features.\n                (default: :obj:`None`)\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each node to a specific graph. (default: :obj:`None`)\n        \"\"\"\n        row = edge_index[0]\n        col = edge_index[1]\n\n        if self.edge_model is not None:\n            edge_attr = self.edge_model(x[row], x[col], edge_attr, u,\n                                        batch if batch is None else batch[row])\n\n        if self.node_model is not None:\n            x = self.node_model(x, edge_index, edge_attr, u, batch)\n\n        if self.global_model is not None:\n            u = self.global_model(x, edge_index, edge_attr, u, batch)\n\n        return x, edge_attr, u\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(\\n'\n                f'  edge_model={self.edge_model},\\n'\n                f'  node_model={self.node_model},\\n'\n                f'  global_model={self.global_model}\\n'\n                f')')\n"
  },
  {
    "path": "torch_geometric/nn/models/metapath2vec.py",
    "content": "from typing import Dict, List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Embedding\nfrom torch.utils.data import DataLoader\n\nfrom torch_geometric.index import index2ptr\nfrom torch_geometric.typing import EdgeType, NodeType, OptTensor\nfrom torch_geometric.utils import sort_edge_index\n\nEPS = 1e-15\n\n\nclass MetaPath2Vec(torch.nn.Module):\n    r\"\"\"The MetaPath2Vec model from the `\"metapath2vec: Scalable Representation\n    Learning for Heterogeneous Networks\"\n    <https://ericdongyx.github.io/papers/\n    KDD17-dong-chawla-swami-metapath2vec.pdf>`_ paper where random walks based\n    on a given :obj:`metapath` are sampled in a heterogeneous graph, and node\n    embeddings are learned via negative sampling optimization.\n\n    .. note::\n\n        For an example of using MetaPath2Vec, see\n        `examples/hetero/metapath2vec.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        hetero/metapath2vec.py>`_.\n\n    Args:\n        edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): Dictionary\n            holding edge indices for each\n            :obj:`(src_node_type, rel_type, dst_node_type)` edge type present\n            in the heterogeneous graph.\n        embedding_dim (int): The size of each embedding vector.\n        metapath (List[Tuple[str, str, str]]): The metapath described as a list\n            of :obj:`(src_node_type, rel_type, dst_node_type)` tuples.\n        walk_length (int): The walk length.\n        context_size (int): The actual context size which is considered for\n            positive samples. This parameter increases the effective sampling\n            rate by reusing samples across different source nodes.\n        walks_per_node (int, optional): The number of walks to sample for each\n            node. (default: :obj:`1`)\n        num_negative_samples (int, optional): The number of negative samples to\n            use for each positive sample. (default: :obj:`1`)\n        num_nodes_dict (Dict[str, int], optional): Dictionary holding the\n            number of nodes for each node type. (default: :obj:`None`)\n        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the\n            weight matrix will be sparse. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        edge_index_dict: Dict[EdgeType, Tensor],\n        embedding_dim: int,\n        metapath: List[EdgeType],\n        walk_length: int,\n        context_size: int,\n        walks_per_node: int = 1,\n        num_negative_samples: int = 1,\n        num_nodes_dict: Optional[Dict[NodeType, int]] = None,\n        sparse: bool = False,\n    ):\n        super().__init__()\n\n        if num_nodes_dict is None:\n            num_nodes_dict = {}\n            for keys, edge_index in edge_index_dict.items():\n                key = keys[0]\n                N = int(edge_index[0].max() + 1)\n                num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))\n\n                key = keys[-1]\n                N = int(edge_index[1].max() + 1)\n                num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))\n\n        self.rowptr_dict, self.col_dict, self.rowcount_dict = {}, {}, {}\n        for keys, edge_index in edge_index_dict.items():\n            sizes = (num_nodes_dict[keys[0]], num_nodes_dict[keys[-1]])\n            row, col = sort_edge_index(edge_index, num_nodes=max(sizes)).cpu()\n            rowptr = index2ptr(row, size=sizes[0])\n            self.rowptr_dict[keys] = rowptr\n            self.col_dict[keys] = col\n            self.rowcount_dict[keys] = rowptr[1:] - rowptr[:-1]\n\n        for edge_type1, edge_type2 in zip(metapath[:-1], metapath[1:]):\n            if edge_type1[-1] != edge_type2[0]:\n                raise ValueError(\n                    \"Found invalid metapath. Ensure that the destination node \"\n                    \"type matches with the source node type across all \"\n                    \"consecutive edge types.\")\n\n        assert walk_length + 1 >= context_size\n        if walk_length > len(metapath) and metapath[0][0] != metapath[-1][-1]:\n            raise AttributeError(\n                \"The 'walk_length' is longer than the given 'metapath', but \"\n                \"the 'metapath' does not denote a cycle\")\n\n        self.embedding_dim = embedding_dim\n        self.metapath = metapath\n        self.walk_length = walk_length\n        self.context_size = context_size\n        self.walks_per_node = walks_per_node\n        self.num_negative_samples = num_negative_samples\n        self.num_nodes_dict = num_nodes_dict\n\n        types = {x[0] for x in metapath} | {x[-1] for x in metapath}\n        types = sorted(list(types))\n\n        count = 0\n        self.start, self.end = {}, {}\n        for key in types:\n            self.start[key] = count\n            count += num_nodes_dict[key]\n            self.end[key] = count\n\n        offset = [self.start[metapath[0][0]]]\n        offset += [self.start[keys[-1]] for keys in metapath\n                   ] * int((walk_length / len(metapath)) + 1)\n        offset = offset[:walk_length + 1]\n        assert len(offset) == walk_length + 1\n        self.offset = torch.tensor(offset)\n\n        # + 1 denotes a dummy node used to link to for isolated nodes.\n        self.embedding = Embedding(count + 1, embedding_dim, sparse=sparse)\n        self.dummy_idx = count\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.embedding.reset_parameters()\n\n    def forward(self, node_type: str, batch: OptTensor = None) -> Tensor:\n        r\"\"\"Returns the embeddings for the nodes in :obj:`batch` of type\n        :obj:`node_type`.\n        \"\"\"\n        emb = self.embedding.weight[self.start[node_type]:self.end[node_type]]\n        return emb if batch is None else emb.index_select(0, batch)\n\n    def loader(self, **kwargs):\n        r\"\"\"Returns the data loader that creates both positive and negative\n        random walks on the heterogeneous graph.\n\n        Args:\n            **kwargs (optional): Arguments of\n                :class:`torch.utils.data.DataLoader`, such as\n                :obj:`batch_size`, :obj:`shuffle`, :obj:`drop_last` or\n                :obj:`num_workers`.\n        \"\"\"\n        return DataLoader(range(self.num_nodes_dict[self.metapath[0][0]]),\n                          collate_fn=self._sample, **kwargs)\n\n    def _pos_sample(self, batch: Tensor) -> Tensor:\n        batch = batch.repeat(self.walks_per_node)\n\n        rws = [batch]\n        for i in range(self.walk_length):\n            edge_type = self.metapath[i % len(self.metapath)]\n            batch = sample(\n                self.rowptr_dict[edge_type],\n                self.col_dict[edge_type],\n                self.rowcount_dict[edge_type],\n                batch,\n                num_neighbors=1,\n                dummy_idx=self.dummy_idx,\n            ).view(-1)\n            rws.append(batch)\n\n        rw = torch.stack(rws, dim=-1)\n        rw.add_(self.offset.view(1, -1))\n        rw[rw > self.dummy_idx] = self.dummy_idx\n\n        walks = []\n        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size\n        for j in range(num_walks_per_rw):\n            walks.append(rw[:, j:j + self.context_size])\n        return torch.cat(walks, dim=0)\n\n    def _neg_sample(self, batch: Tensor) -> Tensor:\n        batch = batch.repeat(self.walks_per_node * self.num_negative_samples)\n\n        rws = [batch]\n        for i in range(self.walk_length):\n            keys = self.metapath[i % len(self.metapath)]\n            batch = torch.randint(0, self.num_nodes_dict[keys[-1]],\n                                  (batch.size(0), ), dtype=torch.long)\n            rws.append(batch)\n\n        rw = torch.stack(rws, dim=-1)\n        rw.add_(self.offset.view(1, -1))\n\n        walks = []\n        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size\n        for j in range(num_walks_per_rw):\n            walks.append(rw[:, j:j + self.context_size])\n        return torch.cat(walks, dim=0)\n\n    def _sample(self, batch: List[int]) -> Tuple[Tensor, Tensor]:\n        if not isinstance(batch, Tensor):\n            batch = torch.tensor(batch, dtype=torch.long)\n        return self._pos_sample(batch), self._neg_sample(batch)\n\n    def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor:\n        r\"\"\"Computes the loss given positive and negative random walks.\"\"\"\n        # Positive loss.\n        start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()\n\n        h_start = self.embedding(start).view(pos_rw.size(0), 1,\n                                             self.embedding_dim)\n        h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1,\n                                                    self.embedding_dim)\n\n        out = (h_start * h_rest).sum(dim=-1).view(-1)\n        pos_loss = -torch.log(torch.sigmoid(out) + EPS).mean()\n\n        # Negative loss.\n        start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()\n\n        h_start = self.embedding(start).view(neg_rw.size(0), 1,\n                                             self.embedding_dim)\n        h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1,\n                                                    self.embedding_dim)\n\n        out = (h_start * h_rest).sum(dim=-1).view(-1)\n        neg_loss = -torch.log(1 - torch.sigmoid(out) + EPS).mean()\n\n        return pos_loss + neg_loss\n\n    def test(self, train_z: Tensor, train_y: Tensor, test_z: Tensor,\n             test_y: Tensor, solver: str = \"lbfgs\", *args, **kwargs) -> float:\n        r\"\"\"Evaluates latent space quality via a logistic regression downstream\n        task.\n        \"\"\"\n        from sklearn.linear_model import LogisticRegression\n\n        clf = LogisticRegression(*args, solver=solver,\n                                 **kwargs).fit(train_z.detach().cpu().numpy(),\n                                               train_y.detach().cpu().numpy())\n        return clf.score(test_z.detach().cpu().numpy(),\n                         test_y.detach().cpu().numpy())\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'{self.embedding.weight.size(0) - 1}, '\n                f'{self.embedding.weight.size(1)})')\n\n\ndef sample(rowptr: Tensor, col: Tensor, rowcount: Tensor, subset: Tensor,\n           num_neighbors: int, dummy_idx: int) -> Tensor:\n\n    mask = subset >= dummy_idx\n    subset = subset.clamp(min=0, max=rowptr.numel() - 2)\n    count = rowcount[subset]\n\n    rand = torch.rand((subset.size(0), num_neighbors), device=subset.device)\n    rand *= count.to(rand.dtype).view(-1, 1)\n    rand = rand.to(torch.long) + rowptr[subset].view(-1, 1)\n    rand = rand.clamp(max=col.numel() - 1)  # If last node is isolated.\n\n    col = col[rand] if col.numel() > 0 else rand\n    col[mask | (count == 0)] = dummy_idx\n    return col\n"
  },
  {
    "path": "torch_geometric/nn/models/mlp.py",
    "content": "import inspect\nimport warnings\nfrom typing import Any, Callable, Dict, Final, List, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Identity\n\nfrom torch_geometric.nn.dense.linear import Linear\nfrom torch_geometric.nn.resolver import (\n    activation_resolver,\n    normalization_resolver,\n)\nfrom torch_geometric.typing import NoneType\n\n\nclass MLP(torch.nn.Module):\n    r\"\"\"A Multi-Layer Perception (MLP) model.\n\n    There exists two ways to instantiate an :class:`MLP`:\n\n    1. By specifying explicit channel sizes, *e.g.*,\n\n       .. code-block:: python\n\n          mlp = MLP([16, 32, 64, 128])\n\n       creates a three-layer MLP with **differently** sized hidden layers.\n\n    1. By specifying fixed hidden channel sizes over a number of layers,\n       *e.g.*,\n\n       .. code-block:: python\n\n          mlp = MLP(in_channels=16, hidden_channels=32,\n                    out_channels=128, num_layers=3)\n\n       creates a three-layer MLP with **equally** sized hidden layers.\n\n    Args:\n        channel_list (List[int] or int, optional): List of input, intermediate\n            and output channels such that :obj:`len(channel_list) - 1` denotes\n            the number of layers of the MLP (default: :obj:`None`)\n        in_channels (int, optional): Size of each input sample.\n            Will override :attr:`channel_list`. (default: :obj:`None`)\n        hidden_channels (int, optional): Size of each hidden sample.\n            Will override :attr:`channel_list`. (default: :obj:`None`)\n        out_channels (int, optional): Size of each output sample.\n            Will override :attr:`channel_list`. (default: :obj:`None`)\n        num_layers (int, optional): The number of layers.\n            Will override :attr:`channel_list`. (default: :obj:`None`)\n        dropout (float or List[float], optional): Dropout probability of each\n            hidden embedding. If a list is provided, sets the dropout value per\n            layer. (default: :obj:`0.`)\n        act (str or Callable, optional): The non-linear activation function to\n            use. (default: :obj:`\"relu\"`)\n        act_first (bool, optional): If set to :obj:`True`, activation is\n            applied before normalization. (default: :obj:`False`)\n        act_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective activation function defined by :obj:`act`.\n            (default: :obj:`None`)\n        norm (str or Callable, optional): The normalization function to\n            use. (default: :obj:`\"batch_norm\"`)\n        norm_kwargs (Dict[str, Any], optional): Arguments passed to the\n            respective normalization function defined by :obj:`norm`.\n            (default: :obj:`None`)\n        plain_last (bool, optional): If set to :obj:`False`, will apply\n            non-linearity, batch normalization and dropout to the last layer as\n            well. (default: :obj:`True`)\n        bias (bool or List[bool], optional): If set to :obj:`False`, the module\n            will not learn additive biases. If a list is provided, sets the\n            bias per layer. (default: :obj:`True`)\n        **kwargs (optional): Additional deprecated arguments of the MLP layer.\n    \"\"\"\n    supports_norm_batch: Final[bool]\n\n    def __init__(\n        self,\n        channel_list: Optional[Union[List[int], int]] = None,\n        *,\n        in_channels: Optional[int] = None,\n        hidden_channels: Optional[int] = None,\n        out_channels: Optional[int] = None,\n        num_layers: Optional[int] = None,\n        dropout: Union[float, List[float]] = 0.,\n        act: Union[str, Callable, None] = \"relu\",\n        act_first: bool = False,\n        act_kwargs: Optional[Dict[str, Any]] = None,\n        norm: Union[str, Callable, None] = \"batch_norm\",\n        norm_kwargs: Optional[Dict[str, Any]] = None,\n        plain_last: bool = True,\n        bias: Union[bool, List[bool]] = True,\n        **kwargs,\n    ):\n        super().__init__()\n\n        # Backward compatibility:\n        act_first = act_first or kwargs.get(\"relu_first\", False)\n        batch_norm = kwargs.get(\"batch_norm\", None)\n        if batch_norm is not None and isinstance(batch_norm, bool):\n            warnings.warn(\n                \"Argument `batch_norm` is deprecated, \"\n                \"please use `norm` to specify normalization layer.\",\n                stacklevel=2)\n            norm = 'batch_norm' if batch_norm else None\n            batch_norm_kwargs = kwargs.get(\"batch_norm_kwargs\", None)\n            norm_kwargs = batch_norm_kwargs or {}\n\n        if isinstance(channel_list, int):\n            in_channels = channel_list\n\n        if in_channels is not None:\n            if num_layers is None:\n                raise ValueError(\"Argument `num_layers` must be given\")\n            if num_layers > 1 and hidden_channels is None:\n                raise ValueError(f\"Argument `hidden_channels` must be given \"\n                                 f\"for `num_layers={num_layers}`\")\n            if out_channels is None:\n                raise ValueError(\"Argument `out_channels` must be given\")\n\n            channel_list = [hidden_channels] * (num_layers - 1)\n            channel_list = [in_channels] + channel_list + [out_channels]\n\n        assert isinstance(channel_list, (tuple, list))\n        assert len(channel_list) >= 2\n        self.channel_list = channel_list\n\n        self.act = activation_resolver(act, **(act_kwargs or {}))\n        self.act_first = act_first\n        self.plain_last = plain_last\n\n        if isinstance(dropout, float):\n            dropout = [dropout] * (len(channel_list) - 1)\n            if plain_last:\n                dropout[-1] = 0.\n        if len(dropout) != len(channel_list) - 1:\n            raise ValueError(\n                f\"Number of dropout values provided ({len(dropout)} does not \"\n                f\"match the number of layers specified \"\n                f\"({len(channel_list)-1})\")\n        self.dropout = dropout\n\n        if isinstance(bias, bool):\n            bias = [bias] * (len(channel_list) - 1)\n        if len(bias) != len(channel_list) - 1:\n            raise ValueError(\n                f\"Number of bias values provided ({len(bias)}) does not match \"\n                f\"the number of layers specified ({len(channel_list)-1})\")\n\n        self.lins = torch.nn.ModuleList()\n        iterator = zip(channel_list[:-1], channel_list[1:], bias)\n        for in_channels, out_channels, _bias in iterator:\n            self.lins.append(Linear(in_channels, out_channels, bias=_bias))\n\n        self.norms = torch.nn.ModuleList()\n        iterator = channel_list[1:-1] if plain_last else channel_list[1:]\n        for hidden_channels in iterator:\n            if norm is not None:\n                norm_layer = normalization_resolver(\n                    norm,\n                    hidden_channels,\n                    **(norm_kwargs or {}),\n                )\n            else:\n                norm_layer = Identity()\n            self.norms.append(norm_layer)\n\n        self.supports_norm_batch = False\n        if len(self.norms) > 0 and hasattr(self.norms[0], 'forward'):\n            norm_params = inspect.signature(self.norms[0].forward).parameters\n            self.supports_norm_batch = 'batch' in norm_params\n\n        self.reset_parameters()\n\n    @property\n    def in_channels(self) -> int:\n        r\"\"\"Size of each input sample.\"\"\"\n        return self.channel_list[0]\n\n    @property\n    def out_channels(self) -> int:\n        r\"\"\"Size of each output sample.\"\"\"\n        return self.channel_list[-1]\n\n    @property\n    def num_layers(self) -> int:\n        r\"\"\"The number of layers.\"\"\"\n        return len(self.channel_list) - 1\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for lin in self.lins:\n            lin.reset_parameters()\n        for norm in self.norms:\n            if hasattr(norm, 'reset_parameters'):\n                norm.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        batch: Optional[Tensor] = None,\n        batch_size: Optional[int] = None,\n        return_emb: NoneType = None,\n    ) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example.\n                Only needs to be passed in case the underlying normalization\n                layers require the :obj:`batch` information.\n                (default: :obj:`None`)\n            batch_size (int, optional): The number of examples :math:`B`.\n                Automatically calculated if not given.\n                Only needs to be passed in case the underlying normalization\n                layers require the :obj:`batch` information.\n                (default: :obj:`None`)\n            return_emb (bool, optional): If set to :obj:`True`, will\n                additionally return the embeddings before execution of the\n                final output layer. (default: :obj:`False`)\n        \"\"\"\n        # `return_emb` is annotated here as `NoneType` to be compatible with\n        # TorchScript, which does not support different return types based on\n        # the value of an input argument.\n        emb: Optional[Tensor] = None\n\n        # If `plain_last=True`, then `len(norms) = len(lins) -1, thus skipping\n        # the execution of the last layer inside the for-loop.\n        for i, (lin, norm) in enumerate(zip(self.lins, self.norms)):\n            x = lin(x)\n            if self.act is not None and self.act_first:\n                x = self.act(x)\n            if self.supports_norm_batch:\n                x = norm(x, batch, batch_size)\n            else:\n                x = norm(x)\n            if self.act is not None and not self.act_first:\n                x = self.act(x)\n            x = F.dropout(x, p=self.dropout[i], training=self.training)\n            if isinstance(return_emb, bool) and return_emb is True:\n                emb = x\n\n        if self.plain_last:\n            x = self.lins[-1](x)\n            x = F.dropout(x, p=self.dropout[-1], training=self.training)\n\n        return (x, emb) if isinstance(return_emb, bool) else x\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({str(self.channel_list)[1:-1]})'\n"
  },
  {
    "path": "torch_geometric/nn/models/neural_fingerprint.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn import Linear, MFConv, global_add_pool\nfrom torch_geometric.typing import Adj\n\n\nclass NeuralFingerprint(torch.nn.Module):\n    r\"\"\"The Neural Fingerprint model from the\n    `\"Convolutional Networks on Graphs for Learning Molecular Fingerprints\"\n    <https://arxiv.org/abs/1509.09292>`__ paper to generate fingerprints\n    of molecules.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        hidden_channels (int): Size of each hidden sample.\n        out_channels (int): Size of each output fingerprint.\n        num_layers (int): Number of layers.\n        **kwargs (optional): Additional arguments of\n            :class:`torch_geometric.nn.conv.MFConv`.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        num_layers: int,\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.hidden_channels = hidden_channels\n        self.out_channels = out_channels\n        self.num_layers = num_layers\n\n        self.convs = torch.nn.ModuleList()\n        for i in range(self.num_layers):\n            in_channels = self.in_channels if i == 0 else self.hidden_channels\n            self.convs.append(MFConv(in_channels, hidden_channels, **kwargs))\n\n        self.lins = torch.nn.ModuleList()\n        for _ in range(self.num_layers):\n            self.lins.append(Linear(hidden_channels, out_channels, bias=False))\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for conv in self.convs:\n            conv.reset_parameters()\n        for lin in self.lins:\n            lin.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Adj,\n        batch: Optional[Tensor] = None,\n        batch_size: Optional[int] = None,\n    ) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        outs = []\n        for conv, lin in zip(self.convs, self.lins):\n            x = conv(x, edge_index).sigmoid()\n            y = lin(x).softmax(dim=-1)\n            outs.append(global_add_pool(y, batch, batch_size))\n        return sum(outs)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, num_layers={self.num_layers})')\n"
  },
  {
    "path": "torch_geometric/nn/models/node2vec.py",
    "content": "from typing import List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Embedding\nfrom torch.utils.data import DataLoader\n\nfrom torch_geometric.index import index2ptr\nfrom torch_geometric.typing import WITH_PYG_LIB, WITH_TORCH_CLUSTER\nfrom torch_geometric.utils import sort_edge_index\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\nclass Node2Vec(torch.nn.Module):\n    r\"\"\"The Node2Vec model from the\n    `\"node2vec: Scalable Feature Learning for Networks\"\n    <https://arxiv.org/abs/1607.00653>`_ paper where random walks of\n    length :obj:`walk_length` are sampled in a given graph, and node embeddings\n    are learned via negative sampling optimization.\n\n    .. note::\n\n        For an example of using Node2Vec, see `examples/node2vec.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        node2vec.py>`_.\n\n    Args:\n        edge_index (torch.Tensor): The edge indices.\n        embedding_dim (int): The size of each embedding vector.\n        walk_length (int): The walk length.\n        context_size (int): The actual context size which is considered for\n            positive samples. This parameter increases the effective sampling\n            rate by reusing samples across different source nodes.\n        walks_per_node (int, optional): The number of walks to sample for each\n            node. (default: :obj:`1`)\n        p (float, optional): Likelihood of immediately revisiting a node in the\n            walk. (default: :obj:`1`)\n        q (float, optional): Control parameter to interpolate between\n            breadth-first strategy and depth-first strategy (default: :obj:`1`)\n        num_negative_samples (int, optional): The number of negative samples to\n            use for each positive sample. (default: :obj:`1`)\n        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)\n        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the\n            weight matrix will be sparse. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        edge_index: Tensor,\n        embedding_dim: int,\n        walk_length: int,\n        context_size: int,\n        walks_per_node: int = 1,\n        p: float = 1.0,\n        q: float = 1.0,\n        num_negative_samples: int = 1,\n        num_nodes: Optional[int] = None,\n        sparse: bool = False,\n    ):\n        super().__init__()\n\n        if WITH_PYG_LIB and p == 1.0 and q == 1.0:\n            self.random_walk_fn = torch.ops.pyg.random_walk\n        elif WITH_TORCH_CLUSTER:\n            self.random_walk_fn = torch.ops.torch_cluster.random_walk\n        else:\n            if p == 1.0 and q == 1.0:\n                raise ImportError(f\"'{self.__class__.__name__}' \"\n                                  f\"requires either the 'pyg-lib' or \"\n                                  f\"'torch-cluster' package\")\n            else:\n                raise ImportError(f\"'{self.__class__.__name__}' \"\n                                  f\"requires the 'torch-cluster' package\")\n\n        self.num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n        row, col = sort_edge_index(edge_index, num_nodes=self.num_nodes).cpu()\n        self.rowptr, self.col = index2ptr(row, self.num_nodes), col\n\n        self.EPS = 1e-15\n        assert walk_length >= context_size\n\n        self.embedding_dim = embedding_dim\n        self.walk_length = walk_length - 1\n        self.context_size = context_size\n        self.walks_per_node = walks_per_node\n        self.p = p\n        self.q = q\n        self.num_negative_samples = num_negative_samples\n\n        self.embedding = Embedding(self.num_nodes, embedding_dim,\n                                   sparse=sparse)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.embedding.reset_parameters()\n\n    def forward(self, batch: Optional[Tensor] = None) -> Tensor:\n        \"\"\"Returns the embeddings for the nodes in :obj:`batch`.\"\"\"\n        emb = self.embedding.weight\n        return emb if batch is None else emb[batch]\n\n    def loader(self, **kwargs) -> DataLoader:\n        return DataLoader(range(self.num_nodes), collate_fn=self.sample,\n                          **kwargs)\n\n    @torch.jit.export\n    def pos_sample(self, batch: Tensor) -> Tensor:\n        batch = batch.repeat(self.walks_per_node)\n        rw = self.random_walk_fn(self.rowptr, self.col, batch,\n                                 self.walk_length, self.p, self.q)\n        if not isinstance(rw, Tensor):\n            rw = rw[0]\n\n        walks = []\n        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size\n        for j in range(num_walks_per_rw):\n            walks.append(rw[:, j:j + self.context_size])\n        return torch.cat(walks, dim=0)\n\n    @torch.jit.export\n    def neg_sample(self, batch: Tensor) -> Tensor:\n        batch = batch.repeat(self.walks_per_node * self.num_negative_samples)\n\n        rw = torch.randint(self.num_nodes, (batch.size(0), self.walk_length),\n                           dtype=batch.dtype, device=batch.device)\n        rw = torch.cat([batch.view(-1, 1), rw], dim=-1)\n\n        walks = []\n        num_walks_per_rw = 1 + self.walk_length + 1 - self.context_size\n        for j in range(num_walks_per_rw):\n            walks.append(rw[:, j:j + self.context_size])\n        return torch.cat(walks, dim=0)\n\n    @torch.jit.export\n    def sample(self, batch: Union[List[int], Tensor]) -> Tuple[Tensor, Tensor]:\n        if not isinstance(batch, Tensor):\n            batch = torch.tensor(batch)\n        return self.pos_sample(batch), self.neg_sample(batch)\n\n    @torch.jit.export\n    def loss(self, pos_rw: Tensor, neg_rw: Tensor) -> Tensor:\n        r\"\"\"Computes the loss given positive and negative random walks.\"\"\"\n        # Positive loss.\n        start, rest = pos_rw[:, 0], pos_rw[:, 1:].contiguous()\n\n        h_start = self.embedding(start).view(pos_rw.size(0), 1,\n                                             self.embedding_dim)\n        h_rest = self.embedding(rest.view(-1)).view(pos_rw.size(0), -1,\n                                                    self.embedding_dim)\n\n        out = (h_start * h_rest).sum(dim=-1).view(-1)\n        pos_loss = -torch.log(torch.sigmoid(out) + self.EPS).mean()\n\n        # Negative loss.\n        start, rest = neg_rw[:, 0], neg_rw[:, 1:].contiguous()\n\n        h_start = self.embedding(start).view(neg_rw.size(0), 1,\n                                             self.embedding_dim)\n        h_rest = self.embedding(rest.view(-1)).view(neg_rw.size(0), -1,\n                                                    self.embedding_dim)\n\n        out = (h_start * h_rest).sum(dim=-1).view(-1)\n        neg_loss = -torch.log(1 - torch.sigmoid(out) + self.EPS).mean()\n\n        return pos_loss + neg_loss\n\n    def test(\n        self,\n        train_z: Tensor,\n        train_y: Tensor,\n        test_z: Tensor,\n        test_y: Tensor,\n        solver: str = 'lbfgs',\n        *args,\n        **kwargs,\n    ) -> float:\n        r\"\"\"Evaluates latent space quality via a logistic regression downstream\n        task.\n        \"\"\"\n        from sklearn.linear_model import LogisticRegression\n\n        clf = LogisticRegression(*args, solver=solver,\n                                 **kwargs).fit(train_z.detach().cpu().numpy(),\n                                               train_y.detach().cpu().numpy())\n        return clf.score(test_z.detach().cpu().numpy(),\n                         test_y.detach().cpu().numpy())\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.embedding.weight.size(0)}, '\n                f'{self.embedding.weight.size(1)})')\n"
  },
  {
    "path": "torch_geometric/nn/models/pmlp.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn import SimpleConv\nfrom torch_geometric.nn.dense.linear import Linear\n\n\nclass PMLP(torch.nn.Module):\n    r\"\"\"The P(ropagational)MLP model from the `\"Graph Neural Networks are\n    Inherently Good Generalizers: Insights by Bridging GNNs and MLPs\"\n    <https://arxiv.org/abs/2212.09034>`_ paper.\n    :class:`PMLP` is identical to a standard MLP during training, but then\n    adopts a GNN architecture during testing.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        hidden_channels (int): Size of each hidden sample.\n        out_channels (int): Size of each output sample.\n        num_layers (int): The number of layers.\n        dropout (float, optional): Dropout probability of each hidden\n            embedding. (default: :obj:`0.`)\n        norm (bool, optional): If set to :obj:`False`, will not apply batch\n            normalization. (default: :obj:`True`)\n        bias (bool, optional): If set to :obj:`False`, the module\n            will not learn additive biases. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        num_layers: int,\n        dropout: float = 0.,\n        norm: bool = True,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.hidden_channels = hidden_channels\n        self.out_channels = out_channels\n        self.num_layers = num_layers\n        self.dropout = dropout\n        self.bias = bias\n\n        self.lins = torch.nn.ModuleList()\n        self.lins.append(Linear(in_channels, hidden_channels, self.bias))\n        for _ in range(self.num_layers - 2):\n            lin = Linear(hidden_channels, hidden_channels, self.bias)\n            self.lins.append(lin)\n        self.lins.append(Linear(hidden_channels, out_channels, self.bias))\n\n        self.norm = None\n        if norm:\n            self.norm = torch.nn.BatchNorm1d(\n                hidden_channels,\n                affine=False,\n                track_running_stats=False,\n            )\n\n        self.conv = SimpleConv(aggr='mean', combine_root='self_loop')\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for lin in self.lins:\n            torch.nn.init.xavier_uniform_(lin.weight, gain=1.414)\n            if self.bias:\n                torch.nn.init.zeros_(lin.bias)\n\n    def forward(\n        self,\n        x: torch.Tensor,\n        edge_index: Optional[Tensor] = None,\n    ) -> torch.Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        if not self.training and edge_index is None:\n            raise ValueError(f\"'edge_index' needs to be present during \"\n                             f\"inference in '{self.__class__.__name__}'\")\n\n        for i in range(self.num_layers):\n            x = x @ self.lins[i].weight.t()\n            if not self.training:\n                x = self.conv(x, edge_index)\n            if self.bias:\n                x = x + self.lins[i].bias\n            if i != self.num_layers - 1:\n                if self.norm is not None:\n                    x = self.norm(x)\n                x = x.relu()\n                x = F.dropout(x, p=self.dropout, training=self.training)\n\n        return x\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, num_layers={self.num_layers})')\n"
  },
  {
    "path": "torch_geometric/nn/models/polynormer.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn import GATConv, GCNConv\nfrom torch_geometric.nn.attention import PolynormerAttention\nfrom torch_geometric.utils import to_dense_batch\n\n\nclass Polynormer(torch.nn.Module):\n    r\"\"\"The polynormer module from the\n    `\"Polynormer: polynomial-expressive graph\n    transformer in linear time\"\n    <https://arxiv.org/abs/2403.01232>`_ paper.\n\n    Args:\n        in_channels (int): Input channels.\n        hidden_channels (int): Hidden channels.\n        out_channels (int): Output channels.\n        local_layers (int): The number of local attention layers.\n            (default: :obj:`7`)\n        global_layers (int): The number of global attention layers.\n            (default: :obj:`2`)\n        in_dropout (float): Input dropout rate.\n            (default: :obj:`0.15`)\n        dropout (float): Dropout rate.\n            (default: :obj:`0.5`)\n        global_dropout (float): Global dropout rate.\n            (default: :obj:`0.5`)\n        heads (int): The number of heads.\n            (default: :obj:`1`)\n        beta (float): Aggregate type.\n            (default: :obj:`0.9`)\n        qk_shared (bool optional): Whether weight of query and key are shared.\n            (default: :obj:`True`)\n        pre_ln (bool): Pre layer normalization.\n            (default: :obj:`False`)\n        post_bn (bool): Post batch normalization.\n            (default: :obj:`True`)\n        local_attn (bool): Whether use local attention.\n            (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        local_layers: int = 7,\n        global_layers: int = 2,\n        in_dropout: float = 0.15,\n        dropout: float = 0.5,\n        global_dropout: float = 0.5,\n        heads: int = 1,\n        beta: float = 0.9,\n        qk_shared: bool = False,\n        pre_ln: bool = False,\n        post_bn: bool = True,\n        local_attn: bool = False,\n    ) -> None:\n        super().__init__()\n        self._global = False\n        self.in_drop = in_dropout\n        self.dropout = dropout\n        self.pre_ln = pre_ln\n        self.post_bn = post_bn\n\n        self.beta = beta\n\n        self.h_lins = torch.nn.ModuleList()\n        self.local_convs = torch.nn.ModuleList()\n        self.lins = torch.nn.ModuleList()\n        self.lns = torch.nn.ModuleList()\n        if self.pre_ln:\n            self.pre_lns = torch.nn.ModuleList()\n        if self.post_bn:\n            self.post_bns = torch.nn.ModuleList()\n\n        # first layer\n        inner_channels = heads * hidden_channels\n        self.h_lins.append(torch.nn.Linear(in_channels, inner_channels))\n        if local_attn:\n            self.local_convs.append(\n                GATConv(in_channels, hidden_channels, heads=heads, concat=True,\n                        add_self_loops=False, bias=False))\n        else:\n            self.local_convs.append(\n                GCNConv(in_channels, inner_channels, cached=False,\n                        normalize=True))\n\n        self.lins.append(torch.nn.Linear(in_channels, inner_channels))\n        self.lns.append(torch.nn.LayerNorm(inner_channels))\n        if self.pre_ln:\n            self.pre_lns.append(torch.nn.LayerNorm(in_channels))\n        if self.post_bn:\n            self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))\n\n        # following layers\n        for _ in range(local_layers - 1):\n            self.h_lins.append(torch.nn.Linear(inner_channels, inner_channels))\n            if local_attn:\n                self.local_convs.append(\n                    GATConv(inner_channels, hidden_channels, heads=heads,\n                            concat=True, add_self_loops=False, bias=False))\n            else:\n                self.local_convs.append(\n                    GCNConv(inner_channels, inner_channels, cached=False,\n                            normalize=True))\n\n            self.lins.append(torch.nn.Linear(inner_channels, inner_channels))\n            self.lns.append(torch.nn.LayerNorm(inner_channels))\n            if self.pre_ln:\n                self.pre_lns.append(torch.nn.LayerNorm(heads *\n                                                       hidden_channels))\n            if self.post_bn:\n                self.post_bns.append(torch.nn.BatchNorm1d(inner_channels))\n\n        self.lin_in = torch.nn.Linear(in_channels, inner_channels)\n        self.ln = torch.nn.LayerNorm(inner_channels)\n\n        self.global_attn = torch.nn.ModuleList()\n        for _ in range(global_layers):\n            self.global_attn.append(\n                PolynormerAttention(\n                    channels=hidden_channels,\n                    heads=heads,\n                    head_channels=hidden_channels,\n                    beta=beta,\n                    dropout=global_dropout,\n                    qk_shared=qk_shared,\n                ))\n        self.pred_local = torch.nn.Linear(inner_channels, out_channels)\n        self.pred_global = torch.nn.Linear(inner_channels, out_channels)\n        self.reset_parameters()\n\n    def reset_parameters(self) -> None:\n        for local_conv in self.local_convs:\n            local_conv.reset_parameters()\n        for attn in self.global_attn:\n            attn.reset_parameters()\n        for lin in self.lins:\n            lin.reset_parameters()\n        for h_lin in self.h_lins:\n            h_lin.reset_parameters()\n        for ln in self.lns:\n            ln.reset_parameters()\n        if self.pre_ln:\n            for p_ln in self.pre_lns:\n                p_ln.reset_parameters()\n        if self.post_bn:\n            for p_bn in self.post_bns:\n                p_bn.reset_parameters()\n        self.lin_in.reset_parameters()\n        self.ln.reset_parameters()\n        self.pred_local.reset_parameters()\n        self.pred_global.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Optional[Tensor],\n    ) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The input node features.\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example.\n        \"\"\"\n        x = F.dropout(x, p=self.in_drop, training=self.training)\n\n        # equivariant local attention\n        x_local = 0\n        for i, local_conv in enumerate(self.local_convs):\n            if self.pre_ln:\n                x = self.pre_lns[i](x)\n            h = self.h_lins[i](x)\n            h = F.relu(h)\n            x = local_conv(x, edge_index) + self.lins[i](x)\n            if self.post_bn:\n                x = self.post_bns[i](x)\n            x = F.relu(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n            x = (1 - self.beta) * self.lns[i](h * x) + self.beta * x\n            x_local = x_local + x\n\n        # equivariant global attention\n        if self._global:\n            batch, indices = batch.sort()\n            rev_perm = torch.empty_like(indices)\n            rev_perm[indices] = torch.arange(len(indices),\n                                             device=indices.device)\n            x_local = self.ln(x_local[indices])\n            x_global, mask = to_dense_batch(x_local, batch)\n            for attn in self.global_attn:\n                x_global = attn(x_global, mask)\n            x = x_global[mask][rev_perm]\n            x = self.pred_global(x)\n        else:\n            x = self.pred_local(x_local)\n\n        return F.log_softmax(x, dim=-1)\n"
  },
  {
    "path": "torch_geometric/nn/models/re_net.py",
    "content": "import math\nfrom typing import Callable, List, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import GRU, Linear, Parameter\n\nfrom torch_geometric.data.data import Data\nfrom torch_geometric.utils import scatter\n\n\nclass RENet(torch.nn.Module):\n    r\"\"\"The Recurrent Event Network model from the `\"Recurrent Event Network\n    for Reasoning over Temporal Knowledge Graphs\"\n    <https://arxiv.org/abs/1904.05530>`_ paper.\n\n    .. math::\n        f_{\\mathbf{\\Theta}}(\\mathbf{e}_s, \\mathbf{e}_r,\n        \\mathbf{h}^{(t-1)}(s, r))\n\n    based on a RNN encoder\n\n    .. math::\n        \\mathbf{h}^{(t)}(s, r) = \\textrm{RNN}(\\mathbf{e}_s, \\mathbf{e}_r,\n        g(\\mathcal{O}^{(t)}_r(s)), \\mathbf{h}^{(t-1)}(s, r))\n\n    where :math:`\\mathbf{e}_s` and :math:`\\mathbf{e}_r` denote entity and\n    relation embeddings, and :math:`\\mathcal{O}^{(t)}_r(s)` represents the set\n    of objects interacted with subject :math:`s` under relation :math:`r` at\n    timestamp :math:`t`.\n    This model implements :math:`g` as the **Mean Aggregator** and\n    :math:`f_{\\mathbf{\\Theta}}` as a linear projection.\n\n    Args:\n        num_nodes (int): The number of nodes in the knowledge graph.\n        num_rels (int): The number of relations in the knowledge graph.\n        hidden_channels (int): Hidden size of node and relation embeddings.\n        seq_len (int): The sequence length of past events.\n        num_layers (int, optional): The number of recurrent layers.\n            (default: :obj:`1`)\n        dropout (float): If non-zero, introduces a dropout layer before the\n            final prediction. (default: :obj:`0.`)\n        bias (bool, optional): If set to :obj:`False`, all layers will not\n            learn an additive bias. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        num_nodes: int,\n        num_rels: int,\n        hidden_channels: int,\n        seq_len: int,\n        num_layers: int = 1,\n        dropout: float = 0.,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        self.num_nodes = num_nodes\n        self.hidden_channels = hidden_channels\n        self.num_rels = num_rels\n        self.seq_len = seq_len\n        self.dropout = dropout\n\n        self.ent = Parameter(torch.empty(num_nodes, hidden_channels))\n        self.rel = Parameter(torch.empty(num_rels, hidden_channels))\n\n        self.sub_gru = GRU(3 * hidden_channels, hidden_channels, num_layers,\n                           batch_first=True, bias=bias)\n        self.obj_gru = GRU(3 * hidden_channels, hidden_channels, num_layers,\n                           batch_first=True, bias=bias)\n\n        self.sub_lin = Linear(3 * hidden_channels, num_nodes, bias=bias)\n        self.obj_lin = Linear(3 * hidden_channels, num_nodes, bias=bias)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.ent, gain=math.sqrt(2.0))\n        torch.nn.init.xavier_uniform_(self.rel, gain=math.sqrt(2.0))\n\n        self.sub_gru.reset_parameters()\n        self.obj_gru.reset_parameters()\n        self.sub_lin.reset_parameters()\n        self.obj_lin.reset_parameters()\n\n    @staticmethod\n    def pre_transform(seq_len: int) -> Callable:\n        r\"\"\"Precomputes history objects.\n\n        .. math::\n            \\{ \\mathcal{O}^{(t-k-1)}_r(s), \\ldots, \\mathcal{O}^{(t-1)}_r(s) \\}\n\n        of a :class:`torch_geometric.datasets.icews.EventDataset` with\n        :math:`k` denoting the sequence length :obj:`seq_len`.\n        \"\"\"\n        class PreTransform:\n            def __init__(self, seq_len: int):\n                self.seq_len = seq_len\n                self.inc = 5000\n                self.t_last = 0\n                self.sub_hist = self.increase_hist_node_size([])\n                self.obj_hist = self.increase_hist_node_size([])\n\n            def increase_hist_node_size(self, hist: List[int]) -> List[int]:\n                hist_inc = torch.zeros((self.inc, self.seq_len + 1, 0))\n                return hist + hist_inc.tolist()\n\n            def get_history(\n                self,\n                hist: List[int],\n                node: int,\n                rel: int,\n            ) -> Tuple[Tensor, Tensor]:\n                hists, ts = [], []\n                for s in range(seq_len):\n                    h = hist[node][s]\n                    hists += h\n                    ts.append(torch.full((len(h), ), s, dtype=torch.long))\n                node, r = torch.tensor(hists, dtype=torch.long).view(\n                    -1, 2).t().contiguous()\n                node = node[r == rel]\n                t = torch.cat(ts, dim=0)[r == rel]\n                return node, t\n\n            def step(self, hist: List[int]) -> List[int]:\n                for i in range(len(hist)):\n                    hist[i] = hist[i][1:]\n                    hist[i].append([])\n                return hist\n\n            def __call__(self, data: Data) -> Data:\n                sub, rel, obj, t = data.sub, data.rel, data.obj, data.t\n\n                if max(sub, obj) + 1 > len(self.sub_hist):  # pragma: no cover\n                    self.sub_hist = self.increase_hist_node_size(self.sub_hist)\n                    self.obj_hist = self.increase_hist_node_size(self.obj_hist)\n\n                # Delete last timestamp in history.\n                if t > self.t_last:\n                    self.sub_hist = self.step(self.sub_hist)\n                    self.obj_hist = self.step(self.obj_hist)\n                    self.t_last = t\n\n                # Save history in data object.\n                data.h_sub, data.h_sub_t = self.get_history(\n                    self.sub_hist, sub, rel)\n                data.h_obj, data.h_obj_t = self.get_history(\n                    self.obj_hist, obj, rel)\n\n                # Add new event to history.\n                self.sub_hist[sub][-1].append([obj, rel])\n                self.obj_hist[obj][-1].append([sub, rel])\n\n                return data\n\n            def __repr__(self) -> str:  # pragma: no cover\n                return f'{self.__class__.__name__}(seq_len={self.seq_len})'\n\n        return PreTransform(seq_len)\n\n    def forward(self, data: Data) -> Tuple[Tensor, Tensor]:\n        \"\"\"Given a :obj:`data` batch, computes the forward pass.\n\n        Args:\n            data (torch_geometric.data.Data): The input data, holding subject\n                :obj:`sub`, relation :obj:`rel` and object :obj:`obj`\n                information with shape :obj:`[batch_size]`.\n                In addition, :obj:`data` needs to hold history information for\n                subjects, given by a vector of node indices :obj:`h_sub` and\n                their relative timestamps :obj:`h_sub_t` and batch assignments\n                :obj:`h_sub_batch`.\n                The same information must be given for objects (:obj:`h_obj`,\n                :obj:`h_obj_t`, :obj:`h_obj_batch`).\n        \"\"\"\n        assert 'h_sub_batch' in data and 'h_obj_batch' in data\n        batch_size, seq_len = data.sub.size(0), self.seq_len\n\n        h_sub_t = data.h_sub_t + data.h_sub_batch * seq_len\n        h_obj_t = data.h_obj_t + data.h_obj_batch * seq_len\n\n        h_sub = scatter(self.ent[data.h_sub], h_sub_t, dim=0,\n                        dim_size=batch_size * seq_len,\n                        reduce='mean').view(batch_size, seq_len, -1)\n        h_obj = scatter(self.ent[data.h_obj], h_obj_t, dim=0,\n                        dim_size=batch_size * seq_len,\n                        reduce='mean').view(batch_size, seq_len, -1)\n\n        sub = self.ent[data.sub].unsqueeze(1).repeat(1, seq_len, 1)\n        rel = self.rel[data.rel].unsqueeze(1).repeat(1, seq_len, 1)\n        obj = self.ent[data.obj].unsqueeze(1).repeat(1, seq_len, 1)\n\n        _, h_sub = self.sub_gru(torch.cat([sub, h_sub, rel], dim=-1))\n        _, h_obj = self.obj_gru(torch.cat([obj, h_obj, rel], dim=-1))\n        h_sub, h_obj = h_sub.squeeze(0), h_obj.squeeze(0)\n\n        h_sub = torch.cat([self.ent[data.sub], h_sub, self.rel[data.rel]],\n                          dim=-1)\n        h_obj = torch.cat([self.ent[data.obj], h_obj, self.rel[data.rel]],\n                          dim=-1)\n\n        h_sub = F.dropout(h_sub, p=self.dropout, training=self.training)\n        h_obj = F.dropout(h_obj, p=self.dropout, training=self.training)\n\n        log_prob_obj = F.log_softmax(self.sub_lin(h_sub), dim=1)\n        log_prob_sub = F.log_softmax(self.obj_lin(h_obj), dim=1)\n\n        return log_prob_obj, log_prob_sub\n\n    def test(self, logits: Tensor, y: Tensor) -> Tensor:\n        \"\"\"Given ground-truth :obj:`y`, computes Mean Reciprocal Rank (MRR)\n        and Hits at 1/3/10.\n        \"\"\"\n        _, perm = logits.sort(dim=1, descending=True)\n        mask = (y.view(-1, 1) == perm)\n\n        nnz = mask.nonzero(as_tuple=False)\n        mrr = (1 / (nnz[:, -1] + 1).to(torch.float)).mean().item()\n        hits1 = mask[:, :1].sum().item() / y.size(0)\n        hits3 = mask[:, :3].sum().item() / y.size(0)\n        hits10 = mask[:, :10].sum().item() / y.size(0)\n\n        return torch.tensor([mrr, hits1, hits3, hits10])\n"
  },
  {
    "path": "torch_geometric/nn/models/rect.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import GCNConv\nfrom torch_geometric.typing import Adj, OptTensor\nfrom torch_geometric.utils import scatter\n\n\nclass RECT_L(torch.nn.Module):\n    r\"\"\"The RECT model, *i.e.* its supervised RECT-L part, from the\n    `\"Network Embedding with Completely-imbalanced Labels\"\n    <https://arxiv.org/abs/2007.03545>`_ paper.\n    In particular, a GCN model is trained that reconstructs semantic class\n    knowledge.\n\n    .. note::\n\n        For an example of using RECT, see `examples/rect.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        rect.py>`_.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        hidden_channels (int): Intermediate size of each sample.\n        normalize (bool, optional): Whether to add self-loops and compute\n            symmetric normalization coefficients on-the-fly.\n            (default: :obj:`True`)\n        dropout (float, optional): The dropout probability.\n            (default: :obj:`0.0`)\n    \"\"\"\n    def __init__(self, in_channels: int, hidden_channels: int,\n                 normalize: bool = True, dropout: float = 0.0):\n        super().__init__()\n        self.in_channels = in_channels\n        self.hidden_channels = hidden_channels\n        self.dropout = dropout\n\n        self.conv = GCNConv(in_channels, hidden_channels, normalize=normalize)\n        self.lin = Linear(hidden_channels, in_channels)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.conv.reset_parameters()\n        self.lin.reset_parameters()\n        torch.nn.init.xavier_uniform_(self.lin.weight.data)\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n    ) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        x = self.conv(x, edge_index, edge_weight)\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        return self.lin(x)\n\n    @torch.jit.export\n    def embed(\n        self,\n        x: Tensor,\n        edge_index: Adj,\n        edge_weight: OptTensor = None,\n    ) -> Tensor:\n        with torch.no_grad():\n            return self.conv(x, edge_index, edge_weight)\n\n    @torch.jit.export\n    def get_semantic_labels(\n        self,\n        x: Tensor,\n        y: Tensor,\n        mask: Tensor,\n    ) -> Tensor:\n        r\"\"\"Replaces the original labels by their class-centers.\"\"\"\n        with torch.no_grad():\n            y = y[mask]\n            mean = scatter(x[mask], y, dim=0, reduce='mean')\n            return mean[y]\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.hidden_channels})')\n"
  },
  {
    "path": "torch_geometric/nn/models/rev_gnn.py",
    "content": "import copy\nfrom typing import Any, List, Optional, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.typing import Adj\n\n\nclass InvertibleFunction(torch.autograd.Function):\n    r\"\"\"An invertible autograd function. This allows for automatic\n    backpropagation in a reversible fashion so that the memory of intermediate\n    results can be freed during the forward pass and be constructed on-the-fly\n    during the bachward pass.\n\n    Args:\n        ctx (torch.autograd.function.InvertibleFunctionBackward):\n            A context object that can be used to stash information for backward\n            computation.\n        fn (torch.nn.Module): The forward function.\n        fn_inverse (torch.nn.Module): The inverse function to recompute the\n            freed input.\n        num_bwd_passes (int): Number of backward passes to retain a link\n            with the output. After the last backward pass the output is\n            discarded and memory is freed.\n        num_inputs (int): The number of inputs to the forward function.\n        *args (tuple): Inputs and weights.\n    \"\"\"\n    @staticmethod\n    def forward(ctx, fn: torch.nn.Module, fn_inverse: torch.nn.Module,\n                num_bwd_passes: int, num_inputs: int, *args):\n        ctx.fn = fn\n        ctx.fn_inverse = fn_inverse\n        ctx.weights = args[num_inputs:]\n        ctx.num_bwd_passes = num_bwd_passes\n        ctx.num_inputs = num_inputs\n        inputs = args[:num_inputs]\n        ctx.input_requires_grad = []\n\n        with torch.no_grad():  # Make a detached copy which shares the storage:\n            x = []\n            for element in inputs:\n                if isinstance(element, torch.Tensor):\n                    x.append(element.detach())\n                    ctx.input_requires_grad.append(element.requires_grad)\n                else:\n                    x.append(element)\n                    ctx.input_requires_grad.append(None)\n            outputs = ctx.fn(*x)\n\n        if not isinstance(outputs, tuple):\n            outputs = (outputs, )\n\n        # Detaches outputs in-place, allows discarding the intermedate result:\n        detached_outputs = tuple(element.detach_() for element in outputs)\n\n        # Clear memory of node features:\n        if torch_geometric.typing.WITH_PT20:\n            inputs[0].untyped_storage().resize_(0)\n        else:  # pragma: no cover\n            inputs[0].storage().resize_(0)\n\n        # Store these tensor nodes for backward passes:\n        ctx.inputs = [inputs] * num_bwd_passes\n        ctx.outputs = [detached_outputs] * num_bwd_passes\n\n        return detached_outputs\n\n    @staticmethod\n    def backward(ctx, *grad_outputs):\n        if len(ctx.outputs) == 0:\n            raise RuntimeError(\n                f\"Trying to perform a backward pass on the \"\n                f\"'InvertibleFunction' for more than '{ctx.num_bwd_passes}' \"\n                f\"times. Try raising 'num_bwd_passes'.\")\n\n        inputs = ctx.inputs.pop()\n        outputs = ctx.outputs.pop()\n\n        # Recompute input by swapping out the first argument:\n        with torch.no_grad():\n            inputs_inverted = ctx.fn_inverse(*(outputs + inputs[1:]))\n            if len(ctx.outputs) == 0:  # Clear memory from outputs:\n                for element in outputs:\n                    if torch_geometric.typing.WITH_PT20:\n                        element.untyped_storage().resize_(0)\n                    else:  # pragma: no cover\n                        element.storage().resize_(0)\n\n            if not isinstance(inputs_inverted, tuple):\n                inputs_inverted = (inputs_inverted, )\n\n            for elem_orig, elem_inv in zip(inputs, inputs_inverted):\n                if torch_geometric.typing.WITH_PT20:\n                    elem_orig.untyped_storage().resize_(\n                        int(np.prod(elem_orig.size())) *\n                        elem_orig.element_size())\n                else:  # pragma: no cover\n                    elem_orig.storage().resize_(int(np.prod(elem_orig.size())))\n                elem_orig.set_(elem_inv)\n\n        # Compute gradients with grad enabled:\n        with torch.set_grad_enabled(True):\n            detached_inputs = []\n            for element in inputs:\n                if isinstance(element, torch.Tensor):\n                    detached_inputs.append(element.detach())\n                else:\n                    detached_inputs.append(element)\n            detached_inputs = tuple(detached_inputs)\n            for x, req_grad in zip(detached_inputs, ctx.input_requires_grad):\n                if isinstance(x, torch.Tensor):\n                    x.requires_grad = req_grad\n            tmp_output = ctx.fn(*detached_inputs)\n\n        if not isinstance(tmp_output, tuple):\n            tmp_output = (tmp_output, )\n\n        filtered_detached_inputs = tuple(\n            filter(\n                lambda x: x.requires_grad\n                if isinstance(x, torch.Tensor) else False,\n                detached_inputs,\n            ))\n        gradients = torch.autograd.grad(\n            outputs=tmp_output,\n            inputs=filtered_detached_inputs + ctx.weights,\n            grad_outputs=grad_outputs,\n        )\n\n        input_gradients = []\n        i = 0\n        for rg in ctx.input_requires_grad:\n            if rg:\n                input_gradients.append(gradients[i])\n                i += 1\n            else:\n                input_gradients.append(None)\n\n        gradients = tuple(input_gradients) + gradients[-len(ctx.weights):]\n\n        return (None, None, None, None) + gradients\n\n\nclass InvertibleModule(torch.nn.Module):\n    r\"\"\"An abstract class for implementing invertible modules.\n\n    Args:\n        disable (bool, optional): If set to :obj:`True`, will disable the usage\n            of :class:`InvertibleFunction` and will execute the module without\n            memory savings. (default: :obj:`False`)\n        num_bwd_passes (int, optional): Number of backward passes to retain a\n            link with the output. After the last backward pass the output is\n            discarded and memory is freed. (default: :obj:`1`)\n    \"\"\"\n    def __init__(self, disable: bool = False, num_bwd_passes: int = 1):\n        super().__init__()\n        self.disable = disable\n        self.num_bwd_passes = num_bwd_passes\n\n    def forward(self, *args):\n        \"\"\"\"\"\"  # noqa: D419\n        return self._fn_apply(args, self._forward, self._inverse)\n\n    def inverse(self, *args):\n        return self._fn_apply(args, self._inverse, self._forward)\n\n    def _forward(self):\n        raise NotImplementedError\n\n    def _inverse(self):\n        raise NotImplementedError\n\n    def _fn_apply(self, args, fn, fn_inverse):\n        if not self.disable:\n            out = InvertibleFunction.apply(\n                fn,\n                fn_inverse,\n                self.num_bwd_passes,\n                len(args),\n                *args,\n                *tuple(p for p in self.parameters() if p.requires_grad),\n            )\n        else:\n            out = fn(*args)\n\n        # If the layer only has one input, we unpack the tuple:\n        if isinstance(out, tuple) and len(out) == 1:\n            return out[0]\n\n        return out\n\n\nclass GroupAddRev(InvertibleModule):\n    r\"\"\"The Grouped Reversible GNN module from the `\"Graph Neural Networks with\n    1000 Layers\" <https://arxiv.org/abs/2106.07476>`_ paper.\n    This module enables training of arbitrary deep GNNs with a memory\n    complexity independent of the number of layers.\n\n    It does so by partitioning input node features :math:`\\mathbf{X}` into\n    :math:`C` groups across the feature dimension. Then, a grouped reversible\n    GNN block :math:`f_{\\theta(i)}` operates on a group of inputs and produces\n    a group of outputs:\n\n    .. math::\n\n        \\mathbf{X}^{\\prime}_0 &= \\sum_{i=2}^C \\mathbf{X}_i\n\n        \\mathbf{X}^{\\prime}_i &= f_{\\theta(i)} ( \\mathbf{X}^{\\prime}_{i - 1},\n        \\mathbf{A}) + \\mathbf{X}_i\n\n    for all :math:`i \\in \\{ 1, \\ldots, C \\}`.\n\n    .. note::\n\n        For an example of using :class:`GroupAddRev`, see `examples/rev_gnn.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        rev_gnn.py>`_.\n\n    Args:\n        conv (torch.nn.Module or torch.nn.ModuleList]): A seed GNN. The input\n            and output feature dimensions need to match.\n        split_dim (int, optional): The dimension across which to split groups.\n            (default: :obj:`-1`)\n        num_groups (int, optional): The number of groups :math:`C`.\n            (default: :obj:`None`)\n        disable (bool, optional): If set to :obj:`True`, will disable the usage\n            of :class:`InvertibleFunction` and will execute the module without\n            memory savings. (default: :obj:`False`)\n        num_bwd_passes (int, optional): Number of backward passes to retain a\n            link with the output. After the last backward pass the output is\n            discarded and memory is freed. (default: :obj:`1`)\n    \"\"\"\n    def __init__(\n        self,\n        conv: Union[torch.nn.Module, torch.nn.ModuleList],\n        split_dim: int = -1,\n        num_groups: Optional[int] = None,\n        disable: bool = False,\n        num_bwd_passes: int = 1,\n    ):\n        super().__init__(disable, num_bwd_passes)\n        self.split_dim = split_dim\n\n        if isinstance(conv, torch.nn.ModuleList):\n            self.convs = conv\n        else:\n            assert num_groups is not None, \"Please specific 'num_groups'\"\n            self.convs = torch.nn.ModuleList([conv])\n            for _ in range(num_groups - 1):\n                conv = copy.deepcopy(self.convs[0])\n                if hasattr(conv, 'reset_parameters'):\n                    conv.reset_parameters()\n                self.convs.append(conv)\n\n        if len(self.convs) < 2:\n            raise ValueError(f\"The number of groups should not be smaller \"\n                             f\"than '2' (got '{self.num_groups}'))\")\n\n    @property\n    def num_groups(self) -> int:\n        return len(self.convs)\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for conv in self.convs:\n            conv.reset_parameters()\n\n    def _forward(self, x: Tensor, edge_index: Adj, *args):\n        channels = x.size(self.split_dim)\n        xs = self._chunk(x, channels)\n        args = list(zip(*[self._chunk(arg, channels) for arg in args]))\n        args = [[]] * self.num_groups if len(args) == 0 else args\n\n        ys = []\n        y_in = sum(xs[1:])\n        for i in range(self.num_groups):\n            y_in = xs[i] + self.convs[i](y_in, edge_index, *args[i])\n            ys.append(y_in)\n        return torch.cat(ys, dim=self.split_dim)\n\n    def _inverse(self, y: Tensor, edge_index: Adj, *args):\n        channels = y.size(self.split_dim)\n        ys = self._chunk(y, channels)\n        args = list(zip(*[self._chunk(arg, channels) for arg in args]))\n        args = [[]] * self.num_groups if len(args) == 0 else args\n\n        xs = []\n        for i in range(self.num_groups - 1, -1, -1):\n            if i != 0:\n                y_in = ys[i - 1]\n            else:\n                y_in = sum(xs)\n            x = ys[i] - self.convs[i](y_in, edge_index, *args[i])\n            xs.append(x)\n\n        return torch.cat(xs[::-1], dim=self.split_dim)\n\n    def _chunk(self, x: Any, channels: int) -> List[Any]:\n        if not isinstance(x, Tensor):\n            return [x] * self.num_groups\n\n        try:\n            if x.size(self.split_dim) != channels:\n                return [x] * self.num_groups\n        except IndexError:\n            return [x] * self.num_groups\n\n        return torch.chunk(x, self.num_groups, dim=self.split_dim)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.convs[0]}, '\n                f'num_groups={self.num_groups})')\n"
  },
  {
    "path": "torch_geometric/nn/models/schnet.py",
    "content": "import os\nimport os.path as osp\nimport warnings\nfrom math import pi as PI\nfrom typing import Callable, Dict, Optional, Tuple\n\nimport numpy as np\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Embedding, Linear, ModuleList, Sequential\n\nfrom torch_geometric.data import Dataset, download_url, extract_zip\nfrom torch_geometric.io import fs\nfrom torch_geometric.nn import MessagePassing, SumAggregation, radius_graph\nfrom torch_geometric.nn.resolver import aggregation_resolver as aggr_resolver\nfrom torch_geometric.typing import OptTensor\n\nqm9_target_dict: Dict[int, str] = {\n    0: 'dipole_moment',\n    1: 'isotropic_polarizability',\n    2: 'homo',\n    3: 'lumo',\n    4: 'gap',\n    5: 'electronic_spatial_extent',\n    6: 'zpve',\n    7: 'energy_U0',\n    8: 'energy_U',\n    9: 'enthalpy_H',\n    10: 'free_energy',\n    11: 'heat_capacity',\n}\n\n\nclass SchNet(torch.nn.Module):\n    r\"\"\"The continuous-filter convolutional neural network SchNet from the\n    `\"SchNet: A Continuous-filter Convolutional Neural Network for Modeling\n    Quantum Interactions\" <https://arxiv.org/abs/1706.08566>`_ paper that uses\n    the interactions blocks of the form.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\sum_{j \\in \\mathcal{N}(i)} \\mathbf{x}_j \\odot\n        h_{\\mathbf{\\Theta}} ( \\exp(-\\gamma(\\mathbf{e}_{j,i} - \\mathbf{\\mu}))),\n\n    here :math:`h_{\\mathbf{\\Theta}}` denotes an MLP and\n    :math:`\\mathbf{e}_{j,i}` denotes the interatomic distances between atoms.\n\n    .. note::\n\n        For an example of using a pretrained SchNet variant, see\n        `examples/qm9_pretrained_schnet.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        qm9_pretrained_schnet.py>`_.\n\n    Args:\n        hidden_channels (int, optional): Hidden embedding size.\n            (default: :obj:`128`)\n        num_filters (int, optional): The number of filters to use.\n            (default: :obj:`128`)\n        num_interactions (int, optional): The number of interaction blocks.\n            (default: :obj:`6`)\n        num_gaussians (int, optional): The number of gaussians :math:`\\mu`.\n            (default: :obj:`50`)\n        interaction_graph (callable, optional): The function used to compute\n            the pairwise interaction graph and interatomic distances. If set to\n            :obj:`None`, will construct a graph based on :obj:`cutoff` and\n            :obj:`max_num_neighbors` properties.\n            If provided, this method takes in :obj:`pos` and :obj:`batch`\n            tensors and should return :obj:`(edge_index, edge_weight)` tensors.\n            (default :obj:`None`)\n        cutoff (float, optional): Cutoff distance for interatomic interactions.\n            (default: :obj:`10.0`)\n        max_num_neighbors (int, optional): The maximum number of neighbors to\n            collect for each node within the :attr:`cutoff` distance.\n            (default: :obj:`32`)\n        readout (str, optional): Whether to apply :obj:`\"add\"` or :obj:`\"mean\"`\n            global aggregation. (default: :obj:`\"add\"`)\n        dipole (bool, optional): If set to :obj:`True`, will use the magnitude\n            of the dipole moment to make the final prediction, *e.g.*, for\n            target 0 of :class:`torch_geometric.datasets.QM9`.\n            (default: :obj:`False`)\n        mean (float, optional): The mean of the property to predict.\n            (default: :obj:`None`)\n        std (float, optional): The standard deviation of the property to\n            predict. (default: :obj:`None`)\n        atomref (torch.Tensor, optional): The reference of single-atom\n            properties.\n            Expects a vector of shape :obj:`(max_atomic_number, )`.\n    \"\"\"\n\n    url = 'http://www.quantum-machine.org/datasets/trained_schnet_models.zip'\n\n    def __init__(\n        self,\n        hidden_channels: int = 128,\n        num_filters: int = 128,\n        num_interactions: int = 6,\n        num_gaussians: int = 50,\n        cutoff: float = 10.0,\n        interaction_graph: Optional[Callable] = None,\n        max_num_neighbors: int = 32,\n        readout: str = 'add',\n        dipole: bool = False,\n        mean: Optional[float] = None,\n        std: Optional[float] = None,\n        atomref: OptTensor = None,\n    ):\n        super().__init__()\n\n        self.hidden_channels = hidden_channels\n        self.num_filters = num_filters\n        self.num_interactions = num_interactions\n        self.num_gaussians = num_gaussians\n        self.cutoff = cutoff\n        self.dipole = dipole\n        self.sum_aggr = SumAggregation()\n        self.readout = aggr_resolver('sum' if self.dipole else readout)\n        self.mean = mean\n        self.std = std\n        self.scale = None\n\n        if self.dipole:\n            import ase\n\n            atomic_mass = torch.from_numpy(ase.data.atomic_masses)\n            self.register_buffer('atomic_mass', atomic_mass)\n\n        # Support z == 0 for padding atoms so that their embedding vectors\n        # are zeroed and do not receive any gradients.\n        self.embedding = Embedding(100, hidden_channels, padding_idx=0)\n\n        if interaction_graph is not None:\n            self.interaction_graph = interaction_graph\n        else:\n            self.interaction_graph = RadiusInteractionGraph(\n                cutoff, max_num_neighbors)\n\n        self.distance_expansion = GaussianSmearing(0.0, cutoff, num_gaussians)\n\n        self.interactions = ModuleList()\n        for _ in range(num_interactions):\n            block = InteractionBlock(hidden_channels, num_gaussians,\n                                     num_filters, cutoff)\n            self.interactions.append(block)\n\n        self.lin1 = Linear(hidden_channels, hidden_channels // 2)\n        self.act = ShiftedSoftplus()\n        self.lin2 = Linear(hidden_channels // 2, 1)\n\n        self.register_buffer('initial_atomref', atomref)\n        self.atomref = None\n        if atomref is not None:\n            self.atomref = Embedding(100, 1)\n            self.atomref.weight.data.copy_(atomref)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.embedding.reset_parameters()\n        for interaction in self.interactions:\n            interaction.reset_parameters()\n        torch.nn.init.xavier_uniform_(self.lin1.weight)\n        self.lin1.bias.data.fill_(0)\n        torch.nn.init.xavier_uniform_(self.lin2.weight)\n        self.lin2.bias.data.fill_(0)\n        if self.atomref is not None:\n            self.atomref.weight.data.copy_(self.initial_atomref)\n\n    @staticmethod\n    def from_qm9_pretrained(\n        root: str,\n        dataset: Dataset,\n        target: int,\n    ) -> Tuple['SchNet', Dataset, Dataset, Dataset]:  # pragma: no cover\n        r\"\"\"Returns a pre-trained :class:`SchNet` model on the\n        :class:`~torch_geometric.datasets.QM9` dataset, trained on the\n        specified target :obj:`target`.\n        \"\"\"\n        import ase\n        import schnetpack as spk  # noqa\n\n        assert target >= 0 and target <= 12\n        is_dipole = target == 0\n\n        units = [1] * 12\n        units[0] = ase.units.Debye\n        units[1] = ase.units.Bohr**3\n        units[5] = ase.units.Bohr**2\n\n        root = osp.expanduser(osp.normpath(root))\n        os.makedirs(root, exist_ok=True)\n        folder = 'trained_schnet_models'\n        if not osp.exists(osp.join(root, folder)):\n            path = download_url(SchNet.url, root)\n            extract_zip(path, root)\n            os.unlink(path)\n\n        name = f'qm9_{qm9_target_dict[target]}'\n        path = osp.join(root, 'trained_schnet_models', name, 'split.npz')\n\n        split = np.load(path)\n        train_idx = split['train_idx']\n        val_idx = split['val_idx']\n        test_idx = split['test_idx']\n\n        # Filter the splits to only contain characterized molecules.\n        idx = dataset.data.idx\n        assoc = idx.new_empty(idx.max().item() + 1)\n        assoc[idx] = torch.arange(idx.size(0))\n\n        train_idx = assoc[train_idx[np.isin(train_idx, idx)]]\n        val_idx = assoc[val_idx[np.isin(val_idx, idx)]]\n        test_idx = assoc[test_idx[np.isin(test_idx, idx)]]\n\n        path = osp.join(root, 'trained_schnet_models', name, 'best_model')\n\n        with warnings.catch_warnings():\n            warnings.simplefilter('ignore')\n            state = fs.torch_load(path, map_location='cpu')\n\n        net = SchNet(\n            hidden_channels=128,\n            num_filters=128,\n            num_interactions=6,\n            num_gaussians=50,\n            cutoff=10.0,\n            dipole=is_dipole,\n            atomref=dataset.atomref(target),\n        )\n\n        net.embedding.weight = state.representation.embedding.weight\n\n        for int1, int2 in zip(state.representation.interactions,\n                              net.interactions):\n            int2.mlp[0].weight = int1.filter_network[0].weight\n            int2.mlp[0].bias = int1.filter_network[0].bias\n            int2.mlp[2].weight = int1.filter_network[1].weight\n            int2.mlp[2].bias = int1.filter_network[1].bias\n            int2.lin.weight = int1.dense.weight\n            int2.lin.bias = int1.dense.bias\n\n            int2.conv.lin1.weight = int1.cfconv.in2f.weight\n            int2.conv.lin2.weight = int1.cfconv.f2out.weight\n            int2.conv.lin2.bias = int1.cfconv.f2out.bias\n\n        net.lin1.weight = state.output_modules[0].out_net[1].out_net[0].weight\n        net.lin1.bias = state.output_modules[0].out_net[1].out_net[0].bias\n        net.lin2.weight = state.output_modules[0].out_net[1].out_net[1].weight\n        net.lin2.bias = state.output_modules[0].out_net[1].out_net[1].bias\n\n        mean = state.output_modules[0].atom_pool.average\n        net.readout = aggr_resolver('mean' if mean is True else 'add')\n\n        dipole = state.output_modules[0].__class__.__name__ == 'DipoleMoment'\n        net.dipole = dipole\n\n        net.mean = state.output_modules[0].standardize.mean.item()\n        net.std = state.output_modules[0].standardize.stddev.item()\n\n        if state.output_modules[0].atomref is not None:\n            net.atomref.weight = state.output_modules[0].atomref.weight\n        else:\n            net.atomref = None\n\n        net.scale = 1.0 / units[target]\n\n        return net, (dataset[train_idx], dataset[val_idx], dataset[test_idx])\n\n    def forward(self, z: Tensor, pos: Tensor,\n                batch: OptTensor = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            z (torch.Tensor): Atomic number of each atom with shape\n                :obj:`[num_atoms]`.\n            pos (torch.Tensor): Coordinates of each atom with shape\n                :obj:`[num_atoms, 3]`.\n            batch (torch.Tensor, optional): Batch indices assigning each atom\n                to a separate molecule with shape :obj:`[num_atoms]`.\n                (default: :obj:`None`)\n        \"\"\"\n        batch = torch.zeros_like(z) if batch is None else batch\n\n        h = self.embedding(z)\n        edge_index, edge_weight = self.interaction_graph(pos, batch)\n        edge_attr = self.distance_expansion(edge_weight)\n\n        for interaction in self.interactions:\n            h = h + interaction(h, edge_index, edge_weight, edge_attr)\n\n        h = self.lin1(h)\n        h = self.act(h)\n        h = self.lin2(h)\n\n        if self.dipole:\n            # Get center of mass.\n            mass = self.atomic_mass[z].view(-1, 1)\n            M = self.sum_aggr(mass, batch, dim=0)\n            c = self.sum_aggr(mass * pos, batch, dim=0) / M\n            h = h * (pos - c.index_select(0, batch))\n\n        if not self.dipole and self.mean is not None and self.std is not None:\n            h = h * self.std + self.mean\n\n        if not self.dipole and self.atomref is not None:\n            h = h + self.atomref(z)\n\n        out = self.readout(h, batch, dim=0)\n\n        if self.dipole:\n            out = torch.norm(out, dim=-1, keepdim=True)\n\n        if self.scale is not None:\n            out = self.scale * out\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'hidden_channels={self.hidden_channels}, '\n                f'num_filters={self.num_filters}, '\n                f'num_interactions={self.num_interactions}, '\n                f'num_gaussians={self.num_gaussians}, '\n                f'cutoff={self.cutoff})')\n\n\nclass RadiusInteractionGraph(torch.nn.Module):\n    r\"\"\"Creates edges based on atom positions :obj:`pos` to all points within\n    the cutoff distance.\n\n    Args:\n        cutoff (float, optional): Cutoff distance for interatomic interactions.\n            (default: :obj:`10.0`)\n        max_num_neighbors (int, optional): The maximum number of neighbors to\n            collect for each node within the :attr:`cutoff` distance with the\n            default interaction graph method.\n            (default: :obj:`32`)\n    \"\"\"\n    def __init__(self, cutoff: float = 10.0, max_num_neighbors: int = 32):\n        super().__init__()\n        self.cutoff = cutoff\n        self.max_num_neighbors = max_num_neighbors\n\n    def forward(self, pos: Tensor, batch: Tensor) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Forward pass.\n\n        Args:\n            pos (Tensor): Coordinates of each atom.\n            batch (LongTensor, optional): Batch indices assigning each atom to\n                a separate molecule.\n\n        :rtype: (:class:`LongTensor`, :class:`Tensor`)\n        \"\"\"\n        edge_index = radius_graph(pos, r=self.cutoff, batch=batch,\n                                  max_num_neighbors=self.max_num_neighbors)\n        row, col = edge_index\n        edge_weight = (pos[row] - pos[col]).norm(dim=-1)\n        return edge_index, edge_weight\n\n\nclass InteractionBlock(torch.nn.Module):\n    def __init__(self, hidden_channels: int, num_gaussians: int,\n                 num_filters: int, cutoff: float):\n        super().__init__()\n        self.mlp = Sequential(\n            Linear(num_gaussians, num_filters),\n            ShiftedSoftplus(),\n            Linear(num_filters, num_filters),\n        )\n        self.conv = CFConv(hidden_channels, hidden_channels, num_filters,\n                           self.mlp, cutoff)\n        self.act = ShiftedSoftplus()\n        self.lin = Linear(hidden_channels, hidden_channels)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.mlp[0].weight)\n        self.mlp[0].bias.data.fill_(0)\n        torch.nn.init.xavier_uniform_(self.mlp[2].weight)\n        self.mlp[2].bias.data.fill_(0)\n        self.conv.reset_parameters()\n        torch.nn.init.xavier_uniform_(self.lin.weight)\n        self.lin.bias.data.fill_(0)\n\n    def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor,\n                edge_attr: Tensor) -> Tensor:\n        x = self.conv(x, edge_index, edge_weight, edge_attr)\n        x = self.act(x)\n        x = self.lin(x)\n        return x\n\n\nclass CFConv(MessagePassing):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels: int,\n        num_filters: int,\n        nn: Sequential,\n        cutoff: float,\n    ):\n        super().__init__(aggr='add')\n        self.lin1 = Linear(in_channels, num_filters, bias=False)\n        self.lin2 = Linear(num_filters, out_channels)\n        self.nn = nn\n        self.cutoff = cutoff\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        torch.nn.init.xavier_uniform_(self.lin1.weight)\n        torch.nn.init.xavier_uniform_(self.lin2.weight)\n        self.lin2.bias.data.fill_(0)\n\n    def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Tensor,\n                edge_attr: Tensor) -> Tensor:\n        C = 0.5 * (torch.cos(edge_weight * PI / self.cutoff) + 1.0)\n        W = self.nn(edge_attr) * C.view(-1, 1)\n\n        x = self.lin1(x)\n        x = self.propagate(edge_index, x=x, W=W)\n        x = self.lin2(x)\n        return x\n\n    def message(self, x_j: Tensor, W: Tensor) -> Tensor:\n        return x_j * W\n\n\nclass GaussianSmearing(torch.nn.Module):\n    def __init__(\n        self,\n        start: float = 0.0,\n        stop: float = 5.0,\n        num_gaussians: int = 50,\n    ):\n        super().__init__()\n        offset = torch.linspace(start, stop, num_gaussians)\n        self.coeff = -0.5 / (offset[1] - offset[0]).item()**2\n        self.register_buffer('offset', offset)\n\n    def forward(self, dist: Tensor) -> Tensor:\n        dist = dist.view(-1, 1) - self.offset.view(1, -1)\n        return torch.exp(self.coeff * torch.pow(dist, 2))\n\n\nclass ShiftedSoftplus(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.shift = torch.log(torch.tensor(2.0)).item()\n\n    def forward(self, x: Tensor) -> Tensor:\n        return F.softplus(x) - self.shift\n"
  },
  {
    "path": "torch_geometric/nn/models/sgformer.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn.attention import SGFormerAttention\nfrom torch_geometric.nn.conv import GCNConv\nfrom torch_geometric.utils import to_dense_batch\n\n\nclass GraphModule(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        hidden_channels,\n        num_layers=2,\n        dropout=0.5,\n    ):\n        super().__init__()\n\n        self.convs = torch.nn.ModuleList()\n        self.fcs = torch.nn.ModuleList()\n        self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))\n\n        self.bns = torch.nn.ModuleList()\n        self.bns.append(torch.nn.BatchNorm1d(hidden_channels))\n        for _ in range(num_layers):\n            self.convs.append(GCNConv(hidden_channels, hidden_channels))\n            self.bns.append(torch.nn.BatchNorm1d(hidden_channels))\n\n        self.dropout = dropout\n        self.activation = F.relu\n\n    def reset_parameters(self):\n        for conv in self.convs:\n            conv.reset_parameters()\n        for bn in self.bns:\n            bn.reset_parameters()\n        for fc in self.fcs:\n            fc.reset_parameters()\n\n    def forward(self, x, edge_index):\n        x = self.fcs[0](x)\n        x = self.bns[0](x)\n        x = self.activation(x)\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        last_x = x\n\n        for i, conv in enumerate(self.convs):\n            x = conv(x, edge_index)\n            x = self.bns[i + 1](x)\n            x = self.activation(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n            x = x + last_x\n        return x\n\n\nclass SGModule(torch.nn.Module):\n    def __init__(\n        self,\n        in_channels,\n        hidden_channels,\n        num_layers=2,\n        num_heads=1,\n        dropout=0.5,\n    ):\n        super().__init__()\n\n        self.attns = torch.nn.ModuleList()\n        self.fcs = torch.nn.ModuleList()\n        self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))\n        self.bns = torch.nn.ModuleList()\n        self.bns.append(torch.nn.LayerNorm(hidden_channels))\n        for _ in range(num_layers):\n            self.attns.append(\n                SGFormerAttention(hidden_channels, num_heads, hidden_channels))\n            self.bns.append(torch.nn.LayerNorm(hidden_channels))\n\n        self.dropout = dropout\n        self.activation = F.relu\n\n    def reset_parameters(self):\n        for attn in self.attns:\n            attn.reset_parameters()\n        for bn in self.bns:\n            bn.reset_parameters()\n        for fc in self.fcs:\n            fc.reset_parameters()\n\n    def forward(self, x: Tensor, batch: Tensor):\n        # to dense batch expects sorted batch\n        batch, indices = batch.sort(stable=True)\n        rev_perm = torch.empty_like(indices)\n        rev_perm[indices] = torch.arange(len(indices), device=indices.device)\n        x = x[indices]\n        x, mask = to_dense_batch(x, batch)\n        layer_ = []\n\n        # input MLP layer\n        x = self.fcs[0](x)\n        x = self.bns[0](x)\n        x = self.activation(x)\n        x = F.dropout(x, p=self.dropout, training=self.training)\n\n        # store as residual link\n        layer_.append(x)\n\n        for i, attn in enumerate(self.attns):\n            x = attn(x, mask)\n            x = (x + layer_[i]) / 2.\n            x = self.bns[i + 1](x)\n            x = self.activation(x)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n            layer_.append(x)\n\n        x_mask = x[mask]\n        # reverse the sorting\n        unsorted_x_mask = x_mask[rev_perm]\n        return unsorted_x_mask\n\n\nclass SGFormer(torch.nn.Module):\n    r\"\"\"The sgformer module from the\n    `\"SGFormer: Simplifying and Empowering Transformers for\n    Large-Graph Representations\"\n    <https://arxiv.org/abs/2306.10759>`_ paper.\n\n    Args:\n        in_channels (int): Input channels.\n        hidden_channels (int): Hidden channels.\n        out_channels (int): Output channels.\n        trans_num_layers (int): The number of layers for all-pair attention.\n            (default: :obj:`2`)\n        trans_num_heads (int): The number of heads for attention.\n            (default: :obj:`1`)\n        trans_dropout (float): Global dropout rate.\n            (default: :obj:`0.5`)\n        gnn_num_layers (int): The number of layers for GNN.\n            (default: :obj:`3`)\n        gnn_dropout (float): GNN dropout rate.\n            (default: :obj:`0.5`)\n        graph_weight (float): The weight balance global and gnn module.\n            (default: :obj:`0.5`)\n        aggregate (str): Aggregate type.\n            (default: :obj:`add`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        out_channels: int,\n        trans_num_layers: int = 2,\n        trans_num_heads: int = 1,\n        trans_dropout: float = 0.5,\n        gnn_num_layers: int = 3,\n        gnn_dropout: float = 0.5,\n        graph_weight: float = 0.5,\n        aggregate: str = 'add',\n    ):\n        super().__init__()\n        self.trans_conv = SGModule(\n            in_channels,\n            hidden_channels,\n            trans_num_layers,\n            trans_num_heads,\n            trans_dropout,\n        )\n        self.graph_conv = GraphModule(\n            in_channels,\n            hidden_channels,\n            gnn_num_layers,\n            gnn_dropout,\n        )\n        self.graph_weight = graph_weight\n\n        self.aggregate = aggregate\n\n        if aggregate == 'add':\n            self.fc = torch.nn.Linear(hidden_channels, out_channels)\n        elif aggregate == 'cat':\n            self.fc = torch.nn.Linear(2 * hidden_channels, out_channels)\n        else:\n            raise ValueError(f'Invalid aggregate type:{aggregate}')\n\n        self.params1 = list(self.trans_conv.parameters())\n        self.params2 = list(self.graph_conv.parameters())\n        self.params2.extend(list(self.fc.parameters()))\n\n        self.out_channels = out_channels\n\n    def reset_parameters(self) -> None:\n        self.trans_conv.reset_parameters()\n        self.graph_conv.reset_parameters()\n        self.fc.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Optional[Tensor],\n    ) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The input node features.\n            edge_index (torch.Tensor or SparseTensor): The edge indices.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example.\n        \"\"\"\n        x1 = self.trans_conv(x, batch)\n        x2 = self.graph_conv(x, edge_index)\n        if self.aggregate == 'add':\n            x = self.graph_weight * x2 + (1 - self.graph_weight) * x1\n        else:\n            x = torch.cat((x1, x2), dim=1)\n        x = self.fc(x)\n        return F.log_softmax(x, dim=-1)\n"
  },
  {
    "path": "torch_geometric/nn/models/signed_gcn.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.nn import SignedConv\nfrom torch_geometric.utils import (\n    coalesce,\n    negative_sampling,\n    structured_negative_sampling,\n)\n\n\nclass SignedGCN(torch.nn.Module):\n    r\"\"\"The signed graph convolutional network model from the `\"Signed Graph\n    Convolutional Network\" <https://arxiv.org/abs/1808.06354>`_ paper.\n    Internally, this module uses the\n    :class:`torch_geometric.nn.conv.SignedConv` operator.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        hidden_channels (int): Size of each hidden sample.\n        num_layers (int): Number of layers.\n        lamb (float, optional): Balances the contributions of the overall\n            objective. (default: :obj:`5`)\n        bias (bool, optional): If set to :obj:`False`, all layers will not\n            learn an additive bias. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        hidden_channels: int,\n        num_layers: int,\n        lamb: float = 5,\n        bias: bool = True,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.hidden_channels = hidden_channels\n        self.num_layers = num_layers\n        self.lamb = lamb\n\n        self.conv1 = SignedConv(in_channels, hidden_channels // 2,\n                                first_aggr=True)\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_layers - 1):\n            self.convs.append(\n                SignedConv(hidden_channels // 2, hidden_channels // 2,\n                           first_aggr=False))\n\n        self.lin = torch.nn.Linear(2 * hidden_channels, 3)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.conv1.reset_parameters()\n        for conv in self.convs:\n            conv.reset_parameters()\n        self.lin.reset_parameters()\n\n    def split_edges(\n        self,\n        edge_index: Tensor,\n        test_ratio: float = 0.2,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Splits the edges :obj:`edge_index` into train and test edges.\n\n        Args:\n            edge_index (LongTensor): The edge indices.\n            test_ratio (float, optional): The ratio of test edges.\n                (default: :obj:`0.2`)\n        \"\"\"\n        mask = torch.ones(edge_index.size(1), dtype=torch.bool)\n        mask[torch.randperm(mask.size(0))[:int(test_ratio * mask.size(0))]] = 0\n\n        train_edge_index = edge_index[:, mask]\n        test_edge_index = edge_index[:, ~mask]\n\n        return train_edge_index, test_edge_index\n\n    def create_spectral_features(\n        self,\n        pos_edge_index: Tensor,\n        neg_edge_index: Tensor,\n        num_nodes: Optional[int] = None,\n    ) -> Tensor:\n        r\"\"\"Creates :obj:`in_channels` spectral node features based on\n        positive and negative edges.\n\n        Args:\n            pos_edge_index (LongTensor): The positive edge indices.\n            neg_edge_index (LongTensor): The negative edge indices.\n            num_nodes (int, optional): The number of nodes, *i.e.*\n                :obj:`max_val + 1` of :attr:`pos_edge_index` and\n                :attr:`neg_edge_index`. (default: :obj:`None`)\n        \"\"\"\n        import scipy.sparse as sp\n        from sklearn.decomposition import TruncatedSVD\n\n        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=1)\n        N = edge_index.max().item() + 1 if num_nodes is None else num_nodes\n        edge_index = edge_index.to(torch.device('cpu'))\n\n        pos_val = torch.full((pos_edge_index.size(1), ), 2, dtype=torch.float)\n        neg_val = torch.full((neg_edge_index.size(1), ), 0, dtype=torch.float)\n        val = torch.cat([pos_val, neg_val], dim=0)\n\n        row, col = edge_index\n        edge_index = torch.cat([edge_index, torch.stack([col, row])], dim=1)\n        val = torch.cat([val, val], dim=0)\n\n        edge_index, val = coalesce(edge_index, val, num_nodes=N)\n        val = val - 1\n\n        # Borrowed from:\n        # https://github.com/benedekrozemberczki/SGCN/blob/master/src/utils.py\n        edge_index = edge_index.detach().numpy()\n        val = val.detach().numpy()\n        A = sp.coo_matrix((val, edge_index), shape=(N, N))\n        svd = TruncatedSVD(n_components=self.in_channels, n_iter=128)\n        svd.fit(A)\n        x = svd.components_.T\n        return torch.from_numpy(x).to(torch.float).to(pos_edge_index.device)\n\n    def forward(\n        self,\n        x: Tensor,\n        pos_edge_index: Tensor,\n        neg_edge_index: Tensor,\n    ) -> Tensor:\n        \"\"\"Computes node embeddings :obj:`z` based on positive edges\n        :obj:`pos_edge_index` and negative edges :obj:`neg_edge_index`.\n\n        Args:\n            x (torch.Tensor): The input node features.\n            pos_edge_index (torch.Tensor): The positive edge indices.\n            neg_edge_index (torch.Tensor): The negative edge indices.\n        \"\"\"\n        z = F.relu(self.conv1(x, pos_edge_index, neg_edge_index))\n        for conv in self.convs:\n            z = F.relu(conv(z, pos_edge_index, neg_edge_index))\n        return z\n\n    def discriminate(self, z: Tensor, edge_index: Tensor) -> Tensor:\n        \"\"\"Given node embeddings :obj:`z`, classifies the link relation\n        between node pairs :obj:`edge_index` to be either positive,\n        negative or non-existent.\n\n        Args:\n            z (torch.Tensor): The input node features.\n            edge_index (torch.Tensor): The edge indices.\n        \"\"\"\n        value = torch.cat([z[edge_index[0]], z[edge_index[1]]], dim=1)\n        value = self.lin(value)\n        return torch.log_softmax(value, dim=1)\n\n    def nll_loss(\n        self,\n        z: Tensor,\n        pos_edge_index: Tensor,\n        neg_edge_index: Tensor,\n    ) -> Tensor:\n        \"\"\"Computes the discriminator loss based on node embeddings :obj:`z`,\n        and positive edges :obj:`pos_edge_index` and negative nedges\n        :obj:`neg_edge_index`.\n\n        Args:\n            z (torch.Tensor): The node embeddings.\n            pos_edge_index (torch.Tensor): The positive edge indices.\n            neg_edge_index (torch.Tensor): The negative edge indices.\n        \"\"\"\n        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=1)\n        none_edge_index = negative_sampling(edge_index, z.size(0))\n\n        nll_loss = 0\n        nll_loss += F.nll_loss(\n            self.discriminate(z, pos_edge_index),\n            pos_edge_index.new_full((pos_edge_index.size(1), ), 0))\n        nll_loss += F.nll_loss(\n            self.discriminate(z, neg_edge_index),\n            neg_edge_index.new_full((neg_edge_index.size(1), ), 1))\n        nll_loss += F.nll_loss(\n            self.discriminate(z, none_edge_index),\n            none_edge_index.new_full((none_edge_index.size(1), ), 2))\n        return nll_loss / 3.0\n\n    def pos_embedding_loss(\n        self,\n        z: Tensor,\n        pos_edge_index: Tensor,\n    ) -> Tensor:\n        \"\"\"Computes the triplet loss between positive node pairs and sampled\n        non-node pairs.\n\n        Args:\n            z (torch.Tensor): The node embeddings.\n            pos_edge_index (torch.Tensor): The positive edge indices.\n        \"\"\"\n        i, j, k = structured_negative_sampling(pos_edge_index, z.size(0))\n\n        out = (z[i] - z[j]).pow(2).sum(dim=1) - (z[i] - z[k]).pow(2).sum(dim=1)\n        return torch.clamp(out, min=0).mean()\n\n    def neg_embedding_loss(self, z: Tensor, neg_edge_index: Tensor) -> Tensor:\n        \"\"\"Computes the triplet loss between negative node pairs and sampled\n        non-node pairs.\n\n        Args:\n            z (torch.Tensor): The node embeddings.\n            neg_edge_index (torch.Tensor): The negative edge indices.\n        \"\"\"\n        i, j, k = structured_negative_sampling(neg_edge_index, z.size(0))\n\n        out = (z[i] - z[k]).pow(2).sum(dim=1) - (z[i] - z[j]).pow(2).sum(dim=1)\n        return torch.clamp(out, min=0).mean()\n\n    def loss(\n        self,\n        z: Tensor,\n        pos_edge_index: Tensor,\n        neg_edge_index: Tensor,\n    ) -> Tensor:\n        \"\"\"Computes the overall objective.\n\n        Args:\n            z (torch.Tensor): The node embeddings.\n            pos_edge_index (torch.Tensor): The positive edge indices.\n            neg_edge_index (torch.Tensor): The negative edge indices.\n        \"\"\"\n        nll_loss = self.nll_loss(z, pos_edge_index, neg_edge_index)\n        loss_1 = self.pos_embedding_loss(z, pos_edge_index)\n        loss_2 = self.neg_embedding_loss(z, neg_edge_index)\n        return nll_loss + self.lamb * (loss_1 + loss_2)\n\n    def test(\n        self,\n        z: Tensor,\n        pos_edge_index: Tensor,\n        neg_edge_index: Tensor,\n    ) -> Tuple[float, float]:\n        \"\"\"Evaluates node embeddings :obj:`z` on positive and negative test\n        edges by computing AUC and F1 scores.\n\n        Args:\n            z (torch.Tensor): The node embeddings.\n            pos_edge_index (torch.Tensor): The positive edge indices.\n            neg_edge_index (torch.Tensor): The negative edge indices.\n        \"\"\"\n        from sklearn.metrics import f1_score, roc_auc_score\n\n        with torch.no_grad():\n            pos_p = self.discriminate(z, pos_edge_index)[:, :2].max(dim=1)[1]\n            neg_p = self.discriminate(z, neg_edge_index)[:, :2].max(dim=1)[1]\n        pred = (1 - torch.cat([pos_p, neg_p])).cpu()\n        y = torch.cat(\n            [pred.new_ones(pos_p.size(0)),\n             pred.new_zeros(neg_p.size(0))])\n        pred, y = pred.numpy(), y.numpy()\n\n        auc = roc_auc_score(y, pred)\n        f1 = f1_score(y, pred, average='binary') if pred.sum() > 0 else 0\n\n        return auc, f1\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.hidden_channels}, num_layers={self.num_layers})')\n"
  },
  {
    "path": "torch_geometric/nn/models/tgn.py",
    "content": "import copy\nfrom typing import Callable, Dict, Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import GRUCell, Linear\n\nfrom torch_geometric.nn.inits import zeros\nfrom torch_geometric.utils import scatter\nfrom torch_geometric.utils._scatter import scatter_argmax\n\nTGNMessageStoreType = Dict[int, Tuple[Tensor, Tensor, Tensor, Tensor]]\n\n\nclass TGNMemory(torch.nn.Module):\n    r\"\"\"The Temporal Graph Network (TGN) memory model from the\n    `\"Temporal Graph Networks for Deep Learning on Dynamic Graphs\"\n    <https://arxiv.org/abs/2006.10637>`_ paper.\n\n    .. note::\n\n        For an example of using TGN, see `examples/tgn.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        tgn.py>`_.\n\n    Args:\n        num_nodes (int): The number of nodes to save memories for.\n        raw_msg_dim (int): The raw message dimensionality.\n        memory_dim (int): The hidden memory dimensionality.\n        time_dim (int): The time encoding dimensionality.\n        message_module (torch.nn.Module): The message function which\n            combines source and destination node memory embeddings, the raw\n            message and the time encoding.\n        aggregator_module (torch.nn.Module): The message aggregator function\n            which aggregates messages to the same destination into a single\n            representation.\n    \"\"\"\n    def __init__(self, num_nodes: int, raw_msg_dim: int, memory_dim: int,\n                 time_dim: int, message_module: Callable,\n                 aggregator_module: Callable):\n        super().__init__()\n\n        self.num_nodes = num_nodes\n        self.raw_msg_dim = raw_msg_dim\n        self.memory_dim = memory_dim\n        self.time_dim = time_dim\n\n        self.msg_s_module = message_module\n        self.msg_d_module = copy.deepcopy(message_module)\n        self.aggr_module = aggregator_module\n        self.time_enc = TimeEncoder(time_dim)\n        self.gru = GRUCell(message_module.out_channels, memory_dim)\n\n        self.register_buffer('memory', torch.empty(num_nodes, memory_dim))\n        last_update = torch.empty(self.num_nodes, dtype=torch.long)\n        self.register_buffer('last_update', last_update)\n        self.register_buffer('_assoc', torch.empty(num_nodes,\n                                                   dtype=torch.long))\n\n        self.msg_s_store = {}\n        self.msg_d_store = {}\n\n        self.reset_parameters()\n\n    @property\n    def device(self) -> torch.device:\n        return self.time_enc.lin.weight.device\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        if hasattr(self.msg_s_module, 'reset_parameters'):\n            self.msg_s_module.reset_parameters()\n        if hasattr(self.msg_d_module, 'reset_parameters'):\n            self.msg_d_module.reset_parameters()\n        if hasattr(self.aggr_module, 'reset_parameters'):\n            self.aggr_module.reset_parameters()\n        self.time_enc.reset_parameters()\n        self.gru.reset_parameters()\n        self.reset_state()\n\n    def reset_state(self):\n        \"\"\"Resets the memory to its initial state.\"\"\"\n        zeros(self.memory)\n        zeros(self.last_update)\n        self._reset_message_store()\n\n    def detach(self):\n        \"\"\"Detaches the memory from gradient computation.\"\"\"\n        self.memory.detach_()\n\n    def forward(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:\n        \"\"\"Returns, for all nodes :obj:`n_id`, their current memory and their\n        last updated timestamp.\n        \"\"\"\n        if self.training:\n            memory, last_update = self._get_updated_memory(n_id)\n        else:\n            memory, last_update = self.memory[n_id], self.last_update[n_id]\n\n        return memory, last_update\n\n    def update_state(self, src: Tensor, dst: Tensor, t: Tensor,\n                     raw_msg: Tensor):\n        \"\"\"Updates the memory with newly encountered interactions\n        :obj:`(src, dst, t, raw_msg)`.\n        \"\"\"\n        n_id = torch.cat([src, dst]).unique()\n\n        if self.training:\n            self._update_memory(n_id)\n            self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)\n            self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)\n        else:\n            self._update_msg_store(src, dst, t, raw_msg, self.msg_s_store)\n            self._update_msg_store(dst, src, t, raw_msg, self.msg_d_store)\n            self._update_memory(n_id)\n\n    def _reset_message_store(self):\n        i = self.memory.new_empty((0, ), device=self.device, dtype=torch.long)\n        msg = self.memory.new_empty((0, self.raw_msg_dim), device=self.device)\n        # Message store format: (src, dst, t, msg)\n        self.msg_s_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}\n        self.msg_d_store = {j: (i, i, i, msg) for j in range(self.num_nodes)}\n\n    def _update_memory(self, n_id: Tensor):\n        memory, last_update = self._get_updated_memory(n_id)\n        self.memory[n_id] = memory\n        self.last_update[n_id] = last_update\n\n    def _get_updated_memory(self, n_id: Tensor) -> Tuple[Tensor, Tensor]:\n        self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)\n\n        # Compute messages (src -> dst).\n        msg_s, t_s, src_s, dst_s = self._compute_msg(n_id, self.msg_s_store,\n                                                     self.msg_s_module)\n\n        # Compute messages (dst -> src).\n        msg_d, t_d, src_d, dst_d = self._compute_msg(n_id, self.msg_d_store,\n                                                     self.msg_d_module)\n\n        # Aggregate messages.\n        idx = torch.cat([src_s, src_d], dim=0)\n        msg = torch.cat([msg_s, msg_d], dim=0)\n        t = torch.cat([t_s, t_d], dim=0)\n        aggr = self.aggr_module(msg, self._assoc[idx], t, n_id.size(0))\n\n        # Get local copy of updated memory.\n        memory = self.gru(aggr, self.memory[n_id])\n\n        # Get local copy of updated `last_update`.\n        dim_size = self.last_update.size(0)\n        last_update = scatter(t, idx, 0, dim_size, reduce='max')[n_id]\n\n        return memory, last_update\n\n    def _update_msg_store(self, src: Tensor, dst: Tensor, t: Tensor,\n                          raw_msg: Tensor, msg_store: TGNMessageStoreType):\n        n_id, perm = src.sort()\n        n_id, count = n_id.unique_consecutive(return_counts=True)\n        for i, idx in zip(n_id.tolist(), perm.split(count.tolist())):\n            msg_store[i] = (src[idx], dst[idx], t[idx], raw_msg[idx])\n\n    def _compute_msg(self, n_id: Tensor, msg_store: TGNMessageStoreType,\n                     msg_module: Callable):\n        data = [msg_store[i] for i in n_id.tolist()]\n        src, dst, t, raw_msg = list(zip(*data))\n        src = torch.cat(src, dim=0).to(self.device)\n        dst = torch.cat(dst, dim=0).to(self.device)\n        t = torch.cat(t, dim=0).to(self.device)\n        # Filter out empty tensors to avoid `invalid configuration argument`.\n        # TODO Investigate why this is needed.\n        raw_msg = [m for i, m in enumerate(raw_msg) if m.numel() > 0 or i == 0]\n        raw_msg = torch.cat(raw_msg, dim=0).to(self.device)\n        t_rel = t - self.last_update[src]\n        t_enc = self.time_enc(t_rel.to(raw_msg.dtype))\n\n        msg = msg_module(self.memory[src], self.memory[dst], raw_msg, t_enc)\n\n        return msg, t, src, dst\n\n    def train(self, mode: bool = True):\n        \"\"\"Sets the module in training mode.\"\"\"\n        if self.training and not mode:\n            # Flush message store to memory in case we just entered eval mode.\n            self._update_memory(\n                torch.arange(self.num_nodes, device=self.memory.device))\n            self._reset_message_store()\n        super().train(mode)\n\n\nclass IdentityMessage(torch.nn.Module):\n    def __init__(self, raw_msg_dim: int, memory_dim: int, time_dim: int):\n        super().__init__()\n        self.out_channels = raw_msg_dim + 2 * memory_dim + time_dim\n\n    def forward(self, z_src: Tensor, z_dst: Tensor, raw_msg: Tensor,\n                t_enc: Tensor):\n        return torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1)\n\n\nclass LastAggregator(torch.nn.Module):\n    def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):\n        argmax = scatter_argmax(t, index, dim=0, dim_size=dim_size)\n        out = msg.new_zeros((dim_size, msg.size(-1)))\n        mask = argmax < msg.size(0)  # Filter items with at least one entry.\n        out[mask] = msg[argmax[mask]]\n        return out\n\n\nclass MeanAggregator(torch.nn.Module):\n    def forward(self, msg: Tensor, index: Tensor, t: Tensor, dim_size: int):\n        return scatter(msg, index, dim=0, dim_size=dim_size, reduce='mean')\n\n\nclass TimeEncoder(torch.nn.Module):\n    def __init__(self, out_channels: int):\n        super().__init__()\n        self.out_channels = out_channels\n        self.lin = Linear(1, out_channels)\n\n    def reset_parameters(self):\n        self.lin.reset_parameters()\n\n    def forward(self, t: Tensor) -> Tensor:\n        return self.lin(t.view(-1, 1)).cos()\n\n\nclass LastNeighborLoader:\n    def __init__(self, num_nodes: int, size: int, device=None):\n        self.size = size\n\n        self.neighbors = torch.empty((num_nodes, size), dtype=torch.long,\n                                     device=device)\n        self.e_id = torch.empty((num_nodes, size), dtype=torch.long,\n                                device=device)\n        self._assoc = torch.empty(num_nodes, dtype=torch.long, device=device)\n\n        self.reset_state()\n\n    def __call__(self, n_id: Tensor) -> Tuple[Tensor, Tensor, Tensor]:\n        neighbors = self.neighbors[n_id]\n        nodes = n_id.view(-1, 1).repeat(1, self.size)\n        e_id = self.e_id[n_id]\n\n        # Filter invalid neighbors (identified by `e_id < 0`).\n        mask = e_id >= 0\n        neighbors, nodes, e_id = neighbors[mask], nodes[mask], e_id[mask]\n\n        # Relabel node indices.\n        n_id = torch.cat([n_id, neighbors]).unique()\n        self._assoc[n_id] = torch.arange(n_id.size(0), device=n_id.device)\n        neighbors, nodes = self._assoc[neighbors], self._assoc[nodes]\n\n        return n_id, torch.stack([neighbors, nodes]), e_id\n\n    def insert(self, src: Tensor, dst: Tensor):\n        # Inserts newly encountered interactions into an ever-growing\n        # (undirected) temporal graph.\n\n        # Collect central nodes, their neighbors and the current event ids.\n        neighbors = torch.cat([src, dst], dim=0)\n        nodes = torch.cat([dst, src], dim=0)\n        e_id = torch.arange(self.cur_e_id, self.cur_e_id + src.size(0),\n                            device=src.device).repeat(2)\n        self.cur_e_id += src.numel()\n\n        # Convert newly encountered interaction ids so that they point to\n        # locations of a \"dense\" format of shape [num_nodes, size].\n        nodes, perm = nodes.sort()\n        neighbors, e_id = neighbors[perm], e_id[perm]\n\n        n_id = nodes.unique()\n        self._assoc[n_id] = torch.arange(n_id.numel(), device=n_id.device)\n\n        dense_id = torch.arange(nodes.size(0), device=nodes.device) % self.size\n        dense_id += self._assoc[nodes].mul_(self.size)\n\n        dense_e_id = e_id.new_full((n_id.numel() * self.size, ), -1)\n        dense_e_id[dense_id] = e_id\n        dense_e_id = dense_e_id.view(-1, self.size)\n\n        dense_neighbors = e_id.new_empty(n_id.numel() * self.size)\n        dense_neighbors[dense_id] = neighbors\n        dense_neighbors = dense_neighbors.view(-1, self.size)\n\n        # Collect new and old interactions...\n        e_id = torch.cat([self.e_id[n_id, :self.size], dense_e_id], dim=-1)\n        neighbors = torch.cat(\n            [self.neighbors[n_id, :self.size], dense_neighbors], dim=-1)\n\n        # And sort them based on `e_id`.\n        e_id, perm = e_id.topk(self.size, dim=-1)\n        self.e_id[n_id] = e_id\n        self.neighbors[n_id] = torch.gather(neighbors, 1, perm)\n\n    def reset_state(self):\n        self.cur_e_id = 0\n        self.e_id.fill_(-1)\n"
  },
  {
    "path": "torch_geometric/nn/models/visnet.py",
    "content": "import math\nfrom typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import grad\nfrom torch.nn import Embedding, LayerNorm, Linear, Parameter\n\nfrom torch_geometric.nn import MessagePassing, radius_graph\nfrom torch_geometric.utils import scatter\n\n\nclass CosineCutoff(torch.nn.Module):\n    r\"\"\"Applies a cosine cutoff to the input distances.\n\n    .. math::\n        \\text{cutoffs} =\n        \\begin{cases}\n        0.5 * (\\cos(\\frac{\\text{distances} * \\pi}{\\text{cutoff}}) + 1.0),\n        & \\text{if } \\text{distances} < \\text{cutoff} \\\\\n        0, & \\text{otherwise}\n        \\end{cases}\n\n    Args:\n        cutoff (float): A scalar that determines the point at which the cutoff\n            is applied.\n    \"\"\"\n    def __init__(self, cutoff: float) -> None:\n        super().__init__()\n        self.cutoff = cutoff\n\n    def forward(self, distances: Tensor) -> Tensor:\n        r\"\"\"Applies a cosine cutoff to the input distances.\n\n        Args:\n            distances (torch.Tensor): A tensor of distances.\n\n        Returns:\n            cutoffs (torch.Tensor): A tensor where the cosine function\n                has been applied to the distances,\n                but any values that exceed the cutoff are set to 0.\n        \"\"\"\n        cutoffs = 0.5 * ((distances * math.pi / self.cutoff).cos() + 1.0)\n        cutoffs = cutoffs * (distances < self.cutoff).float()\n        return cutoffs\n\n\nclass ExpNormalSmearing(torch.nn.Module):\n    r\"\"\"Applies exponential normal smearing to the input distances.\n\n    .. math::\n        \\text{smeared\\_dist} = \\text{CosineCutoff}(\\text{dist})\n        * e^{-\\beta * (e^{\\alpha * (-\\text{dist})} - \\text{means})^2}\n\n    Args:\n        cutoff (float, optional): A scalar that determines the point at which\n            the cutoff is applied. (default: :obj:`5.0`)\n        num_rbf (int, optional): The number of radial basis functions.\n            (default: :obj:`128`)\n        trainable (bool, optional): If set to :obj:`False`, the means and betas\n            of the RBFs will not be trained. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        cutoff: float = 5.0,\n        num_rbf: int = 128,\n        trainable: bool = True,\n    ) -> None:\n        super().__init__()\n        self.cutoff = cutoff\n        self.num_rbf = num_rbf\n        self.trainable = trainable\n\n        self.cutoff_fn = CosineCutoff(cutoff)\n        self.alpha = 5.0 / cutoff\n\n        means, betas = self._initial_params()\n        if trainable:\n            self.register_parameter('means', Parameter(means))\n            self.register_parameter('betas', Parameter(betas))\n        else:\n            self.register_buffer('means', means)\n            self.register_buffer('betas', betas)\n\n    def _initial_params(self) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Initializes the means and betas for the radial basis functions.\"\"\"\n        start_value = torch.exp(torch.tensor(-self.cutoff))\n        means = torch.linspace(start_value, 1, self.num_rbf)\n        betas = torch.tensor([(2 / self.num_rbf * (1 - start_value))**-2] *\n                             self.num_rbf)\n        return means, betas\n\n    def reset_parameters(self):\n        r\"\"\"Resets the means and betas to their initial values.\"\"\"\n        means, betas = self._initial_params()\n        self.means.data.copy_(means)\n        self.betas.data.copy_(betas)\n\n    def forward(self, dist: Tensor) -> Tensor:\n        r\"\"\"Applies the exponential normal smearing to the input distance.\n\n        Args:\n            dist (torch.Tensor): A tensor of distances.\n        \"\"\"\n        dist = dist.unsqueeze(-1)\n        smeared_dist = self.cutoff_fn(dist) * (-self.betas * (\n            (self.alpha * (-dist)).exp() - self.means)**2).exp()\n        return smeared_dist\n\n\nclass Sphere(torch.nn.Module):\n    r\"\"\"Computes spherical harmonics of the input data.\n\n    This module computes the spherical harmonics up to a given degree\n    :obj:`lmax` for the input tensor of 3D vectors.\n    The vectors are assumed to be given in Cartesian coordinates.\n    See `here <https://en.wikipedia.org/wiki/Table_of_spherical_harmonics>`_\n    for mathematical details.\n\n    Args:\n        lmax (int, optional): The maximum degree of the spherical harmonics.\n            (default: :obj:`2`)\n    \"\"\"\n    def __init__(self, lmax: int = 2) -> None:\n        super().__init__()\n        self.lmax = lmax\n\n    def forward(self, edge_vec: Tensor) -> Tensor:\n        r\"\"\"Computes the spherical harmonics of the input tensor.\n\n        Args:\n            edge_vec (torch.Tensor): A tensor of 3D vectors.\n        \"\"\"\n        return self._spherical_harmonics(\n            self.lmax,\n            edge_vec[..., 0],\n            edge_vec[..., 1],\n            edge_vec[..., 2],\n        )\n\n    @staticmethod\n    def _spherical_harmonics(\n        lmax: int,\n        x: Tensor,\n        y: Tensor,\n        z: Tensor,\n    ) -> Tensor:\n        r\"\"\"Computes the spherical harmonics up to degree :obj:`lmax` of the\n        input vectors.\n\n        Args:\n            lmax (int): The maximum degree of the spherical harmonics.\n            x (torch.Tensor): The x coordinates of the vectors.\n            y (torch.Tensor): The y coordinates of the vectors.\n            z (torch.Tensor): The z coordinates of the vectors.\n        \"\"\"\n        sh_1_0, sh_1_1, sh_1_2 = x, y, z\n\n        if lmax == 1:\n            return torch.stack([sh_1_0, sh_1_1, sh_1_2], dim=-1)\n\n        sh_2_0 = math.sqrt(3.0) * x * z\n        sh_2_1 = math.sqrt(3.0) * x * y\n        y2 = y.pow(2)\n        x2z2 = x.pow(2) + z.pow(2)\n        sh_2_2 = y2 - 0.5 * x2z2\n        sh_2_3 = math.sqrt(3.0) * y * z\n        sh_2_4 = math.sqrt(3.0) / 2.0 * (z.pow(2) - x.pow(2))\n\n        if lmax == 2:\n            return torch.stack([\n                sh_1_0,\n                sh_1_1,\n                sh_1_2,\n                sh_2_0,\n                sh_2_1,\n                sh_2_2,\n                sh_2_3,\n                sh_2_4,\n            ], dim=-1)\n\n        raise ValueError(f\"'lmax' needs to be 1 or 2 (got {lmax})\")\n\n\nclass VecLayerNorm(torch.nn.Module):\n    r\"\"\"Applies layer normalization to the input data.\n\n    This module applies a custom layer normalization to a tensor of vectors.\n    The normalization can either be :obj:`\"max_min\"` normalization, or no\n    normalization.\n\n    Args:\n        hidden_channels (int): The number of hidden channels in the input.\n        trainable (bool): If set to :obj:`True`, the normalization weights are\n            trainable parameters.\n        norm_type (str, optional): The type of normalization to apply, one of\n            :obj:`\"max_min\"` or :obj:`None`. (default: :obj:`\"max_min\"`)\n    \"\"\"\n    def __init__(\n        self,\n        hidden_channels: int,\n        trainable: bool,\n        norm_type: Optional[str] = 'max_min',\n    ) -> None:\n        super().__init__()\n\n        self.hidden_channels = hidden_channels\n        self.norm_type = norm_type\n        self.eps = 1e-12\n\n        weight = torch.ones(self.hidden_channels)\n        if trainable:\n            self.register_parameter('weight', Parameter(weight))\n        else:\n            self.register_buffer('weight', weight)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets the normalization weights to their initial values.\"\"\"\n        torch.nn.init.ones_(self.weight)\n\n    def max_min_norm(self, vec: Tensor) -> Tensor:\n        r\"\"\"Applies max-min normalization to the input tensor.\n\n        .. math::\n            \\text{dist} = ||\\text{vec}||_2\n            \\text{direct} = \\frac{\\text{vec}}{\\text{dist}}\n            \\text{max\\_val} = \\max(\\text{dist})\n            \\text{min\\_val} = \\min(\\text{dist})\n            \\text{delta} = \\text{max\\_val} - \\text{min\\_val}\n            \\text{dist} = \\frac{\\text{dist} - \\text{min\\_val}}{\\text{delta}}\n            \\text{normed\\_vec} = \\max(0, \\text{dist}) \\cdot \\text{direct}\n\n        Args:\n            vec (torch.Tensor): The input tensor.\n        \"\"\"\n        dist = torch.norm(vec, dim=1, keepdim=True)\n\n        if (dist == 0).all():\n            return torch.zeros_like(vec)\n\n        dist = dist.clamp(min=self.eps)\n        direct = vec / dist\n\n        max_val, _ = dist.max(dim=-1)\n        min_val, _ = dist.min(dim=-1)\n        delta = (max_val - min_val).view(-1)\n        delta = torch.where(delta == 0, torch.ones_like(delta), delta)\n        dist = (dist - min_val.view(-1, 1, 1)) / delta.view(-1, 1, 1)\n\n        return dist.relu() * direct\n\n    def forward(self, vec: Tensor) -> Tensor:\n        r\"\"\"Applies the layer normalization to the input tensor.\n\n        Args:\n            vec (torch.Tensor): The input tensor.\n        \"\"\"\n        if vec.size(1) == 3:\n            if self.norm_type == 'max_min':\n                vec = self.max_min_norm(vec)\n            return vec * self.weight.unsqueeze(0).unsqueeze(0)\n        elif vec.size(1) == 8:\n            vec1, vec2 = torch.split(vec, [3, 5], dim=1)\n            if self.norm_type == 'max_min':\n                vec1 = self.max_min_norm(vec1)\n                vec2 = self.max_min_norm(vec2)\n            vec = torch.cat([vec1, vec2], dim=1)\n            return vec * self.weight.unsqueeze(0).unsqueeze(0)\n\n        raise ValueError(f\"'{self.__class__.__name__}' only support 3 or 8 \"\n                         f\"channels (got {vec.size(1)})\")\n\n\nclass Distance(torch.nn.Module):\n    r\"\"\"Computes the pairwise distances between atoms in a molecule.\n\n    This module computes the pairwise distances between atoms in a molecule,\n    represented by their positions :obj:`pos`.\n    The distances are computed only between points that are within a certain\n    cutoff radius.\n\n    Args:\n        cutoff (float): The cutoff radius beyond\n            which distances are not computed.\n        max_num_neighbors (int, optional): The maximum number of neighbors\n            considered for each point. (default: :obj:`32`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not\n            include self-loops. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        cutoff: float,\n        max_num_neighbors: int = 32,\n        add_self_loops: bool = True,\n    ) -> None:\n        super().__init__()\n        self.cutoff = cutoff\n        self.max_num_neighbors = max_num_neighbors\n        self.add_self_loops = add_self_loops\n\n    def forward(\n        self,\n        pos: Tensor,\n        batch: Tensor,\n    ) -> Tuple[Tensor, Tensor, Tensor]:\n        r\"\"\"Computes the pairwise distances between atoms in the molecule.\n\n        Args:\n            pos (torch.Tensor): The positions of the atoms in the molecule.\n            batch (torch.Tensor): A batch vector, which assigns each node to a\n                specific example.\n\n        Returns:\n            edge_index (torch.Tensor): The indices of the edges in the graph.\n            edge_weight (torch.Tensor): The distances between connected nodes.\n            edge_vec (torch.Tensor): The vector differences between connected\n                nodes.\n        \"\"\"\n        edge_index = radius_graph(\n            pos,\n            r=self.cutoff,\n            batch=batch,\n            loop=self.add_self_loops,\n            max_num_neighbors=self.max_num_neighbors,\n        )\n        edge_vec = pos[edge_index[0]] - pos[edge_index[1]]\n\n        if self.add_self_loops:\n            mask = edge_index[0] != edge_index[1]\n            edge_weight = torch.zeros(edge_vec.size(0), device=edge_vec.device)\n            edge_weight[mask] = torch.norm(edge_vec[mask], dim=-1)\n        else:\n            edge_weight = torch.norm(edge_vec, dim=-1)\n\n        return edge_index, edge_weight, edge_vec\n\n\nclass NeighborEmbedding(MessagePassing):\n    r\"\"\"The :class:`NeighborEmbedding` module from the `\"Enhancing Geometric\n    Representations for Molecules with Equivariant Vector-Scalar Interactive\n    Message Passing\" <https://arxiv.org/abs/2210.16518>`_ paper.\n\n    Args:\n        hidden_channels (int): The number of hidden channels in the node\n            embeddings.\n        num_rbf (int): The number of radial basis functions.\n        cutoff (float): The cutoff distance.\n        max_z (int, optional): The maximum atomic numbers.\n            (default: :obj:`100`)\n    \"\"\"\n    def __init__(\n        self,\n        hidden_channels: int,\n        num_rbf: int,\n        cutoff: float,\n        max_z: int = 100,\n    ) -> None:\n        super().__init__(aggr='add')\n        self.embedding = Embedding(max_z, hidden_channels)\n        self.distance_proj = Linear(num_rbf, hidden_channels)\n        self.combine = Linear(hidden_channels * 2, hidden_channels)\n        self.cutoff = CosineCutoff(cutoff)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets the parameters of the module.\"\"\"\n        self.embedding.reset_parameters()\n        torch.nn.init.xavier_uniform_(self.distance_proj.weight)\n        torch.nn.init.xavier_uniform_(self.combine.weight)\n        self.distance_proj.bias.data.zero_()\n        self.combine.bias.data.zero_()\n\n    def forward(\n        self,\n        z: Tensor,\n        x: Tensor,\n        edge_index: Tensor,\n        edge_weight: Tensor,\n        edge_attr: Tensor,\n    ) -> Tensor:\n        r\"\"\"Computes the neighborhood embedding of the nodes in the graph.\n\n        Args:\n            z (torch.Tensor): The atomic numbers.\n            x (torch.Tensor): The node features.\n            edge_index (torch.Tensor): The indices of the edges.\n            edge_weight (torch.Tensor): The weights of the edges.\n            edge_attr (torch.Tensor): The edge features.\n\n        Returns:\n            x_neighbors (torch.Tensor): The neighborhood embeddings of the\n                nodes.\n        \"\"\"\n        mask = edge_index[0] != edge_index[1]\n        if not mask.all():\n            edge_index = edge_index[:, mask]\n            edge_weight = edge_weight[mask]\n            edge_attr = edge_attr[mask]\n\n        C = self.cutoff(edge_weight)\n        W = self.distance_proj(edge_attr) * C.view(-1, 1)\n\n        x_neighbors = self.embedding(z)\n        x_neighbors = self.propagate(edge_index, x=x_neighbors, W=W)\n        x_neighbors = self.combine(torch.cat([x, x_neighbors], dim=1))\n        return x_neighbors\n\n    def message(self, x_j: Tensor, W: Tensor) -> Tensor:\n        return x_j * W\n\n\nclass EdgeEmbedding(torch.nn.Module):\n    r\"\"\"The :class:`EdgeEmbedding` module from the `\"Enhancing Geometric\n    Representations for Molecules with Equivariant Vector-Scalar Interactive\n    Message Passing\" <https://arxiv.org/abs/2210.16518>`_ paper.\n\n    Args:\n        num_rbf (int): The number of radial basis functions.\n        hidden_channels (int): The number of hidden channels in the node\n            embeddings.\n    \"\"\"\n    def __init__(self, num_rbf: int, hidden_channels: int) -> None:\n        super().__init__()\n        self.edge_proj = Linear(num_rbf, hidden_channels)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets the parameters of the module.\"\"\"\n        torch.nn.init.xavier_uniform_(self.edge_proj.weight)\n        self.edge_proj.bias.data.zero_()\n\n    def forward(\n        self,\n        edge_index: Tensor,\n        edge_attr: Tensor,\n        x: Tensor,\n    ) -> Tensor:\n        r\"\"\"Computes the edge embeddings of the graph.\n\n        Args:\n            edge_index (torch.Tensor): The indices of the edges.\n            edge_attr (torch.Tensor): The edge features.\n            x (torch.Tensor): The node features.\n\n        Returns:\n            out_edge_attr (torch.Tensor): The edge embeddings.\n        \"\"\"\n        x_j = x[edge_index[0]]\n        x_i = x[edge_index[1]]\n        return (x_i + x_j) * self.edge_proj(edge_attr)\n\n\nclass ViS_MP(MessagePassing):\n    r\"\"\"The message passing module without vertex geometric features of the\n    equivariant vector-scalar interactive graph neural network (ViSNet)\n    from the `\"Enhancing Geometric Representations for Molecules with\n    Equivariant Vector-Scalar Interactive Message Passing\"\n    <https://arxiv.org/abs/2210.16518>`_ paper.\n\n    Args:\n        num_heads (int): The number of attention heads.\n        hidden_channels (int): The number of hidden channels in the node\n            embeddings.\n        cutoff (float): The cutoff distance.\n        vecnorm_type (str, optional): The type of normalization to apply to the\n            vectors.\n        trainable_vecnorm (bool): Whether the normalization weights are\n            trainable.\n        last_layer (bool, optional): Whether this is the last layer in the\n            model. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        num_heads: int,\n        hidden_channels: int,\n        cutoff: float,\n        vecnorm_type: Optional[str],\n        trainable_vecnorm: bool,\n        last_layer: bool = False,\n    ) -> None:\n        super().__init__(aggr='add', node_dim=0)\n\n        if hidden_channels % num_heads != 0:\n            raise ValueError(\n                f\"The number of hidden channels (got {hidden_channels}) must \"\n                f\"be evenly divisible by the number of attention heads \"\n                f\"(got {num_heads})\")\n\n        self.num_heads = num_heads\n        self.hidden_channels = hidden_channels\n        self.head_dim = hidden_channels // num_heads\n        self.last_layer = last_layer\n\n        self.layernorm = LayerNorm(hidden_channels)\n        self.vec_layernorm = VecLayerNorm(\n            hidden_channels,\n            trainable=trainable_vecnorm,\n            norm_type=vecnorm_type,\n        )\n\n        self.act = torch.nn.SiLU()\n        self.attn_activation = torch.nn.SiLU()\n\n        self.cutoff = CosineCutoff(cutoff)\n\n        self.vec_proj = Linear(hidden_channels, hidden_channels * 3, False)\n\n        self.q_proj = Linear(hidden_channels, hidden_channels)\n        self.k_proj = Linear(hidden_channels, hidden_channels)\n        self.v_proj = Linear(hidden_channels, hidden_channels)\n        self.dk_proj = Linear(hidden_channels, hidden_channels)\n        self.dv_proj = Linear(hidden_channels, hidden_channels)\n\n        self.s_proj = Linear(hidden_channels, hidden_channels * 2)\n        if not self.last_layer:\n            self.f_proj = Linear(hidden_channels, hidden_channels)\n            self.w_src_proj = Linear(hidden_channels, hidden_channels, False)\n            self.w_trg_proj = Linear(hidden_channels, hidden_channels, False)\n\n        self.o_proj = Linear(hidden_channels, hidden_channels * 3)\n\n        self.reset_parameters()\n\n    @staticmethod\n    def vector_rejection(vec: Tensor, d_ij: Tensor) -> Tensor:\n        r\"\"\"Computes the component of :obj:`vec` orthogonal to :obj:`d_ij`.\n\n        Args:\n            vec (torch.Tensor): The input vector.\n            d_ij (torch.Tensor): The reference vector.\n        \"\"\"\n        vec_proj = (vec * d_ij.unsqueeze(2)).sum(dim=1, keepdim=True)\n        return vec - vec_proj * d_ij.unsqueeze(2)\n\n    def reset_parameters(self):\n        r\"\"\"Resets the parameters of the module.\"\"\"\n        self.layernorm.reset_parameters()\n        self.vec_layernorm.reset_parameters()\n        torch.nn.init.xavier_uniform_(self.q_proj.weight)\n        self.q_proj.bias.data.zero_()\n        torch.nn.init.xavier_uniform_(self.k_proj.weight)\n        self.k_proj.bias.data.zero_()\n        torch.nn.init.xavier_uniform_(self.v_proj.weight)\n        self.v_proj.bias.data.zero_()\n        torch.nn.init.xavier_uniform_(self.o_proj.weight)\n        self.o_proj.bias.data.zero_()\n        torch.nn.init.xavier_uniform_(self.s_proj.weight)\n        self.s_proj.bias.data.zero_()\n\n        if not self.last_layer:\n            torch.nn.init.xavier_uniform_(self.f_proj.weight)\n            self.f_proj.bias.data.zero_()\n            torch.nn.init.xavier_uniform_(self.w_src_proj.weight)\n            torch.nn.init.xavier_uniform_(self.w_trg_proj.weight)\n\n        torch.nn.init.xavier_uniform_(self.vec_proj.weight)\n        torch.nn.init.xavier_uniform_(self.dk_proj.weight)\n        self.dk_proj.bias.data.zero_()\n        torch.nn.init.xavier_uniform_(self.dv_proj.weight)\n        self.dv_proj.bias.data.zero_()\n\n    def forward(\n        self,\n        x: Tensor,\n        vec: Tensor,\n        edge_index: Tensor,\n        r_ij: Tensor,\n        f_ij: Tensor,\n        d_ij: Tensor,\n    ) -> Tuple[Tensor, Tensor, Optional[Tensor]]:\n        r\"\"\"Computes the residual scalar and vector features of the nodes and\n        scalar features of the edges.\n\n        Args:\n            x (torch.Tensor): The scalar features of the nodes.\n            vec (torch.Tensor):The vector features of the nodes.\n            edge_index (torch.Tensor): The indices of the edges.\n            r_ij (torch.Tensor): The distances between connected nodes.\n            f_ij (torch.Tensor): The scalar features of the edges.\n            d_ij (torch.Tensor): The unit vectors of the edges\n\n        Returns:\n            dx (torch.Tensor): The residual scalar features of the nodes.\n            dvec (torch.Tensor): The residual vector features of the nodes.\n            df_ij (torch.Tensor, optional): The residual scalar features of the\n                edges, or :obj:`None` if this is the last layer.\n        \"\"\"\n        x = self.layernorm(x)\n        vec = self.vec_layernorm(vec)\n\n        q = self.q_proj(x).reshape(-1, self.num_heads, self.head_dim)\n        k = self.k_proj(x).reshape(-1, self.num_heads, self.head_dim)\n        v = self.v_proj(x).reshape(-1, self.num_heads, self.head_dim)\n        dk = self.act(self.dk_proj(f_ij))\n        dk = dk.reshape(-1, self.num_heads, self.head_dim)\n        dv = self.act(self.dv_proj(f_ij))\n        dv = dv.reshape(-1, self.num_heads, self.head_dim)\n\n        vec1, vec2, vec3 = torch.split(self.vec_proj(vec),\n                                       self.hidden_channels, dim=-1)\n        vec_dot = (vec1 * vec2).sum(dim=1)\n\n        x, vec_out = self.propagate(edge_index, q=q, k=k, v=v, dk=dk, dv=dv,\n                                    vec=vec, r_ij=r_ij, d_ij=d_ij)\n\n        o1, o2, o3 = torch.split(self.o_proj(x), self.hidden_channels, dim=1)\n        dx = vec_dot * o2 + o3\n        dvec = vec3 * o1.unsqueeze(1) + vec_out\n        if not self.last_layer:\n            df_ij = self.edge_updater(edge_index, vec=vec, d_ij=d_ij,\n                                      f_ij=f_ij)\n            return dx, dvec, df_ij\n        else:\n            return dx, dvec, None\n\n    def message(self, q_i: Tensor, k_j: Tensor, v_j: Tensor, vec_j: Tensor,\n                dk: Tensor, dv: Tensor, r_ij: Tensor,\n                d_ij: Tensor) -> Tuple[Tensor, Tensor]:\n\n        attn = (q_i * k_j * dk).sum(dim=-1)\n        attn = self.attn_activation(attn) * self.cutoff(r_ij).unsqueeze(1)\n\n        v_j = v_j * dv\n        v_j = (v_j * attn.unsqueeze(2)).view(-1, self.hidden_channels)\n\n        s1, s2 = torch.split(self.act(self.s_proj(v_j)), self.hidden_channels,\n                             dim=1)\n        vec_j = vec_j * s1.unsqueeze(1) + s2.unsqueeze(1) * d_ij.unsqueeze(2)\n\n        return v_j, vec_j\n\n    def edge_update(self, vec_i: Tensor, vec_j: Tensor, d_ij: Tensor,\n                    f_ij: Tensor) -> Tensor:\n\n        w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij)\n        w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij)\n        w_dot = (w1 * w2).sum(dim=1)\n        df_ij = self.act(self.f_proj(f_ij)) * w_dot\n        return df_ij\n\n    def aggregate(\n        self,\n        features: Tuple[Tensor, Tensor],\n        index: Tensor,\n        ptr: Optional[torch.Tensor],\n        dim_size: Optional[int],\n    ) -> Tuple[Tensor, Tensor]:\n        x, vec = features\n        x = scatter(x, index, dim=self.node_dim, dim_size=dim_size)\n        vec = scatter(vec, index, dim=self.node_dim, dim_size=dim_size)\n        return x, vec\n\n\nclass ViS_MP_Vertex(ViS_MP):\n    r\"\"\"The message passing module with vertex geometric features of the\n    equivariant vector-scalar interactive graph neural network (ViSNet)\n    from the `\"Enhancing Geometric Representations for Molecules with\n    Equivariant Vector-Scalar Interactive Message Passing\"\n    <https://arxiv.org/abs/2210.16518>`_ paper.\n\n    Args:\n        num_heads (int): The number of attention heads.\n        hidden_channels (int): The number of hidden channels in the node\n            embeddings.\n        cutoff (float): The cutoff distance.\n        vecnorm_type (str, optional): The type of normalization to apply to the\n            vectors.\n        trainable_vecnorm (bool): Whether the normalization weights are\n            trainable.\n        last_layer (bool, optional): Whether this is the last layer in the\n            model. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        num_heads: int,\n        hidden_channels: int,\n        cutoff: float,\n        vecnorm_type: Optional[str],\n        trainable_vecnorm: bool,\n        last_layer: bool = False,\n    ) -> None:\n        super().__init__(num_heads, hidden_channels, cutoff, vecnorm_type,\n                         trainable_vecnorm, last_layer)\n\n        if not self.last_layer:\n            self.f_proj = Linear(hidden_channels, hidden_channels * 2)\n            self.t_src_proj = Linear(hidden_channels, hidden_channels, False)\n            self.t_trg_proj = Linear(hidden_channels, hidden_channels, False)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets the parameters of the module.\"\"\"\n        super().reset_parameters()\n\n        if not self.last_layer:\n            if hasattr(self, 't_src_proj'):\n                torch.nn.init.xavier_uniform_(self.t_src_proj.weight)\n            if hasattr(self, 't_trg_proj'):\n                torch.nn.init.xavier_uniform_(self.t_trg_proj.weight)\n\n    def edge_update(self, vec_i: Tensor, vec_j: Tensor, d_ij: Tensor,\n                    f_ij: Tensor) -> Tensor:\n\n        w1 = self.vector_rejection(self.w_trg_proj(vec_i), d_ij)\n        w2 = self.vector_rejection(self.w_src_proj(vec_j), -d_ij)\n        w_dot = (w1 * w2).sum(dim=1)\n\n        t1 = self.vector_rejection(self.t_trg_proj(vec_i), d_ij)\n        t2 = self.vector_rejection(self.t_src_proj(vec_i), -d_ij)\n        t_dot = (t1 * t2).sum(dim=1)\n\n        f1, f2 = torch.split(self.act(self.f_proj(f_ij)), self.hidden_channels,\n                             dim=-1)\n\n        return f1 * w_dot + f2 * t_dot\n\n\nclass ViSNetBlock(torch.nn.Module):\n    r\"\"\"The representation module of the equivariant vector-scalar\n    interactive graph neural network (ViSNet) from the `\"Enhancing Geometric\n    Representations for Molecules with Equivariant Vector-Scalar Interactive\n    Message Passing\" <https://arxiv.org/abs/2210.16518>`_ paper.\n\n    Args:\n        lmax (int, optional): The maximum degree of the spherical harmonics.\n            (default: :obj:`1`)\n        vecnorm_type (str, optional): The type of normalization to apply to the\n            vectors. (default: :obj:`None`)\n        trainable_vecnorm (bool, optional):  Whether the normalization weights\n            are trainable. (default: :obj:`False`)\n        num_heads (int, optional): The number of attention heads.\n            (default: :obj:`8`)\n        num_layers (int, optional): The number of layers in the network.\n            (default: :obj:`6`)\n        hidden_channels (int, optional): The number of hidden channels in the\n            node embeddings. (default: :obj:`128`)\n        num_rbf (int, optional): The number of radial basis functions.\n            (default: :obj:`32`)\n        trainable_rbf (bool, optional): Whether the radial basis function\n            parameters are trainable. (default: :obj:`False`)\n        max_z (int, optional): The maximum atomic numbers.\n            (default: :obj:`100`)\n        cutoff (float, optional): The cutoff distance. (default: :obj:`5.0`)\n        max_num_neighbors (int, optional): The maximum number of neighbors\n            considered for each atom. (default: :obj:`32`)\n        vertex (bool, optional): Whether to use vertex geometric features.\n            (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        lmax: int = 1,\n        vecnorm_type: Optional[str] = None,\n        trainable_vecnorm: bool = False,\n        num_heads: int = 8,\n        num_layers: int = 6,\n        hidden_channels: int = 128,\n        num_rbf: int = 32,\n        trainable_rbf: bool = False,\n        max_z: int = 100,\n        cutoff: float = 5.0,\n        max_num_neighbors: int = 32,\n        vertex: bool = False,\n    ) -> None:\n        super().__init__()\n\n        self.lmax = lmax\n        self.vecnorm_type = vecnorm_type\n        self.trainable_vecnorm = trainable_vecnorm\n        self.num_heads = num_heads\n        self.num_layers = num_layers\n        self.hidden_channels = hidden_channels\n        self.num_rbf = num_rbf\n        self.trainable_rbf = trainable_rbf\n        self.max_z = max_z\n        self.cutoff = cutoff\n        self.max_num_neighbors = max_num_neighbors\n\n        self.embedding = Embedding(max_z, hidden_channels)\n        self.distance = Distance(cutoff, max_num_neighbors=max_num_neighbors)\n        self.sphere = Sphere(lmax=lmax)\n        self.distance_expansion = ExpNormalSmearing(cutoff, num_rbf,\n                                                    trainable_rbf)\n        self.neighbor_embedding = NeighborEmbedding(hidden_channels, num_rbf,\n                                                    cutoff, max_z)\n        self.edge_embedding = EdgeEmbedding(num_rbf, hidden_channels)\n\n        self.vis_mp_layers = torch.nn.ModuleList()\n        vis_mp_kwargs = dict(\n            num_heads=num_heads,\n            hidden_channels=hidden_channels,\n            cutoff=cutoff,\n            vecnorm_type=vecnorm_type,\n            trainable_vecnorm=trainable_vecnorm,\n        )\n        vis_mp_class = ViS_MP if not vertex else ViS_MP_Vertex\n        for _ in range(num_layers - 1):\n            layer = vis_mp_class(last_layer=False, **vis_mp_kwargs)\n            self.vis_mp_layers.append(layer)\n        self.vis_mp_layers.append(\n            vis_mp_class(last_layer=True, **vis_mp_kwargs))\n\n        self.out_norm = LayerNorm(hidden_channels)\n        self.vec_out_norm = VecLayerNorm(\n            hidden_channels,\n            trainable=trainable_vecnorm,\n            norm_type=vecnorm_type,\n        )\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets the parameters of the module.\"\"\"\n        self.embedding.reset_parameters()\n        self.distance_expansion.reset_parameters()\n        self.neighbor_embedding.reset_parameters()\n        self.edge_embedding.reset_parameters()\n        for layer in self.vis_mp_layers:\n            layer.reset_parameters()\n        self.out_norm.reset_parameters()\n        self.vec_out_norm.reset_parameters()\n\n    def forward(\n        self,\n        z: Tensor,\n        pos: Tensor,\n        batch: Tensor,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Computes the scalar and vector features of the nodes.\n\n        Args:\n            z (torch.Tensor): The atomic numbers.\n            pos (torch.Tensor): The coordinates of the atoms.\n            batch (torch.Tensor): A batch vector, which assigns each node to a\n                specific example.\n\n        Returns:\n            x (torch.Tensor): The scalar features of the nodes.\n            vec (torch.Tensor): The vector features of the nodes.\n        \"\"\"\n        x = self.embedding(z)\n        edge_index, edge_weight, edge_vec = self.distance(pos, batch)\n        edge_attr = self.distance_expansion(edge_weight)\n        mask = edge_index[0] != edge_index[1]\n        edge_vec[mask] = edge_vec[mask] / torch.norm(edge_vec[mask],\n                                                     dim=1).unsqueeze(1)\n        edge_vec = self.sphere(edge_vec)\n        x = self.neighbor_embedding(z, x, edge_index, edge_weight, edge_attr)\n        vec = torch.zeros(x.size(0), ((self.lmax + 1)**2) - 1, x.size(1),\n                          dtype=x.dtype, device=x.device)\n        edge_attr = self.edge_embedding(edge_index, edge_attr, x)\n\n        for attn in self.vis_mp_layers[:-1]:\n            dx, dvec, dedge_attr = attn(x, vec, edge_index, edge_weight,\n                                        edge_attr, edge_vec)\n            x = x + dx\n            vec = vec + dvec\n            edge_attr = edge_attr + dedge_attr\n\n        dx, dvec, _ = self.vis_mp_layers[-1](x, vec, edge_index, edge_weight,\n                                             edge_attr, edge_vec)\n        x = x + dx\n        vec = vec + dvec\n\n        x = self.out_norm(x)\n        vec = self.vec_out_norm(vec)\n\n        return x, vec\n\n\nclass GatedEquivariantBlock(torch.nn.Module):\n    r\"\"\"Applies a gated equivariant operation to scalar features and vector\n    features from the `\"Enhancing Geometric Representations for Molecules with\n    Equivariant Vector-Scalar Interactive Message Passing\"\n    <https://arxiv.org/abs/2210.16518>`_ paper.\n\n    Args:\n        hidden_channels (int): The number of hidden channels in the node\n            embeddings.\n        out_channels (int): The number of output channels.\n        intermediate_channels (int, optional): The number of channels in the\n            intermediate layer, or :obj:`None` to use the same number as\n            :obj:`hidden_channels`. (default: :obj:`None`)\n        scalar_activation (bool, optional): Whether to apply a scalar\n            activation function to the output node features.\n            (default: obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        hidden_channels: int,\n        out_channels: int,\n        intermediate_channels: Optional[int] = None,\n        scalar_activation: bool = False,\n    ) -> None:\n        super().__init__()\n        self.out_channels = out_channels\n\n        if intermediate_channels is None:\n            intermediate_channels = hidden_channels\n\n        self.vec1_proj = Linear(hidden_channels, hidden_channels, bias=False)\n        self.vec2_proj = Linear(hidden_channels, out_channels, bias=False)\n\n        self.update_net = torch.nn.Sequential(\n            Linear(hidden_channels * 2, intermediate_channels),\n            torch.nn.SiLU(),\n            Linear(intermediate_channels, out_channels * 2),\n        )\n\n        self.act = torch.nn.SiLU() if scalar_activation else None\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets the parameters of the module.\"\"\"\n        torch.nn.init.xavier_uniform_(self.vec1_proj.weight)\n        torch.nn.init.xavier_uniform_(self.vec2_proj.weight)\n        torch.nn.init.xavier_uniform_(self.update_net[0].weight)\n        self.update_net[0].bias.data.zero_()\n        torch.nn.init.xavier_uniform_(self.update_net[2].weight)\n        self.update_net[2].bias.data.zero_()\n\n    def forward(self, x: Tensor, v: Tensor) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Applies a gated equivariant operation to node features and vector\n        features.\n\n        Args:\n            x (torch.Tensor): The scalar features of the nodes.\n            v (torch.Tensor): The vector features of the nodes.\n        \"\"\"\n        vec1 = torch.norm(self.vec1_proj(v), dim=-2)\n        vec2 = self.vec2_proj(v)\n\n        x = torch.cat([x, vec1], dim=-1)\n        x, v = torch.split(self.update_net(x), self.out_channels, dim=-1)\n        v = v.unsqueeze(1) * vec2\n\n        if self.act is not None:\n            x = self.act(x)\n\n        return x, v\n\n\nclass EquivariantScalar(torch.nn.Module):\n    r\"\"\"Computes final scalar outputs based on node features and vector\n    features.\n\n    Args:\n        hidden_channels (int): The number of hidden channels in the node\n            embeddings.\n    \"\"\"\n    def __init__(self, hidden_channels: int) -> None:\n        super().__init__()\n\n        self.output_network = torch.nn.ModuleList([\n            GatedEquivariantBlock(\n                hidden_channels,\n                hidden_channels // 2,\n                scalar_activation=True,\n            ),\n            GatedEquivariantBlock(\n                hidden_channels // 2,\n                1,\n                scalar_activation=False,\n            ),\n        ])\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets the parameters of the module.\"\"\"\n        for layer in self.output_network:\n            layer.reset_parameters()\n\n    def pre_reduce(self, x: Tensor, v: Tensor) -> Tensor:\n        r\"\"\"Computes the final scalar outputs.\n\n        Args:\n            x (torch.Tensor): The scalar features of the nodes.\n            v (torch.Tensor): The vector features of the nodes.\n\n        Returns:\n            out (torch.Tensor): The final scalar outputs of the nodes.\n        \"\"\"\n        for layer in self.output_network:\n            x, v = layer(x, v)\n\n        return x + v.sum() * 0\n\n\nclass Atomref(torch.nn.Module):\n    r\"\"\"Adds atom reference values to atomic energies.\n\n    Args:\n        atomref (torch.Tensor, optional):  A tensor of atom reference values,\n            or :obj:`None` if not provided. (default: :obj:`None`)\n        max_z (int, optional): The maximum atomic numbers.\n            (default: :obj:`100`)\n    \"\"\"\n    def __init__(\n        self,\n        atomref: Optional[Tensor] = None,\n        max_z: int = 100,\n    ) -> None:\n        super().__init__()\n\n        if atomref is None:\n            atomref = torch.zeros(max_z, 1)\n        else:\n            atomref = torch.as_tensor(atomref)\n\n        if atomref.ndim == 1:\n            atomref = atomref.view(-1, 1)\n\n        self.register_buffer('initial_atomref', atomref)\n        self.atomref = Embedding(len(atomref), 1)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets the parameters of the module.\"\"\"\n        self.atomref.weight.data.copy_(self.initial_atomref)\n\n    def forward(self, x: Tensor, z: Tensor) -> Tensor:\n        r\"\"\"Adds atom reference values to atomic energies.\n\n        Args:\n            x (torch.Tensor): The atomic energies.\n            z (torch.Tensor): The atomic numbers.\n        \"\"\"\n        return x + self.atomref(z)\n\n\nclass ViSNet(torch.nn.Module):\n    r\"\"\"A :pytorch:`PyTorch` module that implements the equivariant\n    vector-scalar interactive graph neural network (ViSNet) from the\n    `\"Enhancing Geometric Representations for Molecules with Equivariant\n    Vector-Scalar Interactive Message Passing\"\n    <https://arxiv.org/abs/2210.16518>`_ paper.\n\n    Args:\n        lmax (int, optional): The maximum degree of the spherical harmonics.\n            (default: :obj:`1`)\n        vecnorm_type (str, optional): The type of normalization to apply to the\n            vectors. (default: :obj:`None`)\n        trainable_vecnorm (bool, optional):  Whether the normalization weights\n            are trainable. (default: :obj:`False`)\n        num_heads (int, optional): The number of attention heads.\n            (default: :obj:`8`)\n        num_layers (int, optional): The number of layers in the network.\n            (default: :obj:`6`)\n        hidden_channels (int, optional): The number of hidden channels in the\n            node embeddings. (default: :obj:`128`)\n        num_rbf (int, optional): The number of radial basis functions.\n            (default: :obj:`32`)\n        trainable_rbf (bool, optional): Whether the radial basis function\n            parameters are trainable. (default: :obj:`False`)\n        max_z (int, optional): The maximum atomic numbers.\n            (default: :obj:`100`)\n        cutoff (float, optional): The cutoff distance. (default: :obj:`5.0`)\n        max_num_neighbors (int, optional): The maximum number of neighbors\n            considered for each atom. (default: :obj:`32`)\n        vertex (bool, optional): Whether to use vertex geometric features.\n            (default: :obj:`False`)\n        atomref (torch.Tensor, optional): A tensor of atom reference values,\n            or :obj:`None` if not provided. (default: :obj:`None`)\n        reduce_op (str, optional): The type of reduction operation to apply\n            (:obj:`\"sum\"`, :obj:`\"mean\"`). (default: :obj:`\"sum\"`)\n        mean (float, optional): The mean of the output distribution.\n            (default: :obj:`0.0`)\n        std (float, optional): The standard deviation of the output\n            distribution. (default: :obj:`1.0`)\n        derivative (bool, optional): Whether to compute the derivative of the\n            output with respect to the positions. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        lmax: int = 1,\n        vecnorm_type: Optional[str] = None,\n        trainable_vecnorm: bool = False,\n        num_heads: int = 8,\n        num_layers: int = 6,\n        hidden_channels: int = 128,\n        num_rbf: int = 32,\n        trainable_rbf: bool = False,\n        max_z: int = 100,\n        cutoff: float = 5.0,\n        max_num_neighbors: int = 32,\n        vertex: bool = False,\n        atomref: Optional[Tensor] = None,\n        reduce_op: str = \"sum\",\n        mean: float = 0.0,\n        std: float = 1.0,\n        derivative: bool = False,\n    ) -> None:\n        super().__init__()\n\n        self.representation_model = ViSNetBlock(\n            lmax=lmax,\n            vecnorm_type=vecnorm_type,\n            trainable_vecnorm=trainable_vecnorm,\n            num_heads=num_heads,\n            num_layers=num_layers,\n            hidden_channels=hidden_channels,\n            num_rbf=num_rbf,\n            trainable_rbf=trainable_rbf,\n            max_z=max_z,\n            cutoff=cutoff,\n            max_num_neighbors=max_num_neighbors,\n            vertex=vertex,\n        )\n\n        self.output_model = EquivariantScalar(hidden_channels=hidden_channels)\n        self.prior_model = Atomref(atomref=atomref, max_z=max_z)\n        self.reduce_op = reduce_op\n        self.derivative = derivative\n\n        self.register_buffer('mean', torch.tensor(mean))\n        self.register_buffer('std', torch.tensor(std))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets the parameters of the module.\"\"\"\n        self.representation_model.reset_parameters()\n        self.output_model.reset_parameters()\n        if self.prior_model is not None:\n            self.prior_model.reset_parameters()\n\n    def forward(\n        self,\n        z: Tensor,\n        pos: Tensor,\n        batch: Tensor,\n    ) -> Tuple[Tensor, Optional[Tensor]]:\n        r\"\"\"Computes the energies or properties (forces) for a batch of\n        molecules.\n\n        Args:\n            z (torch.Tensor): The atomic numbers.\n            pos (torch.Tensor): The coordinates of the atoms.\n            batch (torch.Tensor): A batch vector,\n                which assigns each node to a specific example.\n\n        Returns:\n            y (torch.Tensor): The energies or properties for each molecule.\n            dy (torch.Tensor, optional): The negative derivative of energies.\n        \"\"\"\n        if self.derivative:\n            pos.requires_grad_(True)\n\n        x, v = self.representation_model(z, pos, batch)\n        x = self.output_model.pre_reduce(x, v)\n        x = x * self.std\n\n        if self.prior_model is not None:\n            x = self.prior_model(x, z)\n\n        y = scatter(x, batch, dim=0, reduce=self.reduce_op)\n        y = y + self.mean\n\n        if self.derivative:\n            grad_outputs = [torch.ones_like(y)]\n            dy = grad(\n                [y],\n                [pos],\n                grad_outputs=grad_outputs,\n                create_graph=True,\n                retain_graph=True,\n            )[0]\n            if dy is None:\n                raise RuntimeError(\n                    \"Autograd returned None for the force prediction.\")\n            return y, -dy\n\n        return y, None\n"
  },
  {
    "path": "torch_geometric/nn/module_dict.py",
    "content": "from typing import Final, Iterable, Mapping, Optional, Tuple, Union\n\nimport torch\nfrom torch.nn import Module\n\nKey = Union[str, Tuple[str, ...]]\n\n\n# `torch.nn.ModuleDict` doesn't allow `.` to be used in key names.\n# This `ModuleDict` will support it by converting the `.` to `#` in the\n# internal representation and converts it back to `.` in the external\n# representation. It also allows passing tuples as keys.\nclass ModuleDict(torch.nn.ModuleDict):\n    CLASS_ATTRS: Final[Tuple[str, ...]] = tuple(dir(torch.nn.ModuleDict))\n\n    def __init__(\n        self,\n        modules: Optional[Mapping[Union[str, Tuple[str, ...]], Module]] = None,\n    ):\n        if modules is not None:  # Replace the keys in modules:\n            modules = {\n                self.to_internal_key(key): module\n                for key, module in modules.items()\n            }\n        super().__init__(modules)\n\n    @classmethod\n    def to_internal_key(cls, key: Key) -> str:\n        if isinstance(key, tuple):  # ModuleDict can't handle tuples as keys\n            assert len(key) > 1\n            key = f\"<{'___'.join(key)}>\"\n        assert isinstance(key, str)\n\n        # ModuleDict cannot handle keys that exists as class attributes:\n        if key in cls.CLASS_ATTRS:\n            key = f'<{key}>'\n\n        # ModuleDict cannot handle dots in keys:\n        return key.replace('.', '#')\n\n    @classmethod\n    def to_external_key(cls, key: str) -> Key:\n        key = key.replace('#', '.')\n\n        if key[0] == '<' and key[-1] == '>' and key[1:-1] in cls.CLASS_ATTRS:\n            key = key[1:-1]\n\n        if key[0] == '<' and key[-1] == '>' and '___' in key:\n            key = tuple(key[1:-1].split('___'))\n\n        return key\n\n    def __getitem__(self, key: Key) -> Module:\n        return super().__getitem__(self.to_internal_key(key))\n\n    def __setitem__(self, key: Key, module: Module):\n        return super().__setitem__(self.to_internal_key(key), module)\n\n    def __delitem__(self, key: Key):\n        return super().__delitem__(self.to_internal_key(key))\n\n    def __contains__(self, key: Key) -> bool:\n        return super().__contains__(self.to_internal_key(key))\n\n    def keys(self) -> Iterable[Key]:\n        return [self.to_external_key(key) for key in super().keys()]\n\n    def items(self) -> Iterable[Tuple[Key, Module]]:\n        return [(self.to_external_key(k), v) for k, v in super().items()]\n"
  },
  {
    "path": "torch_geometric/nn/norm/__init__.py",
    "content": "r\"\"\"Normalization package.\"\"\"\n\nfrom .batch_norm import BatchNorm, HeteroBatchNorm\nfrom .instance_norm import InstanceNorm\nfrom .layer_norm import LayerNorm, HeteroLayerNorm\nfrom .graph_norm import GraphNorm\nfrom .graph_size_norm import GraphSizeNorm\nfrom .pair_norm import PairNorm\nfrom .mean_subtraction_norm import MeanSubtractionNorm\nfrom .msg_norm import MessageNorm\nfrom .diff_group_norm import DiffGroupNorm\n\n__all__ = [\n    'BatchNorm',\n    'HeteroBatchNorm',\n    'InstanceNorm',\n    'LayerNorm',\n    'HeteroLayerNorm',\n    'GraphNorm',\n    'GraphSizeNorm',\n    'PairNorm',\n    'MeanSubtractionNorm',\n    'MessageNorm',\n    'DiffGroupNorm',\n]\n\nclasses = __all__\n"
  },
  {
    "path": "torch_geometric/nn/norm/batch_norm.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.aggr.fused import FusedAggregation\n\n\nclass BatchNorm(torch.nn.Module):\n    r\"\"\"Applies batch normalization over a batch of features as described in\n    the `\"Batch Normalization: Accelerating Deep Network Training by\n    Reducing Internal Covariate Shift\" <https://arxiv.org/abs/1502.03167>`_\n    paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\frac{\\mathbf{x} -\n        \\textrm{E}[\\mathbf{x}]}{\\sqrt{\\textrm{Var}[\\mathbf{x}] + \\epsilon}}\n        \\odot \\gamma + \\beta\n\n    The mean and standard-deviation are calculated per-dimension over all nodes\n    inside the mini-batch.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        eps (float, optional): A value added to the denominator for numerical\n            stability. (default: :obj:`1e-5`)\n        momentum (float, optional): The value used for the running mean and\n            running variance computation. (default: :obj:`0.1`)\n        affine (bool, optional): If set to :obj:`True`, this module has\n            learnable affine parameters :math:`\\gamma` and :math:`\\beta`.\n            (default: :obj:`True`)\n        track_running_stats (bool, optional): If set to :obj:`True`, this\n            module tracks the running mean and variance, and when set to\n            :obj:`False`, this module does not track such statistics and always\n            uses batch statistics in both training and eval modes.\n            (default: :obj:`True`)\n        allow_single_element (bool, optional): If set to :obj:`True`, batches\n            with only a single element will work as during in evaluation.\n            That is the running mean and variance will be used.\n            Requires :obj:`track_running_stats=True`. (default: :obj:`False`)\n        device (torch.device, optional): The device to use for the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        eps: float = 1e-5,\n        momentum: Optional[float] = 0.1,\n        affine: bool = True,\n        track_running_stats: bool = True,\n        allow_single_element: bool = False,\n        device: Optional[torch.device] = None,\n    ):\n        super().__init__()\n\n        if allow_single_element and not track_running_stats:\n            raise ValueError(\"'allow_single_element' requires \"\n                             \"'track_running_stats' to be set to `True`\")\n\n        self.module = torch.nn.BatchNorm1d(in_channels, eps, momentum, affine,\n                                           track_running_stats, device=device)\n        self.in_channels = in_channels\n        self.allow_single_element = allow_single_element\n\n    def reset_running_stats(self):\n        r\"\"\"Resets all running statistics of the module.\"\"\"\n        self.module.reset_running_stats()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.module.reset_parameters()\n\n    def forward(self, x: Tensor) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n        \"\"\"\n        if self.allow_single_element and x.size(0) <= 1:\n            return torch.nn.functional.batch_norm(\n                x,\n                self.module.running_mean,\n                self.module.running_var,\n                self.module.weight,\n                self.module.bias,\n                False,  # bn_training\n                0.0,  # momentum\n                self.module.eps,\n            )\n        return self.module(x)\n\n    def __repr__(self):\n        return f'{self.__class__.__name__}({self.module.extra_repr()})'\n\n\nclass HeteroBatchNorm(torch.nn.Module):\n    r\"\"\"Applies batch normalization over a batch of heterogeneous features as\n    described in the `\"Batch Normalization: Accelerating Deep Network Training\n    by Reducing Internal Covariate Shift\" <https://arxiv.org/abs/1502.03167>`_\n    paper.\n    Compared to :class:`BatchNorm`, :class:`HeteroBatchNorm` applies\n    normalization individually for each node or edge type.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        num_types (int): The number of types.\n        eps (float, optional): A value added to the denominator for numerical\n            stability. (default: :obj:`1e-5`)\n        momentum (float, optional): The value used for the running mean and\n            running variance computation. (default: :obj:`0.1`)\n        affine (bool, optional): If set to :obj:`True`, this module has\n            learnable affine parameters :math:`\\gamma` and :math:`\\beta`.\n            (default: :obj:`True`)\n        track_running_stats (bool, optional): If set to :obj:`True`, this\n            module tracks the running mean and variance, and when set to\n            :obj:`False`, this module does not track such statistics and always\n            uses batch statistics in both training and eval modes.\n            (default: :obj:`True`)\n        device (torch.device, optional): The device to use for the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        num_types: int,\n        eps: float = 1e-5,\n        momentum: Optional[float] = 0.1,\n        affine: bool = True,\n        track_running_stats: bool = True,\n        device: Optional[torch.device] = None,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.num_types = num_types\n        self.eps = eps\n        self.momentum = momentum\n        self.affine = affine\n        self.track_running_stats = track_running_stats\n\n        if self.affine:\n            self.weight = Parameter(\n                torch.empty(num_types, in_channels, device=device))\n            self.bias = Parameter(\n                torch.empty(num_types, in_channels, device=device))\n        else:\n            self.register_parameter('weight', None)\n            self.register_parameter('bias', None)\n\n        if self.track_running_stats:\n            self.register_buffer(\n                'running_mean',\n                torch.empty(num_types, in_channels, device=device))\n            self.register_buffer(\n                'running_var',\n                torch.empty(num_types, in_channels, device=device))\n            self.register_buffer('num_batches_tracked', torch.tensor(0))\n        else:\n            self.register_buffer('running_mean', None)\n            self.register_buffer('running_var', None)\n            self.register_buffer('num_batches_tracked', None)\n\n        self.mean_var = FusedAggregation(['mean', 'var'])\n\n        self.reset_parameters()\n\n    def reset_running_stats(self):\n        r\"\"\"Resets all running statistics of the module.\"\"\"\n        if self.track_running_stats:\n            self.running_mean.zero_()\n            self.running_var.fill_(1)\n            self.num_batches_tracked.zero_()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.reset_running_stats()\n        if self.affine:\n            torch.nn.init.ones_(self.weight)\n            torch.nn.init.zeros_(self.bias)\n\n    def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The input features.\n            type_vec (torch.Tensor): A vector that maps each entry to a type.\n        \"\"\"\n        if not self.training and self.track_running_stats:\n            mean, var = self.running_mean, self.running_var\n        else:\n            with torch.no_grad():\n                mean, var = self.mean_var(x, type_vec, dim_size=self.num_types)\n\n        if self.training and self.track_running_stats:\n            if self.momentum is None:\n                self.num_batches_tracked.add_(1)\n                exp_avg_factor = 1.0 / float(self.num_batches_tracked)\n            else:\n                exp_avg_factor = self.momentum\n\n            with torch.no_grad():  # Update running mean and variance:\n                type_index = torch.unique(type_vec)\n\n                self.running_mean[type_index] = (\n                    (1.0 - exp_avg_factor) * self.running_mean[type_index] +\n                    exp_avg_factor * mean[type_index])\n                self.running_var[type_index] = (\n                    (1.0 - exp_avg_factor) * self.running_var[type_index] +\n                    exp_avg_factor * var[type_index])\n\n        out = (x - mean[type_vec]) / var.clamp(self.eps).sqrt()[type_vec]\n\n        if self.affine:\n            out = out * self.weight[type_vec] + self.bias[type_vec]\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'num_types={self.num_types})')\n"
  },
  {
    "path": "torch_geometric/nn/norm/diff_group_norm.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import BatchNorm1d, Linear\n\n\nclass DiffGroupNorm(torch.nn.Module):\n    r\"\"\"The differentiable group normalization layer from the `\"Towards Deeper\n    Graph Neural Networks with Differentiable Group Normalization\"\n    <https://arxiv.org/abs/2006.06972>`_ paper, which normalizes node features\n    group-wise via a learnable soft cluster assignment.\n\n    .. math::\n\n        \\mathbf{S} = \\text{softmax} (\\mathbf{X} \\mathbf{W})\n\n    where :math:`\\mathbf{W} \\in \\mathbb{R}^{F \\times G}` denotes a trainable\n    weight matrix mapping each node into one of :math:`G` clusters.\n    Normalization is then performed group-wise via:\n\n    .. math::\n\n        \\mathbf{X}^{\\prime} = \\mathbf{X} + \\lambda \\sum_{i = 1}^G\n        \\text{BatchNorm}(\\mathbf{S}[:, i] \\odot \\mathbf{X})\n\n    Args:\n        in_channels (int): Size of each input sample :math:`F`.\n        groups (int): The number of groups :math:`G`.\n        lamda (float, optional): The balancing factor :math:`\\lambda` between\n            input embeddings and normalized embeddings. (default: :obj:`0.01`)\n        eps (float, optional): A value added to the denominator for numerical\n            stability. (default: :obj:`1e-5`)\n        momentum (float, optional): The value used for the running mean and\n            running variance computation. (default: :obj:`0.1`)\n        affine (bool, optional): If set to :obj:`True`, this module has\n            learnable affine parameters :math:`\\gamma` and :math:`\\beta`.\n            (default: :obj:`True`)\n        track_running_stats (bool, optional): If set to :obj:`True`, this\n            module tracks the running mean and variance, and when set to\n            :obj:`False`, this module does not track such statistics and always\n            uses batch statistics in both training and eval modes.\n            (default: :obj:`True`)\n        device (torch.device, optional): The device to use for the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        groups: int,\n        lamda: float = 0.01,\n        eps: float = 1e-5,\n        momentum: float = 0.1,\n        affine: bool = True,\n        track_running_stats: bool = True,\n        device: Optional[torch.device] = None,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.groups = groups\n        self.lamda = lamda\n\n        self.lin = Linear(in_channels, groups, bias=False, device=device)\n        self.norm = BatchNorm1d(groups * in_channels, eps, momentum, affine,\n                                track_running_stats, device=device)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.lin.reset_parameters()\n        self.norm.reset_parameters()\n\n    def forward(self, x: Tensor) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n        \"\"\"\n        F, G = self.in_channels, self.groups\n\n        s = self.lin(x).softmax(dim=-1)  # [N, G]\n        out = s.unsqueeze(-1) * x.unsqueeze(-2)  # [N, G, F]\n        out = self.norm(out.view(-1, G * F)).view(-1, G, F).sum(-2)  # [N, F]\n\n        return x + self.lamda * out\n\n    @staticmethod\n    def group_distance_ratio(x: Tensor, y: Tensor, eps: float = 1e-5) -> float:\n        r\"\"\"Measures the ratio of inter-group distance over intra-group\n        distance.\n\n        .. math::\n            R_{\\text{Group}} = \\frac{\\frac{1}{(C-1)^2} \\sum_{i!=j}\n            \\frac{1}{|\\mathbf{X}_i||\\mathbf{X}_j|} \\sum_{\\mathbf{x}_{iv}\n            \\in \\mathbf{X}_i } \\sum_{\\mathbf{x}_{jv^{\\prime}} \\in \\mathbf{X}_j}\n            {\\| \\mathbf{x}_{iv} - \\mathbf{x}_{jv^{\\prime}} \\|}_2 }{\n            \\frac{1}{C} \\sum_{i} \\frac{1}{{|\\mathbf{X}_i|}^2}\n            \\sum_{\\mathbf{x}_{iv}, \\mathbf{x}_{iv^{\\prime}} \\in \\mathbf{X}_i }\n            {\\| \\mathbf{x}_{iv} - \\mathbf{x}_{iv^{\\prime}} \\|}_2 }\n\n        where :math:`\\mathbf{X}_i` denotes the set of all nodes that belong to\n        class :math:`i`, and :math:`C` denotes the total number of classes in\n        :obj:`y`.\n        \"\"\"\n        num_classes = int(y.max()) + 1\n\n        numerator = 0.\n        for i in range(num_classes):\n            mask = y == i\n            dist = torch.cdist(x[mask].unsqueeze(0), x[~mask].unsqueeze(0))\n            numerator += (1 / dist.numel()) * float(dist.sum())\n        numerator *= 1 / (num_classes - 1)**2\n\n        denominator = 0.\n        for i in range(num_classes):\n            mask = y == i\n            dist = torch.cdist(x[mask].unsqueeze(0), x[mask].unsqueeze(0))\n            denominator += (1 / dist.numel()) * float(dist.sum())\n        denominator *= 1 / num_classes\n\n        return numerator / (denominator + eps)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'groups={self.groups})')\n"
  },
  {
    "path": "torch_geometric/nn/norm/graph_norm.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.inits import ones, zeros\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import scatter\n\n\nclass GraphNorm(torch.nn.Module):\n    r\"\"\"Applies graph normalization over individual graphs as described in the\n    `\"GraphNorm: A Principled Approach to Accelerating Graph Neural Network\n    Training\" <https://arxiv.org/abs/2009.03294>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\frac{\\mathbf{x} - \\alpha \\odot\n        \\textrm{E}[\\mathbf{x}]}\n        {\\sqrt{\\textrm{Var}[\\mathbf{x} - \\alpha \\odot \\textrm{E}[\\mathbf{x}]]\n        + \\epsilon}} \\odot \\gamma + \\beta\n\n    where :math:`\\alpha` denotes parameters that learn how much information\n    to keep in the mean.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        eps (float, optional): A value added to the denominator for numerical\n            stability. (default: :obj:`1e-5`)\n        device (torch.device, optional): The device to use for the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, in_channels: int, eps: float = 1e-5,\n                 device: Optional[torch.device] = None):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.eps = eps\n\n        self.weight = torch.nn.Parameter(\n            torch.empty(in_channels, device=device))\n        self.bias = torch.nn.Parameter(torch.empty(in_channels, device=device))\n        self.mean_scale = torch.nn.Parameter(\n            torch.empty(in_channels, device=device))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        ones(self.weight)\n        zeros(self.bias)\n        ones(self.mean_scale)\n\n    def forward(self, x: Tensor, batch: OptTensor = None,\n                batch_size: Optional[int] = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example. (default: :obj:`None`)\n            batch_size (int, optional): The number of examples :math:`B`.\n                Automatically calculated if not given. (default: :obj:`None`)\n        \"\"\"\n        if batch is None:\n            batch = x.new_zeros(x.size(0), dtype=torch.long)\n            batch_size = 1\n\n        if batch_size is None:\n            batch_size = int(batch.max()) + 1\n\n        mean = scatter(x, batch, 0, batch_size, reduce='mean')\n        out = x - mean.index_select(0, batch) * self.mean_scale\n        var = scatter(out.pow(2), batch, 0, batch_size, reduce='mean')\n        std = (var + self.eps).sqrt().index_select(0, batch)\n        return self.weight * out / std + self.bias\n\n    def __repr__(self):\n        return f'{self.__class__.__name__}({self.in_channels})'\n"
  },
  {
    "path": "torch_geometric/nn/norm/graph_size_norm.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import degree\n\n\nclass GraphSizeNorm(torch.nn.Module):\n    r\"\"\"Applies Graph Size Normalization over each individual graph in a batch\n    of node features as described in the\n    `\"Benchmarking Graph Neural Networks\" <https://arxiv.org/abs/2003.00982>`_\n    paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\frac{\\mathbf{x}_i}{\\sqrt{|\\mathcal{V}|}}\n    \"\"\"\n    def __init__(self):\n        super().__init__()\n\n    def forward(self, x: Tensor, batch: OptTensor = None,\n                batch_size: Optional[int] = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example. (default: :obj:`None`)\n            batch_size (int, optional): The number of examples :math:`B`.\n                Automatically calculated if not given. (default: :obj:`None`)\n        \"\"\"\n        if batch is None:\n            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)\n            batch_size = 1\n\n        inv_sqrt_deg = degree(batch, batch_size, dtype=x.dtype).pow(-0.5)\n        return x * inv_sqrt_deg.index_select(0, batch).view(-1, 1)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/nn/norm/instance_norm.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn.modules.instancenorm import _InstanceNorm\n\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import degree, scatter\n\n\nclass InstanceNorm(_InstanceNorm):\n    r\"\"\"Applies instance normalization over each individual example in a batch\n    of node features as described in the `\"Instance Normalization: The Missing\n    Ingredient for Fast Stylization\" <https://arxiv.org/abs/1607.08022>`_\n    paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\frac{\\mathbf{x} -\n        \\textrm{E}[\\mathbf{x}]}{\\sqrt{\\textrm{Var}[\\mathbf{x}] + \\epsilon}}\n        \\odot \\gamma + \\beta\n\n    The mean and standard-deviation are calculated per-dimension separately for\n    each object in a mini-batch.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        eps (float, optional): A value added to the denominator for numerical\n            stability. (default: :obj:`1e-5`)\n        momentum (float, optional): The value used for the running mean and\n            running variance computation. (default: :obj:`0.1`)\n        affine (bool, optional): If set to :obj:`True`, this module has\n            learnable affine parameters :math:`\\gamma` and :math:`\\beta`.\n            (default: :obj:`False`)\n        track_running_stats (bool, optional): If set to :obj:`True`, this\n            module tracks the running mean and variance, and when set to\n            :obj:`False`, this module does not track such statistics and always\n            uses instance statistics in both training and eval modes.\n            (default: :obj:`False`)\n        device (torch.device, optional): The device to use for the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        eps: float = 1e-5,\n        momentum: float = 0.1,\n        affine: bool = False,\n        track_running_stats: bool = False,\n        device: Optional[torch.device] = None,\n    ):\n        super().__init__(in_channels, eps, momentum, affine,\n                         track_running_stats, device=device)\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        super().reset_parameters()\n\n    def forward(self, x: Tensor, batch: OptTensor = None,\n                batch_size: Optional[int] = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example. (default: :obj:`None`)\n            batch_size (int, optional): The number of examples :math:`B`.\n                Automatically calculated if not given. (default: :obj:`None`)\n        \"\"\"\n        if batch is None:\n            out = F.instance_norm(\n                x.t().unsqueeze(0), self.running_mean, self.running_var,\n                self.weight, self.bias, self.training\n                or not self.track_running_stats, self.momentum, self.eps)\n            return out.squeeze(0).t()\n\n        if batch_size is None:\n            batch_size = int(batch.max()) + 1\n\n        mean = var = unbiased_var = x  # Dummies.\n\n        if self.training or not self.track_running_stats:\n            norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)\n            norm = norm.view(-1, 1)\n            unbiased_norm = (norm - 1).clamp_(min=1)\n\n            mean = scatter(x, batch, dim=0, dim_size=batch_size,\n                           reduce='sum') / norm\n\n            x = x - mean.index_select(0, batch)\n\n            var = scatter(x * x, batch, dim=0, dim_size=batch_size,\n                          reduce='sum')\n            unbiased_var = var / unbiased_norm\n            var = var / norm\n\n            momentum = self.momentum\n            if self.running_mean is not None:\n                self.running_mean = (\n                    1 - momentum) * self.running_mean + momentum * mean.mean(0)\n            if self.running_var is not None:\n                self.running_var = (\n                    1 - momentum\n                ) * self.running_var + momentum * unbiased_var.mean(0)\n        else:\n            if self.running_mean is not None:\n                mean = self.running_mean.view(1, -1).expand(batch_size, -1)\n            if self.running_var is not None:\n                var = self.running_var.view(1, -1).expand(batch_size, -1)\n\n            x = x - mean.index_select(0, batch)\n\n        out = x / (var + self.eps).sqrt().index_select(0, batch)\n\n        if self.weight is not None and self.bias is not None:\n            out = out * self.weight.view(1, -1) + self.bias.view(1, -1)\n\n        return out\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.num_features})'\n"
  },
  {
    "path": "torch_geometric/nn/norm/layer_norm.py",
    "content": "from typing import List, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.inits import ones, zeros\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import degree, scatter\n\n\nclass LayerNorm(torch.nn.Module):\n    r\"\"\"Applies layer normalization over each individual example in a batch\n    of features as described in the `\"Layer Normalization\"\n    <https://arxiv.org/abs/1607.06450>`_ paper.\n\n    .. math::\n        \\mathbf{x}^{\\prime}_i = \\frac{\\mathbf{x} -\n        \\textrm{E}[\\mathbf{x}]}{\\sqrt{\\textrm{Var}[\\mathbf{x}] + \\epsilon}}\n        \\odot \\gamma + \\beta\n\n    The mean and standard-deviation are calculated across all nodes and all\n    node channels separately for each object in a mini-batch.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        eps (float, optional): A value added to the denominator for numerical\n            stability. (default: :obj:`1e-5`)\n        affine (bool, optional): If set to :obj:`True`, this module has\n            learnable affine parameters :math:`\\gamma` and :math:`\\beta`.\n            (default: :obj:`True`)\n        mode (str, optional): The normalization mode to use for layer\n            normalization (:obj:`\"graph\"` or :obj:`\"node\"`). If :obj:`\"graph\"`\n            is used, each graph will be considered as an element to be\n            normalized. If `\"node\"` is used, each node will be considered as\n            an element to be normalized. (default: :obj:`\"graph\"`)\n        device (torch.device, optional): The device to use for the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        eps: float = 1e-5,\n        affine: bool = True,\n        mode: str = 'graph',\n        device: Optional[torch.device] = None,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.eps = eps\n        self.affine = affine\n        self.mode = mode\n\n        if affine:\n            self.weight = Parameter(torch.empty(in_channels, device=device))\n            self.bias = Parameter(torch.empty(in_channels, device=device))\n        else:\n            self.register_parameter('weight', None)\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        ones(self.weight)\n        zeros(self.bias)\n\n    def forward(self, x: Tensor, batch: OptTensor = None,\n                batch_size: Optional[int] = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example. (default: :obj:`None`)\n            batch_size (int, optional): The number of examples :math:`B`.\n                Automatically calculated if not given. (default: :obj:`None`)\n        \"\"\"\n        if self.mode == 'graph':\n            if batch is None:\n                x = x - x.mean()\n                out = x / (x.std(unbiased=False) + self.eps)\n\n            else:\n                if batch_size is None:\n                    batch_size = int(batch.max()) + 1\n\n                norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)\n                norm = norm.mul_(x.size(-1)).view(-1, 1)\n\n                mean = scatter(x, batch, dim=0, dim_size=batch_size,\n                               reduce='sum').sum(dim=-1, keepdim=True) / norm\n\n                x = x - mean.index_select(0, batch)\n\n                var = scatter(x * x, batch, dim=0, dim_size=batch_size,\n                              reduce='sum').sum(dim=-1, keepdim=True)\n                var = var / norm\n\n                out = x / (var + self.eps).sqrt().index_select(0, batch)\n\n            if self.weight is not None and self.bias is not None:\n                out = out * self.weight + self.bias\n\n            return out\n\n        if self.mode == 'node':\n            return F.layer_norm(x, (self.in_channels, ), self.weight,\n                                self.bias, self.eps)\n\n        raise ValueError(f\"Unknownn normalization mode: {self.mode}\")\n\n    def __repr__(self):\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'affine={self.affine}, mode={self.mode})')\n\n\nclass HeteroLayerNorm(torch.nn.Module):\n    r\"\"\"Applies layer normalization over each individual example in a batch\n    of heterogeneous features as described in the `\"Layer Normalization\"\n    <https://arxiv.org/abs/1607.06450>`_ paper.\n    Compared to :class:`LayerNorm`, :class:`HeteroLayerNorm` applies\n    normalization individually for each node or edge type.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        num_types (int): The number of types.\n        eps (float, optional): A value added to the denominator for numerical\n            stability. (default: :obj:`1e-5`)\n        affine (bool, optional): If set to :obj:`True`, this module has\n            learnable affine parameters :math:`\\gamma` and :math:`\\beta`.\n            (default: :obj:`True`)\n        mode (str, optional): The normalization mode to use for layer\n            normalization (:obj:`\"node\"`). If `\"node\"` is used, each node will\n            be considered as an element to be normalized.\n            (default: :obj:`\"node\"`)\n        device (torch.device, optional): The device to use for the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        num_types: int,\n        eps: float = 1e-5,\n        affine: bool = True,\n        mode: str = 'node',\n        device: Optional[torch.device] = None,\n    ):\n        super().__init__()\n        assert mode == 'node'\n\n        self.in_channels = in_channels\n        self.num_types = num_types\n        self.eps = eps\n        self.affine = affine\n\n        if affine:\n            self.weight = Parameter(\n                torch.empty(num_types, in_channels, device=device))\n            self.bias = Parameter(\n                torch.empty(num_types, in_channels, device=device))\n        else:\n            self.register_parameter('weight', None)\n            self.register_parameter('bias', None)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        if self.affine:\n            torch.nn.init.ones_(self.weight)\n            torch.nn.init.zeros_(self.bias)\n\n    def forward(\n        self,\n        x: Tensor,\n        type_vec: OptTensor = None,\n        type_ptr: Optional[Union[Tensor, List[int]]] = None,\n    ) -> Tensor:\n        r\"\"\"Forward pass.\n\n        .. note::\n            Either :obj:`type_vec` or :obj:`type_ptr` needs to be specified.\n            In general, relying on :obj:`type_ptr` is more efficient in case\n            the input tensor is sorted by types.\n\n        Args:\n            x (torch.Tensor): The input features.\n            type_vec (torch.Tensor, optional): A vector that maps each entry to\n                a type. (default: :obj:`None`)\n            type_ptr (torch.Tensor or List[int]): A vector denoting the\n                boundaries of types. (default: :obj:`None`)\n        \"\"\"\n        if type_vec is None and type_ptr is None:\n            raise ValueError(\"Either 'type_vec' or 'type_ptr' must be given\")\n\n        out = F.layer_norm(x, (self.in_channels, ), None, None, self.eps)\n\n        if self.affine:\n            # TODO Revisit this logic completely as it performs worse than just\n            # operating on a dictionary of tensors\n            # (especially the `type_vec` code path)\n            if type_ptr is not None:\n                h = torch.empty_like(out)\n                for i, (s, e) in enumerate(zip(type_ptr[:-1], type_ptr[1:])):\n                    h[s:e] = out[s:e] * self.weight[i] + self.bias[i]\n                out = h\n            else:\n                out = out * self.weight[type_vec] + self.bias[type_vec]\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'num_types={self.num_types})')\n"
  },
  {
    "path": "torch_geometric/nn/norm/mean_subtraction_norm.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import scatter\n\n\nclass MeanSubtractionNorm(torch.nn.Module):\n    r\"\"\"Applies layer normalization by subtracting the mean from the inputs\n    as described in the  `\"Revisiting 'Over-smoothing' in Deep GCNs\"\n    <https://arxiv.org/abs/2003.13663>`_ paper.\n\n    .. math::\n        \\mathbf{x}_i = \\mathbf{x}_i - \\frac{1}{|\\mathcal{V}|}\n        \\sum_{j \\in \\mathcal{V}} \\mathbf{x}_j\n    \"\"\"\n    def forward(self, x: Tensor, batch: Optional[Tensor] = None,\n                dim_size: Optional[int] = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example. (default: :obj:`None`)\n            dim_size (int, optional): The number of examples :math:`B` in case\n                :obj:`batch` is given. (default: :obj:`None`)\n        \"\"\"\n        if batch is None:\n            return x - x.mean(dim=0, keepdim=True)\n\n        mean = scatter(x, batch, dim=0, dim_size=dim_size, reduce='mean')\n        return x - mean[batch]\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/nn/norm/msg_norm.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\n\nclass MessageNorm(torch.nn.Module):\n    r\"\"\"Applies message normalization over the aggregated messages as described\n    in the `\"DeeperGCNs: All You Need to Train Deeper GCNs\"\n    <https://arxiv.org/abs/2006.07739>`_ paper.\n\n    .. math::\n\n        \\mathbf{x}_i^{\\prime} = \\mathrm{MLP} \\left( \\mathbf{x}_{i} + s \\cdot\n        {\\| \\mathbf{x}_i \\|}_2 \\cdot\n        \\frac{\\mathbf{m}_{i}}{{\\|\\mathbf{m}_i\\|}_2} \\right)\n\n    Args:\n        learn_scale (bool, optional): If set to :obj:`True`, will learn the\n            scaling factor :math:`s` of message normalization.\n            (default: :obj:`False`)\n        device (torch.device, optional): The device to use for the module.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, learn_scale: bool = False,\n                 device: Optional[torch.device] = None):\n        super().__init__()\n        self.scale = Parameter(torch.empty(1, device=device),\n                               requires_grad=learn_scale)\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.scale.data.fill_(1.0)\n\n    def forward(self, x: Tensor, msg: Tensor, p: float = 2.0) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n            msg (torch.Tensor): The message tensor :math:`\\mathbf{M}`.\n            p (float, optional): The norm :math:`p` to use for normalization.\n                (default: :obj:`2.0`)\n        \"\"\"\n        msg = F.normalize(msg, p=p, dim=-1)\n        x_norm = x.norm(p=p, dim=-1, keepdim=True)\n        return msg * x_norm * self.scale\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}'\n                f'(learn_scale={self.scale.requires_grad})')\n"
  },
  {
    "path": "torch_geometric/nn/norm/pair_norm.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import scatter\n\n\nclass PairNorm(torch.nn.Module):\n    r\"\"\"Applies pair normalization over node features as described in the\n    `\"PairNorm: Tackling Oversmoothing in GNNs\"\n    <https://arxiv.org/abs/1909.12223>`_ paper.\n\n    .. math::\n        \\mathbf{x}_i^c &= \\mathbf{x}_i - \\frac{1}{n}\n        \\sum_{i=1}^n \\mathbf{x}_i \\\\\n\n        \\mathbf{x}_i^{\\prime} &= s \\cdot\n        \\frac{\\mathbf{x}_i^c}{\\sqrt{\\frac{1}{n} \\sum_{i=1}^n\n        {\\| \\mathbf{x}_i^c \\|}^2_2}}\n\n    Args:\n        scale (float, optional): Scaling factor :math:`s` of normalization.\n            (default, :obj:`1.`)\n        scale_individually (bool, optional): If set to :obj:`True`, will\n            compute the scaling step as :math:`\\mathbf{x}^{\\prime}_i = s \\cdot\n            \\frac{\\mathbf{x}_i^c}{{\\| \\mathbf{x}_i^c \\|}_2}`.\n            (default: :obj:`False`)\n        eps (float, optional): A value added to the denominator for numerical\n            stability. (default: :obj:`1e-5`)\n    \"\"\"\n    def __init__(self, scale: float = 1., scale_individually: bool = False,\n                 eps: float = 1e-5):\n        super().__init__()\n\n        self.scale = scale\n        self.scale_individually = scale_individually\n        self.eps = eps\n\n    def forward(self, x: Tensor, batch: OptTensor = None,\n                batch_size: Optional[int] = None) -> Tensor:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The source tensor.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example. (default: :obj:`None`)\n            batch_size (int, optional): The number of examples :math:`B`.\n                Automatically calculated if not given. (default: :obj:`None`)\n        \"\"\"\n        scale = self.scale\n\n        if batch is None:\n            x = x - x.mean(dim=0, keepdim=True)\n\n            if not self.scale_individually:\n                return scale * x / (self.eps + x.pow(2).sum(-1).mean()).sqrt()\n            else:\n                return scale * x / (self.eps + x.norm(2, -1, keepdim=True))\n\n        else:\n            mean = scatter(x, batch, dim=0, dim_size=batch_size, reduce='mean')\n            x = x - mean.index_select(0, batch)\n\n            if not self.scale_individually:\n                return scale * x / torch.sqrt(self.eps + scatter(\n                    x.pow(2).sum(-1, keepdim=True), batch, dim=0,\n                    dim_size=batch_size, reduce='mean').index_select(0, batch))\n            else:\n                return scale * x / (self.eps + x.norm(2, -1, keepdim=True))\n\n    def __repr__(self):\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/nn/parameter_dict.py",
    "content": "from typing import Final, Iterable, Mapping, Optional, Tuple, Union\n\nimport torch\nfrom torch.nn import Parameter\n\nKey = Union[str, Tuple[str, ...]]\n\n\n# `torch.nn.ParameterDict` doesn't allow `.` to be used in key names.\n# This `ParameterDict` will support it by converting the `.` to `#` in the\n# internal representation and converts it back to `.` in the external\n# representation. It also allows passing tuples as keys.\nclass ParameterDict(torch.nn.ParameterDict):\n    CLASS_ATTRS: Final[Tuple[str, ...]] = set(dir(torch.nn.ParameterDict))\n\n    def __init__(\n        self,\n        parameters: Optional[Mapping[Key, Parameter]] = None,\n    ):\n        # Replace the keys in modules.\n        if parameters:\n            parameters = {\n                self.to_internal_key(key): parameter\n                for key, parameter in parameters.items()\n            }\n        super().__init__(parameters)\n\n    @classmethod\n    def to_internal_key(cls, key: Key) -> str:\n        if isinstance(key, tuple):  # ParameterDict can't handle tuples as keys\n            assert len(key) > 1\n            key = f\"<{'___'.join(key)}>\"\n        assert isinstance(key, str)\n\n        # ParameterDict cannot handle keys that exists as class attributes:\n        if key in cls.CLASS_ATTRS:\n            key = f'<{key}>'\n\n        # ParameterDict cannot handle dots in keys:\n        return key.replace('.', '#')\n\n    @classmethod\n    def to_external_key(cls, key: str) -> Key:\n        key = key.replace('#', '.')\n\n        if key[0] == '<' and key[-1] == '>' and key[1:-1] in cls.CLASS_ATTRS:\n            key = key[1:-1]\n\n        if key[0] == '<' and key[-1] == '>' and '___' in key:\n            key = tuple(key[1:-1].split('___'))\n\n        return key\n\n    def __getitem__(self, key: Key) -> Parameter:\n        return super().__getitem__(self.to_internal_key(key))\n\n    def __setitem__(self, key: Key, parameter: Parameter):\n        return super().__setitem__(self.to_internal_key(key), parameter)\n\n    def __delitem__(self, key: Key):\n        return super().__delitem__(self.to_internal_key(key))\n\n    def __contains__(self, key: Key) -> bool:\n        return super().__contains__(self.to_internal_key(key))\n\n    def keys(self) -> Iterable[Key]:\n        return [self.to_external_key(key) for key in super().keys()]\n\n    def items(self) -> Iterable[Tuple[Key, Parameter]]:\n        return [(self.to_external_key(k), v) for k, v in super().items()]\n"
  },
  {
    "path": "torch_geometric/nn/pool/__init__.py",
    "content": "r\"\"\"Pooling package.\"\"\"\n\nimport warnings\nfrom typing import Optional\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.typing import OptTensor, torch_cluster\n\nfrom .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x\nfrom .glob import global_add_pool, global_max_pool, global_mean_pool\nfrom .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,\n                  ApproxMIPSKNNIndex)\nfrom .graclus import graclus\nfrom .max_pool import max_pool, max_pool_neighbor_x, max_pool_x\nfrom .topk_pool import TopKPooling\nfrom .sag_pool import SAGPooling\nfrom .edge_pool import EdgePooling\nfrom .cluster_pool import ClusterPooling\nfrom .asap import ASAPooling\nfrom .pan_pool import PANPooling\nfrom .mem_pool import MemPooling\nfrom .voxel_grid import voxel_grid\nfrom .approx_knn import approx_knn, approx_knn_graph\n\n\ndef fps(\n    x: Tensor,\n    batch: OptTensor = None,\n    ratio: float = 0.5,\n    random_start: bool = True,\n    batch_size: Optional[int] = None,\n) -> Tensor:\n    r\"\"\"A sampling algorithm from the `\"PointNet++: Deep Hierarchical Feature\n    Learning on Point Sets in a Metric Space\"\n    <https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the\n    most distant point with regard to the rest points.\n\n    .. code-block:: python\n\n        import torch\n        from torch_geometric.nn import fps\n\n        x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])\n        batch = torch.tensor([0, 0, 0, 0])\n        index = fps(x, batch, ratio=0.5)\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}`.\n        batch (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        ratio (float, optional): Sampling ratio. (default: :obj:`0.5`)\n        random_start (bool, optional): If set to :obj:`False`, use the first\n            node in :math:`\\mathbf{X}` as starting node. (default: obj:`True`)\n        batch_size (int, optional): The number of examples :math:`B`.\n            Automatically calculated if not given. (default: :obj:`None`)\n\n    :rtype: :class:`torch.Tensor`\n    \"\"\"\n    if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:\n        return torch_cluster.fps(x, batch, ratio, random_start)\n    return torch_cluster.fps(x, batch, ratio, random_start, batch_size)\n\n\ndef knn(\n    x: Tensor,\n    y: Tensor,\n    k: int,\n    batch_x: OptTensor = None,\n    batch_y: OptTensor = None,\n    cosine: bool = False,\n    num_workers: int = 1,\n    batch_size: Optional[int] = None,\n) -> Tensor:\n    r\"\"\"Finds for each element in :obj:`y` the :obj:`k` nearest points in\n    :obj:`x`.\n\n    .. code-block:: python\n\n        import torch\n        from torch_geometric.nn import knn\n\n        x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])\n        batch_x = torch.tensor([0, 0, 0, 0])\n        y = torch.tensor([[-1.0, 0.0], [1.0, 0.0]])\n        batch_y = torch.tensor([0, 0])\n        assign_index = knn(x, y, 2, batch_x, batch_y)\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}`.\n        y (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{M \\times F}`.\n        k (int): The number of neighbors.\n        batch_x (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        batch_y (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^M`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        cosine (bool, optional): If :obj:`True`, will use the cosine\n            distance instead of euclidean distance to find nearest neighbors.\n            (default: :obj:`False`)\n        num_workers (int, optional): Number of workers to use for computation.\n            Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not\n            :obj:`None`, or the input lies on the GPU. (default: :obj:`1`)\n        batch_size (int, optional): The number of examples :math:`B`.\n            Automatically calculated if not given. (default: :obj:`None`)\n\n    :rtype: :class:`torch.Tensor`\n    \"\"\"\n    if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:\n        return torch_cluster.knn(x, y, k, batch_x, batch_y, cosine,\n                                 num_workers)\n    return torch_cluster.knn(x, y, k, batch_x, batch_y, cosine, num_workers,\n                             batch_size)\n\n\ndef knn_graph(\n    x: Tensor,\n    k: int,\n    batch: OptTensor = None,\n    loop: bool = False,\n    flow: str = 'source_to_target',\n    cosine: bool = False,\n    num_workers: int = 1,\n    batch_size: Optional[int] = None,\n) -> Tensor:\n    r\"\"\"Computes graph edges to the nearest :obj:`k` points.\n\n    .. code-block:: python\n\n        import torch\n        from torch_geometric.nn import knn_graph\n\n        x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])\n        batch = torch.tensor([0, 0, 0, 0])\n        edge_index = knn_graph(x, k=2, batch=batch, loop=False)\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}`.\n        k (int): The number of neighbors.\n        batch (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        loop (bool, optional): If :obj:`True`, the graph will contain\n            self-loops. (default: :obj:`False`)\n        flow (str, optional): The flow direction when using in combination with\n            message passing (:obj:`\"source_to_target\"` or\n            :obj:`\"target_to_source\"`). (default: :obj:`\"source_to_target\"`)\n        cosine (bool, optional): If :obj:`True`, will use the cosine\n            distance instead of euclidean distance to find nearest neighbors.\n            (default: :obj:`False`)\n        num_workers (int, optional): Number of workers to use for computation.\n            Has no effect in case :obj:`batch` is not :obj:`None`, or the input\n            lies on the GPU. (default: :obj:`1`)\n        batch_size (int, optional): The number of examples :math:`B`.\n            Automatically calculated if not given. (default: :obj:`None`)\n\n    :rtype: :class:`torch.Tensor`\n    \"\"\"\n    if batch is not None and x.device != batch.device:\n        warnings.warn(\n            \"Input tensor 'x' and 'batch' are on different devices \"\n            \"in 'knn_graph'. Performing blocking device transfer\",\n            stacklevel=2)\n        batch = batch.to(x.device)\n\n    if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:\n        return torch_cluster.knn_graph(x, k, batch, loop, flow, cosine,\n                                       num_workers)\n    return torch_cluster.knn_graph(x, k, batch, loop, flow, cosine,\n                                   num_workers, batch_size)\n\n\ndef radius(\n    x: Tensor,\n    y: Tensor,\n    r: float,\n    batch_x: OptTensor = None,\n    batch_y: OptTensor = None,\n    max_num_neighbors: int = 32,\n    num_workers: int = 1,\n    batch_size: Optional[int] = None,\n) -> Tensor:\n    r\"\"\"Finds for each element in :obj:`y` all points in :obj:`x` within\n    distance :obj:`r`.\n\n    .. code-block:: python\n\n        import torch\n        from torch_geometric.nn import radius\n\n        x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])\n        batch_x = torch.tensor([0, 0, 0, 0])\n        y = torch.tensor([[-1.0, 0.0], [1.0, 0.0]])\n        batch_y = torch.tensor([0, 0])\n        assign_index = radius(x, y, 1.5, batch_x, batch_y)\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}`.\n        y (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{Y} \\in \\mathbb{R}^{M \\times F}`.\n        r (float): The radius.\n        batch_x (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        batch_y (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^M`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        max_num_neighbors (int, optional): The maximum number of neighbors to\n            return for each element in :obj:`y`. (default: :obj:`32`)\n        num_workers (int, optional): Number of workers to use for computation.\n            Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not\n            :obj:`None`, or the input lies on the GPU. (default: :obj:`1`)\n        batch_size (int, optional): The number of examples :math:`B`.\n            Automatically calculated if not given. (default: :obj:`None`)\n\n    :rtype: :class:`torch.Tensor`\n\n    .. warning::\n\n        The CPU implementation of :meth:`radius` with :obj:`max_num_neighbors`\n        is biased towards certain quadrants.\n        Consider setting :obj:`max_num_neighbors` to :obj:`None` or moving\n        inputs to GPU before proceeding.\n    \"\"\"\n    if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:\n        return torch_cluster.radius(x, y, r, batch_x, batch_y,\n                                    max_num_neighbors, num_workers)\n    return torch_cluster.radius(x, y, r, batch_x, batch_y, max_num_neighbors,\n                                num_workers, batch_size)\n\n\ndef radius_graph(\n    x: Tensor,\n    r: float,\n    batch: OptTensor = None,\n    loop: bool = False,\n    max_num_neighbors: int = 32,\n    flow: str = 'source_to_target',\n    num_workers: int = 1,\n    batch_size: Optional[int] = None,\n) -> Tensor:\n    r\"\"\"Computes graph edges to all points within a given distance.\n\n    .. code-block:: python\n\n        import torch\n        from torch_geometric.nn import radius_graph\n\n        x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])\n        batch = torch.tensor([0, 0, 0, 0])\n        edge_index = radius_graph(x, r=1.5, batch=batch, loop=False)\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}`.\n        r (float): The radius.\n        batch (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        loop (bool, optional): If :obj:`True`, the graph will contain\n            self-loops. (default: :obj:`False`)\n        max_num_neighbors (int, optional): The maximum number of neighbors to\n            return for each element in :obj:`y`. (default: :obj:`32`)\n        flow (str, optional): The flow direction when using in combination with\n            message passing (:obj:`\"source_to_target\"` or\n            :obj:`\"target_to_source\"`). (default: :obj:`\"source_to_target\"`)\n        num_workers (int, optional): Number of workers to use for computation.\n            Has no effect in case :obj:`batch` is not :obj:`None`, or the input\n            lies on the GPU. (default: :obj:`1`)\n        batch_size (int, optional): The number of examples :math:`B`.\n            Automatically calculated if not given. (default: :obj:`None`)\n\n    :rtype: :class:`torch.Tensor`\n\n    .. warning::\n\n        The CPU implementation of :meth:`radius_graph` with\n        :obj:`max_num_neighbors` is biased towards certain quadrants.\n        Consider setting :obj:`max_num_neighbors` to :obj:`None` or moving\n        inputs to GPU before proceeding.\n    \"\"\"\n    if batch is not None and x.device != batch.device:\n        warnings.warn(\n            \"Input tensor 'x' and 'batch' are on different devices \"\n            \"in 'radius_graph'. Performing blocking device transfer\",\n            stacklevel=2)\n        batch = batch.to(x.device)\n\n    if not torch_geometric.typing.WITH_TORCH_CLUSTER_BATCH_SIZE:\n        return torch_cluster.radius_graph(x, r, batch, loop, max_num_neighbors,\n                                          flow, num_workers)\n    return torch_cluster.radius_graph(x, r, batch, loop, max_num_neighbors,\n                                      flow, num_workers, batch_size)\n\n\ndef nearest(\n    x: Tensor,\n    y: Tensor,\n    batch_x: OptTensor = None,\n    batch_y: OptTensor = None,\n) -> Tensor:\n    r\"\"\"Finds for each element in :obj:`y` the :obj:`k` nearest point in\n    :obj:`x`.\n\n    .. code-block:: python\n\n        import torch\n        from torch_geometric.nn import nearest\n\n        x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])\n        batch_x = torch.tensor([0, 0, 0, 0])\n        y = torch.tensor([[-1.0, 0.0], [1.0, 0.0]])\n        batch_y = torch.tensor([0, 0])\n        cluster = nearest(x, y, batch_x, batch_y)\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}`.\n        y (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{Y} \\in \\mathbb{R}^{M \\times F}`.\n        batch_x (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        batch_y (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^M`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n\n    :rtype: :class:`torch.Tensor`\n    \"\"\"\n    return torch_cluster.nearest(x, y, batch_x, batch_y)\n\n\n__all__ = [\n    'global_add_pool',\n    'global_mean_pool',\n    'global_max_pool',\n    'KNNIndex',\n    'L2KNNIndex',\n    'MIPSKNNIndex',\n    'ApproxL2KNNIndex',\n    'ApproxMIPSKNNIndex',\n    'TopKPooling',\n    'SAGPooling',\n    'EdgePooling',\n    'ClusterPooling',\n    'ASAPooling',\n    'PANPooling',\n    'MemPooling',\n    'max_pool',\n    'avg_pool',\n    'max_pool_x',\n    'max_pool_neighbor_x',\n    'avg_pool_x',\n    'avg_pool_neighbor_x',\n    'graclus',\n    'voxel_grid',\n    'fps',\n    'knn',\n    'knn_graph',\n    'approx_knn',\n    'approx_knn_graph',\n    'radius',\n    'radius_graph',\n    'nearest',\n]\n\nclasses = __all__\n"
  },
  {
    "path": "torch_geometric/nn/pool/approx_knn.py",
    "content": "import torch\nfrom torch import Tensor\n\n\ndef approx_knn(\n    x: Tensor,\n    y: Tensor,\n    k: int,\n    batch_x: Tensor = None,\n    batch_y: Tensor = None,\n) -> Tensor:  # pragma: no cover\n    r\"\"\"Finds for each element in :obj:`y` the :obj:`k` approximated nearest\n    points in :obj:`x`.\n\n    .. note::\n\n        Approximated :math:`k`-nearest neighbor search is performed via the\n        `pynndescent <https://pynndescent.readthedocs.io/en/latest>`_ library.\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}`.\n        y (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{M \\times F}`.\n        k (int): The number of neighbors.\n        batch_x (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        batch_y (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^M`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n\n    :rtype: :class:`torch.Tensor`\n    \"\"\"\n    from pynndescent import NNDescent\n\n    if batch_x is None:\n        batch_x = x.new_zeros(x.size(0), dtype=torch.long)\n    if batch_y is None:\n        batch_y = y.new_zeros(y.size(0), dtype=torch.long)\n\n    x = x.view(-1, 1) if x.dim() == 1 else x\n    y = y.view(-1, 1) if y.dim() == 1 else y\n\n    assert x.dim() == 2 and batch_x.dim() == 1\n    assert y.dim() == 2 and batch_y.dim() == 1\n    assert x.size(1) == y.size(1)\n    assert x.size(0) == batch_x.size(0)\n    assert y.size(0) == batch_y.size(0)\n\n    min_xy = min(x.min(), y.min())\n    x, y = x - min_xy, y - min_xy\n\n    max_xy = max(x.max(), y.max())\n    x, y, = x / max_xy, y / max_xy\n\n    # Concat batch/features to ensure no cross-links between examples exist:\n    x = torch.cat([x, 2 * x.size(1) * batch_x.view(-1, 1).to(x.dtype)], dim=-1)\n    y = torch.cat([y, 2 * y.size(1) * batch_y.view(-1, 1).to(y.dtype)], dim=-1)\n\n    index = NNDescent(x.detach().cpu().numpy())\n    col, dist = index.query(y.detach().cpu().numpy(), k=k)\n    dist = torch.from_numpy(dist).view(-1).to(x.device, x.dtype)\n    col = torch.from_numpy(col).view(-1).to(x.device, torch.long)\n    row = torch.arange(y.size(0), device=x.device, dtype=torch.long)\n    row = row.repeat_interleave(k)\n    mask = ~torch.isinf(dist)\n    row, col = row[mask], col[mask]\n\n    return torch.stack([row, col], dim=0)\n\n\ndef approx_knn_graph(\n    x: Tensor,\n    k: int,\n    batch: Tensor = None,\n    loop: bool = False,\n    flow: str = 'source_to_target',\n) -> Tensor:  # pragma: no cover\n    r\"\"\"Computes graph edges to the nearest approximated :obj:`k` points.\n\n    .. note::\n\n        Approximated :math:`k`-nearest neighbor search is performed via the\n        `pynndescent <https://pynndescent.readthedocs.io/en/latest>`_ library.\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}`.\n        k (int): The number of neighbors.\n        batch (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        loop (bool, optional): If :obj:`True`, the graph will contain\n            self-loops. (default: :obj:`False`)\n        flow (str, optional): The flow direction when using in combination with\n            message passing (:obj:`\"source_to_target\"` or\n            :obj:`\"target_to_source\"`). (default: :obj:`\"source_to_target\"`)\n\n    :rtype: :class:`torch.Tensor`\n    \"\"\"\n    assert flow in ['source_to_target', 'target_to_source']\n    row, col = approx_knn(x, x, k if loop else k + 1, batch, batch)\n    row, col = (col, row) if flow == 'source_to_target' else (row, col)\n    if not loop:\n        mask = row != col\n        row, col = row[mask], col[mask]\n    return torch.stack([row, col], dim=0)\n"
  },
  {
    "path": "torch_geometric/nn/pool/asap.py",
    "content": "from typing import Callable, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\nfrom torch.nn import Linear\n\nfrom torch_geometric.nn import LEConv\nfrom torch_geometric.nn.pool.select import SelectTopK\nfrom torch_geometric.utils import (\n    add_remaining_self_loops,\n    remove_self_loops,\n    scatter,\n    softmax,\n    to_edge_index,\n    to_torch_coo_tensor,\n    to_torch_csr_tensor,\n)\n\n\nclass ASAPooling(torch.nn.Module):\n    r\"\"\"The Adaptive Structure Aware Pooling operator from the\n    `\"ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical\n    Graph Representations\" <https://arxiv.org/abs/1911.07979>`_ paper.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        ratio (float or int): Graph pooling ratio, which is used to compute\n            :math:`k = \\lceil \\mathrm{ratio} \\cdot N \\rceil`, or the value\n            of :math:`k` itself, depending on whether the type of :obj:`ratio`\n            is :obj:`float` or :obj:`int`. (default: :obj:`0.5`)\n        GNN (torch.nn.Module, optional): A graph neural network layer for\n            using intra-cluster properties.\n            Especially helpful for graphs with higher degree of neighborhood\n            (one of :class:`torch_geometric.nn.conv.GraphConv`,\n            :class:`torch_geometric.nn.conv.GCNConv` or\n            any GNN which supports the :obj:`edge_weight` parameter).\n            (default: :obj:`None`)\n        dropout (float, optional): Dropout probability of the normalized\n            attention coefficients which exposes each node to a stochastically\n            sampled neighborhood during training. (default: :obj:`0`)\n        negative_slope (float, optional): LeakyReLU angle of the negative\n            slope. (default: :obj:`0.2`)\n        add_self_loops (bool, optional): If set to :obj:`True`, will add self\n            loops to the new graph connectivity. (default: :obj:`False`)\n        **kwargs (optional): Additional parameters for initializing the\n            graph neural network layer.\n    \"\"\"\n    def __init__(self, in_channels: int, ratio: Union[float, int] = 0.5,\n                 GNN: Optional[Callable] = None, dropout: float = 0.0,\n                 negative_slope: float = 0.2, add_self_loops: bool = False,\n                 **kwargs):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.ratio = ratio\n        self.negative_slope = negative_slope\n        self.dropout = dropout\n        self.GNN = GNN\n        self.add_self_loops = add_self_loops\n\n        self.lin = Linear(in_channels, in_channels)\n        self.att = Linear(2 * in_channels, 1)\n        self.gnn_score = LEConv(self.in_channels, 1)\n        if self.GNN is not None:\n            self.gnn_intra_cluster = GNN(self.in_channels, self.in_channels,\n                                         **kwargs)\n        else:\n            self.gnn_intra_cluster = None\n\n        self.select = SelectTopK(1, ratio)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.lin.reset_parameters()\n        self.att.reset_parameters()\n        self.gnn_score.reset_parameters()\n        if self.gnn_intra_cluster is not None:\n            self.gnn_intra_cluster.reset_parameters()\n        self.select.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        edge_weight: Optional[Tensor] = None,\n        batch: Optional[Tensor] = None,\n    ) -> Tuple[Tensor, Tensor, Optional[Tensor], Tensor, Tensor]:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The node feature matrix.\n            edge_index (torch.Tensor): The edge indices.\n            edge_weight (torch.Tensor, optional): The edge weights.\n                (default: :obj:`None`)\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each node to a specific example. (default: :obj:`None`)\n\n        Return types:\n            * **x** (*torch.Tensor*): The pooled node embeddings.\n            * **edge_index** (*torch.Tensor*): The coarsened edge indices.\n            * **edge_weight** (*torch.Tensor, optional*): The coarsened edge\n              weights.\n            * **batch** (*torch.Tensor*): The coarsened batch vector.\n            * **index** (*torch.Tensor*): The top-:math:`k` node indices of\n              nodes which are kept after pooling.\n        \"\"\"\n        N = x.size(0)\n\n        edge_index, edge_weight = add_remaining_self_loops(\n            edge_index, edge_weight, fill_value=1., num_nodes=N)\n\n        if batch is None:\n            batch = edge_index.new_zeros(x.size(0))\n\n        x = x.unsqueeze(-1) if x.dim() == 1 else x\n\n        x_pool = x\n        if self.gnn_intra_cluster is not None:\n            x_pool = self.gnn_intra_cluster(x=x, edge_index=edge_index,\n                                            edge_weight=edge_weight)\n\n        x_pool_j = x_pool[edge_index[0]]\n        x_q = scatter(x_pool_j, edge_index[1], dim=0, reduce='max')\n        x_q = self.lin(x_q)[edge_index[1]]\n\n        score = self.att(torch.cat([x_q, x_pool_j], dim=-1)).view(-1)\n        score = F.leaky_relu(score, self.negative_slope)\n        score = softmax(score, edge_index[1], num_nodes=N)\n\n        # Sample attention coefficients stochastically.\n        score = F.dropout(score, p=self.dropout, training=self.training)\n\n        v_j = x[edge_index[0]] * score.view(-1, 1)\n        x = scatter(v_j, edge_index[1], dim=0, reduce='sum')\n\n        # Cluster selection.\n        fitness = self.gnn_score(x, edge_index).sigmoid().view(-1)\n        perm = self.select(fitness, batch).node_index\n        x = x[perm] * fitness[perm].view(-1, 1)\n        batch = batch[perm]\n\n        # Graph coarsening.\n        A = to_torch_csr_tensor(edge_index, edge_weight, size=(N, N))\n        S = to_torch_coo_tensor(edge_index, score, size=(N, N))\n        S = S.index_select(1, perm).to_sparse_csr()\n        A = S.t().to_sparse_csr() @ (A @ S)\n\n        if edge_weight is None:\n            edge_index, _ = to_edge_index(A)\n        else:\n            edge_index, edge_weight = to_edge_index(A)\n\n        if self.add_self_loops:\n            edge_index, edge_weight = add_remaining_self_loops(\n                edge_index, edge_weight, num_nodes=A.size(0))\n        else:\n            edge_index, edge_weight = remove_self_loops(\n                edge_index, edge_weight)\n\n        return x, edge_index, edge_weight, batch, perm\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'ratio={self.ratio})')\n"
  },
  {
    "path": "torch_geometric/nn/pool/avg_pool.py",
    "content": "from typing import Callable, Optional, Tuple\n\nfrom torch import Tensor\n\nfrom torch_geometric.data import Batch, Data\nfrom torch_geometric.nn.pool.consecutive import consecutive_cluster\nfrom torch_geometric.nn.pool.pool import pool_batch, pool_edge, pool_pos\nfrom torch_geometric.utils import add_self_loops, scatter\n\n\ndef _avg_pool_x(\n    cluster: Tensor,\n    x: Tensor,\n    size: Optional[int] = None,\n) -> Tensor:\n    return scatter(x, cluster, dim=0, dim_size=size, reduce='mean')\n\n\ndef avg_pool_x(\n    cluster: Tensor,\n    x: Tensor,\n    batch: Tensor,\n    batch_size: Optional[int] = None,\n    size: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    r\"\"\"Average pools node features according to the clustering defined in\n    :attr:`cluster`.\n    See :meth:`torch_geometric.nn.pool.max_pool_x` for more details.\n\n    Args:\n        cluster (torch.Tensor): The cluster vector\n            :math:`\\mathbf{c} \\in \\{ 0, \\ldots, N - 1 \\}^N`, which assigns each\n            node to a specific cluster.\n        x (Tensor): The node feature matrix.\n        batch (torch.Tensor): The batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example.\n        batch_size (int, optional): The number of examples :math:`B`.\n            Automatically calculated if not given. (default: :obj:`None`)\n        size (int, optional): The maximum number of clusters in a single\n            example. (default: :obj:`None`)\n\n    :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`) if :attr:`size` is\n        :obj:`None`, else :class:`torch.Tensor`\n    \"\"\"\n    if size is not None:\n        if batch_size is None:\n            batch_size = int(batch.max().item()) + 1\n        return _avg_pool_x(cluster, x, batch_size * size), None\n\n    cluster, perm = consecutive_cluster(cluster)\n    x = _avg_pool_x(cluster, x)\n    batch = pool_batch(perm, batch)\n\n    return x, batch\n\n\ndef avg_pool(\n    cluster: Tensor,\n    data: Data,\n    transform: Optional[Callable] = None,\n) -> Data:\n    r\"\"\"Pools and coarsens a graph given by the\n    :class:`torch_geometric.data.Data` object according to the clustering\n    defined in :attr:`cluster`.\n    Final node features are defined by the *average* features of all nodes\n    within the same cluster.\n    See :meth:`torch_geometric.nn.pool.max_pool` for more details.\n\n    Args:\n        cluster (torch.Tensor): The cluster vector\n            :math:`\\mathbf{c} \\in \\{ 0, \\ldots, N - 1 \\}^N`, which assigns each\n            node to a specific cluster.\n        data (Data): Graph data object.\n        transform (callable, optional): A function/transform that takes in the\n            coarsened and pooled :obj:`torch_geometric.data.Data` object and\n            returns a transformed version. (default: :obj:`None`)\n\n    :rtype: :class:`torch_geometric.data.Data`\n    \"\"\"\n    cluster, perm = consecutive_cluster(cluster)\n\n    x = None if data.x is None else _avg_pool_x(cluster, data.x)\n    index, attr = pool_edge(cluster, data.edge_index, data.edge_attr)\n    batch = None if data.batch is None else pool_batch(perm, data.batch)\n    pos = None if data.pos is None else pool_pos(cluster, data.pos)\n\n    data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos)\n\n    if transform is not None:\n        data = transform(data)\n\n    return data\n\n\ndef avg_pool_neighbor_x(\n    data: Data,\n    flow: Optional[str] = 'source_to_target',\n) -> Data:\n    r\"\"\"Average pools neighboring node features, where each feature in\n    :obj:`data.x` is replaced by the average feature values from the central\n    node and its neighbors.\n    \"\"\"\n    x, edge_index = data.x, data.edge_index\n\n    edge_index, _ = add_self_loops(edge_index, num_nodes=data.num_nodes)\n\n    row, col = edge_index\n    row, col = (row, col) if flow == 'source_to_target' else (col, row)\n\n    data.x = scatter(x[row], col, dim=0, dim_size=data.num_nodes,\n                     reduce='mean')\n    return data\n"
  },
  {
    "path": "torch_geometric/nn/pool/cluster_pool.py",
    "content": "from typing import NamedTuple, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.utils import (\n    dense_to_sparse,\n    one_hot,\n    to_dense_adj,\n    to_scipy_sparse_matrix,\n)\n\n\nclass UnpoolInfo(NamedTuple):\n    edge_index: Tensor\n    cluster: Tensor\n    batch: Tensor\n\n\nclass ClusterPooling(torch.nn.Module):\n    r\"\"\"The cluster pooling operator from the `\"Edge-Based Graph Component\n    Pooling\" <https://arxiv.org/abs/2409.11856>`_ paper.\n    :class:`ClusterPooling` computes a score for each edge.\n    Based on the selected edges, graph clusters are calculated and compressed\n    to one node using the injective :obj:`\"sum\"` aggregation function.\n    Edges are remapped based on the nodes created by each cluster and the\n    original edges.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        edge_score_method (str, optional): The function to apply\n            to compute the edge score from raw edge scores (:obj:`\"tanh\"`,\n            :obj:`\"sigmoid\"`, :obj:`\"log_softmax\"`). (default: :obj:`\"tanh\"`)\n        dropout (float, optional): The probability with\n            which to drop edge scores during training. (default: :obj:`0.0`)\n        threshold (float, optional): The threshold of edge scores. If set to\n            :obj:`None`, will be automatically inferred depending on\n            :obj:`edge_score_method`. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        edge_score_method: str = 'tanh',\n        dropout: float = 0.0,\n        threshold: Optional[float] = None,\n    ):\n        super().__init__()\n        assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax']\n\n        if threshold is None:\n            threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0\n\n        self.in_channels = in_channels\n        self.edge_score_method = edge_score_method\n        self.dropout = dropout\n        self.threshold = threshold\n\n        self.lin = torch.nn.Linear(2 * in_channels, 1)\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.lin.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n    ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The node features.\n            edge_index (torch.Tensor): The edge indices.\n            batch (torch.Tensor): Batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each node to a specific example.\n\n        Return types:\n            * **x** *(torch.Tensor)* - The pooled node features.\n            * **edge_index** *(torch.Tensor)* - The coarsened edge indices.\n            * **batch** *(torch.Tensor)* - The coarsened batch vector.\n            * **unpool_info** *(UnpoolInfo)* - Information that can be consumed\n              for unpooling.\n        \"\"\"\n        mask = edge_index[0] != edge_index[1]\n        edge_index = edge_index[:, mask]\n\n        edge_attr = torch.cat(\n            [x[edge_index[0]], x[edge_index[1]]],\n            dim=-1,\n        )\n        edge_score = self.lin(edge_attr).view(-1)\n        edge_score = F.dropout(edge_score, p=self.dropout,\n                               training=self.training)\n\n        if self.edge_score_method == 'tanh':\n            edge_score = edge_score.tanh()\n        elif self.edge_score_method == 'sigmoid':\n            edge_score = edge_score.sigmoid()\n        else:\n            assert self.edge_score_method == 'log_softmax'\n            edge_score = F.log_softmax(edge_score, dim=0)\n\n        return self._merge_edges(x, edge_index, batch, edge_score)\n\n    def _merge_edges(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n        edge_score: Tensor,\n    ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:\n\n        from scipy.sparse.csgraph import connected_components\n\n        edge_contract = edge_index[:, edge_score > self.threshold]\n\n        adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))\n        _, cluster_np = connected_components(adj, directed=True,\n                                             connection=\"weak\")\n\n        cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device)\n        C = one_hot(cluster)\n        A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)\n        S = to_dense_adj(edge_index, edge_attr=edge_score,\n                         max_num_nodes=x.size(0)).squeeze(0)\n\n        A_contract = to_dense_adj(edge_contract,\n                                  max_num_nodes=x.size(0)).squeeze(0)\n        nodes_single = ((A_contract.sum(dim=-1) +\n                         A_contract.sum(dim=-2)) == 0).nonzero()\n        S[nodes_single, nodes_single] = 1.0\n\n        x_out = (S @ C).t() @ x\n        edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))\n        batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch)\n        unpool_info = UnpoolInfo(edge_index, cluster, batch)\n\n        return x_out, edge_index_out, batch_out, unpool_info\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.in_channels})'\n"
  },
  {
    "path": "torch_geometric/nn/pool/connect/__init__.py",
    "content": "r\"\"\"Graph connection package.\n\nThis package provides classes for determining coarsened graph connections in\ngraph pooling scenarios.\n\"\"\"\n\nfrom .base import Connect, ConnectOutput\nfrom .filter_edges import FilterEdges\n\n__all__ = [\n    'Connect',\n    'ConnectOutput',\n    'FilterEdges',\n]\n"
  },
  {
    "path": "torch_geometric/nn/pool/connect/base.py",
    "content": "from dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.pool.select import SelectOutput\n\n\n@dataclass(init=False)\nclass ConnectOutput:\n    r\"\"\"The output of the :class:`Connect` method, which holds the coarsened\n    graph structure, and optional pooled edge features and batch vectors.\n\n    Args:\n        edge_index (torch.Tensor): The edge indices of the cooarsened graph.\n        edge_attr (torch.Tensor, optional): The pooled edge features of the\n            coarsened graph. (default: :obj:`None`)\n        batch (torch.Tensor, optional): The pooled batch vector of the\n            coarsened graph. (default: :obj:`None`)\n    \"\"\"\n    edge_index: Tensor\n    edge_attr: Optional[Tensor] = None\n    batch: Optional[Tensor] = None\n\n    def __init__(\n        self,\n        edge_index: Tensor,\n        edge_attr: Optional[Tensor] = None,\n        batch: Optional[Tensor] = None,\n    ):\n        if edge_index.dim() != 2:\n            raise ValueError(f\"Expected 'edge_index' to be two-dimensional \"\n                             f\"(got {edge_index.dim()} dimensions)\")\n\n        if edge_index.size(0) != 2:\n            raise ValueError(f\"Expected 'edge_index' to have size '2' in the \"\n                             f\"first dimension (got '{edge_index.size(0)}')\")\n\n        if edge_attr is not None and edge_attr.size(0) != edge_index.size(1):\n            raise ValueError(f\"Expected 'edge_index' and 'edge_attr' to \"\n                             f\"hold the same number of edges (got \"\n                             f\"{edge_index.size(1)} and {edge_attr.size(0)} \"\n                             f\"edges)\")\n\n        self.edge_index = edge_index\n        self.edge_attr = edge_attr\n        self.batch = batch\n\n\nConnectOutput = torch.jit.script(ConnectOutput)\n\n\nclass Connect(torch.nn.Module):\n    r\"\"\"An abstract base class for implementing custom edge connection\n    operators as described in the `\"Understanding Pooling in Graph Neural\n    Networks\" <https://arxiv.org/abs/1905.05178>`_ paper.\n\n    Specifically, :class:`Connect` determines for each pair of supernodes the\n    presence or abscene of an edge based on the existing edges between the\n    nodes in the two supernodes.\n    The operator also computes pooled edge features and batch vectors\n    (if present).\n    \"\"\"\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n\n    def forward(\n        self,\n        select_output: SelectOutput,\n        edge_index: Tensor,\n        edge_attr: Optional[Tensor] = None,\n        batch: Optional[Tensor] = None,\n    ) -> ConnectOutput:\n        r\"\"\"Forward pass.\n\n        Args:\n            select_output (SelectOutput): The output of :class:`Select`.\n            edge_index (torch.Tensor): The edge indices.\n            edge_attr (torch.Tensor, optional): The edge features.\n                (default: :obj:`None`)\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each node to a specific graph. (default: :obj:`None`)\n        \"\"\"\n        raise NotImplementedError\n\n    @staticmethod\n    def get_pooled_batch(\n        select_output: SelectOutput,\n        batch: Optional[Tensor],\n    ) -> Optional[Tensor]:\n        r\"\"\"Returns the batch vector of the coarsened graph.\n\n        Args:\n            select_output (SelectOutput): The output of :class:`Select`.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each element to a specific example. (default: :obj:`None`)\n        \"\"\"\n        if batch is None:\n            return batch\n\n        out = torch.arange(select_output.num_clusters, device=batch.device)\n        return out.scatter_(0, select_output.cluster_index,\n                            batch[select_output.node_index])\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/nn/pool/connect/filter_edges.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.pool.connect import Connect, ConnectOutput\nfrom torch_geometric.nn.pool.select import SelectOutput\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef filter_adj(\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    node_index: Tensor,\n    cluster_index: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    if cluster_index is None:\n        cluster_index = torch.arange(node_index.size(0),\n                                     device=node_index.device)\n\n    mask = node_index.new_full((num_nodes, ), -1)\n    mask[node_index] = cluster_index\n\n    row, col = edge_index[0], edge_index[1]\n    row, col = mask[row], mask[col]\n    mask = (row >= 0) & (col >= 0)\n    row, col = row[mask], col[mask]\n\n    if edge_attr is not None:\n        edge_attr = edge_attr[mask]\n\n    return torch.stack([row, col], dim=0), edge_attr\n\n\nclass FilterEdges(Connect):\n    r\"\"\"Filters out edges if their incident nodes are not in any cluster.\n\n    .. math::\n            \\mathbf{A}^{\\prime} &= \\mathbf{A}_{\\mathbf{i},\\mathbf{i}},\n\n    where :math:`\\mathbf{i}` denotes the set of retained nodes.\n    It is assumed that each cluster contains only one node.\n    \"\"\"\n    def forward(\n        self,\n        select_output: SelectOutput,\n        edge_index: Tensor,\n        edge_attr: Optional[Tensor] = None,\n        batch: Optional[Tensor] = None,\n    ) -> ConnectOutput:\n\n        if (not torch.jit.is_scripting() and select_output.num_clusters\n                != select_output.cluster_index.size(0)):\n            raise ValueError(f\"'{self.__class__.__name__}' requires each \"\n                             f\"cluster to contain only one node\")\n\n        edge_index, edge_attr = filter_adj(\n            edge_index,\n            edge_attr,\n            select_output.node_index,\n            select_output.cluster_index,\n            num_nodes=select_output.num_nodes,\n        )\n        batch = self.get_pooled_batch(select_output, batch)\n\n        return ConnectOutput(edge_index, edge_attr, batch)\n"
  },
  {
    "path": "torch_geometric/nn/pool/consecutive.py",
    "content": "import torch\n\n\ndef consecutive_cluster(src):\n    unique, inv = torch.unique(src, sorted=True, return_inverse=True)\n    perm = torch.arange(inv.size(0), dtype=inv.dtype, device=inv.device)\n    perm = inv.new_empty(unique.size(0)).scatter_(0, inv, perm)\n    return inv, perm\n"
  },
  {
    "path": "torch_geometric/nn/pool/decimation.py",
    "content": "from typing import Tuple, Union\n\nimport torch\nfrom torch import LongTensor, Tensor\n\nfrom torch_geometric.utils import cumsum\n\n\ndef decimation_indices(\n    ptr: LongTensor,\n    decimation_factor: Union[int, float],\n) -> Tuple[Tensor, LongTensor]:\n    \"\"\"Gets indices which downsample each point cloud by a decimation factor.\n\n    Decimation happens separately for each cloud to prevent emptying smaller\n    point clouds. Empty clouds are prevented: clouds will have a least\n    one node after decimation.\n\n    Args:\n        ptr (LongTensor): The indices of samples in the batch.\n        decimation_factor (int or float): The value to divide number of nodes\n            with. Should be higher than (or equal to) :obj:`1` for\n            downsampling.\n\n    :rtype: (:class:`LongTensor`, :class:`LongTensor`): The indices and\n        updated :obj:`ptr` after downsampling.\n    \"\"\"\n    if decimation_factor < 1:\n        raise ValueError(\n            f\"The argument `decimation_factor` should be higher than (or \"\n            f\"equal to) 1 for downsampling. (got {decimation_factor})\")\n\n    batch_size = ptr.size(0) - 1\n    count = ptr[1:] - ptr[:-1]\n    decim_count = torch.div(count, decimation_factor, rounding_mode='floor')\n    decim_count.clamp_(min=1)  # Prevent empty examples.\n\n    decim_indices = [\n        ptr[i] + torch.randperm(count[i], device=ptr.device)[:decim_count[i]]\n        for i in range(batch_size)\n    ]\n    decim_indices = torch.cat(decim_indices, dim=0)\n\n    # Get updated ptr (e.g., for future decimations):\n    decim_ptr = cumsum(decim_count)\n\n    return decim_indices, decim_ptr\n"
  },
  {
    "path": "torch_geometric/nn/pool/edge_pool.py",
    "content": "from typing import Callable, List, NamedTuple, Optional, Tuple\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nfrom torch_geometric.utils import coalesce, scatter, softmax\n\n\nclass UnpoolInfo(NamedTuple):\n    edge_index: Tensor\n    cluster: Tensor\n    batch: Tensor\n    new_edge_score: Tensor\n\n\nclass EdgePooling(torch.nn.Module):\n    r\"\"\"The edge pooling operator from the `\"Towards Graph Pooling by Edge\n    Contraction\" <https://graphreason.github.io/papers/17.pdf>`__ and\n    `\"Edge Contraction Pooling for Graph Neural Networks\"\n    <https://arxiv.org/abs/1905.10990>`__ papers.\n\n    In short, a score is computed for each edge.\n    Edges are contracted iteratively according to that score unless one of\n    their nodes has already been part of a contracted edge.\n\n    To duplicate the configuration from the `\"Towards Graph Pooling by Edge\n    Contraction\" <https://graphreason.github.io/papers/17.pdf>`__ paper, use\n    either :func:`EdgePooling.compute_edge_score_softmax`\n    or :func:`EdgePooling.compute_edge_score_tanh`, and set\n    :obj:`add_to_edge_score` to :obj:`0.0`.\n\n    To duplicate the configuration from the `\"Edge Contraction Pooling for\n    Graph Neural Networks\" <https://arxiv.org/abs/1905.10990>`__ paper,\n    set :obj:`dropout` to :obj:`0.2`.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        edge_score_method (callable, optional): The function to apply\n            to compute the edge score from raw edge scores. By default,\n            this is the softmax over all incoming edges for each node.\n            This function takes in a :obj:`raw_edge_score` tensor of shape\n            :obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of\n            nodes :obj:`num_nodes`, and produces a new tensor of the same size\n            as :obj:`raw_edge_score` describing normalized edge scores.\n            Included functions are\n            :func:`EdgePooling.compute_edge_score_softmax`,\n            :func:`EdgePooling.compute_edge_score_tanh`, and\n            :func:`EdgePooling.compute_edge_score_sigmoid`.\n            (default: :func:`EdgePooling.compute_edge_score_softmax`)\n        dropout (float, optional): The probability with\n            which to drop edge scores during training. (default: :obj:`0.0`)\n        add_to_edge_score (float, optional): A value to be added to each\n            computed edge score. Adding this greatly helps with unpooling\n            stability. (default: :obj:`0.5`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        edge_score_method: Optional[Callable] = None,\n        dropout: float = 0.0,\n        add_to_edge_score: float = 0.5,\n    ):\n        super().__init__()\n        self.in_channels = in_channels\n        if edge_score_method is None:\n            edge_score_method = self.compute_edge_score_softmax\n        self.compute_edge_score = edge_score_method\n        self.add_to_edge_score = add_to_edge_score\n        self.dropout = dropout\n\n        self.lin = torch.nn.Linear(2 * in_channels, 1)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.lin.reset_parameters()\n\n    @staticmethod\n    def compute_edge_score_softmax(\n        raw_edge_score: Tensor,\n        edge_index: Tensor,\n        num_nodes: int,\n    ) -> Tensor:\n        r\"\"\"Normalizes edge scores via softmax application.\"\"\"\n        return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes)\n\n    @staticmethod\n    def compute_edge_score_tanh(\n        raw_edge_score: Tensor,\n        edge_index: Optional[Tensor] = None,\n        num_nodes: Optional[int] = None,\n    ) -> Tensor:\n        r\"\"\"Normalizes edge scores via hyperbolic tangent application.\"\"\"\n        return torch.tanh(raw_edge_score)\n\n    @staticmethod\n    def compute_edge_score_sigmoid(\n        raw_edge_score: Tensor,\n        edge_index: Optional[Tensor] = None,\n        num_nodes: Optional[int] = None,\n    ) -> Tensor:\n        r\"\"\"Normalizes edge scores via sigmoid application.\"\"\"\n        return torch.sigmoid(raw_edge_score)\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n    ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The node features.\n            edge_index (torch.Tensor): The edge indices.\n            batch (torch.Tensor): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each node to a specific example.\n\n        Return types:\n            * **x** *(torch.Tensor)* - The pooled node features.\n            * **edge_index** *(torch.Tensor)* - The coarsened edge indices.\n            * **batch** *(torch.Tensor)* - The coarsened batch vector.\n            * **unpool_info** *(UnpoolInfo)* - Information that is\n              consumed by :func:`EdgePooling.unpool` for unpooling.\n        \"\"\"\n        e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1)\n        e = self.lin(e).view(-1)\n        e = F.dropout(e, p=self.dropout, training=self.training)\n        e = self.compute_edge_score(e, edge_index, x.size(0))\n        e = e + self.add_to_edge_score\n\n        x, edge_index, batch, unpool_info = self._merge_edges(\n            x, edge_index, batch, e)\n\n        return x, edge_index, batch, unpool_info\n\n    def _merge_edges(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        batch: Tensor,\n        edge_score: Tensor,\n    ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:\n\n        cluster = torch.empty_like(batch)\n        perm: List[int] = torch.argsort(edge_score, descending=True).tolist()\n\n        # Iterate through all edges, selecting it if it is not incident to\n        # another already chosen edge.\n        mask = torch.ones(x.size(0), dtype=torch.bool)\n\n        i = 0\n        new_edge_indices: List[int] = []\n        edge_index_cpu = edge_index.cpu()\n        for edge_idx in perm:\n            source = int(edge_index_cpu[0, edge_idx])\n            if not bool(mask[source]):\n                continue\n\n            target = int(edge_index_cpu[1, edge_idx])\n            if not bool(mask[target]):\n                continue\n\n            new_edge_indices.append(edge_idx)\n\n            cluster[source] = i\n            mask[source] = False\n\n            if source != target:\n                cluster[target] = i\n                mask[target] = False\n\n            i += 1\n\n        # The remaining nodes are simply kept:\n        j = int(mask.sum())\n        cluster[mask] = torch.arange(i, i + j, device=x.device)\n        i += j\n\n        # We compute the new features as an addition of the old ones.\n        new_x = scatter(x, cluster, dim=0, dim_size=i, reduce='sum')\n        new_edge_score = edge_score[new_edge_indices]\n        if int(mask.sum()) > 0:\n            remaining_score = x.new_ones(\n                (new_x.size(0) - len(new_edge_indices), ))\n            new_edge_score = torch.cat([new_edge_score, remaining_score])\n        new_x = new_x * new_edge_score.view(-1, 1)\n\n        new_edge_index = coalesce(cluster[edge_index], num_nodes=new_x.size(0))\n        new_batch = x.new_empty(new_x.size(0), dtype=torch.long)\n        new_batch = new_batch.scatter_(0, cluster, batch)\n\n        unpool_info = UnpoolInfo(edge_index=edge_index, cluster=cluster,\n                                 batch=batch, new_edge_score=new_edge_score)\n\n        return new_x, new_edge_index, new_batch, unpool_info\n\n    def unpool(\n        self,\n        x: Tensor,\n        unpool_info: UnpoolInfo,\n    ) -> Tuple[Tensor, Tensor, Tensor]:\n        r\"\"\"Unpools a previous edge pooling step.\n\n        For unpooling, :obj:`x` should be of same shape as those produced by\n        this layer's :func:`forward` function. Then, it will produce an\n        unpooled :obj:`x` in addition to :obj:`edge_index` and :obj:`batch`.\n\n        Args:\n            x (torch.Tensor): The node features.\n            unpool_info (UnpoolInfo): Information that has been produced by\n                :func:`EdgePooling.forward`.\n\n        Return types:\n            * **x** *(torch.Tensor)* - The unpooled node features.\n            * **edge_index** *(torch.Tensor)* - The new edge indices.\n            * **batch** *(torch.Tensor)* - The new batch vector.\n        \"\"\"\n        new_x = x / unpool_info.new_edge_score.view(-1, 1)\n        new_x = new_x[unpool_info.cluster]\n        return new_x, unpool_info.edge_index, unpool_info.batch\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.in_channels})'\n"
  },
  {
    "path": "torch_geometric/nn/pool/glob.py",
    "content": "from typing import Optional\n\nfrom torch import Tensor\n\nfrom torch_geometric.utils import scatter\n\n\ndef global_add_pool(x: Tensor, batch: Optional[Tensor],\n                    size: Optional[int] = None) -> Tensor:\n    r\"\"\"Returns batch-wise graph-level-outputs by adding node features\n    across the node dimension.\n\n    For a single graph :math:`\\mathcal{G}_i`, its output is computed by\n\n    .. math::\n        \\mathbf{r}_i = \\sum_{n=1}^{N_i} \\mathbf{x}_n.\n\n    Functional method of the\n    :class:`~torch_geometric.nn.aggr.SumAggregation` module.\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times F}`.\n        batch (torch.Tensor, optional): The batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n            each node to a specific example.\n        size (int, optional): The number of examples :math:`B`.\n            Automatically calculated if not given. (default: :obj:`None`)\n    \"\"\"\n    dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2\n\n    if batch is None:\n        return x.sum(dim=dim, keepdim=x.dim() <= 2)\n    return scatter(x, batch, dim=dim, dim_size=size, reduce='sum')\n\n\ndef global_mean_pool(x: Tensor, batch: Optional[Tensor],\n                     size: Optional[int] = None) -> Tensor:\n    r\"\"\"Returns batch-wise graph-level-outputs by averaging node features\n    across the node dimension.\n\n    For a single graph :math:`\\mathcal{G}_i`, its output is computed by\n\n    .. math::\n        \\mathbf{r}_i = \\frac{1}{N_i} \\sum_{n=1}^{N_i} \\mathbf{x}_n.\n\n    Functional method of the\n    :class:`~torch_geometric.nn.aggr.MeanAggregation` module.\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times F}`.\n        batch (torch.Tensor, optional): The batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n            each node to a specific example.\n        size (int, optional): The number of examples :math:`B`.\n            Automatically calculated if not given. (default: :obj:`None`)\n    \"\"\"\n    dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2\n\n    if batch is None:\n        return x.mean(dim=dim, keepdim=x.dim() <= 2)\n    return scatter(x, batch, dim=dim, dim_size=size, reduce='mean')\n\n\ndef global_max_pool(x: Tensor, batch: Optional[Tensor],\n                    size: Optional[int] = None) -> Tensor:\n    r\"\"\"Returns batch-wise graph-level-outputs by taking the channel-wise\n    maximum across the node dimension.\n\n    For a single graph :math:`\\mathcal{G}_i`, its output is computed by\n\n    .. math::\n        \\mathbf{r}_i = \\mathrm{max}_{n=1}^{N_i} \\, \\mathbf{x}_n.\n\n    Functional method of the\n    :class:`~torch_geometric.nn.aggr.MaxAggregation` module.\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times F}`.\n        batch (torch.Tensor, optional): The batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n            each element to a specific example.\n        size (int, optional): The number of examples :math:`B`.\n            Automatically calculated if not given. (default: :obj:`None`)\n    \"\"\"\n    dim = -1 if isinstance(x, Tensor) and x.dim() == 1 else -2\n\n    if batch is None:\n        return x.max(dim=dim, keepdim=x.dim() <= 2)[0]\n    return scatter(x, batch, dim=dim, dim_size=size, reduce='max')\n"
  },
  {
    "path": "torch_geometric/nn/pool/graclus.py",
    "content": "from typing import Optional\n\nfrom torch import Tensor\n\nimport torch_geometric.typing\n\nif torch_geometric.typing.WITH_TORCH_CLUSTER:\n    from torch_cluster import graclus_cluster\nelse:\n    graclus_cluster = None\n\n\ndef graclus(edge_index: Tensor, weight: Optional[Tensor] = None,\n            num_nodes: Optional[int] = None):\n    r\"\"\"A greedy clustering algorithm from the `\"Weighted Graph Cuts without\n    Eigenvectors: A Multilevel Approach\" <http://www.cs.utexas.edu/users/\n    inderjit/public_papers/multilevel_pami.pdf>`_ paper of picking an unmarked\n    vertex and matching it with one of its unmarked neighbors (that maximizes\n    its edge weight).\n    The GPU algorithm is adapted from the `\"A GPU Algorithm for Greedy Graph\n    Matching\" <http://www.staff.science.uu.nl/~bisse101/Articles/match12.pdf>`_\n    paper.\n\n    Args:\n        edge_index (torch.Tensor): The edge indices.\n        weight (torch.Tensor, optional): One-dimensional edge weights.\n            (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n    :rtype: :class:`torch.Tensor`\n    \"\"\"\n    if graclus_cluster is None:\n        raise ImportError('`graclus` requires `torch-cluster`.')\n\n    return graclus_cluster(edge_index[0], edge_index[1], weight, num_nodes)\n"
  },
  {
    "path": "torch_geometric/nn/pool/knn.py",
    "content": "import warnings\nfrom typing import NamedTuple, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import cumsum, degree, to_dense_batch\n\n\nclass KNNOutput(NamedTuple):\n    score: Tensor\n    index: Tensor\n\n\nclass KNNIndex:\n    r\"\"\"A base class to perform fast :math:`k`-nearest neighbor search\n    (:math:`k`-NN) via the :obj:`faiss` library.\n\n    Please ensure that :obj:`faiss` is installed by running\n\n    .. code-block:: bash\n\n        pip install faiss-cpu\n        # or\n        pip install faiss-gpu\n\n    depending on whether to plan to use GPU-processing for :math:`k`-NN search.\n\n    Args:\n        index_factory (str, optional): The name of the index factory to use,\n            *e.g.*, :obj:`\"IndexFlatL2\"` or :obj:`\"IndexFlatIP\"`. See `here\n            <https://github.com/facebookresearch/faiss/wiki/\n            The-index-factory>`_ for more information.\n        emb (torch.Tensor, optional): The data points to add.\n            (default: :obj:`None`)\n        reserve (int, optional): The number of elements to reserve memory for\n            before re-allocating (GPU-only). (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        index_factory: Optional[str] = None,\n        emb: Optional[Tensor] = None,\n        reserve: Optional[int] = None,\n    ):\n        warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*')\n\n        import faiss\n\n        self.index_factory = index_factory\n        self.index: Optional[faiss.Index] = None\n        self.reserve = reserve\n\n        if emb is not None:\n            self.add(emb)\n\n    @property\n    def numel(self) -> int:\n        r\"\"\"The number of data points to search in.\"\"\"\n        if self.index is None:\n            return 0\n        return self.index.ntotal\n\n    def _create_index(self, channels: int):\n        import faiss\n        return faiss.index_factory(channels, self.index_factory)\n\n    def add(self, emb: Tensor):\n        r\"\"\"Adds new data points to the :class:`KNNIndex` to search in.\n\n        Args:\n            emb (torch.Tensor): The data points to add.\n        \"\"\"\n        import faiss\n        import faiss.contrib.torch_utils\n\n        if emb.dim() != 2:\n            raise ValueError(f\"'emb' needs to be two-dimensional \"\n                             f\"(got {emb.dim()} dimensions)\")\n\n        if self.index is None:\n            self.index = self._create_index(emb.size(1))\n\n            if emb.device != torch.device('cpu'):\n                self.index = faiss.index_cpu_to_gpu(\n                    faiss.StandardGpuResources(),\n                    emb.device.index,\n                    self.index,\n                )\n\n                if self.reserve is not None:\n                    if hasattr(self.index, 'reserveMemory'):\n                        self.index.reserveMemory(self.reserve)\n                    else:\n                        warnings.warn(\n                            f\"'{self.index.__class__.__name__}' \"\n                            f\"does not support pre-allocation of \"\n                            f\"memory\", stacklevel=2)\n\n            self.index.train(emb)\n\n        self.index.add(emb.detach())\n\n    def search(\n        self,\n        emb: Tensor,\n        k: int,\n        exclude_links: Optional[Tensor] = None,\n    ) -> KNNOutput:\n        r\"\"\"Search for the :math:`k` nearest neighbors of the given data\n        points. Returns the distance/similarity score of the nearest neighbors\n        and their indices.\n\n        Args:\n            emb (torch.Tensor): The data points to add.\n            k (int): The number of nearest neighbors to return.\n            exclude_links (torch.Tensor): The links to exclude from searching.\n                Needs to be a COO tensor of shape :obj:`[2, num_links]`, where\n                :obj:`exclude_links[0]` refers to indices in :obj:`emb`, and\n                :obj:`exclude_links[1]` refers to the data points in the\n                :class:`KNNIndex`. (default: :obj:`None`)\n        \"\"\"\n        if self.index is None:\n            raise RuntimeError(f\"'{self.__class__.__name__}' is not yet \"\n                               \"initialized. Please call `add(...)` first.\")\n\n        if emb.dim() != 2:\n            raise ValueError(f\"'emb' needs to be two-dimensional \"\n                             f\"(got {emb.dim()} dimensions)\")\n\n        query_k = k\n\n        if exclude_links is not None:\n            deg = degree(exclude_links[0], num_nodes=emb.size(0)).max()\n            query_k = k + int(deg.max() if deg.numel() > 0 else 0)\n\n        query_k = min(query_k, self.numel)\n\n        if k > 2048:  # `faiss` supports up-to `k=2048`:\n            warnings.warn(\n                f\"Capping 'k' to faiss' upper limit of 2048 \"\n                f\"(got {k}). This may cause some relevant items to \"\n                f\"not be retrieved.\", stacklevel=2)\n        elif query_k > 2048:\n            warnings.warn(\n                f\"Capping 'k' to faiss' upper limit of 2048 \"\n                f\"(got {k} which got extended to {query_k} due to \"\n                f\"the exclusion of existing links). This may cause \"\n                f\"some relevant items to not be retrieved.\", stacklevel=2)\n            query_k = 2048\n\n        score, index = self.index.search(emb.detach(), query_k)\n\n        if exclude_links is not None:\n            # Drop indices to exclude by converting to flat vector:\n            flat_exclude = self.numel * exclude_links[0] + exclude_links[1]\n\n            offset = torch.arange(\n                start=0,\n                end=self.numel * index.size(0),\n                step=self.numel,\n                device=index.device,\n            ).view(-1, 1)\n            flat_index = (index + offset).view(-1)\n\n            notin = torch.isin(flat_index, flat_exclude).logical_not_()\n\n            score = score.view(-1)[notin]\n            index = index.view(-1)[notin]\n\n            # Only maintain top-k scores:\n            count = notin.view(-1, query_k).sum(dim=1)\n            cum_count = cumsum(count)\n\n            batch = torch.arange(count.numel(), device=count.device)\n            batch = batch.repeat_interleave(count, output_size=cum_count[-1])\n\n            batch_arange = torch.arange(count.sum(), device=count.device)\n            batch_arange = batch_arange - cum_count[batch]\n\n            mask = batch_arange < k\n            score = score[mask]\n            index = index[mask]\n\n            if count.min() < k:  # Fill with dummy scores:\n                batch = batch[mask]\n                score, _ = to_dense_batch(\n                    score,\n                    batch,\n                    fill_value=float('-inf'),\n                    max_num_nodes=k,\n                    batch_size=emb.size(0),\n                )\n                index, _ = to_dense_batch(\n                    index,\n                    batch,\n                    fill_value=-1,\n                    max_num_nodes=k,\n                    batch_size=emb.size(0),\n                )\n\n            score = score.view(-1, k)\n            index = index.view(-1, k)\n\n        return KNNOutput(score, index)\n\n    def get_emb(self) -> Tensor:\n        r\"\"\"Returns the data points stored in the :class:`KNNIndex`.\"\"\"\n        if self.index is None:\n            raise RuntimeError(f\"'{self.__class__.__name__}' is not yet \"\n                               \"initialized. Please call `add(...)` first.\")\n\n        return self.index.reconstruct_n(0, self.numel)\n\n\nclass L2KNNIndex(KNNIndex):\n    r\"\"\"Performs fast :math:`k`-nearest neighbor search (:math:`k`-NN) based on\n    the :math:`L_2` metric via the :obj:`faiss` library.\n\n    Args:\n        emb (torch.Tensor, optional): The data points to add.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, emb: Optional[Tensor] = None):\n        super().__init__(index_factory=None, emb=emb)\n\n    def _create_index(self, channels: int):\n        import faiss\n        return faiss.IndexFlatL2(channels)\n\n\nclass MIPSKNNIndex(KNNIndex):\n    r\"\"\"Performs fast :math:`k`-nearest neighbor search (:math:`k`-NN) based on\n    the maximum inner product via the :obj:`faiss` library.\n\n    Args:\n        emb (torch.Tensor, optional): The data points to add.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, emb: Optional[Tensor] = None):\n        super().__init__(index_factory=None, emb=emb)\n\n    def _create_index(self, channels: int):\n        import faiss\n        return faiss.IndexFlatIP(channels)\n\n\nclass ApproxL2KNNIndex(KNNIndex):\n    r\"\"\"Performs fast approximate :math:`k`-nearest neighbor search\n    (:math:`k`-NN) based on the the :math:`L_2` metric via the :obj:`faiss`\n    library.\n    Hyperparameters needs to be tuned for speed-accuracy trade-off.\n\n    Args:\n        num_cells (int): The number of cells.\n        num_cells_to_visit (int): The number of cells that are visited to\n            perform to search.\n        bits_per_vector (int): The number of bits per sub-vector.\n        emb (torch.Tensor, optional): The data points to add.\n            (default: :obj:`None`)\n        reserve (int, optional): The number of elements to reserve memory for\n            before re-allocating (GPU only). (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        num_cells: int,\n        num_cells_to_visit: int,\n        bits_per_vector: int,\n        emb: Optional[Tensor] = None,\n        reserve: Optional[int] = None,\n    ):\n        self.num_cells = num_cells\n        self.num_cells_to_visit = num_cells_to_visit\n        self.bits_per_vector = bits_per_vector\n        super().__init__(index_factory=None, emb=emb, reserve=reserve)\n\n    def _create_index(self, channels: int):\n        import faiss\n        index = faiss.IndexIVFPQ(\n            faiss.IndexFlatL2(channels),\n            channels,\n            self.num_cells,\n            self.bits_per_vector,\n            8,\n            faiss.METRIC_L2,\n        )\n        index.nprobe = self.num_cells_to_visit\n        return index\n\n\nclass ApproxMIPSKNNIndex(KNNIndex):\n    r\"\"\"Performs fast approximate :math:`k`-nearest neighbor search\n    (:math:`k`-NN) based on the maximum inner product via the :obj:`faiss`\n    library.\n    Hyperparameters needs to be tuned for speed-accuracy trade-off.\n\n    Args:\n        num_cells (int): The number of cells.\n        num_cells_to_visit (int): The number of cells that are visited to\n            perform to search.\n        bits_per_vector (int): The number of bits per sub-vector.\n        emb (torch.Tensor, optional): The data points to add.\n            (default: :obj:`None`)\n        reserve (int, optional): The number of elements to reserve memory for\n            before re-allocating (GPU only). (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        num_cells: int,\n        num_cells_to_visit: int,\n        bits_per_vector: int,\n        emb: Optional[Tensor] = None,\n        reserve: Optional[int] = None,\n    ):\n        self.num_cells = num_cells\n        self.num_cells_to_visit = num_cells_to_visit\n        self.bits_per_vector = bits_per_vector\n        super().__init__(index_factory=None, emb=emb, reserve=reserve)\n\n    def _create_index(self, channels: int):\n        import faiss\n        index = faiss.IndexIVFPQ(\n            faiss.IndexFlatIP(channels),\n            channels,\n            self.num_cells,\n            self.bits_per_vector,\n            8,\n            faiss.METRIC_INNER_PRODUCT,\n        )\n        index.nprobe = self.num_cells_to_visit\n        return index\n"
  },
  {
    "path": "torch_geometric/nn/pool/max_pool.py",
    "content": "from typing import Callable, Optional, Tuple\n\nfrom torch import Tensor\n\nfrom torch_geometric.data import Batch, Data\nfrom torch_geometric.nn.pool.consecutive import consecutive_cluster\nfrom torch_geometric.nn.pool.pool import pool_batch, pool_edge, pool_pos\nfrom torch_geometric.utils import add_self_loops, scatter\n\n\ndef _max_pool_x(\n    cluster: Tensor,\n    x: Tensor,\n    size: Optional[int] = None,\n) -> Tensor:\n    return scatter(x, cluster, dim=0, dim_size=size, reduce='max')\n\n\ndef max_pool_x(\n    cluster: Tensor,\n    x: Tensor,\n    batch: Tensor,\n    batch_size: Optional[int] = None,\n    size: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    r\"\"\"Max-Pools node features according to the clustering defined in\n    :attr:`cluster`.\n\n    Args:\n        cluster (torch.Tensor): The cluster vector\n            :math:`\\mathbf{c} \\in \\{ 0, \\ldots, N - 1 \\}^N`, which assigns each\n            node to a specific cluster.\n        x (Tensor): The node feature matrix.\n        batch (torch.Tensor): The batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example.\n        batch_size (int, optional): The number of examples :math:`B`.\n            Automatically calculated if not given. (default: :obj:`None`)\n        size (int, optional): The maximum number of clusters in a single\n            example. This property is useful to obtain a batch-wise dense\n            representation, *e.g.* for applying FC layers, but should only be\n            used if the size of the maximum number of clusters per example is\n            known in advance. (default: :obj:`None`)\n\n    :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`) if :attr:`size` is\n        :obj:`None`, else :class:`torch.Tensor`\n    \"\"\"\n    if size is not None:\n        if batch_size is None:\n            batch_size = int(batch.max().item()) + 1\n        return _max_pool_x(cluster, x, batch_size * size), None\n\n    cluster, perm = consecutive_cluster(cluster)\n    x = _max_pool_x(cluster, x)\n    batch = pool_batch(perm, batch)\n\n    return x, batch\n\n\ndef max_pool(\n    cluster: Tensor,\n    data: Data,\n    transform: Optional[Callable] = None,\n) -> Data:\n    r\"\"\"Pools and coarsens a graph given by the\n    :class:`torch_geometric.data.Data` object according to the clustering\n    defined in :attr:`cluster`.\n    All nodes within the same cluster will be represented as one node.\n    Final node features are defined by the *maximum* features of all nodes\n    within the same cluster, node positions are averaged and edge indices are\n    defined to be the union of the edge indices of all nodes within the same\n    cluster.\n\n    Args:\n        cluster (torch.Tensor): The cluster vector\n            :math:`\\mathbf{c} \\in \\{ 0, \\ldots, N - 1 \\}^N`, which assigns each\n            node to a specific cluster.\n        data (Data): Graph data object.\n        transform (callable, optional): A function/transform that takes in the\n            coarsened and pooled :obj:`torch_geometric.data.Data` object and\n            returns a transformed version. (default: :obj:`None`)\n\n    :rtype: :class:`torch_geometric.data.Data`\n    \"\"\"\n    cluster, perm = consecutive_cluster(cluster)\n\n    x = None if data.x is None else _max_pool_x(cluster, data.x)\n    index, attr = pool_edge(cluster, data.edge_index, data.edge_attr)\n    batch = None if data.batch is None else pool_batch(perm, data.batch)\n    pos = None if data.pos is None else pool_pos(cluster, data.pos)\n\n    data = Batch(batch=batch, x=x, edge_index=index, edge_attr=attr, pos=pos)\n\n    if transform is not None:\n        data = transform(data)\n\n    return data\n\n\ndef max_pool_neighbor_x(\n    data: Data,\n    flow: Optional[str] = 'source_to_target',\n) -> Data:\n    r\"\"\"Max pools neighboring node features, where each feature in\n    :obj:`data.x` is replaced by the feature value with the maximum value from\n    the central node and its neighbors.\n    \"\"\"\n    x, edge_index = data.x, data.edge_index\n\n    edge_index, _ = add_self_loops(edge_index, num_nodes=data.num_nodes)\n\n    row, col = edge_index\n    row, col = (row, col) if flow == 'source_to_target' else (col, row)\n\n    data.x = scatter(x[row], col, dim=0, dim_size=data.num_nodes, reduce='max')\n    return data\n"
  },
  {
    "path": "torch_geometric/nn/pool/mem_pool.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Conv2d, KLDivLoss, Linear, Parameter\n\nfrom torch_geometric.utils import to_dense_batch\n\nEPS = 1e-15\n\n\nclass MemPooling(torch.nn.Module):\n    r\"\"\"Memory based pooling layer from `\"Memory-Based Graph Networks\"\n    <https://arxiv.org/abs/2002.09518>`_ paper, which learns a coarsened graph\n    representation based on soft cluster assignments.\n\n    .. math::\n        S_{i,j}^{(h)} &= \\frac{\n        (1+{\\| \\mathbf{x}_i-\\mathbf{k}^{(h)}_j \\|}^2 / \\tau)^{\n        -\\frac{1+\\tau}{2}}}{\n        \\sum_{k=1}^K (1 + {\\| \\mathbf{x}_i-\\mathbf{k}^{(h)}_k \\|}^2 / \\tau)^{\n        -\\frac{1+\\tau}{2}}}\n\n        \\mathbf{S} &= \\textrm{softmax}(\\textrm{Conv2d}\n        (\\Vert_{h=1}^H \\mathbf{S}^{(h)})) \\in \\mathbb{R}^{N \\times K}\n\n        \\mathbf{X}^{\\prime} &= \\mathbf{S}^{\\top} \\mathbf{X} \\mathbf{W} \\in\n        \\mathbb{R}^{K \\times F^{\\prime}}\n\n    where :math:`H` denotes the number of heads, and :math:`K` denotes the\n    number of clusters.\n\n    Args:\n        in_channels (int): Size of each input sample :math:`F`.\n        out_channels (int): Size of each output sample :math:`F^{\\prime}`.\n        heads (int): The number of heads :math:`H`.\n        num_clusters (int): number of clusters :math:`K` per head.\n        tau (int, optional): The temperature :math:`\\tau`. (default: :obj:`1.`)\n    \"\"\"\n    def __init__(self, in_channels: int, out_channels: int, heads: int,\n                 num_clusters: int, tau: float = 1.):\n        super().__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.heads = heads\n        self.num_clusters = num_clusters\n        self.tau = tau\n\n        self.k = Parameter(torch.empty(heads, num_clusters, in_channels))\n        self.conv = Conv2d(heads, 1, kernel_size=1, padding=0, bias=False)\n        self.lin = Linear(in_channels, out_channels, bias=False)\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        torch.nn.init.uniform_(self.k.data, -1., 1.)\n        self.conv.reset_parameters()\n        self.lin.reset_parameters()\n\n    @staticmethod\n    def kl_loss(S: Tensor) -> Tensor:\n        r\"\"\"The additional KL divergence-based loss.\n\n        .. math::\n            P_{i,j} &= \\frac{S_{i,j}^2 / \\sum_{n=1}^N S_{n,j}}{\\sum_{k=1}^K\n            S_{i,k}^2 / \\sum_{n=1}^N S_{n,k}}\n\n            \\mathcal{L}_{\\textrm{KL}} &= \\textrm{KLDiv}(\\mathbf{P} \\Vert\n            \\mathbf{S})\n        \"\"\"\n        S_2 = S**2\n        P = S_2 / S.sum(dim=1, keepdim=True)\n        denom = P.sum(dim=2, keepdim=True)\n        denom[S.sum(dim=2, keepdim=True) == 0.0] = 1.0\n        P /= denom\n\n        loss = KLDivLoss(reduction='batchmean', log_target=False)\n        return loss(S.clamp(EPS).log(), P.clamp(EPS))\n\n    def forward(\n        self,\n        x: Tensor,\n        batch: Optional[Tensor] = None,\n        mask: Optional[Tensor] = None,\n        max_num_nodes: Optional[int] = None,\n        batch_size: Optional[int] = None,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The node feature tensor of shape\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}` or\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each node to a specific example.\n                Should not be provided in case node features already have shape\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`.\n                (default: :obj:`None`)\n            mask (torch.Tensor, optional): A mask matrix\n                :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{B \\times N}`, which\n                indicates valid nodes for each graph when using\n                node features of shape\n                :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N \\times F}`.\n                (default: :obj:`None`)\n            max_num_nodes (int, optional): The size of the :math:`B` node\n                dimension. Automatically calculated if not given.\n                (default: :obj:`None`)\n            batch_size (int, optional): The number of examples :math:`B`.\n                Automatically calculated if not given. (default: :obj:`None`)\n        \"\"\"\n        if x.dim() <= 2:\n            x, mask = to_dense_batch(x, batch, max_num_nodes=max_num_nodes,\n                                     batch_size=batch_size)\n        elif mask is None:\n            mask = x.new_ones((x.size(0), x.size(1)), dtype=torch.bool)\n\n        (B, N, _), H, K = x.size(), self.heads, self.num_clusters\n\n        dist = torch.cdist(self.k.view(H * K, -1), x.view(B * N, -1), p=2)**2\n        dist = (1. + dist / self.tau).pow(-(self.tau + 1.0) / 2.0)\n\n        dist = dist.view(H, K, B, N).permute(2, 0, 3, 1)  # [B, H, N, K]\n        S = dist / dist.sum(dim=-1, keepdim=True)\n\n        S = self.conv(S).squeeze(dim=1).softmax(dim=-1)  # [B, N, K]\n        S = S * mask.view(B, N, 1)\n\n        x = self.lin(S.transpose(1, 2) @ x)\n\n        return x, S\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.in_channels}, '\n                f'{self.out_channels}, heads={self.heads}, '\n                f'num_clusters={self.num_clusters})')\n"
  },
  {
    "path": "torch_geometric/nn/pool/pan_pool.py",
    "content": "from typing import Callable, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Parameter\n\nfrom torch_geometric.nn.pool.connect import FilterEdges\nfrom torch_geometric.nn.pool.select import SelectTopK\nfrom torch_geometric.typing import OptTensor, SparseTensor\nfrom torch_geometric.utils import scatter\n\n\nclass PANPooling(torch.nn.Module):\n    r\"\"\"The path integral based pooling operator from the\n    `\"Path Integral Based Convolution and Pooling for Graph Neural Networks\"\n    <https://arxiv.org/abs/2006.16811>`_ paper.\n\n    PAN pooling performs top-:math:`k` pooling where global node importance is\n    measured based on node features and the MET matrix:\n\n    .. math::\n        {\\rm score} = \\beta_1 \\mathbf{X} \\cdot \\mathbf{p} + \\beta_2\n        {\\rm deg}(\\mathbf{M})\n\n    Args:\n        in_channels (int): Size of each input sample.\n        ratio (float): Graph pooling ratio, which is used to compute\n            :math:`k = \\lceil \\mathrm{ratio} \\cdot N \\rceil`.\n            This value is ignored if min_score is not None.\n            (default: :obj:`0.5`)\n        min_score (float, optional): Minimal node score :math:`\\tilde{\\alpha}`\n            which is used to compute indices of pooled nodes\n            :math:`\\mathbf{i} = \\mathbf{y}_i > \\tilde{\\alpha}`.\n            When this value is not :obj:`None`, the :obj:`ratio` argument is\n            ignored. (default: :obj:`None`)\n        multiplier (float, optional): Coefficient by which features gets\n            multiplied after pooling. This can be useful for large graphs and\n            when :obj:`min_score` is used. (default: :obj:`1.0`)\n        nonlinearity (str or callable, optional): The non-linearity to use.\n            (default: :obj:`\"tanh\"`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        ratio: float = 0.5,\n        min_score: Optional[float] = None,\n        multiplier: float = 1.0,\n        nonlinearity: Union[str, Callable] = 'tanh',\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.ratio = ratio\n        self.min_score = min_score\n        self.multiplier = multiplier\n\n        self.p = Parameter(torch.empty(in_channels))\n        self.beta = Parameter(torch.empty(2))\n        self.select = SelectTopK(1, ratio, min_score, nonlinearity)\n        self.connect = FilterEdges()\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.p.data.fill_(1)\n        self.beta.data.fill_(0.5)\n        self.select.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        M: SparseTensor,\n        batch: OptTensor = None,\n    ) -> Tuple[Tensor, Tensor, Tensor, OptTensor, Tensor, Tensor]:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The node feature matrix.\n            M (SparseTensor): The MET matrix :math:`\\mathbf{M}`.\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each node to a specific example. (default: :obj:`None`)\n        \"\"\"\n        if batch is None:\n            batch = x.new_zeros(x.size(0), dtype=torch.long)\n\n        row, col, edge_weight = M.coo()\n        assert edge_weight is not None\n\n        score1 = (x * self.p).sum(dim=-1)\n        score2 = scatter(edge_weight, col, 0, dim_size=x.size(0), reduce='sum')\n        score = self.beta[0] * score1 + self.beta[1] * score2\n\n        select_out = self.select(score, batch)\n\n        perm = select_out.node_index\n        score = select_out.weight\n        assert score is not None\n\n        x = x[perm] * score.view(-1, 1)\n        x = self.multiplier * x if self.multiplier != 1 else x\n\n        edge_index = torch.stack([col, row], dim=0)\n        connect_out = self.connect(select_out, edge_index, edge_weight, batch)\n        edge_weight = connect_out.edge_attr\n        assert edge_weight is not None\n\n        return (x, connect_out.edge_index, edge_weight, connect_out.batch,\n                perm, score)\n\n    def __repr__(self) -> str:\n        if self.min_score is None:\n            ratio = f'ratio={self.ratio}'\n        else:\n            ratio = f'min_score={self.min_score}'\n\n        return (f'{self.__class__.__name__}({self.in_channels}, {ratio}, '\n                f'multiplier={self.multiplier})')\n"
  },
  {
    "path": "torch_geometric/nn/pool/pool.py",
    "content": "from typing import Optional\n\nimport torch\n\nfrom torch_geometric.utils import coalesce, remove_self_loops, scatter\n\n\ndef pool_edge(\n    cluster,\n    edge_index,\n    edge_attr: Optional[torch.Tensor] = None,\n    reduce: Optional[str] = 'sum',\n):\n    num_nodes = cluster.size(0)\n    edge_index = cluster[edge_index.view(-1)].view(2, -1)\n    edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)\n    if edge_index.numel() > 0:\n        edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes,\n                                         reduce=reduce)\n    return edge_index, edge_attr\n\n\ndef pool_batch(perm, batch):\n    return batch[perm]\n\n\ndef pool_pos(cluster, pos):\n    return scatter(pos, cluster, dim=0, reduce='mean')\n"
  },
  {
    "path": "torch_geometric/nn/pool/sag_pool.py",
    "content": "from typing import Callable, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn import GraphConv\nfrom torch_geometric.nn.pool.connect import FilterEdges\nfrom torch_geometric.nn.pool.select import SelectTopK\nfrom torch_geometric.typing import OptTensor\n\n\nclass SAGPooling(torch.nn.Module):\n    r\"\"\"The self-attention pooling operator from the `\"Self-Attention Graph\n    Pooling\" <https://arxiv.org/abs/1904.08082>`_ and `\"Understanding\n    Attention and Generalization in Graph Neural Networks\"\n    <https://arxiv.org/abs/1905.02850>`_ papers.\n\n    If :obj:`min_score` :math:`\\tilde{\\alpha}` is :obj:`None`, computes:\n\n        .. math::\n            \\mathbf{y} &= \\textrm{GNN}(\\mathbf{X}, \\mathbf{A})\n\n            \\mathbf{i} &= \\mathrm{top}_k(\\mathbf{y})\n\n            \\mathbf{X}^{\\prime} &= (\\mathbf{X} \\odot\n            \\mathrm{tanh}(\\mathbf{y}))_{\\mathbf{i}}\n\n            \\mathbf{A}^{\\prime} &= \\mathbf{A}_{\\mathbf{i},\\mathbf{i}}\n\n    If :obj:`min_score` :math:`\\tilde{\\alpha}` is a value in :obj:`[0, 1]`,\n    computes:\n\n        .. math::\n            \\mathbf{y} &= \\mathrm{softmax}(\\textrm{GNN}(\\mathbf{X},\\mathbf{A}))\n\n            \\mathbf{i} &= \\mathbf{y}_i > \\tilde{\\alpha}\n\n            \\mathbf{X}^{\\prime} &= (\\mathbf{X} \\odot \\mathbf{y})_{\\mathbf{i}}\n\n            \\mathbf{A}^{\\prime} &= \\mathbf{A}_{\\mathbf{i},\\mathbf{i}}.\n\n    Projections scores are learned based on a graph neural network layer.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        ratio (float or int): Graph pooling ratio, which is used to compute\n            :math:`k = \\lceil \\mathrm{ratio} \\cdot N \\rceil`, or the value\n            of :math:`k` itself, depending on whether the type of :obj:`ratio`\n            is :obj:`float` or :obj:`int`.\n            This value is ignored if :obj:`min_score` is not :obj:`None`.\n            (default: :obj:`0.5`)\n        GNN (torch.nn.Module, optional): A graph neural network layer for\n            calculating projection scores (one of\n            :class:`torch_geometric.nn.conv.GraphConv`,\n            :class:`torch_geometric.nn.conv.GCNConv`,\n            :class:`torch_geometric.nn.conv.GATConv` or\n            :class:`torch_geometric.nn.conv.SAGEConv`). (default:\n            :class:`torch_geometric.nn.conv.GraphConv`)\n        min_score (float, optional): Minimal node score :math:`\\tilde{\\alpha}`\n            which is used to compute indices of pooled nodes\n            :math:`\\mathbf{i} = \\mathbf{y}_i > \\tilde{\\alpha}`.\n            When this value is not :obj:`None`, the :obj:`ratio` argument is\n            ignored. (default: :obj:`None`)\n        multiplier (float, optional): Coefficient by which features gets\n            multiplied after pooling. This can be useful for large graphs and\n            when :obj:`min_score` is used. (default: :obj:`1`)\n        nonlinearity (str or callable, optional): The non-linearity to use.\n            (default: :obj:`\"tanh\"`)\n        **kwargs (optional): Additional parameters for initializing the graph\n            neural network layer.\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        ratio: Union[float, int] = 0.5,\n        GNN: torch.nn.Module = GraphConv,\n        min_score: Optional[float] = None,\n        multiplier: float = 1.0,\n        nonlinearity: Union[str, Callable] = 'tanh',\n        **kwargs,\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.ratio = ratio\n        self.min_score = min_score\n        self.multiplier = multiplier\n\n        self.gnn = GNN(in_channels, 1, **kwargs)\n        self.select = SelectTopK(1, ratio, min_score, nonlinearity)\n        self.connect = FilterEdges()\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.gnn.reset_parameters()\n        self.select.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        edge_attr: OptTensor = None,\n        batch: OptTensor = None,\n        attn: OptTensor = None,\n    ) -> Tuple[Tensor, Tensor, OptTensor, OptTensor, Tensor, Tensor]:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The node feature matrix.\n            edge_index (torch.Tensor): The edge indices.\n            edge_attr (torch.Tensor, optional): The edge features.\n                (default: :obj:`None`)\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each node to a specific example. (default: :obj:`None`)\n            attn (torch.Tensor, optional): Optional node-level matrix to use\n                for computing attention scores instead of using the node\n                feature matrix :obj:`x`. (default: :obj:`None`)\n        \"\"\"\n        if batch is None:\n            batch = edge_index.new_zeros(x.size(0))\n\n        attn = x if attn is None else attn\n        attn = attn.view(-1, 1) if attn.dim() == 1 else attn\n        attn = self.gnn(attn, edge_index)\n\n        select_out = self.select(attn, batch)\n\n        perm = select_out.node_index\n        score = select_out.weight\n        assert score is not None\n\n        x = x[perm] * score.view(-1, 1)\n        x = self.multiplier * x if self.multiplier != 1 else x\n\n        connect_out = self.connect(select_out, edge_index, edge_attr, batch)\n\n        return (x, connect_out.edge_index, connect_out.edge_attr,\n                connect_out.batch, perm, score)\n\n    def __repr__(self) -> str:\n        if self.min_score is None:\n            ratio = f'ratio={self.ratio}'\n        else:\n            ratio = f'min_score={self.min_score}'\n\n        return (f'{self.__class__.__name__}({self.gnn.__class__.__name__}, '\n                f'{self.in_channels}, {ratio}, multiplier={self.multiplier})')\n"
  },
  {
    "path": "torch_geometric/nn/pool/select/__init__.py",
    "content": "r\"\"\"Node-selection package.\n\nThis package provides classes for node selection methods in graph pooling\nscenarios.\n\"\"\"\n\nfrom .base import Select, SelectOutput\nfrom .topk import SelectTopK\n\n__all__ = [\n    'Select',\n    'SelectOutput',\n    'SelectTopK',\n]\n"
  },
  {
    "path": "torch_geometric/nn/pool/select/base.py",
    "content": "from dataclasses import dataclass\nfrom typing import Optional\n\nimport torch\nfrom torch import Tensor\n\n\n@dataclass(init=False)\nclass SelectOutput:\n    r\"\"\"The output of the :class:`Select` method, which holds an assignment\n    from selected nodes to their respective cluster(s).\n\n    Args:\n        node_index (torch.Tensor): The indices of the selected nodes.\n        num_nodes (int): The number of nodes.\n        cluster_index (torch.Tensor): The indices of the clusters each node in\n            :obj:`node_index` is assigned to.\n        num_clusters (int): The number of clusters.\n        weight (torch.Tensor, optional): A weight vector, denoting the strength\n            of the assignment of a node to its cluster. (default: :obj:`None`)\n    \"\"\"\n    node_index: Tensor\n    num_nodes: int\n    cluster_index: Tensor\n    num_clusters: int\n    weight: Optional[Tensor] = None\n\n    def __init__(\n        self,\n        node_index: Tensor,\n        num_nodes: int,\n        cluster_index: Tensor,\n        num_clusters: int,\n        weight: Optional[Tensor] = None,\n    ):\n        if node_index.dim() != 1:\n            raise ValueError(f\"Expected 'node_index' to be one-dimensional \"\n                             f\"(got {node_index.dim()} dimensions)\")\n\n        if cluster_index.dim() != 1:\n            raise ValueError(f\"Expected 'cluster_index' to be one-dimensional \"\n                             f\"(got {cluster_index.dim()} dimensions)\")\n\n        if node_index.numel() != cluster_index.numel():\n            raise ValueError(f\"Expected 'node_index' and 'cluster_index' to \"\n                             f\"hold the same number of values (got \"\n                             f\"{node_index.numel()} and \"\n                             f\"{cluster_index.numel()} values)\")\n\n        if weight is not None and weight.dim() != 1:\n            raise ValueError(f\"Expected 'weight' vector to be one-dimensional \"\n                             f\"(got {weight.dim()} dimensions)\")\n\n        if weight is not None and weight.numel() != node_index.numel():\n            raise ValueError(f\"Expected 'weight' to hold {node_index.numel()} \"\n                             f\"values (got {weight.numel()} values)\")\n\n        self.node_index = node_index\n        self.num_nodes = num_nodes\n        self.cluster_index = cluster_index\n        self.num_clusters = num_clusters\n        self.weight = weight\n\n\nSelectOutput = torch.jit.script(SelectOutput)\n\n\nclass Select(torch.nn.Module):\n    r\"\"\"An abstract base class for implementing custom node selections as\n    described in the `\"Understanding Pooling in Graph Neural Networks\"\n    <https://arxiv.org/abs/1905.05178>`_ paper, which maps the nodes of an\n    input graph to supernodes in the coarsened graph.\n\n    Specifically, :class:`Select` returns a :class:`SelectOutput` output, which\n    holds a (sparse) mapping :math:`\\mathbf{C} \\in {[0, 1]}^{N \\times C}` that\n    assigns selected nodes to one or more of :math:`C` super nodes.\n    \"\"\"\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n\n    def forward(self, *args, **kwargs) -> SelectOutput:\n        raise NotImplementedError\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/nn/pool/select/topk.py",
    "content": "from typing import Callable, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.inits import uniform\nfrom torch_geometric.nn.pool.select import Select, SelectOutput\nfrom torch_geometric.nn.resolver import activation_resolver\nfrom torch_geometric.utils import cumsum, scatter, softmax\n\n\n# TODO (matthias) Document this method.\ndef topk(\n    x: Tensor,\n    ratio: Optional[Union[float, int]],\n    batch: Tensor,\n    min_score: Optional[float] = None,\n    tol: float = 1e-7,\n) -> Tensor:\n    if min_score is not None:\n        # Make sure that we do not drop all nodes in a graph.\n        scores_max = scatter(x, batch, reduce='max')[batch] - tol\n        scores_min = scores_max.clamp(max=min_score)\n\n        perm = (x > scores_min).nonzero().view(-1)\n        return perm\n\n    if ratio is not None:\n        num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')\n\n        if ratio >= 1:\n            k = num_nodes.new_full((num_nodes.size(0), ), int(ratio))\n        else:\n            k = (float(ratio) * num_nodes.to(x.dtype)).ceil().to(torch.long)\n\n        x, x_perm = torch.sort(x.view(-1), descending=True)\n        batch = batch[x_perm]\n        batch, batch_perm = torch.sort(batch, descending=False, stable=True)\n\n        arange = torch.arange(x.size(0), dtype=torch.long, device=x.device)\n        ptr = cumsum(num_nodes)\n        batched_arange = arange - ptr[batch]\n        mask = batched_arange < k[batch]\n\n        return x_perm[batch_perm[mask]]\n\n    raise ValueError(\"At least one of the 'ratio' and 'min_score' parameters \"\n                     \"must be specified\")\n\n\nclass SelectTopK(Select):\n    r\"\"\"Selects the top-:math:`k` nodes with highest projection scores from the\n    `\"Graph U-Nets\" <https://arxiv.org/abs/1905.05178>`_, `\"Towards Sparse\n    Hierarchical Graph Classifiers\" <https://arxiv.org/abs/1811.01287>`_\n    and `\"Understanding Attention and Generalization in Graph Neural\n    Networks\" <https://arxiv.org/abs/1905.02850>`_ papers.\n\n    If :obj:`min_score` :math:`\\tilde{\\alpha}` is :obj:`None`, computes:\n\n        .. math::\n            \\mathbf{y} &= \\sigma \\left( \\frac{\\mathbf{X}\\mathbf{p}}{\\|\n            \\mathbf{p} \\|} \\right)\n\n            \\mathbf{i} &= \\mathrm{top}_k(\\mathbf{y})\n\n    If :obj:`min_score` :math:`\\tilde{\\alpha}` is a value in :obj:`[0, 1]`,\n    computes:\n\n        .. math::\n            \\mathbf{y} &= \\mathrm{softmax}(\\mathbf{X}\\mathbf{p})\n\n            \\mathbf{i} &= \\mathbf{y}_i > \\tilde{\\alpha}\n\n    where :math:`\\mathbf{p}` is the learnable projection vector.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        ratio (float or int): The graph pooling ratio, which is used to compute\n            :math:`k = \\lceil \\mathrm{ratio} \\cdot N \\rceil`, or the value\n            of :math:`k` itself, depending on whether the type of :obj:`ratio`\n            is :obj:`float` or :obj:`int`.\n            This value is ignored if :obj:`min_score` is not :obj:`None`.\n            (default: :obj:`0.5`)\n        min_score (float, optional): Minimal node score :math:`\\tilde{\\alpha}`\n            which is used to compute indices of pooled nodes\n            :math:`\\mathbf{i} = \\mathbf{y}_i > \\tilde{\\alpha}`.\n            When this value is not :obj:`None`, the :obj:`ratio` argument is\n            ignored. (default: :obj:`None`)\n        act (str or callable, optional): The non-linearity :math:`\\sigma`.\n            (default: :obj:`\"tanh\"`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        ratio: Union[int, float] = 0.5,\n        min_score: Optional[float] = None,\n        act: Union[str, Callable] = 'tanh',\n    ):\n        super().__init__()\n\n        if ratio is None and min_score is None:\n            raise ValueError(f\"At least one of the 'ratio' and 'min_score' \"\n                             f\"parameters must be specified in \"\n                             f\"'{self.__class__.__name__}'\")\n\n        self.in_channels = in_channels\n        self.ratio = ratio\n        self.min_score = min_score\n        self.act = activation_resolver(act)\n\n        self.weight = torch.nn.Parameter(torch.empty(1, in_channels))\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        uniform(self.in_channels, self.weight)\n\n    def forward(\n        self,\n        x: Tensor,\n        batch: Optional[Tensor] = None,\n    ) -> SelectOutput:\n        \"\"\"\"\"\"  # noqa: D419\n        if batch is None:\n            batch = x.new_zeros(x.size(0), dtype=torch.long)\n\n        x = x.view(-1, 1) if x.dim() == 1 else x\n        score = (x * self.weight).sum(dim=-1)\n\n        if self.min_score is None:\n            score = self.act(score / self.weight.norm(p=2, dim=-1))\n        else:\n            score = softmax(score, batch)\n\n        node_index = topk(score, self.ratio, batch, self.min_score)\n\n        return SelectOutput(\n            node_index=node_index,\n            num_nodes=x.size(0),\n            cluster_index=torch.arange(node_index.size(0), device=x.device),\n            num_clusters=node_index.size(0),\n            weight=score[node_index],\n        )\n\n    def __repr__(self) -> str:\n        if self.min_score is None:\n            arg = f'ratio={self.ratio}'\n        else:\n            arg = f'min_score={self.min_score}'\n        return f'{self.__class__.__name__}({self.in_channels}, {arg})'\n"
  },
  {
    "path": "torch_geometric/nn/pool/topk_pool.py",
    "content": "from typing import Callable, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.nn.pool.connect import FilterEdges\nfrom torch_geometric.nn.pool.select import SelectTopK\nfrom torch_geometric.typing import OptTensor\n\n\nclass TopKPooling(torch.nn.Module):\n    r\"\"\":math:`\\mathrm{top}_k` pooling operator from the `\"Graph U-Nets\"\n    <https://arxiv.org/abs/1905.05178>`_, `\"Towards Sparse\n    Hierarchical Graph Classifiers\" <https://arxiv.org/abs/1811.01287>`_\n    and `\"Understanding Attention and Generalization in Graph Neural\n    Networks\" <https://arxiv.org/abs/1905.02850>`_ papers.\n\n    If :obj:`min_score` :math:`\\tilde{\\alpha}` is :obj:`None`, computes:\n\n        .. math::\n            \\mathbf{y} &= \\sigma \\left( \\frac{\\mathbf{X}\\mathbf{p}}{\\|\n            \\mathbf{p} \\|} \\right)\n\n            \\mathbf{i} &= \\mathrm{top}_k(\\mathbf{y})\n\n            \\mathbf{X}^{\\prime} &= (\\mathbf{X} \\odot\n            \\mathrm{tanh}(\\mathbf{y}))_{\\mathbf{i}}\n\n            \\mathbf{A}^{\\prime} &= \\mathbf{A}_{\\mathbf{i},\\mathbf{i}}\n\n    If :obj:`min_score` :math:`\\tilde{\\alpha}` is a value in :obj:`[0, 1]`,\n    computes:\n\n        .. math::\n            \\mathbf{y} &= \\mathrm{softmax}(\\mathbf{X}\\mathbf{p})\n\n            \\mathbf{i} &= \\mathbf{y}_i > \\tilde{\\alpha}\n\n            \\mathbf{X}^{\\prime} &= (\\mathbf{X} \\odot \\mathbf{y})_{\\mathbf{i}}\n\n            \\mathbf{A}^{\\prime} &= \\mathbf{A}_{\\mathbf{i},\\mathbf{i}},\n\n    where nodes are dropped based on a learnable projection score\n    :math:`\\mathbf{p}`.\n\n    Args:\n        in_channels (int): Size of each input sample.\n        ratio (float or int): The graph pooling ratio, which is used to compute\n            :math:`k = \\lceil \\mathrm{ratio} \\cdot N \\rceil`, or the value\n            of :math:`k` itself, depending on whether the type of :obj:`ratio`\n            is :obj:`float` or :obj:`int`.\n            This value is ignored if :obj:`min_score` is not :obj:`None`.\n            (default: :obj:`0.5`)\n        min_score (float, optional): Minimal node score :math:`\\tilde{\\alpha}`\n            which is used to compute indices of pooled nodes\n            :math:`\\mathbf{i} = \\mathbf{y}_i > \\tilde{\\alpha}`.\n            When this value is not :obj:`None`, the :obj:`ratio` argument is\n            ignored. (default: :obj:`None`)\n        multiplier (float, optional): Coefficient by which features gets\n            multiplied after pooling. This can be useful for large graphs and\n            when :obj:`min_score` is used. (default: :obj:`1`)\n        nonlinearity (str or callable, optional): The non-linearity\n            :math:`\\sigma`. (default: :obj:`\"tanh\"`)\n    \"\"\"\n    def __init__(\n        self,\n        in_channels: int,\n        ratio: Union[int, float] = 0.5,\n        min_score: Optional[float] = None,\n        multiplier: float = 1.,\n        nonlinearity: Union[str, Callable] = 'tanh',\n    ):\n        super().__init__()\n\n        self.in_channels = in_channels\n        self.ratio = ratio\n        self.min_score = min_score\n        self.multiplier = multiplier\n\n        self.select = SelectTopK(in_channels, ratio, min_score, nonlinearity)\n        self.connect = FilterEdges()\n\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        self.select.reset_parameters()\n\n    def forward(\n        self,\n        x: Tensor,\n        edge_index: Tensor,\n        edge_attr: Optional[Tensor] = None,\n        batch: Optional[Tensor] = None,\n        attn: Optional[Tensor] = None,\n    ) -> Tuple[Tensor, Tensor, OptTensor, OptTensor, Tensor, Tensor]:\n        r\"\"\"Forward pass.\n\n        Args:\n            x (torch.Tensor): The node feature matrix.\n            edge_index (torch.Tensor): The edge indices.\n            edge_attr (torch.Tensor, optional): The edge features.\n                (default: :obj:`None`)\n            batch (torch.Tensor, optional): The batch vector\n                :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n                each node to a specific example. (default: :obj:`None`)\n            attn (torch.Tensor, optional): Optional node-level matrix to use\n                for computing attention scores instead of using the node\n                feature matrix :obj:`x`. (default: :obj:`None`)\n        \"\"\"\n        if batch is None:\n            batch = edge_index.new_zeros(x.size(0))\n\n        attn = x if attn is None else attn\n        select_out = self.select(attn, batch)\n\n        perm = select_out.node_index\n        score = select_out.weight\n        assert score is not None\n\n        x = x[perm] * score.view(-1, 1)\n        x = self.multiplier * x if self.multiplier != 1 else x\n\n        connect_out = self.connect(select_out, edge_index, edge_attr, batch)\n\n        return (x, connect_out.edge_index, connect_out.edge_attr,\n                connect_out.batch, perm, score)\n\n    def __repr__(self) -> str:\n        if self.min_score is None:\n            ratio = f'ratio={self.ratio}'\n        else:\n            ratio = f'min_score={self.min_score}'\n\n        return (f'{self.__class__.__name__}({self.in_channels}, {ratio}, '\n                f'multiplier={self.multiplier})')\n"
  },
  {
    "path": "torch_geometric/nn/pool/voxel_grid.py",
    "content": "from typing import List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.utils.repeat import repeat\n\nif torch_geometric.typing.WITH_TORCH_CLUSTER:\n    from torch_cluster import grid_cluster\nelse:\n    grid_cluster = None\n\n\ndef voxel_grid(\n    pos: Tensor,\n    size: Union[float, List[float], Tensor],\n    batch: Optional[Tensor] = None,\n    start: Optional[Union[float, List[float], Tensor]] = None,\n    end: Optional[Union[float, List[float], Tensor]] = None,\n) -> Tensor:\n    r\"\"\"Voxel grid pooling from the, *e.g.*, `Dynamic Edge-Conditioned Filters\n    in Convolutional Networks on Graphs <https://arxiv.org/abs/1704.02901>`_\n    paper, which overlays a regular grid of user-defined size over a point\n    cloud and clusters all points within the same voxel.\n\n    Args:\n        pos (torch.Tensor): Node position matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times D}`.\n        size (float or [float] or Tensor): Size of a voxel (in each dimension).\n        batch (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots,B-1\\}}^N`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        start (float or [float] or Tensor, optional): Start coordinates of the\n            grid (in each dimension). If set to :obj:`None`, will be set to the\n            minimum coordinates found in :attr:`pos`. (default: :obj:`None`)\n        end (float or [float] or Tensor, optional): End coordinates of the grid\n            (in each dimension). If set to :obj:`None`, will be set to the\n            maximum coordinates found in :attr:`pos`. (default: :obj:`None`)\n\n    :rtype: :class:`torch.Tensor`\n    \"\"\"\n    if grid_cluster is None:\n        raise ImportError('`voxel_grid` requires `torch-cluster`.')\n\n    pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos\n    dim = pos.size(1)\n\n    if batch is None:\n        batch = pos.new_zeros(pos.size(0), dtype=torch.long)\n\n    pos = torch.cat([pos, batch.view(-1, 1).to(pos.dtype)], dim=-1)\n\n    if not isinstance(size, Tensor):\n        size = torch.tensor(size, dtype=pos.dtype, device=pos.device)\n    size = repeat(size, dim)\n    size = torch.cat([size, size.new_ones(1)])  # Add additional batch dim.\n\n    if start is not None:\n        if not isinstance(start, Tensor):\n            start = torch.tensor(start, dtype=pos.dtype, device=pos.device)\n        start = repeat(start, dim)\n        start = torch.cat([start, start.new_zeros(1)])\n\n    if end is not None:\n        if not isinstance(end, Tensor):\n            end = torch.tensor(end, dtype=pos.dtype, device=pos.device)\n        end = repeat(end, dim)\n        end = torch.cat([end, batch.max().unsqueeze(0)])\n\n    return grid_cluster(pos, size, start, end)\n"
  },
  {
    "path": "torch_geometric/nn/reshape.py",
    "content": "import torch\nfrom torch import Tensor\n\n\nclass Reshape(torch.nn.Module):\n    def __init__(self, *shape):\n        super().__init__()\n        self.shape = shape\n\n    def forward(self, x: Tensor) -> Tensor:\n        \"\"\"\"\"\"  # noqa: D419\n        x = x.view(*self.shape)\n        return x\n\n    def __repr__(self) -> str:\n        shape = ', '.join([str(dim) for dim in self.shape])\n        return f'{self.__class__.__name__}({shape})'\n"
  },
  {
    "path": "torch_geometric/nn/resolver.py",
    "content": "import inspect\nfrom typing import Any, Optional, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.optim import Optimizer\nfrom torch.optim.lr_scheduler import ReduceLROnPlateau\n\nfrom torch_geometric.nn.lr_scheduler import (\n    ConstantWithWarmupLR,\n    CosineWithWarmupLR,\n    CosineWithWarmupRestartsLR,\n    LinearWithWarmupLR,\n    PolynomialWithWarmupLR,\n)\nfrom torch_geometric.resolver import normalize_string, resolver\n\ntry:\n    from torch.optim.lr_scheduler import LRScheduler\nexcept ImportError:  # PyTorch < 2.0\n    from torch.optim.lr_scheduler import _LRScheduler as LRScheduler\n\n# Activation Resolver #########################################################\n\n\ndef swish(x: Tensor) -> Tensor:\n    return x * x.sigmoid()\n\n\ndef activation_resolver(query: Union[Any, str] = 'relu', *args, **kwargs):\n    base_cls = torch.nn.Module\n    base_cls_repr = 'Act'\n    acts = [\n        act for act in vars(torch.nn.modules.activation).values()\n        if isinstance(act, type) and issubclass(act, base_cls)\n    ]\n    acts += [\n        swish,\n    ]\n    act_dict = {}\n    return resolver(acts, act_dict, query, base_cls, base_cls_repr, *args,\n                    **kwargs)\n\n\n# Normalization Resolver ######################################################\n\n\ndef normalization_resolver(query: Union[Any, str], *args, **kwargs):\n    import torch_geometric.nn.norm as norm\n    base_cls = torch.nn.Module\n    base_cls_repr = 'Norm'\n    norms = [\n        norm for norm in vars(norm).values()\n        if isinstance(norm, type) and issubclass(norm, base_cls)\n    ]\n    norm_dict = {}\n    return resolver(norms, norm_dict, query, base_cls, base_cls_repr, *args,\n                    **kwargs)\n\n\n# Aggregation Resolver ########################################################\n\n\ndef aggregation_resolver(query: Union[Any, str], *args, **kwargs):\n    import torch_geometric.nn.aggr as aggr\n    if isinstance(query, (list, tuple)):\n        return aggr.MultiAggregation(query, *args, **kwargs)\n\n    base_cls = aggr.Aggregation\n    aggrs = [\n        aggr for aggr in vars(aggr).values()\n        if isinstance(aggr, type) and issubclass(aggr, base_cls)\n    ]\n    aggr_dict = {\n        'add': aggr.SumAggregation,\n    }\n    return resolver(aggrs, aggr_dict, query, base_cls, None, *args, **kwargs)\n\n\n# Optimizer Resolver ##########################################################\n\n\ndef optimizer_resolver(query: Union[Any, str], *args, **kwargs):\n    base_cls = Optimizer\n    optimizers = [\n        optimizer for optimizer in vars(torch.optim).values()\n        if isinstance(optimizer, type) and issubclass(optimizer, base_cls)\n    ]\n    return resolver(optimizers, {}, query, base_cls, None, *args, **kwargs)\n\n\n# Learning Rate Scheduler Resolver ############################################\n\n\ndef lr_scheduler_resolver(\n    query: Union[Any, str],\n    optimizer: Optimizer,\n    warmup_ratio_or_steps: Optional[Union[float, int]] = 0.1,\n    num_training_steps: Optional[int] = None,\n    **kwargs,\n) -> Union[LRScheduler, ReduceLROnPlateau]:\n    r\"\"\"A resolver to obtain a learning rate scheduler implemented in either\n    PyG or PyTorch from its name or type.\n\n    Args:\n        query (Any or str): The query name of the learning rate scheduler.\n        optimizer (Optimizer): The optimizer to be scheduled.\n        warmup_ratio_or_steps (float or int, optional): The number of warmup\n            steps. If given as a `float`, it will act as a ratio that gets\n            multiplied with the number of training steps to obtain the number\n            of warmup steps. Only required for warmup-based LR schedulers.\n            (default: :obj:`0.1`)\n        num_training_steps (int, optional): The total number of training steps.\n            (default: :obj:`None`)\n        **kwargs (optional): Additional arguments of the LR scheduler.\n    \"\"\"\n    if not isinstance(query, str):\n        return query\n\n    if isinstance(warmup_ratio_or_steps, float):\n        if warmup_ratio_or_steps < 0 or warmup_ratio_or_steps > 1:\n            raise ValueError(f\"`warmup_ratio_or_steps` needs to be between \"\n                             f\"0.0 and 1.0 when given as a floating point \"\n                             f\"number (got {warmup_ratio_or_steps}).\")\n        if num_training_steps is not None:\n            warmup_steps = round(warmup_ratio_or_steps * num_training_steps)\n    elif isinstance(warmup_ratio_or_steps, int):\n        if warmup_ratio_or_steps < 0:\n            raise ValueError(f\"`warmup_ratio_or_steps` needs to be positive \"\n                             f\"when given as an integer \"\n                             f\"(got {warmup_ratio_or_steps}).\")\n        warmup_steps = warmup_ratio_or_steps\n    else:\n        raise ValueError(f\"Found invalid type of `warmup_ratio_or_steps` \"\n                         f\"(got {type(warmup_ratio_or_steps)})\")\n\n    base_cls = LRScheduler\n    classes = [\n        scheduler for scheduler in vars(torch.optim.lr_scheduler).values()\n        if isinstance(scheduler, type) and issubclass(scheduler, base_cls)\n    ] + [ReduceLROnPlateau]\n\n    customized_lr_schedulers = [\n        ConstantWithWarmupLR,\n        LinearWithWarmupLR,\n        CosineWithWarmupLR,\n        CosineWithWarmupRestartsLR,\n        PolynomialWithWarmupLR,\n    ]\n    classes += customized_lr_schedulers\n\n    query_repr = normalize_string(query)\n    base_cls_repr = normalize_string('LR')\n\n    for cls in classes:\n        cls_repr = normalize_string(cls.__name__)\n        if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]:\n            if inspect.isclass(cls):\n                if cls in customized_lr_schedulers:\n                    cls_keys = inspect.signature(cls).parameters.keys()\n                    if 'num_warmup_steps' in cls_keys:\n                        kwargs['num_warmup_steps'] = warmup_steps\n                    if 'num_training_steps' in cls_keys:\n                        kwargs['num_training_steps'] = num_training_steps\n                obj = cls(optimizer, **kwargs)\n                return obj\n            return cls\n\n    choices = {cls.__name__ for cls in classes}\n    raise ValueError(f\"Could not resolve '{query}' among choices {choices}\")\n"
  },
  {
    "path": "torch_geometric/nn/sequential.jinja",
    "content": "import typing\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\n{% for module in modules %}\nfrom {{module}} import *\n{%- endfor %}\n\n\ndef forward(\n    self,\n{%- for param in signature.param_dict.values() %}\n    {{param.name}}: {{param.type_repr}},\n{%- endfor %}\n) -> {{signature.return_type_repr}}:\n\n{%- for child in children %}\n    {{child.return_names|join(', ')}} = self.{{child.name}}({{child.param_names|join(', ')}})\n{%- endfor %}\n    return {{children[-1].return_names|join(', ')}}\n"
  },
  {
    "path": "torch_geometric/nn/sequential.py",
    "content": "import copy\nimport inspect\nimport os.path as osp\nimport random\nimport sys\nfrom typing import (\n    Any,\n    Callable,\n    Dict,\n    List,\n    NamedTuple,\n    Optional,\n    Tuple,\n    Union,\n)\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.inspector import Parameter, Signature, eval_type, split\nfrom torch_geometric.template import module_from_template\n\n\nclass Child(NamedTuple):\n    name: str\n    param_names: List[str]\n    return_names: List[str]\n\n\nclass Sequential(torch.nn.Module):\n    r\"\"\"An extension of the :class:`torch.nn.Sequential` container in order to\n    define a sequential GNN model.\n\n    Since GNN operators take in multiple input arguments,\n    :class:`torch_geometric.nn.Sequential` additionally expects both global\n    input arguments, and function header definitions of individual operators.\n    If omitted, an intermediate module will operate on the *output* of its\n    preceding module:\n\n    .. code-block:: python\n\n        from torch.nn import Linear, ReLU\n        from torch_geometric.nn import Sequential, GCNConv\n\n        model = Sequential('x, edge_index', [\n            (GCNConv(in_channels, 64), 'x, edge_index -> x'),\n            ReLU(inplace=True),\n            (GCNConv(64, 64), 'x, edge_index -> x'),\n            ReLU(inplace=True),\n            Linear(64, out_channels),\n        ])\n\n    Here, :obj:`'x, edge_index'` defines the input arguments of :obj:`model`,\n    and :obj:`'x, edge_index -> x'` defines the function header, *i.e.* input\n    arguments *and* return types of :class:`~torch_geometric.nn.conv.GCNConv`.\n\n    In particular, this also allows to create more sophisticated models,\n    such as utilizing :class:`~torch_geometric.nn.models.JumpingKnowledge`:\n\n    .. code-block:: python\n\n        from torch.nn import Linear, ReLU, Dropout\n        from torch_geometric.nn import Sequential, GCNConv, JumpingKnowledge\n        from torch_geometric.nn import global_mean_pool\n\n        model = Sequential('x, edge_index, batch', [\n            (Dropout(p=0.5), 'x -> x'),\n            (GCNConv(dataset.num_features, 64), 'x, edge_index -> x1'),\n            ReLU(inplace=True),\n            (GCNConv(64, 64), 'x1, edge_index -> x2'),\n            ReLU(inplace=True),\n            (lambda x1, x2: [x1, x2], 'x1, x2 -> xs'),\n            (JumpingKnowledge(\"cat\", 64, num_layers=2), 'xs -> x'),\n            (global_mean_pool, 'x, batch -> x'),\n            Linear(2 * 64, dataset.num_classes),\n        ])\n\n    Args:\n        input_args (str): The input arguments of the model.\n        modules ([(Callable, str) or Callable]): A list of modules (with\n            optional function header definitions). Alternatively, an\n            :obj:`OrderedDict` of modules (and function header definitions) can\n            be passed.\n    \"\"\"\n    _children: List[Child]\n\n    def __init__(\n        self,\n        input_args: str,\n        modules: List[Union[Tuple[Callable, str], Callable]],\n    ) -> None:\n        super().__init__()\n\n        caller_path = inspect.stack()[1].filename\n        self._caller_module = osp.splitext(osp.basename(caller_path))[0]\n\n        _globals = copy.copy(globals())\n        _globals.update(sys.modules['__main__'].__dict__)\n        if self._caller_module in sys.modules:\n            _globals.update(sys.modules[self._caller_module].__dict__)\n\n        signature = input_args.split('->')\n        if len(signature) == 1:\n            args_repr = signature[0]\n            return_type_repr = 'Tensor'\n            return_type = Tensor\n        elif len(signature) == 2:\n            args_repr = signature[0]\n            return_type_repr = signature[1].strip()\n            return_type = eval_type(return_type_repr, _globals)\n        else:\n            raise ValueError(f\"Failed to parse arguments (got '{input_args}')\")\n\n        param_dict: Dict[str, Parameter] = {}\n        for arg in split(args_repr, sep=','):\n            signature = arg.split(':')\n            if len(signature) == 1:\n                name = signature[0].strip()\n                param_dict[name] = Parameter(\n                    name=name,\n                    type=Tensor,\n                    type_repr='Tensor',\n                    default=inspect._empty,\n                )\n            elif len(signature) == 2:\n                name = signature[0].strip()\n                param_dict[name] = Parameter(\n                    name=name,\n                    type=eval_type(signature[1].strip(), _globals),\n                    type_repr=signature[1].strip(),\n                    default=inspect._empty,\n                )\n            else:\n                raise ValueError(f\"Failed to parse argument \"\n                                 f\"(got '{arg.strip()}')\")\n\n        self.signature = Signature(param_dict, return_type, return_type_repr)\n\n        if not isinstance(modules, dict):\n            modules = {\n                f'module_{i}': module\n                for i, module in enumerate(modules)\n            }\n        if len(modules) == 0:\n            raise ValueError(f\"'{self.__class__.__name__}' expects a \"\n                             f\"non-empty list of modules\")\n\n        self._children: List[Child] = []\n        for i, (name, module) in enumerate(modules.items()):\n            desc: Optional[str] = None\n            if isinstance(module, (tuple, list)):\n                if len(module) == 1:\n                    module = module[0]\n                elif len(module) == 2:\n                    module, desc = module\n                else:\n                    raise ValueError(f\"Expected tuple of length 2 \"\n                                     f\"(got {module})\")\n\n            if i == 0 and desc is None:\n                raise ValueError(\"Signature for first module required\")\n            if not callable(module):\n                raise ValueError(f\"Expected callable module (got {module})\")\n            if desc is not None and not isinstance(desc, str):\n                raise ValueError(f\"Expected type hint representation \"\n                                 f\"(got {desc})\")\n\n            if desc is not None:\n                signature = desc.split('->')\n                if len(signature) != 2:\n                    raise ValueError(\n                        f\"Failed to parse arguments (got '{desc}')\")\n                param_names = [v.strip() for v in signature[0].split(',')]\n                return_names = [v.strip() for v in signature[1].split(',')]\n                child = Child(name, param_names, return_names)\n            else:\n                param_names = self._children[-1].return_names\n                child = Child(name, param_names, param_names)\n\n            setattr(self, name, module)\n            self._children.append(child)\n\n        self._set_jittable_template()\n\n    def reset_parameters(self) -> None:\n        r\"\"\"Resets all learnable parameters of the module.\"\"\"\n        for child in self._children:\n            module = getattr(self, child.name)\n            if hasattr(module, 'reset_parameters'):\n                module.reset_parameters()\n\n    def __len__(self) -> int:\n        return len(self._children)\n\n    def __getitem__(self, idx: int) -> torch.nn.Module:\n        return getattr(self, self._children[idx].name)\n\n    def __setstate__(self, data: Dict[str, Any]) -> None:\n        super().__setstate__(data)\n        self._set_jittable_template()\n\n    def __repr__(self) -> str:\n        module_descs = [\n            f\"{', '.join(c.param_names)} -> {', '.join(c.return_names)}\"\n            for c in self._children\n        ]\n        module_reprs = [\n            f'  ({i}) - {self[i]}: {module_descs[i]}' for i in range(len(self))\n        ]\n        return '{}(\\n{}\\n)'.format(\n            self.__class__.__name__,\n            '\\n'.join(module_reprs),\n        )\n\n    def forward(self, *args: Any, **kwargs: Any) -> Any:\n        \"\"\"\"\"\"  # noqa: D419\n        value_dict = {\n            name: arg\n            for name, arg in zip(self.signature.param_dict.keys(), args)\n        }\n        for key, arg in kwargs.items():\n            if key in value_dict:\n                raise TypeError(f\"'{self.__class__.__name__}' got multiple \"\n                                f\"values for argument '{key}'\")\n            value_dict[key] = arg\n\n        for child in self._children:\n            args = [value_dict[name] for name in child.param_names]\n            outs = getattr(self, child.name)(*args)\n            if len(child.return_names) == 1:\n                value_dict[child.return_names[0]] = outs\n            else:\n                for name, out in zip(child.return_names, outs):\n                    value_dict[name] = out\n\n        return outs\n\n    # TorchScript Support #####################################################\n\n    def _set_jittable_template(self, raise_on_error: bool = False) -> None:\n        try:  # Optimize `forward()` via `*.jinja` templates:\n            if ('forward' in self.__class__.__dict__ and\n                    self.__class__.__dict__['forward'] != Sequential.forward):\n                raise ValueError(\"Cannot compile custom 'forward' method\")\n\n            root_dir = osp.dirname(osp.realpath(__file__))\n            uid = '%06x' % random.randrange(16**6)\n            jinja_prefix = f'{self.__module__}_{self.__class__.__name__}_{uid}'\n            module = module_from_template(\n                module_name=jinja_prefix,\n                template_path=osp.join(root_dir, 'sequential.jinja'),\n                tmp_dirname='sequential',\n                # Keyword arguments:\n                modules=[self._caller_module],\n                signature=self.signature,\n                children=self._children,\n            )\n\n            self.forward = module.forward.__get__(self)\n\n            # NOTE We override `forward` on the class level here in order to\n            # support `torch.jit.trace` - this is generally dangerous to do,\n            # and limits `torch.jit.trace` to a single `Sequential` module:\n            self.__class__.forward = module.forward\n        except Exception as e:  # pragma: no cover\n            if raise_on_error:\n                raise e\n\n    def __prepare_scriptable__(self) -> 'Sequential':\n        # Prevent type sharing when scripting `Sequential` modules:\n        type_store = torch.jit._recursive.concrete_type_store.type_store\n        type_store.pop(self.__class__, None)\n        return self\n"
  },
  {
    "path": "torch_geometric/nn/summary.py",
    "content": "from collections import defaultdict\nfrom typing import Any, List, Optional, Union\n\nimport torch\nfrom torch.jit import ScriptModule\nfrom torch.nn import Module\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense.linear import is_uninitialized_parameter\nfrom torch_geometric.typing import SparseTensor\n\n\ndef summary(\n    model: torch.nn.Module,\n    *args,\n    max_depth: int = 3,\n    leaf_module: Optional[Union[Module, List[Module]]] = 'MessagePassing',\n    **kwargs,\n) -> str:\n    r\"\"\"Summarizes a given :class:`torch.nn.Module`.\n    The summarized information includes (1) layer names, (2) input and output\n    shapes, and (3) the number of parameters.\n\n    .. code-block:: python\n\n        import torch\n        from torch_geometric.nn import GCN, summary\n\n        model = GCN(128, 64, num_layers=2, out_channels=32)\n        x = torch.randn(100, 128)\n        edge_index = torch.randint(100, size=(2, 20))\n\n        print(summary(model, x, edge_index))\n\n    .. code-block::\n\n        +---------------------+---------------------+--------------+--------+\n        | Layer               | Input Shape         | Output Shape | #Param |\n        |---------------------+---------------------+--------------+--------|\n        | GCN                 | [100, 128], [2, 20] | [100, 32]    | 10,336 |\n        | ├─(act)ReLU         | [100, 64]           | [100, 64]    | --     |\n        | ├─(convs)ModuleList | --                  | --           | 10,336 |\n        | │    └─(0)GCNConv   | [100, 128], [2, 20] | [100, 64]    | 8,256  |\n        | │    └─(1)GCNConv   | [100, 64], [2, 20]  | [100, 32]    | 2,080  |\n        +---------------------+---------------------+--------------+--------+\n\n    Args:\n        model (torch.nn.Module): The model to summarize.\n        *args: The arguments of the :obj:`model`.\n        max_depth (int, optional): The depth of nested layers to display.\n            Any layers deeper than this depth will not be displayed in the\n            summary. (default: :obj:`3`)\n        leaf_module (torch.nn.Module or [torch.nn.Module], optional): The\n            modules to be treated as leaf modules, whose submodules are\n            excluded from the summary.\n            (default: :class:`~torch_geometric.nn.conv.MessagePassing`)\n        **kwargs: Additional arguments of the :obj:`model`.\n    \"\"\"\n    # NOTE This is just for the doc-string to render nicely:\n    if leaf_module == 'MessagePassing':\n        leaf_module = MessagePassing\n\n    def register_hook(info):\n        def hook(module, inputs, output):\n            info['input_shape'].append(get_shape(inputs))\n            info['output_shape'].append(get_shape(output))\n\n        return hook\n\n    hooks = {}\n    depth = 0\n    stack = [(model.__class__.__name__, model, depth)]\n\n    info_list = []\n    input_shape = defaultdict(list)\n    output_shape = defaultdict(list)\n    while stack:\n        name, module, depth = stack.pop()\n        module_id = id(module)\n\n        if name.startswith('(_'):  # Do not summarize private modules.\n            continue\n\n        if module_id in hooks:  # Avoid duplicated hooks.\n            hooks[module_id].remove()\n\n        info = {}\n        info['name'] = name\n        info['input_shape'] = input_shape[module_id]\n        info['output_shape'] = output_shape[module_id]\n        info['depth'] = depth\n        if any([is_uninitialized_parameter(p) for p in module.parameters()]):\n            info['#param'] = '-1'\n        else:\n            num_params = sum(p.numel() for p in module.parameters())\n            info['#param'] = f'{num_params:,}' if num_params > 0 else '--'\n        info_list.append(info)\n\n        if not isinstance(module, ScriptModule):\n            hooks[module_id] = module.register_forward_hook(\n                register_hook(info))\n\n        if depth >= max_depth:\n            continue\n\n        if (leaf_module is not None and isinstance(module, leaf_module)):\n            continue\n\n        module_items = reversed(module._modules.items())\n        stack += [(f\"({name}){mod.__class__.__name__}\", mod, depth + 1)\n                  for name, mod in module_items if mod is not None]\n\n    training = model.training\n    model.eval()\n\n    with torch.no_grad():\n        model(*args, **kwargs)\n\n    model.train(training)\n\n    for h in hooks.values():  # Remove hooks.\n        h.remove()\n\n    info_list = postprocess(info_list)\n    return make_table(info_list, max_depth=max_depth)\n\n\ndef get_shape(inputs: Any) -> str:\n    if not isinstance(inputs, (tuple, list)):\n        inputs = (inputs, )\n\n    out = []\n    for x in inputs:\n        if isinstance(x, SparseTensor):\n            out.append(str(list(x.sizes())))\n        elif hasattr(x, 'size'):\n            out.append(str(list(x.size())))\n    return ', '.join(out)\n\n\ndef postprocess(info_list: List[dict]) -> List[dict]:\n    for idx, info in enumerate(info_list):\n        depth = info['depth']\n        if idx > 0:  # root module (0) is excluded\n            if depth == 1:\n                prefix = '├─'\n            else:\n                prefix = f\"{'│    '*(depth-1)}└─\"\n            info['name'] = prefix + info['name']\n\n        if info['input_shape']:\n            info['input_shape'] = info['input_shape'].pop(0)\n            info['output_shape'] = info['output_shape'].pop(0)\n        else:\n            info['input_shape'] = '--'\n            info['output_shape'] = '--'\n    return info_list\n\n\ndef make_table(info_list: List[dict], max_depth: int) -> str:\n    from tabulate import tabulate\n    content = [['Layer', 'Input Shape', 'Output Shape', '#Param']]\n    for info in info_list:\n        content.append([\n            info['name'],\n            info['input_shape'],\n            info['output_shape'],\n            info['#param'],\n        ])\n    return tabulate(content, headers='firstrow', tablefmt='psql')\n"
  },
  {
    "path": "torch_geometric/nn/to_fixed_size_transformer.py",
    "content": "from typing import Any\n\nfrom torch.nn import Module\n\nfrom torch_geometric.nn.fx import Transformer\n\ntry:\n    from torch.fx import Graph, GraphModule, Node\nexcept (ImportError, ModuleNotFoundError, AttributeError):\n    GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node'\n\n\ndef to_fixed_size(module: Module, batch_size: int,\n                  debug: bool = False) -> GraphModule:\n    r\"\"\"Converts a model and injects a pre-computed and fixed batch size to all\n    global pooling operators.\n\n    Args:\n        module (torch.nn.Module): The model to transform.\n        batch_size (int): The fixed batch size used in global pooling modules.\n        debug (bool, optional): If set to :obj:`True`, will perform\n            transformation in debug mode. (default: :obj:`False`)\n    \"\"\"\n    transformer = ToFixedSizeTransformer(module, batch_size, debug)\n    return transformer.transform()\n\n\nclass ToFixedSizeTransformer(Transformer):\n    def __init__(self, module: Module, batch_size: int, debug: bool = False):\n        super().__init__(module, debug=debug)\n        self.batch_size = batch_size\n\n    def call_global_pooling_module(self, node: Node, target: Any, name: str):\n        kwargs = node.kwargs.copy()\n        kwargs['dim_size'] = self.batch_size\n        node.kwargs = kwargs\n"
  },
  {
    "path": "torch_geometric/nn/to_hetero_module.py",
    "content": "import copy\nimport warnings\nfrom typing import Dict, List, Optional, Union\n\nimport torch\nimport torch.nn.functional as F\nfrom torch import Tensor\n\nimport torch_geometric\nfrom torch_geometric import is_compiling\nfrom torch_geometric.typing import EdgeType, NodeType, OptTensor\nfrom torch_geometric.utils import cumsum, scatter\n\n\nclass ToHeteroLinear(torch.nn.Module):\n    def __init__(\n        self,\n        module: torch.nn.Module,\n        types: Union[List[NodeType], List[EdgeType]],\n    ):\n        from torch_geometric.nn import HeteroLinear, Linear\n\n        super().__init__()\n\n        self.types = types\n\n        if isinstance(module, Linear):\n            in_channels = module.in_channels\n            out_channels = module.out_channels\n            bias = module.bias is not None\n\n        elif isinstance(module, torch.nn.Linear):\n            in_channels = module.in_features\n            out_channels = module.out_features\n            bias = module.bias is not None\n\n        else:\n            raise ValueError(f\"Expected 'Linear' module (got '{type(module)}'\")\n\n        # TODO We currently assume that `x` is sorted according to `type`.\n        self.hetero_module = HeteroLinear(\n            in_channels,\n            out_channels,\n            num_types=len(types),\n            is_sorted=True,\n            bias=bias,\n        )\n\n    def fused_forward(self, x: Tensor, type_vec: Tensor) -> Tensor:\n        return self.hetero_module(x, type_vec)\n\n    def dict_forward(\n        self,\n        x_dict: Dict[Union[NodeType, EdgeType], Tensor],\n    ) -> Dict[Union[NodeType, EdgeType], Tensor]:\n\n        if not torch_geometric.typing.WITH_PYG_LIB or is_compiling():\n            return {\n                key:\n                F.linear(x_dict[key], self.hetero_module.weight[i].t()) +\n                self.hetero_module.bias[i]\n                for i, key in enumerate(self.types)\n            }\n\n        x = torch.cat([x_dict[key] for key in self.types], dim=0)\n        sizes = [x_dict[key].size(0) for key in self.types]\n        type_vec = torch.arange(len(self.types), device=x.device)\n        size = torch.tensor(sizes, device=x.device)\n        type_vec = type_vec.repeat_interleave(size)\n        outs = self.hetero_module(x, type_vec).split(sizes)\n        return {key: out for key, out in zip(self.types, outs)}\n\n    def forward(\n        self,\n        x: Union[Tensor, Dict[Union[NodeType, EdgeType], Tensor]],\n        type_vec: Optional[Tensor] = None,\n    ) -> Union[Tensor, Dict[Union[NodeType, EdgeType], Tensor]]:\n\n        if isinstance(x, dict):\n            return self.dict_forward(x)\n\n        elif isinstance(x, Tensor) and type_vec is not None:\n            return self.fused_forward(x, type_vec)\n\n        raise ValueError(f\"Encountered invalid forward types in \"\n                         f\"'{self.__class__.__name__}'\")\n\n\nclass ToHeteroMessagePassing(torch.nn.Module):\n    def __init__(\n        self,\n        module: torch.nn.Module,\n        node_types: List[NodeType],\n        edge_types: List[NodeType],\n        aggr: str = 'sum',\n    ):\n        from torch_geometric.nn import HeteroConv, MessagePassing\n\n        super().__init__()\n\n        self.node_types = node_types\n        self.node_type_to_index = {key: i for i, key in enumerate(node_types)}\n        self.edge_types = edge_types\n\n        if not isinstance(module, MessagePassing):\n            raise ValueError(f\"Expected 'MessagePassing' module \"\n                             f\"(got '{type(module)}'\")\n\n        if (not hasattr(module, 'reset_parameters')\n                and sum([p.numel() for p in module.parameters()]) > 0):\n            warnings.warn(\n                f\"'{module}' will be duplicated, but its parameters \"\n                f\"cannot be reset. To suppress this warning, add a \"\n                f\"'reset_parameters()' method to '{module}'\", stacklevel=2)\n\n        convs = {edge_type: copy.deepcopy(module) for edge_type in edge_types}\n        self.hetero_module = HeteroConv(convs, aggr)\n        self.hetero_module.reset_parameters()\n\n    def fused_forward(self, x: Tensor, edge_index: Tensor, node_type: Tensor,\n                      edge_type: Tensor) -> Tensor:\n        # TODO This currently does not fuse at all :(\n        # TODO We currently assume that `x` and `edge_index` are both sorted\n        # according to `type`.\n\n        node_sizes = scatter(torch.ones_like(node_type), node_type, dim=0,\n                             dim_size=len(self.node_types), reduce='sum')\n        edge_sizes = scatter(torch.ones_like(edge_type), edge_type, dim=0,\n                             dim_size=len(self.edge_types), reduce='sum')\n\n        ptr = cumsum(node_sizes)\n\n        xs = x.split(node_sizes.tolist())\n        x_dict = {node_type: x for node_type, x in zip(self.node_types, xs)}\n\n        # TODO Consider out-sourcing to its own function.\n        edge_indices = edge_index.clone().split(edge_sizes.tolist(), dim=1)\n        for (src, _, dst), index in zip(self.edge_types, edge_indices):\n            index[0] -= ptr[self.node_type_to_index[src]]\n            index[1] -= ptr[self.node_type_to_index[dst]]\n\n        edge_index_dict = {\n            edge_type: edge_index\n            for edge_type, edge_index in zip(self.edge_types, edge_indices)\n        }\n\n        out_dict = self.hetero_module(x_dict, edge_index_dict)\n        return torch.cat([out_dict[key] for key in self.node_types], dim=0)\n\n    def dict_forward(\n        self,\n        x_dict: Dict[NodeType, Tensor],\n        edge_index_dict: Dict[EdgeType, Tensor],\n        **kwargs,\n    ) -> Dict[NodeType, Tensor]:\n        return self.hetero_module(x_dict, edge_index_dict, **kwargs)\n\n    def forward(\n        self,\n        x: Union[Tensor, Dict[NodeType, Tensor]],\n        edge_index: Union[Tensor, Dict[EdgeType, Tensor]],\n        node_type: OptTensor = None,\n        edge_type: OptTensor = None,\n        **kwargs,\n    ) -> Union[Tensor, Dict[NodeType, Tensor]]:\n\n        if isinstance(x, dict) and isinstance(edge_index, dict):\n            return self.dict_forward(x, edge_index, **kwargs)\n\n        elif (isinstance(x, Tensor) and isinstance(edge_index, Tensor)\n              and node_type is not None and edge_type is not None):\n\n            if len(kwargs) > 0:\n                raise ValueError(\"Additional forward arguments not yet \"\n                                 \"supported in fused mode\")\n\n            return self.fused_forward(x, edge_index, node_type, edge_type)\n\n        raise ValueError(f\"Encountered invalid forward types in \"\n                         f\"'{self.__class__.__name__}'\")\n"
  },
  {
    "path": "torch_geometric/nn/to_hetero_transformer.py",
    "content": "import copy\nimport warnings\nfrom collections import defaultdict, deque\nfrom typing import Any, Dict, Optional, Tuple, Union\n\nimport torch\nfrom torch.nn import Module\n\nfrom torch_geometric.nn.dense.linear import is_uninitialized_parameter\nfrom torch_geometric.nn.fx import Transformer, get_submodule\nfrom torch_geometric.typing import EdgeType, Metadata, NodeType\nfrom torch_geometric.utils.hetero import (\n    check_add_self_loops,\n    get_unused_node_types,\n)\n\ntry:\n    from torch.fx import Graph, GraphModule, Node\nexcept (ImportError, ModuleNotFoundError, AttributeError):\n    GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node'\n\n\ndef get_dict(mapping: Optional[Dict[str, Any]]) -> Dict[str, Any]:\n    return mapping if mapping is not None else {}\n\n\ndef to_hetero(module: Module, metadata: Metadata, aggr: str = \"sum\",\n              input_map: Optional[Dict[str, str]] = None,\n              debug: bool = False) -> GraphModule:\n    r\"\"\"Converts a homogeneous GNN model into its heterogeneous equivalent in\n    which node representations are learned for each node type in\n    :obj:`metadata[0]`, and messages are exchanged between each edge type in\n    :obj:`metadata[1]`, as denoted in the `\"Modeling Relational Data with Graph\n    Convolutional Networks\" <https://arxiv.org/abs/1703.06103>`_ paper.\n\n    .. code-block:: python\n\n        import torch\n        from torch_geometric.nn import SAGEConv, to_hetero\n\n        class GNN(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv1 = SAGEConv((-1, -1), 32)\n                self.conv2 = SAGEConv((32, 32), 32)\n\n            def forward(self, x, edge_index):\n                x = self.conv1(x, edge_index).relu()\n                x = self.conv2(x, edge_index).relu()\n                return x\n\n        model = GNN()\n\n        node_types = ['paper', 'author']\n        edge_types = [\n            ('paper', 'cites', 'paper'),\n            ('paper', 'written_by', 'author'),\n            ('author', 'writes', 'paper'),\n        ]\n        metadata = (node_types, edge_types)\n\n        model = to_hetero(model, metadata)\n        model(x_dict, edge_index_dict)\n\n    where :obj:`x_dict` and :obj:`edge_index_dict` denote dictionaries that\n    hold node features and edge connectivity information for each node type and\n    edge type, respectively.\n\n    The below illustration shows the original computation graph of the\n    homogeneous model on the left, and the newly obtained computation graph of\n    the heterogeneous model on the right:\n\n    .. figure:: ../_figures/to_hetero.svg\n      :align: center\n      :width: 90%\n\n      Transforming a model via :func:`to_hetero`.\n\n    Here, each :class:`~torch_geometric.nn.conv.MessagePassing` instance\n    :math:`f_{\\theta}^{(\\ell)}` is duplicated and stored in a set\n    :math:`\\{ f_{\\theta}^{(\\ell, r)} : r \\in \\mathcal{R} \\}` (one instance for\n    each relation in :math:`\\mathcal{R}`), and message passing in layer\n    :math:`\\ell` is performed via\n\n    .. math::\n\n        \\mathbf{h}^{(\\ell)}_v = \\bigoplus_{r \\in \\mathcal{R}}\n        f_{\\theta}^{(\\ell, r)} ( \\mathbf{h}^{(\\ell - 1)}_v, \\{\n        \\mathbf{h}^{(\\ell - 1)}_w : w \\in \\mathcal{N}^{(r)}(v) \\}),\n\n    where :math:`\\mathcal{N}^{(r)}(v)` denotes the neighborhood of :math:`v \\in\n    \\mathcal{V}` under relation :math:`r \\in \\mathcal{R}`, and\n    :math:`\\bigoplus` denotes the aggregation scheme :attr:`aggr` to use for\n    grouping node embeddings generated by different relations\n    (:obj:`\"sum\"`, :obj:`\"mean\"`, :obj:`\"min\"`, :obj:`\"max\"` or :obj:`\"mul\"`).\n\n    Args:\n        module (torch.nn.Module): The homogeneous model to transform.\n        metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata\n            of the heterogeneous graph, *i.e.* its node and edge types given\n            by a list of strings and a list of string triplets, respectively.\n            See :meth:`torch_geometric.data.HeteroData.metadata` for more\n            information.\n        aggr (str, optional): The aggregation scheme to use for grouping node\n            embeddings generated by different relations\n            (:obj:`\"sum\"`, :obj:`\"mean\"`, :obj:`\"min\"`, :obj:`\"max\"`,\n            :obj:`\"mul\"`). (default: :obj:`\"sum\"`)\n        input_map (Dict[str, str], optional): A dictionary holding information\n            about the type of input arguments of :obj:`module.forward`.\n            For example, in case :obj:`arg` is a node-level argument, then\n            :obj:`input_map['arg'] = 'node'`, and\n            :obj:`input_map['arg'] = 'edge'` otherwise.\n            In case :obj:`input_map` is not further specified, will try to\n            automatically determine the correct type of input arguments.\n            (default: :obj:`None`)\n        debug (bool, optional): If set to :obj:`True`, will perform\n            transformation in debug mode. (default: :obj:`False`)\n    \"\"\"\n    transformer = ToHeteroTransformer(module, metadata, aggr, input_map, debug)\n    return transformer.transform()\n\n\nclass ToHeteroTransformer(Transformer):\n\n    aggrs = {\n        'sum': torch.add,\n        # For 'mean' aggregation, we first sum up all feature matrices, and\n        # divide by the number of matrices in a later step.\n        'mean': torch.add,\n        'max': torch.max,\n        'min': torch.min,\n        'mul': torch.mul,\n    }\n\n    def __init__(\n        self,\n        module: Module,\n        metadata: Metadata,\n        aggr: str = 'sum',\n        input_map: Optional[Dict[str, str]] = None,\n        debug: bool = False,\n    ):\n        super().__init__(module, input_map, debug)\n\n        self.metadata = metadata\n        self.aggr = aggr\n        assert len(metadata) == 2\n        assert len(metadata[0]) > 0 and len(metadata[1]) > 0\n        assert aggr in self.aggrs.keys()\n\n        self.validate()\n\n    def validate(self):\n        unused_node_types = get_unused_node_types(*self.metadata)\n        if len(unused_node_types) > 0:\n            warnings.warn(\n                f\"There exist node types ({unused_node_types}) whose \"\n                f\"representations do not get updated during message passing \"\n                f\"as they do not occur as destination type in any edge type. \"\n                f\"This may lead to unexpected behavior.\", stacklevel=2)\n\n        names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]\n        for name in names:\n            if not name.isidentifier():\n                warnings.warn(\n                    f\"The type '{name}' contains invalid characters which \"\n                    f\"may lead to unexpected behavior. To avoid any issues, \"\n                    f\"ensure that your types only contain letters, numbers \"\n                    f\"and underscores.\", stacklevel=2)\n\n    def placeholder(self, node: Node, target: Any, name: str):\n        # Adds a `get` call to the input dictionary for every node-type or\n        # edge-type.\n        if node.type is not None:\n            Type = EdgeType if self.is_edge_level(node) else NodeType\n            node.type = Dict[Type, node.type]\n\n        self.graph.inserting_after(node)\n\n        dict_node = self.graph.create_node('call_function', target=get_dict,\n                                           args=(node, ), name=f'{name}_dict')\n        self.graph.inserting_after(dict_node)\n\n        for key in self.metadata[int(self.is_edge_level(node))]:\n            out = self.graph.create_node('call_method', target='get',\n                                         args=(dict_node, key, None),\n                                         name=f'{name}__{key2str(key)}')\n            self.graph.inserting_after(out)\n\n    def get_attr(self, node: Node, target: Any, name: str):\n        raise NotImplementedError\n\n    def call_message_passing_module(self, node: Node, target: Any, name: str):\n        # Add calls to edge type-wise `MessagePassing` modules and aggregate\n        # the outputs to node type-wise embeddings afterwards.\n\n        module = get_submodule(self.module, target)\n        check_add_self_loops(module, self.metadata[1])\n\n        # Group edge-wise keys per destination:\n        key_name, keys_per_dst = {}, defaultdict(list)\n        for key in self.metadata[1]:\n            keys_per_dst[key[-1]].append(key)\n            key_name[key] = f'{name}__{key[-1]}{len(keys_per_dst[key[-1]])}'\n\n        for dst, keys in dict(keys_per_dst).items():\n            # In case there is only a single edge-wise connection, there is no\n            # need for any destination-wise aggregation, and we can already set\n            # the intermediate variable name to the final output name.\n            if len(keys) == 1:\n                key_name[keys[0]] = f'{name}__{dst}'\n                del keys_per_dst[dst]\n\n        self.graph.inserting_after(node)\n        for key in self.metadata[1]:\n            args, kwargs = self.map_args_kwargs(node, key)\n            out = self.graph.create_node('call_module',\n                                         target=f'{target}.{key2str(key)}',\n                                         args=args, kwargs=kwargs,\n                                         name=key_name[key])\n            self.graph.inserting_after(out)\n\n        # Perform destination-wise aggregation.\n        # Here, we aggregate in pairs, popping the first two elements of\n        # `keys_per_dst` and append the result to the list.\n        for dst, keys in keys_per_dst.items():\n            queue = deque([key_name[key] for key in keys])\n            i = 1\n            while len(queue) >= 2:\n                key1, key2 = queue.popleft(), queue.popleft()\n                args = (self.find_by_name(key1), self.find_by_name(key2))\n\n                new_name = f'{name}__{dst}'\n                if self.aggr == 'mean' or len(queue) > 0:\n                    new_name = f'{new_name}_{i}'\n\n                out = self.graph.create_node('call_function',\n                                             target=self.aggrs[self.aggr],\n                                             args=args, name=new_name)\n                self.graph.inserting_after(out)\n                queue.append(new_name)\n                i += 1\n\n            if self.aggr == 'mean':\n                key = queue.popleft()\n                out = self.graph.create_node(\n                    'call_function', target=torch.div,\n                    args=(self.find_by_name(key), len(keys_per_dst[dst])),\n                    name=f'{name}__{dst}')\n                self.graph.inserting_after(out)\n\n    def call_global_pooling_module(self, node: Node, target: Any, name: str):\n        # Add calls to node type-wise `GlobalPooling` modules and aggregate\n        # the outputs to graph type-wise embeddings afterwards.\n        self.graph.inserting_after(node)\n        for key in self.metadata[0]:\n            args, kwargs = self.map_args_kwargs(node, key)\n            out = self.graph.create_node('call_module',\n                                         target=f'{target}.{key2str(key)}',\n                                         args=args, kwargs=kwargs,\n                                         name=f'{node.name}__{key2str(key)}')\n            self.graph.inserting_after(out)\n\n        # Perform node-wise aggregation.\n        queue = deque(\n            [f'{node.name}__{key2str(key)}' for key in self.metadata[0]])\n        i = 1\n        while len(queue) >= 2:\n            key1, key2 = queue.popleft(), queue.popleft()\n            args = (self.find_by_name(key1), self.find_by_name(key2))\n            out = self.graph.create_node('call_function',\n                                         target=self.aggrs[self.aggr],\n                                         args=args, name=f'{name}_{i}')\n            self.graph.inserting_after(out)\n            queue.append(f'{name}_{i}')\n            i += 1\n\n        if self.aggr == 'mean':\n            key = queue.popleft()\n            out = self.graph.create_node(\n                'call_function', target=torch.div,\n                args=(self.find_by_name(key), len(self.metadata[0])),\n                name=f'{name}_{i}')\n            self.graph.inserting_after(out)\n        self.replace_all_uses_with(node, out)\n\n    def call_module(self, node: Node, target: Any, name: str):\n        if self.is_graph_level(node):\n            return\n\n        # Add calls to node type-wise or edge type-wise modules.\n        self.graph.inserting_after(node)\n        for key in self.metadata[int(self.is_edge_level(node))]:\n            args, kwargs = self.map_args_kwargs(node, key)\n            out = self.graph.create_node('call_module',\n                                         target=f'{target}.{key2str(key)}',\n                                         args=args, kwargs=kwargs,\n                                         name=f'{name}__{key2str(key)}')\n            self.graph.inserting_after(out)\n\n    def call_method(self, node: Node, target: Any, name: str):\n        if self.is_graph_level(node):\n            return\n\n        # Add calls to node type-wise or edge type-wise methods.\n        self.graph.inserting_after(node)\n        for key in self.metadata[int(self.is_edge_level(node))]:\n            args, kwargs = self.map_args_kwargs(node, key)\n            out = self.graph.create_node('call_method', target=target,\n                                         args=args, kwargs=kwargs,\n                                         name=f'{name}__{key2str(key)}')\n            self.graph.inserting_after(out)\n\n    def call_function(self, node: Node, target: Any, name: str):\n        if self.is_graph_level(node):\n            return\n\n        # Add calls to node type-wise or edge type-wise functions.\n        self.graph.inserting_after(node)\n        for key in self.metadata[int(self.is_edge_level(node))]:\n            args, kwargs = self.map_args_kwargs(node, key)\n            out = self.graph.create_node('call_function', target=target,\n                                         args=args, kwargs=kwargs,\n                                         name=f'{name}__{key2str(key)}')\n            self.graph.inserting_after(out)\n\n    def output(self, node: Node, target: Any, name: str):\n        # Replace the output by dictionaries, holding either node type-wise or\n        # edge type-wise data.\n        def _recurse(value: Any) -> Any:\n            if isinstance(value, Node):\n                if self.is_graph_level(value):\n                    return value\n                return {\n                    key: self.find_by_name(f'{value.name}__{key2str(key)}')\n                    for key in self.metadata[int(self.is_edge_level(value))]\n                }\n            elif isinstance(value, dict):\n                return {k: _recurse(v) for k, v in value.items()}\n            elif isinstance(value, list):\n                return [_recurse(v) for v in value]\n            elif isinstance(value, tuple):\n                return tuple(_recurse(v) for v in value)\n            else:\n                return value\n\n        if node.type is not None and isinstance(node.args[0], Node):\n            output = node.args[0]\n            if self.is_node_level(output):\n                node.type = Dict[NodeType, node.type]\n            elif self.is_edge_level(output):\n                node.type = Dict[EdgeType, node.type]\n        else:\n            node.type = None\n\n        node.args = (_recurse(node.args[0]), )\n\n    def init_submodule(self, module: Module, target: str) -> Module:\n        # Replicate each module for each node type or edge type.\n        has_node_level_target = bool(\n            self.find_by_target(f'{target}.{key2str(self.metadata[0][0])}'))\n        has_edge_level_target = bool(\n            self.find_by_target(f'{target}.{key2str(self.metadata[1][0])}'))\n\n        if not has_node_level_target and not has_edge_level_target:\n            return module\n\n        module_dict = torch.nn.ModuleDict()\n        for key in self.metadata[int(has_edge_level_target)]:\n            module_dict[key2str(key)] = copy.deepcopy(module)\n            if len(self.metadata[int(has_edge_level_target)]) <= 1:\n                continue\n            if hasattr(module, 'reset_parameters'):\n                module_dict[key2str(key)].reset_parameters()\n            elif sum([\n                    is_uninitialized_parameter(p) or p.numel()\n                    for p in module.parameters()\n            ]) > 0:\n                warnings.warn(\n                    f\"'{target}' will be duplicated, but its parameters \"\n                    f\"cannot be reset. To suppress this warning, add a \"\n                    f\"'reset_parameters()' method to '{target}'\", stacklevel=2)\n\n        return module_dict\n\n    # Helper methods ##########################################################\n\n    def map_args_kwargs(self, node: Node,\n                        key: Union[NodeType, EdgeType]) -> Tuple[Tuple, Dict]:\n        def _recurse(value: Any) -> Any:\n            if isinstance(value, Node):\n                out = self.find_by_name(f'{value.name}__{key2str(key)}')\n                if out is not None:\n                    return out\n                elif isinstance(key, tuple) and key[0] == key[-1]:\n                    name = f'{value.name}__{key2str(key[0])}'\n                    return self.find_by_name(name)\n                elif isinstance(key, tuple) and key[0] != key[-1]:\n                    return (\n                        self.find_by_name(f'{value.name}__{key2str(key[0])}'),\n                        self.find_by_name(f'{value.name}__{key2str(key[-1])}'),\n                    )\n                else:\n                    raise ValueError(f\"Cannot generate a graph node '{node}' \"\n                                     f\"for type '{key}' since it does not \"\n                                     f\"exist. Please make sure that all \"\n                                     f\"node types get updated during message \"\n                                     f\"passing.\")\n            elif isinstance(value, dict):\n                return {k: _recurse(v) for k, v in value.items()}\n            elif isinstance(value, list):\n                return [_recurse(v) for v in value]\n            elif isinstance(value, tuple):\n                return tuple(_recurse(v) for v in value)\n            else:\n                return value\n\n        args = tuple(_recurse(v) for v in node.args)\n        kwargs = {k: _recurse(v) for k, v in node.kwargs.items()}\n        return args, kwargs\n\n\ndef key2str(key: Union[NodeType, EdgeType]) -> str:\n    key = '__'.join(key) if isinstance(key, tuple) else key\n    return key.replace(' ', '_').replace('-', '_').replace(':', '_')\n"
  },
  {
    "path": "torch_geometric/nn/to_hetero_with_bases_transformer.py",
    "content": "import copy\nimport warnings\nfrom typing import Any, Dict, List, Optional, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import Module, Parameter\n\nfrom torch_geometric.nn.conv import MessagePassing\nfrom torch_geometric.nn.dense import Linear\nfrom torch_geometric.nn.fx import Transformer\nfrom torch_geometric.typing import EdgeType, Metadata, NodeType, SparseTensor\nfrom torch_geometric.utils.hetero import get_unused_node_types\n\ntry:\n    from torch.fx import Graph, GraphModule, Node\nexcept (ImportError, ModuleNotFoundError, AttributeError):\n    GraphModule, Graph, Node = 'GraphModule', 'Graph', 'Node'\n\n\ndef to_hetero_with_bases(module: Module, metadata: Metadata, num_bases: int,\n                         in_channels: Optional[Dict[str, int]] = None,\n                         input_map: Optional[Dict[str, str]] = None,\n                         debug: bool = False) -> GraphModule:\n    r\"\"\"Converts a homogeneous GNN model into its heterogeneous equivalent\n    via the basis-decomposition technique introduced in the\n    `\"Modeling Relational Data with Graph Convolutional Networks\"\n    <https://arxiv.org/abs/1703.06103>`_ paper.\n\n    For this, the heterogeneous graph is mapped to a typed homogeneous graph,\n    in which its feature representations are aligned and grouped to a single\n    representation.\n    All GNN layers inside the model will then perform message passing via\n    basis-decomposition regularization.\n    This transformation is especially useful in highly multi-relational data,\n    such that the number of parameters no longer depend on the number of\n    relations of the input graph:\n\n    .. code-block:: python\n\n        import torch\n        from torch_geometric.nn import SAGEConv, to_hetero_with_bases\n\n        class GNN(torch.nn.Module):\n            def __init__(self):\n                super().__init__()\n                self.conv1 = SAGEConv((16, 16), 32)\n                self.conv2 = SAGEConv((32, 32), 32)\n\n            def forward(self, x, edge_index):\n                x = self.conv1(x, edge_index).relu()\n                x = self.conv2(x, edge_index).relu()\n                return x\n\n        model = GNN()\n\n        node_types = ['paper', 'author']\n        edge_types = [\n            ('paper', 'cites', 'paper'),\n            ('paper', 'written_by', 'author'),\n            ('author', 'writes', 'paper'),\n        ]\n        metadata = (node_types, edge_types)\n\n        model = to_hetero_with_bases(model, metadata, num_bases=3,\n                                     in_channels={'x': 16})\n        model(x_dict, edge_index_dict)\n\n    where :obj:`x_dict` and :obj:`edge_index_dict` denote dictionaries that\n    hold node features and edge connectivity information for each node type and\n    edge type, respectively.\n    In case :obj:`in_channels` is given for a specific input argument, its\n    heterogeneous feature information is first aligned to the given\n    dimensionality.\n\n    The below illustration shows the original computation graph of the\n    homogeneous model on the left, and the newly obtained computation graph of\n    the regularized heterogeneous model on the right:\n\n    .. figure:: ../_figures/to_hetero_with_bases.svg\n      :align: center\n      :width: 90%\n\n      Transforming a model via :func:`to_hetero_with_bases`.\n\n    Here, each :class:`~torch_geometric.nn.conv.MessagePassing` instance\n    :math:`f_{\\theta}^{(\\ell)}` is duplicated :obj:`num_bases` times and\n    stored in a set :math:`\\{ f_{\\theta}^{(\\ell, b)} : b \\in \\{ 1, \\ldots, B \\}\n    \\}` (one instance for each basis in\n    :obj:`num_bases`), and message passing in layer :math:`\\ell` is performed\n    via\n\n    .. math::\n\n        \\mathbf{h}^{(\\ell)}_v = \\sum_{r \\in \\mathcal{R}} \\sum_{b=1}^B\n        f_{\\theta}^{(\\ell, b)} ( \\mathbf{h}^{(\\ell - 1)}_v, \\{\n        a^{(\\ell)}_{r, b} \\cdot \\mathbf{h}^{(\\ell - 1)}_w :\n        w \\in \\mathcal{N}^{(r)}(v) \\}),\n\n    where :math:`\\mathcal{N}^{(r)}(v)` denotes the neighborhood of :math:`v \\in\n    \\mathcal{V}` under relation :math:`r \\in \\mathcal{R}`.\n    Notably, only the trainable basis coefficients :math:`a^{(\\ell)}_{r, b}`\n    depend on the relations in :math:`\\mathcal{R}`.\n\n    Args:\n        module (torch.nn.Module): The homogeneous model to transform.\n        metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata\n            of the heterogeneous graph, *i.e.* its node and edge types given\n            by a list of strings and a list of string triplets, respectively.\n            See :meth:`torch_geometric.data.HeteroData.metadata` for more\n            information.\n        num_bases (int): The number of bases to use.\n        in_channels (Dict[str, int], optional): A dictionary holding\n            information about the desired input feature dimensionality of\n            input arguments of :obj:`module.forward`.\n            In case :obj:`in_channels` is given for a specific input argument,\n            its heterogeneous feature information is first aligned to the given\n            dimensionality.\n            This allows handling of node and edge features with varying feature\n            dimensionality across different types. (default: :obj:`None`)\n        input_map (Dict[str, str], optional): A dictionary holding information\n            about the type of input arguments of :obj:`module.forward`.\n            For example, in case :obj:`arg` is a node-level argument, then\n            :obj:`input_map['arg'] = 'node'`, and\n            :obj:`input_map['arg'] = 'edge'` otherwise.\n            In case :obj:`input_map` is not further specified, will try to\n            automatically determine the correct type of input arguments.\n            (default: :obj:`None`)\n        debug (bool, optional): If set to :obj:`True`, will perform\n            transformation in debug mode. (default: :obj:`False`)\n    \"\"\"\n    transformer = ToHeteroWithBasesTransformer(module, metadata, num_bases,\n                                               in_channels, input_map, debug)\n    return transformer.transform()\n\n\nclass ToHeteroWithBasesTransformer(Transformer):\n    def __init__(\n        self,\n        module: Module,\n        metadata: Metadata,\n        num_bases: int,\n        in_channels: Optional[Dict[str, int]] = None,\n        input_map: Optional[Dict[str, str]] = None,\n        debug: bool = False,\n    ):\n        super().__init__(module, input_map, debug)\n\n        self.metadata = metadata\n        self.num_bases = num_bases\n        self.in_channels = in_channels or {}\n        assert len(metadata) == 2\n        assert len(metadata[0]) > 0 and len(metadata[1]) > 0\n\n        self.validate()\n\n        # Compute IDs for each node and edge type:\n        self.node_type2id = {k: i for i, k in enumerate(metadata[0])}\n        self.edge_type2id = {k: i for i, k in enumerate(metadata[1])}\n\n    def validate(self):\n        unused_node_types = get_unused_node_types(*self.metadata)\n        if len(unused_node_types) > 0:\n            warnings.warn(\n                f\"There exist node types ({unused_node_types}) whose \"\n                f\"representations do not get updated during message passing \"\n                f\"as they do not occur as destination type in any edge type. \"\n                f\"This may lead to unexpected behavior.\", stacklevel=2)\n\n        names = self.metadata[0] + [rel for _, rel, _ in self.metadata[1]]\n        for name in names:\n            if not name.isidentifier():\n                warnings.warn(\n                    f\"The type '{name}' contains invalid characters which \"\n                    f\"may lead to unexpected behavior. To avoid any issues, \"\n                    f\"ensure that your types only contain letters, numbers \"\n                    f\"and underscores.\", stacklevel=2)\n\n    def transform(self) -> GraphModule:\n        self._node_offset_dict_initialized = False\n        self._edge_offset_dict_initialized = False\n        self._edge_type_initialized = False\n        out = super().transform()\n        del self._node_offset_dict_initialized\n        del self._edge_offset_dict_initialized\n        del self._edge_type_initialized\n        return out\n\n    def placeholder(self, node: Node, target: Any, name: str):\n        if node.type is not None:\n            Type = EdgeType if self.is_edge_level(node) else NodeType\n            node.type = Dict[Type, node.type]\n\n        out = node\n\n        # Create `node_offset_dict` and `edge_offset_dict` dictionaries in case\n        # they are not yet initialized. These dictionaries hold the cumulated\n        # sizes used to create a unified graph representation and to split the\n        # output data.\n        if self.is_edge_level(node) and not self._edge_offset_dict_initialized:\n            self.graph.inserting_after(out)\n            out = self.graph.create_node('call_function',\n                                         target=get_edge_offset_dict,\n                                         args=(node, self.edge_type2id),\n                                         name='edge_offset_dict')\n            self._edge_offset_dict_initialized = True\n\n        elif not self._node_offset_dict_initialized:\n            self.graph.inserting_after(out)\n            out = self.graph.create_node('call_function',\n                                         target=get_node_offset_dict,\n                                         args=(node, self.node_type2id),\n                                         name='node_offset_dict')\n            self._node_offset_dict_initialized = True\n\n        # Create a `edge_type` tensor used as input to `HeteroBasisConv`:\n        if self.is_edge_level(node) and not self._edge_type_initialized:\n            self.graph.inserting_after(out)\n            out = self.graph.create_node('call_function', target=get_edge_type,\n                                         args=(node, self.edge_type2id),\n                                         name='edge_type')\n            self._edge_type_initialized = True\n\n        # Add `Linear` operation to align features to the same dimensionality:\n        if name in self.in_channels:\n            self.graph.inserting_after(out)\n            out = self.graph.create_node('call_module',\n                                         target=f'align_lin__{name}',\n                                         args=(node, ),\n                                         name=f'{name}__aligned')\n            self._state[out.name] = self._state[name]\n\n            lin = LinearAlign(self.metadata[int(self.is_edge_level(node))],\n                              self.in_channels[name])\n            setattr(self.module, f'align_lin__{name}', lin)\n\n        # Perform grouping of type-wise values into a single tensor:\n        if self.is_edge_level(node):\n            self.graph.inserting_after(out)\n            out = self.graph.create_node(\n                'call_function', target=group_edge_placeholder,\n                args=(out if name in self.in_channels else node,\n                      self.edge_type2id,\n                      self.find_by_name('node_offset_dict')),\n                name=f'{name}__grouped')\n            self._state[out.name] = 'edge'\n\n        else:\n            self.graph.inserting_after(out)\n            out = self.graph.create_node(\n                'call_function', target=group_node_placeholder,\n                args=(out if name in self.in_channels else node,\n                      self.node_type2id), name=f'{name}__grouped')\n            self._state[out.name] = 'node'\n\n        self.replace_all_uses_with(node, out)\n\n    def call_message_passing_module(self, node: Node, target: Any, name: str):\n        # Call the `HeteroBasisConv` wrapper instead instead of a single\n        # message passing layer. We need to inject the `edge_type` as first\n        # argument in order to do so.\n        node.args = (self.find_by_name('edge_type'), ) + node.args\n\n    def output(self, node: Node, target: Any, name: str):\n        # Split the output to dictionaries, holding either node type-wise or\n        # edge type-wise data.\n        def _recurse(value: Any) -> Any:\n            if isinstance(value, Node) and self.is_edge_level(value):\n                self.graph.inserting_before(node)\n                return self.graph.create_node(\n                    'call_function', target=split_output,\n                    args=(value, self.find_by_name('edge_offset_dict')),\n                    name=f'{value.name}__split')\n\n            elif isinstance(value, Node):\n                self.graph.inserting_before(node)\n                return self.graph.create_node(\n                    'call_function', target=split_output,\n                    args=(value, self.find_by_name('node_offset_dict')),\n                    name=f'{value.name}__split')\n\n            elif isinstance(value, dict):\n                return {k: _recurse(v) for k, v in value.items()}\n            elif isinstance(value, list):\n                return [_recurse(v) for v in value]\n            elif isinstance(value, tuple):\n                return tuple(_recurse(v) for v in value)\n            else:\n                return value\n\n        if node.type is not None and isinstance(node.args[0], Node):\n            output = node.args[0]\n            Type = EdgeType if self.is_edge_level(output) else NodeType\n            node.type = Dict[Type, node.type]\n        else:\n            node.type = None\n\n        node.args = (_recurse(node.args[0]), )\n\n    def init_submodule(self, module: Module, target: str) -> Module:\n        if not isinstance(module, MessagePassing):\n            return module\n\n        # Replace each `MessagePassing` module by a `HeteroBasisConv` wrapper:\n        return HeteroBasisConv(module, len(self.metadata[1]), self.num_bases)\n\n\n###############################################################################\n\n\n# We make use of a post-message computation hook to inject the\n# basis re-weighting for each individual edge type.\n# This currently requires us to set `conv.fuse = False`, which leads\n# to a materialization of messages.\ndef hook(module, inputs, output):\n    assert isinstance(module._edge_type, Tensor)\n    if module._edge_type.size(0) != output.size(-2):\n        raise ValueError(\n            f\"Number of messages ({output.size(0)}) does not match \"\n            f\"with the number of original edges \"\n            f\"({module._edge_type.size(0)}). Does your message \"\n            f\"passing layer create additional self-loops? Try to \"\n            f\"remove them via 'add_self_loops=False'\")\n    weight = module.edge_type_weight.view(-1)[module._edge_type]\n    weight = weight.view([1] * (output.dim() - 2) + [-1, 1])\n    return weight * output\n\n\nclass HeteroBasisConv(torch.nn.Module):\n    # A wrapper layer that applies the basis-decomposition technique to a\n    # heterogeneous graph.\n    def __init__(self, module: MessagePassing, num_relations: int,\n                 num_bases: int):\n        super().__init__()\n\n        self.num_relations = num_relations\n        self.num_bases = num_bases\n\n        params = list(module.parameters())\n        device = params[0].device if len(params) > 0 else 'cpu'\n\n        self.convs = torch.nn.ModuleList()\n        for _ in range(num_bases):\n            conv = copy.deepcopy(module)\n            conv.fuse = False  # Disable `message_and_aggregate` functionality.\n            # We learn a single scalar weight for each individual edge type,\n            # which is used to weight the output message based on edge type:\n            conv.edge_type_weight = Parameter(\n                torch.empty(1, num_relations, device=device))\n            conv.register_message_forward_hook(hook)\n            self.convs.append(conv)\n\n        if self.num_bases > 1:\n            self.reset_parameters()\n\n    def reset_parameters(self):\n        for conv in self.convs:\n            if hasattr(conv, 'reset_parameters'):\n                conv.reset_parameters()\n            elif sum([p.numel() for p in conv.parameters()]) > 0:\n                warnings.warn(\n                    f\"'{conv}' will be duplicated, but its parameters cannot \"\n                    f\"be reset. To suppress this warning, add a \"\n                    f\"'reset_parameters()' method to '{conv}'\", stacklevel=2)\n            torch.nn.init.xavier_uniform_(conv.edge_type_weight)\n\n    def forward(self, edge_type: Tensor, *args, **kwargs) -> Tensor:\n        out = None\n        # Call message passing modules and perform aggregation:\n        for conv in self.convs:\n            conv._edge_type = edge_type\n            res = conv(*args, **kwargs)\n            del conv._edge_type\n            out = res if out is None else out.add_(res)\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(num_relations='\n                f'{self.num_relations}, num_bases={self.num_bases})')\n\n\nclass LinearAlign(torch.nn.Module):\n    # Aligns representations to the same dimensionality. Note that this will\n    # create lazy modules, and as such requires a forward pass in order to\n    # initialize parameters.\n    def __init__(self, keys: List[Union[NodeType, EdgeType]],\n                 out_channels: int):\n        super().__init__()\n        self.out_channels = out_channels\n        self.lins = torch.nn.ModuleDict()\n        for key in keys:\n            self.lins[key2str(key)] = Linear(-1, out_channels, bias=False)\n\n    def forward(\n        self, x_dict: Dict[Union[NodeType, EdgeType], Tensor]\n    ) -> Dict[Union[NodeType, EdgeType], Tensor]:\n        return {key: self.lins[key2str(key)](x) for key, x in x_dict.items()}\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(num_relations={len(self.lins)}, '\n                f'out_channels={self.out_channels})')\n\n\n###############################################################################\n\n# These methods are used in order to receive the cumulated sizes of input\n# dictionaries. We make use of them for creating a unified homogeneous graph\n# representation, as well as to split the final output data once again.\n\n\ndef get_node_offset_dict(\n    input_dict: Dict[NodeType, Union[Tensor, SparseTensor]],\n    type2id: Dict[NodeType, int],\n) -> Dict[NodeType, int]:\n    cumsum = 0\n    out: Dict[NodeType, int] = {}\n    for key in type2id.keys():\n        out[key] = cumsum\n        cumsum += input_dict[key].size(-2)\n    return out\n\n\ndef get_edge_offset_dict(\n    input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]],\n    type2id: Dict[EdgeType, int],\n) -> Dict[EdgeType, int]:\n    cumsum = 0\n    out: Dict[EdgeType, int] = {}\n    for key in type2id.keys():\n        out[key] = cumsum\n        value = input_dict[key]\n        if isinstance(value, SparseTensor):\n            cumsum += value.nnz()\n        elif value.dtype == torch.long and value.size(0) == 2:\n            cumsum += value.size(-1)\n        else:\n            cumsum += value.size(-2)\n    return out\n\n\n###############################################################################\n\n# This method computes the edge type of the final homogeneous graph\n# representation. It will be used in the `HeteroBasisConv` wrapper.\n\n\ndef get_edge_type(\n    input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]],\n    type2id: Dict[EdgeType, int],\n) -> Tensor:\n\n    inputs = [input_dict[key] for key in type2id.keys()]\n    outs = []\n\n    for i, value in enumerate(inputs):\n        if value.size(0) == 2 and value.dtype == torch.long:  # edge_index\n            out = value.new_full((value.size(-1), ), i, dtype=torch.long)\n        elif isinstance(value, SparseTensor):\n            out = torch.full((value.nnz(), ), i, dtype=torch.long,\n                             device=value.device())\n        else:\n            out = value.new_full((value.size(-2), ), i, dtype=torch.long)\n        outs.append(out)\n\n    return outs[0] if len(outs) == 1 else torch.cat(outs, dim=0)\n\n\n###############################################################################\n\n# These methods are used to group the individual type-wise components into a\n# unified single representation.\n\n\ndef group_node_placeholder(input_dict: Dict[NodeType, Tensor],\n                           type2id: Dict[NodeType, int]) -> Tensor:\n\n    inputs = [input_dict[key] for key in type2id.keys()]\n    return inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=-2)\n\n\ndef group_edge_placeholder(\n    input_dict: Dict[EdgeType, Union[Tensor, SparseTensor]],\n    type2id: Dict[EdgeType, int],\n    offset_dict: Dict[NodeType, int] = None,\n) -> Union[Tensor, SparseTensor]:\n\n    inputs = [input_dict[key] for key in type2id.keys()]\n\n    if len(inputs) == 1:\n        return inputs[0]\n\n    # In case of grouping a graph connectivity tensor `edge_index` or `adj_t`,\n    # we need to increment its indices:\n    elif inputs[0].size(0) == 2 and inputs[0].dtype == torch.long:\n        if offset_dict is None:\n            raise AttributeError(\n                \"Can not infer node-level offsets. Please ensure that there \"\n                \"exists a node-level argument before the 'edge_index' \"\n                \"argument in your forward header.\")\n\n        outputs = []\n        for value, (src_type, _, dst_type) in zip(inputs, type2id):\n            value = value.clone()\n            value[0, :] += offset_dict[src_type]\n            value[1, :] += offset_dict[dst_type]\n            outputs.append(value)\n\n        return torch.cat(outputs, dim=-1)\n\n    elif isinstance(inputs[0], SparseTensor):\n        if offset_dict is None:\n            raise AttributeError(\n                \"Can not infer node-level offsets. Please ensure that there \"\n                \"exists a node-level argument before the 'SparseTensor' \"\n                \"argument in your forward header.\")\n\n        # For grouping a list of SparseTensors, we convert them into a\n        # unified `edge_index` representation in order to avoid conflicts\n        # induced by re-shuffling the data.\n        rows, cols = [], []\n        for value, (src_type, _, dst_type) in zip(inputs, type2id):\n            col, row, value = value.coo()\n            assert value is None\n            rows.append(row + offset_dict[src_type])\n            cols.append(col + offset_dict[dst_type])\n\n        row = torch.cat(rows, dim=0)\n        col = torch.cat(cols, dim=0)\n        return torch.stack([row, col], dim=0)\n\n    else:\n        return torch.cat(inputs, dim=-2)\n\n\n###############################################################################\n\n# This method is used to split the output tensors into individual type-wise\n# components:\n\n\ndef split_output(\n    output: Tensor,\n    offset_dict: Union[Dict[NodeType, int], Dict[EdgeType, int]],\n) -> Union[Dict[NodeType, Tensor], Dict[EdgeType, Tensor]]:\n\n    cumsums = list(offset_dict.values()) + [output.size(-2)]\n    sizes = [cumsums[i + 1] - cumsums[i] for i in range(len(offset_dict))]\n    outputs = output.split(sizes, dim=-2)\n    return {key: output for key, output in zip(offset_dict, outputs)}\n\n\n###############################################################################\n\n\ndef key2str(key: Union[NodeType, EdgeType]) -> str:\n    key = '__'.join(key) if isinstance(key, tuple) else key\n    return key.replace(' ', '_').replace('-', '_').replace(':', '_')\n"
  },
  {
    "path": "torch_geometric/nn/unpool/__init__.py",
    "content": "r\"\"\"Unpooling package.\"\"\"\n\nfrom .knn_interpolate import knn_interpolate\n\n__all__ = [\n    'knn_interpolate',\n]\n\nclasses = __all__\n"
  },
  {
    "path": "torch_geometric/nn/unpool/knn_interpolate.py",
    "content": "import torch\n\nfrom torch_geometric.nn import knn\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import scatter\n\n\ndef knn_interpolate(x: torch.Tensor, pos_x: torch.Tensor, pos_y: torch.Tensor,\n                    batch_x: OptTensor = None, batch_y: OptTensor = None,\n                    k: int = 3, num_workers: int = 1):\n    r\"\"\"The k-NN interpolation from the `\"PointNet++: Deep Hierarchical\n    Feature Learning on Point Sets in a Metric Space\"\n    <https://arxiv.org/abs/1706.02413>`_ paper.\n\n    For each point :math:`y` with position :math:`\\mathbf{p}(y)`, its\n    interpolated features :math:`\\mathbf{f}(y)` are given by\n\n    .. math::\n        \\mathbf{f}(y) = \\frac{\\sum_{i=1}^k w(x_i) \\mathbf{f}(x_i)}{\\sum_{i=1}^k\n        w(x_i)} \\textrm{, where } w(x_i) = \\frac{1}{d(\\mathbf{p}(y),\n        \\mathbf{p}(x_i))^2}\n\n    and :math:`\\{ x_1, \\ldots, x_k \\}` denoting the :math:`k` nearest points\n    to :math:`y`.\n\n    Args:\n        x (torch.Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{N \\times F}`.\n        pos_x (torch.Tensor): Node position matrix\n            :math:`\\in \\mathbb{R}^{N \\times d}`.\n        pos_y (torch.Tensor): Upsampled node position matrix\n            :math:`\\in \\mathbb{R}^{M \\times d}`.\n        batch_x (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b_x} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n            each node from :math:`\\mathbf{X}` to a specific example.\n            (default: :obj:`None`)\n        batch_y (torch.Tensor, optional): Batch vector\n            :math:`\\mathbf{b_y} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns\n            each node from :math:`\\mathbf{Y}` to a specific example.\n            (default: :obj:`None`)\n        k (int, optional): Number of neighbors. (default: :obj:`3`)\n        num_workers (int, optional): Number of workers to use for computation.\n            Has no effect in case :obj:`batch_x` or :obj:`batch_y` is not\n            :obj:`None`, or the input lies on the GPU. (default: :obj:`1`)\n    \"\"\"\n    with torch.no_grad():\n        assign_index = knn(pos_x, pos_y, k, batch_x=batch_x, batch_y=batch_y,\n                           num_workers=num_workers)\n        y_idx, x_idx = assign_index[0], assign_index[1]\n        diff = pos_x[x_idx] - pos_y[y_idx]\n        squared_distance = (diff * diff).sum(dim=-1, keepdim=True)\n        weights = 1.0 / torch.clamp(squared_distance, min=1e-16)\n\n    y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce='sum')\n    y = y / scatter(weights, y_idx, 0, pos_y.size(0), reduce='sum')\n\n    return y\n"
  },
  {
    "path": "torch_geometric/profile/__init__.py",
    "content": "r\"\"\"GNN profiling package.\"\"\"\n\nfrom .benchmark import benchmark\nfrom .profile import (\n    get_stats_summary,\n    print_time_total,\n    profileit,\n    rename_profile_file,\n    timeit,\n    torch_profile,\n    trace_handler,\n    xpu_profile,\n)\nfrom .utils import (\n    count_parameters,\n    get_cpu_memory_from_gc,\n    get_data_size,\n    get_gpu_memory_from_gc,\n    get_gpu_memory_from_ipex,\n    get_gpu_memory_from_nvidia_smi,\n    get_model_size,\n)\nfrom .nvtx import nvtxit\n\n__all__ = [\n    'profileit',\n    'timeit',\n    'get_stats_summary',\n    'trace_handler',\n    'print_time_total',\n    'rename_profile_file',\n    'torch_profile',\n    'xpu_profile',\n    'count_parameters',\n    'get_model_size',\n    'get_data_size',\n    'get_cpu_memory_from_gc',\n    'get_gpu_memory_from_gc',\n    'get_gpu_memory_from_nvidia_smi',\n    'get_gpu_memory_from_ipex',\n    'benchmark',\n    'nvtxit',\n]\n\nclasses = __all__\n"
  },
  {
    "path": "torch_geometric/profile/benchmark.py",
    "content": "import time\nfrom typing import Any, Callable, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import is_torch_sparse_tensor\n\n\ndef require_grad(x: Any, requires_grad: bool = True) -> Any:\n    if (isinstance(x, Tensor) and x.is_floating_point()\n            and not is_torch_sparse_tensor(x)):\n        return x.detach().requires_grad_(requires_grad)\n    elif isinstance(x, list):\n        return [require_grad(v, requires_grad) for v in x]\n    elif isinstance(x, tuple):\n        return tuple(require_grad(v, requires_grad) for v in x)\n    elif isinstance(x, dict):\n        return {k: require_grad(v, requires_grad) for k, v in x.items()}\n    return x\n\n\ndef benchmark(\n    funcs: List[Callable],\n    args: Union[Tuple[Any], List[Tuple[Any]]],\n    num_steps: int,\n    func_names: Optional[List[str]] = None,\n    num_warmups: int = 10,\n    backward: bool = False,\n    per_step: bool = False,\n    progress_bar: bool = False,\n):\n    r\"\"\"Benchmark a list of functions :obj:`funcs` that receive the same set\n    of arguments :obj:`args`.\n\n    Args:\n        funcs ([Callable]): The list of functions to benchmark.\n        args ((Any, ) or [(Any, )]): The arguments to pass to the functions.\n            Can be a list of arguments for each function in :obj:`funcs` in\n            case their headers differ.\n            Alternatively, you can pass in functions that generate arguments\n            on-the-fly (e.g., useful for benchmarking models on various sizes).\n        num_steps (int): The number of steps to run the benchmark.\n        func_names ([str], optional): The names of the functions. If not given,\n            will try to infer the name from the function itself.\n            (default: :obj:`None`)\n        num_warmups (int, optional): The number of warmup steps.\n            (default: :obj:`10`)\n        backward (bool, optional): If set to :obj:`True`, will benchmark both\n            forward and backward passes. (default: :obj:`False`)\n        per_step (bool, optional): If set to :obj:`True`, will report runtimes\n            per step. (default: :obj:`False`)\n        progress_bar (bool, optional): If set to :obj:`True`, will print a\n            progress bar during benchmarking. (default: :obj:`False`)\n    \"\"\"\n    from tabulate import tabulate\n\n    if num_steps <= 0:\n        raise ValueError(f\"'num_steps' must be a positive integer \"\n                         f\"(got {num_steps})\")\n\n    if num_warmups <= 0:\n        raise ValueError(f\"'num_warmups' must be a positive integer \"\n                         f\"(got {num_warmups})\")\n\n    if func_names is None:\n        func_names = [get_func_name(func) for func in funcs]\n\n    if len(funcs) != len(func_names):\n        raise ValueError(f\"Length of 'funcs' (got {len(funcs)}) and \"\n                         f\"'func_names' (got {len(func_names)}) must be equal\")\n\n    # Zero-copy `args` for each function (if necessary):\n    args_list = [args] * len(funcs) if not isinstance(args, list) else args\n\n    iterator = zip(funcs, args_list, func_names)\n    if progress_bar:\n        from tqdm import tqdm\n        iterator = tqdm(iterator, total=len(funcs))\n\n    ts: List[List[str]] = []\n    for func, inputs, name in iterator:\n        t_forward = t_backward = 0\n        for i in range(num_warmups + num_steps):\n            args = inputs() if callable(inputs) else inputs\n            args = require_grad(args, backward)\n\n            if torch.cuda.is_available():\n                torch.cuda.synchronize()\n            t_start = time.perf_counter()\n\n            out = func(*args)\n\n            if torch.cuda.is_available():\n                torch.cuda.synchronize()\n            if i >= num_warmups:\n                t_forward += time.perf_counter() - t_start\n\n            if backward:\n                if isinstance(out, (tuple, list)):\n                    out = sum(o.sum() for o in out if isinstance(o, Tensor))\n                elif isinstance(out, dict):\n                    out = out.values()\n                    out = sum(o.sum() for o in out if isinstance(o, Tensor))\n\n                out_grad = torch.randn_like(out)\n                t_start = time.perf_counter()\n\n                out.backward(out_grad)\n\n                if torch.cuda.is_available():\n                    torch.cuda.synchronize()\n                if i >= num_warmups:\n                    t_backward += time.perf_counter() - t_start\n\n        if per_step:\n            ts.append([name, f'{t_forward/num_steps:.6f}s'])\n        else:\n            ts.append([name, f'{t_forward:.4f}s'])\n        if backward:\n            if per_step:\n                ts[-1].append(f'{t_backward/num_steps:.6f}s')\n                ts[-1].append(f'{(t_forward + t_backward)/num_steps:.6f}s')\n            else:\n                ts[-1].append(f'{t_backward:.4f}s')\n                ts[-1].append(f'{t_forward + t_backward:.4f}s')\n\n    header = ['Name', 'Forward']\n    if backward:\n        header.extend(['Backward', 'Total'])\n\n    print(tabulate(ts, headers=header, tablefmt='psql'))\n\n\ndef get_func_name(func: Callable) -> str:\n    if hasattr(func, '__name__'):\n        return func.__name__\n    elif hasattr(func, '__class__'):\n        return func.__class__.__name__\n    raise ValueError(\"Could not infer name for function '{func}'\")\n"
  },
  {
    "path": "torch_geometric/profile/nvtx.py",
    "content": "from functools import wraps\nfrom typing import Optional\n\nimport torch\n\nCUDA_PROFILE_STARTED = False\n\n\ndef begin_cuda_profile():\n    global CUDA_PROFILE_STARTED\n    prev_state = CUDA_PROFILE_STARTED\n    if prev_state is False:\n        CUDA_PROFILE_STARTED = True\n        torch.cuda.cudart().cudaProfilerStart()\n    return prev_state\n\n\ndef end_cuda_profile(prev_state: bool):\n    global CUDA_PROFILE_STARTED\n    CUDA_PROFILE_STARTED = prev_state\n    if prev_state is False:\n        torch.cuda.cudart().cudaProfilerStop()\n\n\ndef nvtxit(name: Optional[str] = None, n_warmups: int = 0,\n           n_iters: Optional[int] = None):\n    \"\"\"Enables NVTX profiling for a function.\n\n    Args:\n        name (Optional[str], optional): Name to give the reference frame for\n            the function being wrapped. Defaults to the name of the\n            function in code.\n        n_warmups (int, optional): Number of iters to call that function\n            before starting. Defaults to 0.\n        n_iters (Optional[int], optional): Number of iters of that function to\n            record. Defaults to all of them.\n    \"\"\"\n    def nvtx(func):\n\n        nonlocal name\n        iters_so_far = 0\n        if name is None:\n            name = func.__name__\n\n        @wraps(func)\n        def wrapper(*args, **kwargs):\n            nonlocal iters_so_far\n            if not torch.cuda.is_available():\n                return func(*args, **kwargs)\n            elif iters_so_far < n_warmups:\n                iters_so_far += 1\n                return func(*args, **kwargs)\n            elif n_iters is None or iters_so_far < n_iters + n_warmups:\n                prev_state = begin_cuda_profile()\n                torch.cuda.nvtx.range_push(f\"{name}_{iters_so_far}\")\n                result = func(*args, **kwargs)\n                torch.cuda.nvtx.range_pop()\n                end_cuda_profile(prev_state)\n                iters_so_far += 1\n                return result\n            else:\n                return func(*args, **kwargs)\n\n        return wrapper\n\n    return nvtx\n"
  },
  {
    "path": "torch_geometric/profile/profile.py",
    "content": "import os\nimport pathlib\nimport time\nfrom contextlib import ContextDecorator, contextmanager\nfrom dataclasses import dataclass\nfrom typing import Any, List, Tuple, Union\n\nimport torch\nfrom torch.autograd.profiler import EventList\nfrom torch.profiler import ProfilerActivity, profile\n\nfrom torch_geometric.profile.utils import (\n    byte_to_megabyte,\n    get_gpu_memory_from_ipex,\n    get_gpu_memory_from_nvidia_smi,\n)\n\n\n@dataclass\nclass GPUStats:\n    time: float\n    max_allocated_gpu: float\n    max_reserved_gpu: float\n    max_active_gpu: float\n\n\n@dataclass\nclass CUDAStats(GPUStats):\n    nvidia_smi_free_cuda: float\n    nvidia_smi_used_cuda: float\n\n\n@dataclass\nclass GPUStatsSummary:\n    time_mean: float\n    time_std: float\n    max_allocated_gpu: float\n    max_reserved_gpu: float\n    max_active_gpu: float\n\n\n@dataclass\nclass CUDAStatsSummary(GPUStatsSummary):\n    min_nvidia_smi_free_cuda: float\n    max_nvidia_smi_used_cuda: float\n\n\ndef profileit(device: str):  # pragma: no cover\n    r\"\"\"A decorator to facilitate profiling a function, *e.g.*, obtaining\n    training runtime and memory statistics of a specific model on a specific\n    dataset.\n    Returns a :obj:`GPUStats` if :obj:`device` is :obj:`xpu` or extended\n    object :obj:`CUDAStats`, if :obj:`device` is :obj:`cuda`.\n\n    Args:\n        device (str): Target device for profiling. Options are:\n            :obj:`cuda` and obj:`xpu`.\n\n    .. code-block:: python\n\n        @profileit(\"cuda\")\n        def train(model, optimizer, x, edge_index, y):\n            optimizer.zero_grad()\n            out = model(x, edge_index)\n            loss = criterion(out, y)\n            loss.backward()\n            optimizer.step()\n            return float(loss)\n\n        loss, stats = train(model, x, edge_index, y)\n    \"\"\"\n    def decorator(func):\n        def wrapper(\n                *args, **kwargs\n        ) -> Union[Tuple[Any, GPUStats], Tuple[Any, CUDAStats]]:\n            model = args[0]\n            if not isinstance(model, torch.nn.Module):\n                raise AttributeError(\n                    'First argument for profiling needs to be torch.nn.Module')\n            if device not in ['cuda', 'xpu']:\n                raise AttributeError(\n                    \"The profiling decorator supports only CUDA and \"\n                    \"XPU devices\")\n\n            device_id = None\n            for arg in list(args) + list(kwargs.values()):\n                if isinstance(arg, torch.Tensor):\n                    device_id = arg.get_device()\n                    break\n            if device_id is None:\n                raise AttributeError(\n                    \"Could not infer GPU device from the args in the \"\n                    \"function being profiled\")\n            if device_id == -1:\n                raise RuntimeError(\n                    \"The profiling decorator does not support profiling \"\n                    \"on non GPU devices\")\n\n            is_cuda = device == 'cuda'\n            torch_gpu = torch.cuda if is_cuda else torch.xpu\n\n            # `pytorch_memlab` supports only CUDA devices\n            if is_cuda:\n                from pytorch_memlab import LineProfiler\n\n                # Init `pytorch_memlab` for analyzing the model forward pass:\n                line_profiler = LineProfiler(target_gpu=device_id)\n                line_profiler.enable()\n                line_profiler.add_function(args[0].forward)\n\n            start = torch_gpu.Event(enable_timing=True)\n            end = torch_gpu.Event(enable_timing=True)\n            start.record()\n\n            out = func(*args, **kwargs)\n\n            end.record()\n            torch_gpu.synchronize()\n            time = start.elapsed_time(end) / 1000\n\n            if is_cuda:\n                # Get the global memory statistics collected\n                # by `pytorch_memlab`:\n                memlab = read_from_memlab(line_profiler)\n                max_allocated, max_reserved, max_active = memlab\n                line_profiler.disable()\n\n                # Get additional information from `nvidia-smi`:\n                free_cuda, used_cuda = get_gpu_memory_from_nvidia_smi(\n                    device=device_id)\n\n                stats = CUDAStats(time, max_allocated, max_reserved,\n                                  max_active, free_cuda, used_cuda)\n                return out, stats\n            else:\n                stats = GPUStats(time, *get_gpu_memory_from_ipex(device_id))\n                return out, stats\n\n        return wrapper\n\n    return decorator\n\n\nclass timeit(ContextDecorator):\n    r\"\"\"A context decorator to facilitate timing a function, *e.g.*, obtaining\n    the runtime of a specific model on a specific dataset.\n\n    .. code-block:: python\n\n        @torch.no_grad()\n        def test(model, x, edge_index):\n            return model(x, edge_index)\n\n        with timeit() as t:\n            z = test(model, x, edge_index)\n        time = t.duration\n\n    Args:\n        log (bool, optional): If set to :obj:`False`, will not log any runtime\n            to the console. (default: :obj:`True`)\n        avg_time_divisor (int, optional): If set to a value greater than\n            :obj:`1`, will divide the total time by this value. Useful for\n            calculating the average of runtimes within a for-loop.\n            (default: :obj:`0`)\n    \"\"\"\n    def __init__(self, log: bool = True, avg_time_divisor: int = 0):\n        self.log = log\n        self.avg_time_divisor = avg_time_divisor\n\n    def __enter__(self):\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        self.t_start = time.time()\n        return self\n\n    def __exit__(self, *args):\n        if torch.cuda.is_available():\n            torch.cuda.synchronize()\n        self.t_end = time.time()\n        self.duration = self.t_end - self.t_start\n        if self.avg_time_divisor > 1:\n            self.duration = self.duration / self.avg_time_divisor\n        if self.log:  # pragma: no cover\n            print(f'Time: {self.duration:.8f}s', flush=True)\n\n    def reset(self):\n        r\"\"\"Prints the duration and resets current timer.\"\"\"\n        if self.t_start is None:\n            raise RuntimeError(\"Timer wasn't started.\")\n        else:\n            self.__exit__()\n            self.__enter__()\n\n\ndef get_stats_summary(\n    stats_list: Union[List[GPUStats], List[CUDAStats]]\n) -> Union[GPUStatsSummary, CUDAStatsSummary]:  # pragma: no cover\n    r\"\"\"Creates a summary of collected runtime and memory statistics.\n    Returns a :obj:`GPUStatsSummary` if list of :obj:`GPUStats` was passed,\n    otherwise (list of :obj:`CUDAStats` was passed),\n    returns a :obj:`CUDAStatsSummary`.\n\n    Args:\n        stats_list (Union[List[GPUStats], List[CUDAStats]]): A list of\n            :obj:`GPUStats` or :obj:`CUDAStats` objects, as returned by\n            :meth:`~torch_geometric.profile.profileit`.\n    \"\"\"\n    # calculate common statistics\n    kwargs = dict(\n        time_mean=float(torch.tensor([s.time for s in stats_list]).mean()),\n        time_std=float(torch.tensor([s.time for s in stats_list]).std()),\n        max_allocated_gpu=max([s.max_allocated_gpu for s in stats_list]),\n        max_reserved_gpu=max([s.max_reserved_gpu for s in stats_list]),\n        max_active_gpu=max([s.max_active_gpu for s in stats_list]))\n\n    if all(isinstance(s, CUDAStats) for s in stats_list):\n        return CUDAStatsSummary(\n            **kwargs,\n            min_nvidia_smi_free_cuda=min(\n                [s.nvidia_smi_free_cuda for s in stats_list]),\n            max_nvidia_smi_used_cuda=max(\n                [s.nvidia_smi_used_cuda for s in stats_list]),\n        )\n    else:\n        return GPUStatsSummary(**kwargs)\n\n\n###############################################################################\n\n\ndef read_from_memlab(line_profiler: Any) -> List[float]:  # pragma: no cover\n    from pytorch_memlab.line_profiler.line_records import LineRecords\n\n    # See: https://pytorch.org/docs/stable/cuda.html#torch.cuda.memory_stats\n\n    track_stats = [  # Different statistic can be collected as needed.\n        'allocated_bytes.all.peak',\n        'reserved_bytes.all.peak',\n        'active_bytes.all.peak',\n    ]\n\n    records = LineRecords(line_profiler._raw_line_records,\n                          line_profiler._code_infos)\n    stats = records.display(None, track_stats)._line_records\n    return [byte_to_megabyte(x) for x in stats.values.max(axis=0).tolist()]\n\n\ndef trace_handler(p):\n    print_time_total(p)\n    profile_dir = str(pathlib.Path.cwd()) + '/'\n    timeline_file = profile_dir + 'timeline' + '.json'\n    p.export_chrome_trace(timeline_file)\n\n\ndef print_time_total(p):\n    if torch.cuda.is_available():\n        profile_sort = 'self_cuda_time_total'\n    else:\n        profile_sort = 'self_cpu_time_total'\n    output = p.key_averages().table(sort_by=profile_sort)\n    print(output)\n\n\ndef rename_profile_file(*args):\n    profile_dir = str(pathlib.Path.cwd()) + '/'\n    timeline_file = profile_dir + 'profile'\n    for arg in args:\n        timeline_file += '-' + arg\n    timeline_file += '.json'\n    os.rename('timeline.json', timeline_file)\n\n\n@contextmanager\ndef torch_profile(export_chrome_trace=True, csv_data=None, write_csv=None):\n    use_cuda = torch.cuda.is_available()\n\n    activities = [ProfilerActivity.CPU]\n    if use_cuda:\n        activities.append(ProfilerActivity.CUDA)\n\n    if export_chrome_trace:\n        p_trace_handler = trace_handler\n    else:\n        p_trace_handler = print_time_total\n\n    p = profile(activities=activities, on_trace_ready=p_trace_handler)\n\n    with p:\n        yield\n        p.step()\n\n    if csv_data is not None and write_csv == 'prof':\n        if use_cuda:\n            profile_sort = 'self_cuda_time_total'\n        else:\n            profile_sort = 'self_cpu_time_total'\n        events = EventList(\n            sorted(\n                p.key_averages(),\n                key=lambda evt: getattr(evt, profile_sort),\n                reverse=True,\n            ), use_cuda=use_cuda)\n\n        save_profile_data(csv_data, events, use_cuda)\n\n\n@contextmanager\ndef xpu_profile(export_chrome_trace=True):\n    with torch.autograd.profiler_legacy.profile(use_xpu=True) as profile:\n        yield\n    print(profile.key_averages().table(sort_by='self_xpu_time_total'))\n    if export_chrome_trace:\n        profile.export_chrome_trace('timeline.json')\n\n\ndef format_prof_time(time):\n    # Profile time is in micro seconds, so format it appropriately:\n    return round(time / 1e6, 3)\n\n\ndef save_profile_data(csv_data, events, use_cuda):\n    sum_self_cpu_time_total = sum(\n        [event.self_cpu_time_total for event in events])\n    sum_cpu_time_total = sum([event.self_cpu_time_total for event in events])\n    sum_self_cuda_time_total = sum(\n        [event.self_cuda_time_total for event in events]) if use_cuda else 0\n\n    for e in events[:5]:  # Save top 5 most time consuming operations:\n        csv_data['NAME'].append(e.key)\n        csv_data['SELF CPU %'].append(\n            round(e.self_cpu_time_total * 100.0 / sum_self_cpu_time_total, 3))\n        csv_data['SELF CPU'].append(format_prof_time(e.self_cpu_time_total))\n        csv_data['CPU TOTAL %'].append(\n            round(e.cpu_time_total * 100.0 / sum_cpu_time_total, 3))\n        csv_data['CPU TOTAL'].append(format_prof_time(e.cpu_time_total))\n        csv_data['CPU TIME AVG'].append(format_prof_time(e.cpu_time_total))\n        if use_cuda:\n            csv_data['SELF CUDA %'].append(e.self_cuda_time_total * 100.0 /\n                                           sum_self_cuda_time_total)\n            csv_data['SELF CUDA'].append(\n                format_prof_time(e.self_cuda_time_total))\n            csv_data['CUDA TOTAL'].append(format_prof_time(e.cpu_time_total))\n            csv_data['CUDA TIME AVG'].append(format_prof_time(\n                e.cpu_time_total))\n        csv_data['# OF CALLS'].append(e.count)\n"
  },
  {
    "path": "torch_geometric/profile/profiler.py",
    "content": "import functools\nfrom collections import OrderedDict, defaultdict, namedtuple\nfrom typing import Any, List, NamedTuple, Optional, Tuple\n\nimport torch\nimport torch.profiler as torch_profiler\n\nimport torch_geometric.typing\n\n# predefined namedtuple for variable setting (global template)\nTrace = namedtuple('Trace', ['path', 'leaf', 'module'])\n\n# the metrics returned from the torch profiler\nMeasure = namedtuple('Measure', [\n    'self_cpu_total',\n    'cpu_total',\n    'self_cuda_total',\n    'cuda_total',\n    'self_cpu_memory',\n    'cpu_memory',\n    'self_cuda_memory',\n    'cuda_memory',\n    'occurrences',\n])\n\n\nclass Profiler:\n    r\"\"\"Layer by layer profiling of PyTorch models, using the PyTorch profiler\n    for memory profiling. Parts of the code are adapted from :obj:`torchprof`\n    for layer-wise grouping.\n\n    Args:\n        model (torch.nn.Module): The underlying model to be profiled.\n        enabled (bool, optional): If set to :obj:`True`, turn on the profiler.\n            (default: :obj:`False`)\n        use_cuda (bool, optional): Whether to profile CUDA execution.\n            (default: :obj:`False`)\n        profile_memory (bool, optional): If set to :obj:`True`, also profile\n            memory usage. (default: :obj:`False`)\n        paths ([str], optional): Pre-defined paths for fast loading.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        model: torch.nn.Module,\n        enabled: bool = True,\n        use_cuda: bool = False,\n        profile_memory: bool = False,\n        paths: Optional[List[str]] = None,\n    ):\n        self._model = model\n        self.enabled = enabled\n        self.use_cuda = use_cuda\n        self.profile_memory = profile_memory\n        self.paths = paths\n\n        self.entered = False\n        self.exited = False\n        self.traces = ()\n        self._ids = set()\n        self.trace_profile_events = defaultdict(list)\n\n    def __enter__(self):\n        if not self.enabled:\n            return self\n        if self.entered:\n            raise RuntimeError(\"the profiler can be initialized only once\")\n        self.entered = True\n        self._forwards = {}  # store the original forward functions\n\n        # generate the trace and conduct profiling\n        self.traces = tuple(map(self._hook_trace, _walk_modules(self._model)))\n        return self\n\n    def __exit__(self, exc_type, exc_val, exc_tb):\n        if not self.enabled:\n            return\n        tuple(map(self._remove_hook_trace, self.traces))\n        del self._forwards  # remove unnecessary forwards\n        self.exited = True\n\n    def get_trace(self):\n        return _layer_trace(self.traces, self.trace_profile_events)\n\n    def __repr__(self) -> str:\n        return self.get_trace()[0]\n\n    def __call__(self, *args, **kwargs):\n        return self._model(*args, **kwargs)\n\n    def _hook_trace(self, trace):\n        \"\"\"Add hooks to torch modules for profiling. The underlying model's\n        forward pass is hooked/decorated here.\n        \"\"\"\n        [path, leaf, module] = trace\n\n        # the id of the model  is guaranteed to be unique\n        _id = id(module)\n        if (self.paths is not None\n                and path in self.paths) or (self.paths is None and leaf):\n            if _id in self._ids:\n                # already wrapped\n                return trace\n            self._ids.add(_id)\n            _forward = module.forward\n            self._forwards[path] = _forward\n\n            @functools.wraps(_forward)\n            def wrap_forward(*args, **kwargs):\n                \"\"\"The forward pass is decorated and profiled here.\"\"\"\n                # only torch 1.8.1+ is supported\n                torch_version = torch.__version__\n                if torch_version <= '1.8.1':\n                    raise NotImplementedError(\n                        \"Profiler requires at least torch 1.8.1\")\n\n                activities = [torch.profiler.ProfilerActivity.CPU]\n                if self.use_cuda:\n                    activities.append(torch.profiler.ProfilerActivity.CUDA)\n                with torch_profiler.profile(\n                        activities=activities,\n                        profile_memory=self.profile_memory,\n                ) as prof:\n                    res = _forward(*args, **kwargs)\n\n                event_list = prof.events()\n\n                # each profile call should be contained in its own list\n                self.trace_profile_events[path].append(event_list)\n                return res\n\n            # decorate the underlying model's forward pass\n            module.forward = wrap_forward\n        return trace\n\n    def _remove_hook_trace(self, trace):\n        \"\"\"Clean it up after the profiling is done.\"\"\"\n        [path, leaf, module] = trace\n        _id = id(module)\n        if _id in self._ids:\n            self._ids.discard(_id)\n        else:\n            return\n        if (self.paths is not None\n                and path in self.paths) or (self.paths is None and leaf):\n            module.forward = self._forwards[path]\n\n\ndef _layer_trace(\n        traces: NamedTuple,\n        trace_events: Any,\n        show_events: bool = True,\n        paths: List[str] = None,\n        use_cuda: bool = False,\n        profile_memory: bool = False,\n        dt: Tuple[str, ...] = ('-', '-', '-', ' '),\n) -> object:\n    \"\"\"Construct human readable output of the profiler traces and events. The\n    information is presented in layers, and each layer contains its underlying\n    operators.\n\n    Args:\n        traces (trace object): Raw trace to be parsed.\n        trace_events (trace object): Raw events to be parsed.\n        show_events (bool, optional): If True, show detailed event information.\n            (default: :obj:`True`)\n        paths (str, optional): Predefine path for fast loading. By default, it\n            will not be used.\n            (default: :obj:`False`)\n        use_cuda (bool, optional): Enables timing of CUDA events.\n            (default: :obj:`False`)\n        profile_memory (bool, optional): If True, also profile for the memory\n            usage information.\n            (default: :obj:`False`)\n        dt (object, optional): Delimiters for showing the events.\n    \"\"\"\n    tree = OrderedDict()\n\n    for trace in traces:\n        [path, leaf, module] = trace\n        current_tree = tree\n        # unwrap all of the events, in case model is called multiple times\n        events = [te for t_events in trace_events[path] for te in t_events]\n        for depth, name in enumerate(path, 1):\n            if name not in current_tree:\n                current_tree[name] = OrderedDict()\n            if depth == len(path) and ((paths is None and leaf) or\n                                       (paths is not None and path in paths)):\n                # tree measurements have key None, avoiding name conflict\n                if show_events:\n                    for event_name, event_group in _group_by(\n                            events, lambda e: e.name):\n                        event_group = list(event_group)\n                        current_tree[name][event_name] = {\n                            None:\n                            _build_measure_tuple(event_group, len(event_group))\n                        }\n                else:\n                    current_tree[name][None] = _build_measure_tuple(\n                        events, len(trace_events[path]))\n            current_tree = current_tree[name]\n    tree_lines = _flatten_tree(tree)\n\n    format_lines = []\n    has_self_cuda_total = False\n    has_self_cpu_memory = False\n    has_cpu_memory = False\n    has_self_cuda_memory = False\n    has_cuda_memory = False\n\n    raw_results = {}\n    for idx, tree_line in enumerate(tree_lines):\n        depth, name, measures = tree_line\n\n        next_depths = [pl[0] for pl in tree_lines[idx + 1:]]\n        pre = \"-\"\n        if depth > 0:\n            pre = dt[1] if depth in next_depths and next_depths[\n                0] >= depth else dt[2]\n            depth -= 1\n        while depth > 0:\n            pre = (dt[0] + pre) if depth in next_depths else (dt[3] + pre)\n            depth -= 1\n\n        format_lines.append([pre + name, *_format_measure_tuple(measures)])\n        if measures:\n            has_self_cuda_total = (has_self_cuda_total\n                                   or measures.self_cuda_total is not None)\n            has_self_cpu_memory = (has_self_cpu_memory\n                                   or measures.self_cpu_memory is not None)\n            has_cpu_memory = has_cpu_memory or measures.cpu_memory is not None\n            has_self_cuda_memory = (has_self_cuda_memory\n                                    or measures.self_cuda_memory is not None)\n            has_cuda_memory = (has_cuda_memory\n                               or measures.cuda_memory is not None)\n\n            raw_results[name] = [\n                measures.self_cpu_total, measures.cpu_total,\n                measures.self_cuda_total, measures.cuda_total,\n                measures.self_cpu_memory, measures.cpu_memory,\n                measures.self_cuda_memory, measures.cuda_memory,\n                measures.occurrences\n            ]\n\n    # construct the table (this is pretty ugly and can probably be optimized)\n    heading = (\n        \"Module\",\n        \"Self CPU total\",\n        \"CPU total\",\n        \"Self CUDA total\",\n        \"CUDA total\",\n        \"Self CPU Mem\",\n        \"CPU Mem\",\n        \"Self CUDA Mem\",\n        \"CUDA Mem\",\n        \"Number of Calls\",\n    )\n\n    # get the output aligned\n    max_lens = [max(map(len, col)) for col in zip(*([heading] + format_lines))]\n\n    # not all columns should be displayed, specify kept indexes\n    keep_indexes = [0, 1, 2, 9]\n    if profile_memory:\n        if has_self_cpu_memory:\n            keep_indexes.append(5)\n        if has_cpu_memory:\n            keep_indexes.append(6)\n    if use_cuda:\n        if has_self_cuda_total:\n            keep_indexes.append(3)\n        keep_indexes.append(4)\n        if profile_memory:\n            if has_self_cuda_memory:\n                keep_indexes.append(7)\n            if has_cuda_memory:\n                keep_indexes.append(8)\n\n    # the final columns to be shown\n    keep_indexes = tuple(sorted(keep_indexes))\n\n    heading_list = list(heading)\n\n    display = (  # table heading\n        \" | \".join([\n            \"{:<{}s}\".format(heading[keep_index], max_lens[keep_index])\n            for keep_index in keep_indexes\n        ]) + \"\\n\")\n    display += (  # separator\n        \"-|-\".join([\n            \"-\" * max_len for val_idx, max_len in enumerate(max_lens)\n            if val_idx in keep_indexes\n        ]) + \"\\n\")\n    for format_line in format_lines:  # body\n        display += (\" | \".join([\n            \"{:<{}s}\".format(value, max_lens[val_idx])\n            for val_idx, value in enumerate(format_line)\n            if val_idx in keep_indexes\n        ]) + \"\\n\")\n    # layer information readable\n    key_dict = {}\n    layer_names = []\n    layer_stats = []\n    for format_line in format_lines:  # body\n        if format_line[1] == '':  # key line\n            key_dict[format_line[0].count(\"-\")] = format_line[0]\n        else:  # must print\n            # get current line's level\n            curr_level = format_line[0].count(\"-\")\n            par_str = \"\"\n            for i in range(1, curr_level):\n                par_str += key_dict[i]\n            curr_key = par_str + format_line[0]\n            layer_names.append(curr_key)\n            layer_stats.append(format_line[1:])\n\n    return display, heading_list, raw_results, layer_names, layer_stats\n\n\ndef _flatten_tree(t, depth=0):\n    flat = []\n    for name, st in t.items():\n        measures = st.pop(None, None)\n        flat.append([depth, name, measures])\n        flat.extend(_flatten_tree(st, depth=depth + 1))\n    return flat\n\n\ndef _build_measure_tuple(events: List, occurrences: List) -> NamedTuple:\n    device_str = 'device' if torch_geometric.typing.WITH_PT24 else 'cuda'\n\n    # memory profiling supported in torch >= 1.6\n    self_cpu_memory = None\n    has_self_cpu_memory = any(\n        hasattr(e, \"self_cpu_memory_usage\") for e in events)\n    if has_self_cpu_memory:\n        self_cpu_memory = sum(\n            [getattr(e, \"self_cpu_memory_usage\", 0) or 0 for e in events])\n    cpu_memory = None\n    has_cpu_memory = any(hasattr(e, \"cpu_memory_usage\") for e in events)\n    if has_cpu_memory:\n        cpu_memory = sum(\n            [getattr(e, \"cpu_memory_usage\", 0) or 0 for e in events])\n    self_cuda_memory = None\n    has_self_cuda_memory = any(\n        hasattr(e, f\"self_{device_str}_memory_usage\") for e in events)\n    if has_self_cuda_memory:\n        self_cuda_memory = sum([\n            getattr(e, f\"self_{device_str}_memory_usage\", 0) or 0\n            for e in events\n        ])\n    cuda_memory = None\n    has_cuda_memory = any(\n        hasattr(e, f\"{device_str}_memory_usage\") for e in events)\n    if has_cuda_memory:\n        cuda_memory = sum(\n            [getattr(e, f\"{device_str}_memory_usage\", 0) or 0 for e in events])\n\n    # self CUDA time supported in torch >= 1.7\n    self_cuda_total = None\n    has_self_cuda_time = any(\n        hasattr(e, f\"self_{device_str}_time_total\") for e in events)\n    if has_self_cuda_time:\n        self_cuda_total = sum([\n            getattr(e, f\"self_{device_str}_time_total\", 0) or 0 for e in events\n        ])\n\n    return Measure(\n        self_cpu_total=sum([e.self_cpu_time_total or 0 for e in events]),\n        cpu_total=sum([e.cpu_time_total or 0 for e in events]),\n        self_cuda_total=self_cuda_total,\n        cuda_total=sum(\n            [getattr(e, f\"{device_str}_time_total\") or 0 for e in events]),\n        self_cpu_memory=self_cpu_memory,\n        cpu_memory=cpu_memory,\n        self_cuda_memory=self_cuda_memory,\n        cuda_memory=cuda_memory,\n        occurrences=occurrences,\n    )\n\n\ndef _format_measure_tuple(measure: NamedTuple) -> NamedTuple:\n    self_cpu_total = (format_time(measure.self_cpu_total) if measure else \"\")\n    cpu_total = format_time(measure.cpu_total) if measure else \"\"\n    self_cuda_total = (format_time(measure.self_cuda_total) if measure\n                       and measure.self_cuda_total is not None else \"\")\n    cuda_total = format_time(measure.cuda_total) if measure else \"\"\n    self_cpu_memory = (format_memory(measure.self_cpu_memory) if measure\n                       and measure.self_cpu_memory is not None else \"\")\n    cpu_memory = (format_memory(measure.cpu_memory)\n                  if measure and measure.cpu_memory is not None else \"\")\n    self_cuda_memory = (format_memory(measure.self_cuda_memory) if measure\n                        and measure.self_cuda_memory is not None else \"\")\n    cuda_memory = (format_memory(measure.cuda_memory)\n                   if measure and measure.cuda_memory is not None else \"\")\n    occurrences = str(measure.occurrences) if measure else \"\"\n\n    return Measure(\n        self_cpu_total=self_cpu_total,\n        cpu_total=cpu_total,\n        self_cuda_total=self_cuda_total,\n        cuda_total=cuda_total,\n        self_cpu_memory=self_cpu_memory,\n        cpu_memory=cpu_memory,\n        self_cuda_memory=self_cuda_memory,\n        cuda_memory=cuda_memory,\n        occurrences=occurrences,\n    )\n\n\ndef _group_by(events, keyfn):\n    event_groups = OrderedDict()\n    for event in events:\n        key = keyfn(event)\n        key_events = event_groups.get(key, [])\n        key_events.append(event)\n        event_groups[key] = key_events\n    return event_groups.items()\n\n\ndef _walk_modules(module, name: str = \"\", path=()):\n    # Walk through a PyTorch model and output trace tuples (its path, leafe\n    # node, model).\n    if not name:\n        name = module.__class__.__name__\n\n    # This will track the children of the module (layers)\n    # for instance, [('conv1', GCNConv(10, 16)), ('conv2', GCNConv(16, 3))]\n    named_children = list(module.named_children())\n\n    # it builds the path of the structure\n    # for instance, ('GCN', 'conv1', 'lin')\n    path = path + (name, )\n\n    # create namedtuple [path, (whether has) leaf, module]\n    yield Trace(path, len(named_children) == 0, module)\n\n    # recursively walk into all submodules\n    for name, child_module in named_children:\n        yield from _walk_modules(child_module, name=name, path=path)\n\n\ndef format_time(time_us: int) -> str:\n    r\"\"\"Returns a formatted time string.\"\"\"\n    US_IN_SECOND = 1000.0 * 1000.0\n    US_IN_MS = 1000.0\n    if time_us >= US_IN_SECOND:\n        return f'{time_us / US_IN_SECOND:.3f}s'\n    if time_us >= US_IN_MS:\n        return f'{time_us / US_IN_MS:.3f}ms'\n    return f'{time_us:.3f}us'\n\n\ndef format_memory(nbytes: int) -> str:\n    \"\"\"Returns a formatted memory size string.\"\"\"\n    KB = 1024\n    MB = 1024 * KB\n    GB = 1024 * MB\n    if (abs(nbytes) >= GB):\n        return f'{nbytes * 1.0 / GB:.2f} Gb'\n    elif (abs(nbytes) >= MB):\n        return f'{nbytes * 1.0 / MB:.2f} Mb'\n    elif (abs(nbytes) >= KB):\n        return f'{nbytes * 1.0 / KB:.2f} Kb'\n    else:\n        return str(nbytes) + ' b'\n"
  },
  {
    "path": "torch_geometric/profile/utils.py",
    "content": "import gc\nimport os\nimport os.path as osp\nimport random\nimport subprocess as sp\nimport sys\nimport warnings\nfrom collections.abc import Mapping, Sequence\nfrom typing import Any, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data.data import BaseData\nfrom torch_geometric.typing import SparseTensor\n\n\ndef count_parameters(model: torch.nn.Module) -> int:\n    r\"\"\"Given a :class:`torch.nn.Module`, count its trainable parameters.\n\n    Args:\n        model (torch.nn.Model): The model.\n    \"\"\"\n    return sum([p.numel() for p in model.parameters() if p.requires_grad])\n\n\ndef get_model_size(model: torch.nn.Module) -> int:\n    r\"\"\"Given a :class:`torch.nn.Module`, get its actual disk size in bytes.\n\n    Args:\n        model (torch model): The model.\n    \"\"\"\n    path = f'{random.randrange(sys.maxsize)}.pt'\n    torch.save(model.state_dict(), path)\n    model_size = osp.getsize(path)\n    os.remove(path)\n    return model_size\n\n\ndef get_data_size(data: BaseData) -> int:\n    r\"\"\"Given a :class:`torch_geometric.data.Data` object, get its theoretical\n    memory usage in bytes.\n\n    Args:\n        data (torch_geometric.data.Data or torch_geometric.data.HeteroData):\n            The :class:`~torch_geometric.data.Data` or\n            :class:`~torch_geometric.data.HeteroData` graph object.\n    \"\"\"\n    data_ptrs = set()\n\n    def _get_size(obj: Any) -> int:\n        if isinstance(obj, Tensor):\n            if obj.data_ptr() in data_ptrs:\n                return 0\n            data_ptrs.add(obj.data_ptr())\n            return obj.numel() * obj.element_size()\n        elif isinstance(obj, SparseTensor):\n            return _get_size(obj.csr())\n        elif isinstance(obj, Sequence) and not isinstance(obj, str):\n            return sum([_get_size(x) for x in obj])\n        elif isinstance(obj, Mapping):\n            return sum([_get_size(x) for x in obj.values()])\n        else:\n            return 0\n\n    return sum([_get_size(store) for store in data.stores])\n\n\ndef get_cpu_memory_from_gc() -> int:\n    r\"\"\"Returns the used CPU memory in bytes, as reported by the\n    :python:`Python` garbage collector.\n    \"\"\"\n    warnings.filterwarnings('ignore', '.*torch.distributed.reduce_op.*')\n\n    mem = 0\n    for obj in gc.get_objects():\n        try:\n            if isinstance(obj, Tensor) and not obj.is_cuda:\n                mem += obj.numel() * obj.element_size()\n        except Exception:\n            pass\n    return mem\n\n\ndef get_gpu_memory_from_gc(device: int = 0) -> int:  # pragma: no cover\n    r\"\"\"Returns the used GPU memory in bytes, as reported by the\n    :python:`Python` garbage collector.\n\n    Args:\n        device (int, optional): The GPU device identifier. (default: :obj:`1`)\n    \"\"\"\n    warnings.filterwarnings('ignore', '.*torch.distributed.reduce_op.*')\n\n    mem = 0\n    for obj in gc.get_objects():\n        try:\n            if isinstance(obj, Tensor) and obj.get_device() == device:\n                mem += obj.numel() * obj.element_size()\n        except Exception:\n            pass\n    return mem\n\n\ndef get_gpu_memory_from_nvidia_smi(  # pragma: no cover\n    device: int = 0,\n    digits: int = 2,\n) -> Tuple[float, float]:\n    r\"\"\"Returns the free and used GPU memory in megabytes, as reported by\n    :obj:`nivdia-smi`.\n\n    .. note::\n\n        :obj:`nvidia-smi` will generally overestimate the amount of memory used\n        by the actual program, see `here <https://pytorch.org/docs/stable/\n        notes/faq.html#my-gpu-memory-isn-t-freed-properly>`__.\n\n    Args:\n        device (int, optional): The GPU device identifier. (default: :obj:`1`)\n        digits (int): The number of decimals to use for megabytes.\n            (default: :obj:`2`)\n    \"\"\"\n    def parse_memory(output: str) -> list:\n        lines = output.decode('utf-8').split('\\n')[1:-1]\n        mem_list = []\n        for line in lines:\n            try:\n                mem_list.append(int(line.split()[0]))\n            except (TypeError, ValueError):\n                mem_list.append(None)\n        return mem_list\n\n    def get_gpu_memory(out_device, digits):\n        if out_device is None:\n            return 0\n\n        return medibyte_to_megabyte(out_device, digits)\n\n    CMD = 'nvidia-smi --query-gpu=memory.free --format=csv'\n    free_out = parse_memory(sp.check_output(CMD.split()))\n\n    CMD = 'nvidia-smi --query-gpu=memory.used --format=csv'\n    used_out = parse_memory(sp.check_output(CMD.split()))\n\n    if device < 0 or device >= len(free_out):\n        raise AttributeError(\n            f'GPU {device} not available (found {len(free_out)} GPUs)')\n\n    free_mem = get_gpu_memory(free_out[device], digits)\n    used_mem = get_gpu_memory(used_out[device], digits)\n    return free_mem, used_mem\n\n\ndef get_gpu_memory_from_ipex(\n        device: int = 0,\n        digits=2) -> Tuple[float, float, float]:  # pragma: no cover\n    r\"\"\"Returns the XPU memory statistics.\n\n    Args:\n        device (int, optional): The GPU device identifier. (default: :obj:`0`)\n        digits (int): The number of decimals to use for megabytes.\n            (default: :obj:`2`)\n    \"\"\"\n    import intel_extension_for_pytorch as ipex\n    stats = ipex.xpu.memory_stats_as_nested_dict(device)\n    max_allocated = stats['allocated_bytes']['all']['peak']\n    max_reserved = stats['reserved_bytes']['all']['peak']\n    max_active = stats['active_bytes']['all']['peak']\n    max_allocated = byte_to_megabyte(max_allocated, digits)\n    max_reserved = byte_to_megabyte(max_reserved, digits)\n    max_active = byte_to_megabyte(max_active, digits)\n    ipex.xpu.reset_peak_memory_stats(device)\n    return max_allocated, max_reserved, max_active\n\n\n###############################################################################\n\n\ndef byte_to_megabyte(value: int, digits: int = 2) -> float:\n    return round(value / (1024 * 1024), digits)\n\n\ndef medibyte_to_megabyte(value: int, digits: int = 2) -> float:\n    return round(1.0485 * value, digits)\n"
  },
  {
    "path": "torch_geometric/resolver.py",
    "content": "import inspect\nfrom typing import Any, Dict, List, Optional, Union\n\n\ndef normalize_string(s: str) -> str:\n    return s.lower().replace('-', '').replace('_', '').replace(' ', '')\n\n\ndef resolver(\n    classes: List[Any],\n    class_dict: Dict[str, Any],\n    query: Union[Any, str],\n    base_cls: Optional[Any],\n    base_cls_repr: Optional[str],\n    *args: Any,\n    **kwargs: Any,\n) -> Any:\n\n    if not isinstance(query, str):\n        return query\n\n    query_repr = normalize_string(query)\n    if base_cls_repr is None:\n        base_cls_repr = base_cls.__name__ if base_cls else ''\n    base_cls_repr = normalize_string(base_cls_repr)\n\n    for key_repr, cls in class_dict.items():\n        if query_repr == key_repr:\n            if inspect.isclass(cls):\n                obj = cls(*args, **kwargs)\n                return obj\n            return cls\n\n    for cls in classes:\n        cls_repr = normalize_string(cls.__name__)\n        if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]:\n            if inspect.isclass(cls):\n                obj = cls(*args, **kwargs)\n                return obj\n            return cls\n\n    choices = {cls.__name__ for cls in classes} | set(class_dict.keys())\n    raise ValueError(f\"Could not resolve '{query}' among choices {choices}\")\n"
  },
  {
    "path": "torch_geometric/sampler/__init__.py",
    "content": "r\"\"\"Graph sampler package.\"\"\"\n\nfrom .base import (BaseSampler, NodeSamplerInput, EdgeSamplerInput,\n                   SamplerOutput, HeteroSamplerOutput, NegativeSampling,\n                   NumNeighbors)\nfrom .neighbor_sampler import NeighborSampler, BidirectionalNeighborSampler\nfrom .hgt_sampler import HGTSampler\n\n__all__ = classes = [\n    'BaseSampler',\n    'NodeSamplerInput',\n    'EdgeSamplerInput',\n    'SamplerOutput',\n    'HeteroSamplerOutput',\n    'NumNeighbors',\n    'NegativeSampling',\n    'NeighborSampler',\n    'BidirectionalNeighborSampler',\n    'HGTSampler',\n]\n"
  },
  {
    "path": "torch_geometric/sampler/base.py",
    "content": "import copy\nimport math\nimport warnings\nfrom abc import ABC, abstractmethod\nfrom collections import defaultdict\nfrom dataclasses import dataclass, field\nfrom enum import Enum\nfrom typing import Any, Dict, List, Literal, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData\nfrom torch_geometric.sampler.utils import (\n    global_to_local_node_idx,\n    local_to_global_node_idx,\n    to_bidirectional,\n    unique_unsorted,\n)\nfrom torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType, OptTensor\nfrom torch_geometric.utils.mixin import CastMixin\n\n\nclass DataType(Enum):\n    r\"\"\"The data type a sampler is operating on.\"\"\"\n    homogeneous = 'homogeneous'\n    heterogeneous = 'heterogeneous'\n    remote = 'remote'\n\n    @classmethod\n    def from_data(cls, data: Any):\n        if isinstance(data, Data):\n            return cls.homogeneous\n        elif isinstance(data, HeteroData):\n            return cls.heterogeneous\n        elif (isinstance(data, (list, tuple)) and len(data) == 2\n              and isinstance(data[0], FeatureStore)\n              and isinstance(data[1], GraphStore)):\n            return cls.remote\n\n        raise ValueError(f\"Expected a 'Data', 'HeteroData', or a tuple of \"\n                         f\"'FeatureStore' and 'GraphStore' \"\n                         f\"(got '{type(data)}')\")\n\n\nclass SubgraphType(Enum):\n    r\"\"\"The type of the returned subgraph.\"\"\"\n    directional = 'directional'\n    bidirectional = 'bidirectional'\n    induced = 'induced'\n\n\n@dataclass(init=False)\nclass NodeSamplerInput(CastMixin):\n    r\"\"\"The sampling input of\n    :meth:`~torch_geometric.sampler.BaseSampler.sample_from_nodes`.\n\n    Args:\n        input_id (torch.Tensor, optional): The indices of the data loader input\n            of the current mini-batch.\n        node (torch.Tensor): The indices of seed nodes to start sampling from.\n        time (torch.Tensor, optional): The timestamp for the seed nodes.\n            (default: :obj:`None`)\n        input_type (str, optional): The input node type (in case of sampling in\n            a heterogeneous graph). (default: :obj:`None`)\n    \"\"\"\n    input_id: OptTensor\n    node: Tensor\n    time: OptTensor = None\n    input_type: Optional[NodeType] = None\n\n    def __init__(\n        self,\n        input_id: OptTensor,\n        node: Tensor,\n        time: OptTensor = None,\n        input_type: Optional[NodeType] = None,\n    ):\n        if input_id is not None:\n            input_id = input_id.cpu()\n        node = node.cpu()\n        if time is not None:\n            time = time.cpu()\n\n        self.input_id = input_id\n        self.node = node\n        self.time = time\n        self.input_type = input_type\n\n    def __getitem__(self, index: Union[Tensor, Any]) -> 'NodeSamplerInput':\n        if not isinstance(index, Tensor):\n            index = torch.tensor(index, dtype=torch.long)\n\n        return NodeSamplerInput(\n            self.input_id[index] if self.input_id is not None else index,\n            self.node[index],\n            self.time[index] if self.time is not None else None,\n            self.input_type,\n        )\n\n\n@dataclass(init=False)\nclass EdgeSamplerInput(CastMixin):\n    r\"\"\"The sampling input of\n    :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.\n\n    Args:\n        input_id (torch.Tensor, optional): The indices of the data loader input\n            of the current mini-batch.\n        row (torch.Tensor): The source node indices of seed links to start\n            sampling from.\n        col (torch.Tensor): The destination node indices of seed links to start\n            sampling from.\n        label (torch.Tensor, optional): The label for the seed links.\n            (default: :obj:`None`)\n        time (torch.Tensor, optional): The timestamp for the seed links.\n            (default: :obj:`None`)\n        input_type (Tuple[str, str, str], optional): The input edge type (in\n            case of sampling in a heterogeneous graph). (default: :obj:`None`)\n    \"\"\"\n    input_id: OptTensor\n    row: Tensor\n    col: Tensor\n    label: OptTensor = None\n    time: OptTensor = None\n    input_type: Optional[EdgeType] = None\n\n    def __init__(\n        self,\n        input_id: OptTensor,\n        row: Tensor,\n        col: Tensor,\n        label: OptTensor = None,\n        time: OptTensor = None,\n        input_type: Optional[EdgeType] = None,\n    ):\n        if input_id is not None:\n            input_id = input_id.cpu()\n        row = row.clone().cpu()\n        col = col.clone().cpu()\n        if label is not None:\n            label = label.cpu()\n        if time is not None:\n            time = time.cpu()\n\n        self.input_id = input_id\n        self.row = row\n        self.col = col\n        self.label = label\n        self.time = time\n        self.input_type = input_type\n\n    def __getitem__(self, index: Union[Tensor, Any]) -> 'EdgeSamplerInput':\n        if not isinstance(index, Tensor):\n            index = torch.tensor(index, dtype=torch.long)\n\n        return EdgeSamplerInput(\n            self.input_id[index] if self.input_id is not None else index,\n            self.row[index],\n            self.col[index],\n            self.label[index] if self.label is not None else None,\n            self.time[index] if self.time is not None else None,\n            self.input_type,\n        )\n\n\n@dataclass\nclass SamplerOutput(CastMixin):\n    r\"\"\"The sampling output of a :class:`~torch_geometric.sampler.BaseSampler`\n    on homogeneous graphs.\n\n    Args:\n        node (torch.Tensor): The sampled nodes in the original graph.\n        row (torch.Tensor): The source node indices of the sampled subgraph.\n            Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }`\n            corresponding to the nodes in the :obj:`node` tensor.\n        col (torch.Tensor): The destination node indices of the sampled\n            subgraph.\n            Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }`\n            corresponding to the nodes in the :obj:`node` tensor.\n        edge (torch.Tensor, optional): The sampled edges in the original graph.\n            This tensor is used to obtain edge features from the original\n            graph. If no edge attributes are present, it may be omitted.\n        batch (torch.Tensor, optional): The vector to identify the seed node\n            for each sampled node. Can be present in case of disjoint subgraph\n            sampling per seed node. (default: :obj:`None`)\n        num_sampled_nodes (List[int], optional): The number of sampled nodes\n            per hop. (default: :obj:`None`)\n        num_sampled_edges (List[int], optional): The number of sampled edges\n            per hop. (default: :obj:`None`)\n        orig_row (torch.Tensor, optional): The original source node indices\n            returned by the sampler.\n            Filled in case :meth:`to_bidirectional` is called with the\n            :obj:`keep_orig_edges` option. (default: :obj:`None`)\n        orig_col (torch.Tensor, optional): The original destination node\n            indices indices returned by the sampler.\n            Filled in case :meth:`to_bidirectional` is called with the\n            :obj:`keep_orig_edges` option. (default: :obj:`None`)\n        metadata: (Any, optional): Additional metadata information.\n            (default: :obj:`None`)\n    \"\"\"\n    node: Tensor\n    row: Tensor\n    col: Tensor\n    edge: OptTensor\n    batch: OptTensor = None\n    num_sampled_nodes: Optional[List[int]] = None\n    num_sampled_edges: Optional[List[int]] = None\n    orig_row: Tensor = None\n    orig_col: Tensor = None\n    # TODO(manan): refine this further; it does not currently define a proper\n    # API for the expected output of a sampler.\n    metadata: Optional[Any] = None\n    _seed_node: OptTensor = field(repr=False, default=None)\n\n    @property\n    def global_row(self) -> Tensor:\n        return local_to_global_node_idx(self.node, self.row)\n\n    @property\n    def global_col(self) -> Tensor:\n        return local_to_global_node_idx(self.node, self.col)\n\n    @property\n    def seed_node(self) -> Tensor:\n        # can be set manually if the seed nodes are not contained in the\n        # sampled nodes\n        if self._seed_node is None:\n            self._seed_node = local_to_global_node_idx(\n                self.node, self.batch) if self.batch is not None else None\n        return self._seed_node\n\n    @seed_node.setter\n    def seed_node(self, value: Tensor):\n        assert len(value) == len(self.node)\n        self._seed_node = value\n\n    @property\n    def global_orig_row(self) -> Tensor:\n        return local_to_global_node_idx(\n            self.node, self.orig_row) if self.orig_row is not None else None\n\n    @property\n    def global_orig_col(self) -> Tensor:\n        return local_to_global_node_idx(\n            self.node, self.orig_col) if self.orig_col is not None else None\n\n    def to_bidirectional(\n        self,\n        keep_orig_edges: bool = False,\n    ) -> 'SamplerOutput':\n        r\"\"\"Converts the sampled subgraph into a bidirectional variant, in\n        which all sampled edges are guaranteed to be bidirectional.\n\n        Args:\n            keep_orig_edges (bool, optional): If specified, directional edges\n                are still maintained. (default: :obj:`False`)\n        \"\"\"\n        out = copy.copy(self)\n\n        if keep_orig_edges:\n            out.orig_row = self.row\n            out.orig_col = self.col\n        else:\n            out.num_sampled_nodes = out.num_sampled_edges = None\n\n        out.row, out.col, out.edge = to_bidirectional(\n            row=self.row,\n            col=self.col,\n            rev_row=self.row,\n            rev_col=self.col,\n            edge_id=self.edge,\n            rev_edge_id=self.edge,\n        )\n\n        return out\n\n    @classmethod\n    def collate(cls, outputs: List['SamplerOutput'],\n                replace: bool = True) -> 'SamplerOutput':\n        r\"\"\"Collate a list of :class:`~torch_geometric.sampler.SamplerOutput`\n        objects into a single :class:`~torch_geometric.sampler.SamplerOutput`\n        object. Requires that they all have the same fields.\n        \"\"\"\n        if len(outputs) == 0:\n            raise ValueError(\"Cannot collate an empty list of SamplerOutputs\")\n        out = outputs[0]\n        has_edge = out.edge is not None\n        has_orig_row = out.orig_row is not None\n        has_orig_col = out.orig_col is not None\n        has_batch = out.batch is not None\n        has_num_sampled_nodes = out.num_sampled_nodes is not None\n        has_num_sampled_edges = out.num_sampled_edges is not None\n\n        try:\n            for i, sample_output in enumerate(outputs):  # noqa\n                assert not has_edge == (sample_output.edge is None)\n                assert not has_orig_row == (sample_output.orig_row is None)\n                assert not has_orig_col == (sample_output.orig_col is None)\n                assert not has_batch == (sample_output.batch is None)\n                assert not has_num_sampled_nodes == (\n                    sample_output.num_sampled_nodes is None)\n                assert not has_num_sampled_edges == (\n                    sample_output.num_sampled_edges is None)\n        except AssertionError:\n            error_str = f\"Output {i+1} has a different field than the first output\"  # noqa\n            raise ValueError(error_str)  # noqa\n\n        for other in outputs[1:]:\n            out = out.merge_with(other, replace=replace)\n        return out\n\n    def merge_with(self, other: 'SamplerOutput',\n                   replace: bool = True) -> 'SamplerOutput':\n        \"\"\"Merges two SamplerOutputs.\n        If replace is True, self's nodes and edges take precedence.\n        \"\"\"\n        if not replace:\n            return SamplerOutput(\n                node=torch.cat([self.node, other.node], dim=0),\n                row=torch.cat([self.row, len(self.node) + other.row], dim=0),\n                col=torch.cat([self.col, len(self.node) + other.col], dim=0),\n                edge=torch.cat([self.edge, other.edge], dim=0)\n                if self.edge is not None and other.edge is not None else None,\n                batch=torch.cat(\n                    [self.batch, len(self.node) + other.batch], dim=0) if\n                self.batch is not None and other.batch is not None else None,\n                num_sampled_nodes=self.num_sampled_nodes +\n                other.num_sampled_nodes if self.num_sampled_nodes is not None\n                and other.num_sampled_nodes is not None else None,\n                num_sampled_edges=self.num_sampled_edges +\n                other.num_sampled_edges if self.num_sampled_edges is not None\n                and other.num_sampled_edges is not None else None,\n                orig_row=torch.cat(\n                    [self.orig_row,\n                     len(self.node) +\n                     other.orig_row], dim=0) if self.orig_row is not None\n                and other.orig_row is not None else None,\n                orig_col=torch.cat(\n                    [self.orig_col,\n                     len(self.node) +\n                     other.orig_col], dim=0) if self.orig_col is not None\n                and other.orig_col is not None else None,\n                metadata=[self.metadata, other.metadata],\n            )\n        else:\n\n            # NODES\n            old_nodes, new_nodes = self.node, other.node\n            old_node_uid, new_node_uid = [old_nodes], [new_nodes]\n\n            # batch tracks disjoint subgraph samplings\n            if self.batch is not None and other.batch is not None:\n                # Transform the batch indices to be global node ids\n                old_batch_nodes = self.seed_node\n                new_batch_nodes = other.seed_node\n                old_node_uid.append(old_batch_nodes)\n                new_node_uid.append(new_batch_nodes)\n\n            # NOTE: if any new node fields are added,\n            # they need to be merged here\n\n            old_node_uid = torch.stack(old_node_uid, dim=1)\n            new_node_uid = torch.stack(new_node_uid, dim=1)\n\n            merged_node_uid = unique_unsorted(\n                torch.cat([old_node_uid, new_node_uid], dim=0))\n            num_old_nodes = old_node_uid.shape[0]\n\n            # Recompute num sampled nodes for second output,\n            # subtracting out nodes already seen in first output\n            merged_node_num_sampled_nodes = None\n            if (self.num_sampled_nodes is not None\n                    and other.num_sampled_nodes is not None):\n                merged_node_num_sampled_nodes = copy.copy(\n                    self.num_sampled_nodes)\n                curr_index = 0\n                # NOTE: There's an assumption here that no two nodes will be\n                # sampled twice in the same SampleOutput object\n                for minibatch in other.num_sampled_nodes:\n                    size_of_intersect = torch.cat([\n                        old_node_uid,\n                        new_node_uid[curr_index:curr_index + minibatch]\n                    ]).unique(dim=0, sorted=False).shape[0] - num_old_nodes\n                    merged_node_num_sampled_nodes.append(size_of_intersect)\n                    curr_index += minibatch\n\n            merged_nodes = merged_node_uid[:, 0]\n            merged_batch = None\n            if self.batch is not None and other.batch is not None:\n                # Restore the batch indices to be relative to the nodes field\n                ref_merged_batch_nodes = merged_node_uid[:, 1].unsqueeze(\n                    -1).expand(-1, 2)  # num_nodes x 2\n                merged_batch = global_to_local_node_idx(\n                    merged_node_uid, ref_merged_batch_nodes)\n\n            # EDGES\n            is_bidirectional = self.orig_row is not None \\\n                and self.orig_col is not None \\\n                and other.orig_row is not None \\\n                and other.orig_col is not None\n            if is_bidirectional:\n                old_row, old_col = self.orig_row, self.orig_col\n                new_row, new_col = other.orig_row, other.orig_col\n            else:\n                old_row, old_col = self.row, self.col\n                new_row, new_col = other.row, other.col\n\n            # Transform the row and col indices to be global node ids\n            # instead of relative indices to nodes field\n            # Edge uids build off of node uids\n            old_row_idx, old_col_idx = local_to_global_node_idx(\n                old_node_uid,\n                old_row), local_to_global_node_idx(old_node_uid, old_col)\n            new_row_idx, new_col_idx = local_to_global_node_idx(\n                new_node_uid,\n                new_row), local_to_global_node_idx(new_node_uid, new_col)\n\n            old_edge_uid, new_edge_uid = [old_row_idx, old_col_idx\n                                          ], [new_row_idx, new_col_idx]\n\n            row_idx = 0\n            col_idx = old_row_idx.shape[1]\n            edge_idx = old_row_idx.shape[1] + old_col_idx.shape[1]\n\n            if self.edge is not None and other.edge is not None:\n                if is_bidirectional:\n                    # bidirectional duplicates edge ids\n                    old_edge_uid_ref = torch.stack([self.row, self.col],\n                                                   dim=1)  # num_edges x 2\n                    old_orig_edge_uid_ref = torch.stack(\n                        [self.orig_row, self.orig_col],\n                        dim=1)  # num_orig_edges x 2\n\n                    old_edge_idx = global_to_local_node_idx(\n                        old_edge_uid_ref, old_orig_edge_uid_ref)\n                    old_edge = self.edge[old_edge_idx]\n\n                    new_edge_uid_ref = torch.stack([other.row, other.col],\n                                                   dim=1)  # num_edges x 2\n                    new_orig_edge_uid_ref = torch.stack(\n                        [other.orig_row, other.orig_col],\n                        dim=1)  # num_orig_edges x 2\n\n                    new_edge_idx = global_to_local_node_idx(\n                        new_edge_uid_ref, new_orig_edge_uid_ref)\n                    new_edge = other.edge[new_edge_idx]\n\n                else:\n                    old_edge, new_edge = self.edge, other.edge\n\n                old_edge_uid.append(old_edge.unsqueeze(-1))\n                new_edge_uid.append(new_edge.unsqueeze(-1))\n\n            old_edge_uid = torch.cat(old_edge_uid, dim=1)\n            new_edge_uid = torch.cat(new_edge_uid, dim=1)\n\n            merged_edge_uid = unique_unsorted(\n                torch.cat([old_edge_uid, new_edge_uid], dim=0))\n            num_old_edges = old_edge_uid.shape[0]\n\n            merged_edge_num_sampled_edges = None\n            if (self.num_sampled_edges is not None\n                    and other.num_sampled_edges is not None):\n                merged_edge_num_sampled_edges = copy.copy(\n                    self.num_sampled_edges)\n                curr_index = 0\n                # NOTE: There's an assumption here that no two edges will be\n                # sampled twice in the same SampleOutput object\n                for minibatch in other.num_sampled_edges:\n                    size_of_intersect = torch.cat([\n                        old_edge_uid,\n                        new_edge_uid[curr_index:curr_index + minibatch]\n                    ]).unique(dim=0, sorted=False).shape[0] - num_old_edges\n                    merged_edge_num_sampled_edges.append(size_of_intersect)\n                    curr_index += minibatch\n\n            merged_row = merged_edge_uid[:, row_idx:col_idx]\n            merged_col = merged_edge_uid[:, col_idx:edge_idx]\n            merged_edge = merged_edge_uid[:, edge_idx:].squeeze() \\\n                if self.edge is not None and other.edge is not None else None\n\n            # restore to row and col indices relative to nodes field\n            merged_row = global_to_local_node_idx(merged_node_uid, merged_row)\n            merged_col = global_to_local_node_idx(merged_node_uid, merged_col)\n\n            out = SamplerOutput(\n                node=merged_nodes,\n                row=merged_row,\n                col=merged_col,\n                edge=merged_edge,\n                batch=merged_batch,\n                num_sampled_nodes=merged_node_num_sampled_nodes,\n                num_sampled_edges=merged_edge_num_sampled_edges,\n                metadata=[self.metadata, other.metadata],\n            )\n            # Restores orig_row and orig_col if they existed before merging\n            if is_bidirectional:\n                out = out.to_bidirectional(keep_orig_edges=True)\n            return out\n\n\n@dataclass\nclass HeteroSamplerOutput(CastMixin):\n    r\"\"\"The sampling output of a :class:`~torch_geometric.sampler.BaseSampler`\n    on heterogeneous graphs.\n\n    Args:\n        node (Dict[str, torch.Tensor]): The sampled nodes in the original graph\n            for each node type.\n        row (Dict[Tuple[str, str, str], torch.Tensor]): The source node indices\n            of the sampled subgraph for each edge type.\n            Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }`\n            corresponding to the nodes in the :obj:`node` tensor of the source\n            node type.\n        col (Dict[Tuple[str, str, str], torch.Tensor]): The destination node\n            indices of the sampled subgraph for each edge type.\n            Indices must be re-indexed to :obj:`{ 0, ..., num_nodes - 1 }`\n            corresponding to the nodes in the :obj:`node` tensor of the\n            destination node type.\n        edge (Dict[Tuple[str, str, str], torch.Tensor], optional): The sampled\n            edges in the original graph for each edge type.\n            This tensor is used to obtain edge features from the original\n            graph. If no edge attributes are present, it may be omitted.\n        batch (Dict[str, torch.Tensor], optional): The vector to identify the\n            seed node for each sampled node for each node type. Can be present\n            in case of disjoint subgraph sampling per seed node.\n            (default: :obj:`None`)\n        num_sampled_nodes (Dict[str, List[int]], optional): The number of\n            sampled nodes for each node type and each layer.\n            (default: :obj:`None`)\n        num_sampled_edges (Dict[EdgeType, List[int]], optional): The number of\n            sampled edges for each edge type and each layer.\n            (default: :obj:`None`)\n        orig_row (Dict[EdgeType, torch.Tensor], optional): The original source\n            node indices returned by the sampler.\n            Filled in case :meth:`to_bidirectional` is called with the\n            :obj:`keep_orig_edges` option. (default: :obj:`None`)\n        orig_col (Dict[EdgeType, torch.Tensor], optional): The original\n            destination node indices returned by the sampler.\n            Filled in case :meth:`to_bidirectional` is called with the\n            :obj:`keep_orig_edges` option. (default: :obj:`None`)\n        metadata: (Any, optional): Additional metadata information.\n            (default: :obj:`None`)\n    \"\"\"\n    node: Dict[NodeType, Tensor]\n    row: Dict[EdgeType, Tensor]\n    col: Dict[EdgeType, Tensor]\n    edge: Dict[EdgeType, OptTensor]\n    batch: Optional[Dict[NodeType, Tensor]] = None\n    num_sampled_nodes: Optional[Dict[NodeType, List[int]]] = None\n    num_sampled_edges: Optional[Dict[EdgeType, List[int]]] = None\n    orig_row: Optional[Dict[EdgeType, Tensor]] = None\n    orig_col: Optional[Dict[EdgeType, Tensor]] = None\n    # TODO(manan): refine this further; it does not currently define a proper\n    # API for the expected output of a sampler.\n    metadata: Optional[Any] = None\n\n    @property\n    def global_row(self) -> Dict[EdgeType, Tensor]:\n        return {\n            edge_type: local_to_global_node_idx(self.node[edge_type[0]], row)\n            for edge_type, row in self.row.items()\n        }\n\n    @property\n    def global_col(self) -> Dict[EdgeType, Tensor]:\n        return {\n            edge_type: local_to_global_node_idx(self.node[edge_type[2]], col)\n            for edge_type, col in self.col.items()\n        }\n\n    @property\n    def seed_node(self) -> Optional[Dict[NodeType, Tensor]]:\n        return {\n            node_type: local_to_global_node_idx(self.node[node_type], batch)\n            for node_type, batch in self.batch.items()\n        } if self.batch is not None else None\n\n    @property\n    def global_orig_row(self) -> Optional[Dict[EdgeType, Tensor]]:\n        return {\n            edge_type: local_to_global_node_idx(self.node[edge_type[0]],\n                                                orig_row)\n            for edge_type, orig_row in self.orig_row.items()\n        } if self.orig_row is not None else None\n\n    @property\n    def global_orig_col(self) -> Optional[Dict[EdgeType, Tensor]]:\n        return {\n            edge_type: local_to_global_node_idx(self.node[edge_type[2]],\n                                                orig_col)\n            for edge_type, orig_col in self.orig_col.items()\n        } if self.orig_col is not None else None\n\n    def to_bidirectional(\n        self,\n        keep_orig_edges: bool = False,\n    ) -> 'SamplerOutput':\n        r\"\"\"Converts the sampled subgraph into a bidirectional variant, in\n        which all sampled edges are guaranteed to be bidirectional.\n\n        Args:\n            keep_orig_edges (bool, optional): If specified, directional edges\n                are still maintained. (default: :obj:`False`)\n        \"\"\"\n        out = copy.copy(self)\n        out.row = copy.copy(self.row)\n        out.col = copy.copy(self.col)\n        out.edge = copy.copy(self.edge)\n\n        if keep_orig_edges:\n            out.orig_row = {}\n            out.orig_col = {}\n            for key in self.row.keys():\n                out.orig_row[key] = self.row[key]\n                out.orig_col[key] = self.col[key]\n        else:\n            out.num_sampled_nodes = out.num_sampled_edges = None\n\n        src_dst_dict = defaultdict(list)\n        edge_types = self.row.keys()\n        edge_types = [k for k in edge_types if not k[1].startswith('rev_')]\n        for edge_type in edge_types:\n            src, rel, dst = edge_type\n            rev_edge_type = (dst, f'rev_{rel}', src)\n\n            if src == dst and rev_edge_type not in self.row:\n                out.row[edge_type], out.col[edge_type], _ = to_bidirectional(\n                    row=self.row[edge_type],\n                    col=self.col[edge_type],\n                    rev_row=self.row[edge_type],\n                    rev_col=self.col[edge_type],\n                )\n                if out.edge is not None:\n                    out.edge[edge_type] = None\n\n            elif rev_edge_type in self.row:\n                out.row[edge_type], out.col[edge_type], _ = to_bidirectional(\n                    row=self.row[edge_type],\n                    col=self.col[edge_type],\n                    rev_row=self.row[rev_edge_type],\n                    rev_col=self.col[rev_edge_type],\n                )\n                out.row[rev_edge_type] = out.col[edge_type]\n                out.col[rev_edge_type] = out.row[edge_type]\n                if out.edge is not None:\n                    out.edge[edge_type] = None\n                    out.edge[rev_edge_type] = None\n\n            else:  # Find the reverse edge type (if it is unique):\n                if len(src_dst_dict) == 0:  # Create mapping lazily.\n                    for key in self.row.keys():\n                        v1, _, v2 = key\n                        src_dst_dict[(v1, v2)].append(key)\n\n                if len(src_dst_dict[(dst, src)]) == 1:\n                    rev_edge_type = src_dst_dict[(dst, src)][0]\n                    row, col, _ = to_bidirectional(\n                        row=self.row[edge_type],\n                        col=self.col[edge_type],\n                        rev_row=self.row[rev_edge_type],\n                        rev_col=self.col[rev_edge_type],\n                    )\n                    out.row[edge_type] = row\n                    out.col[edge_type] = col\n                    if out.edge is not None:\n                        out.edge[edge_type] = None\n\n                else:\n                    warnings.warn(\n                        f\"Cannot convert to bidirectional graph \"\n                        f\"since the edge type {edge_type} does not \"\n                        f\"seem to have a reverse edge type\", stacklevel=2)\n\n        return out\n\n    @classmethod\n    def collate(cls, outputs: List['HeteroSamplerOutput'],\n                replace: bool = True) -> 'HeteroSamplerOutput':\n        r\"\"\"Collate a list of\n        :class:`~torch_geometric.sampler.HeteroSamplerOutput`objects into a\n        single :class:`~torch_geometric.sampler.HeteroSamplerOutput` object.\n        Requires that they all have the same fields.\n        \"\"\"\n        # TODO(zaristei)\n        raise NotImplementedError\n\n    def merge_with(self, other: 'HeteroSamplerOutput',\n                   replace: bool = True) -> 'HeteroSamplerOutput':\n        \"\"\"Merges two HeteroSamplerOutputs.\n        If replace is True, self's nodes and edges take precedence.\n        \"\"\"\n        # TODO(zaristei)\n        raise NotImplementedError\n\n\n@dataclass(frozen=True)\nclass NumNeighbors:\n    r\"\"\"The number of neighbors to sample in a homogeneous or heterogeneous\n    graph. In heterogeneous graphs, may also take in a dictionary denoting\n    the amount of neighbors to sample for individual edge types.\n\n    Args:\n        values (List[int] or Dict[Tuple[str, str, str], List[int]]): The\n            number of neighbors to sample.\n            If an entry is set to :obj:`-1`, all neighbors will be included.\n            In heterogeneous graphs, may also take in a dictionary denoting\n            the amount of neighbors to sample for individual edge types.\n        default (List[int], optional): The default number of neighbors for edge\n            types not specified in :obj:`values`. (default: :obj:`None`)\n    \"\"\"\n    values: Union[List[int], Dict[EdgeTypeStr, List[int]]]\n    default: Optional[List[int]] = None\n\n    def __init__(\n        self,\n        values: Union[List[int], Dict[EdgeType, List[int]]],\n        default: Optional[List[int]] = None,\n    ):\n        if isinstance(values, (tuple, list)) and default is not None:\n            raise ValueError(f\"'default' must be set to 'None' in case a \"\n                             f\"single list is given as the number of \"\n                             f\"neighbors (got '{type(default)})'\")\n\n        if isinstance(values, dict):\n            values = {EdgeTypeStr(key): value for key, value in values.items()}\n\n        # Write to `__dict__` since dataclass is annotated with `frozen=True`:\n        self.__dict__['values'] = values\n        self.__dict__['default'] = default\n\n    def _get_values(\n        self,\n        edge_types: Optional[List[EdgeType]] = None,\n        mapped: bool = False,\n    ) -> Union[List[int], Dict[Union[EdgeType, EdgeTypeStr], List[int]]]:\n\n        if edge_types is not None:\n            if isinstance(self.values, (tuple, list)):\n                default = self.values\n            elif isinstance(self.values, dict):\n                default = self.default\n            else:\n                raise AssertionError()\n\n            # Confirm that `values` only hold valid edge types:\n            if isinstance(self.values, dict):\n                edge_types_str = {EdgeTypeStr(key) for key in edge_types}\n                invalid_edge_types = set(self.values.keys()) - edge_types_str\n                if len(invalid_edge_types) > 0:\n                    raise ValueError(\"Not all edge types specified in \"\n                                     \"'num_neighbors' exist in the graph\")\n\n            out = {}\n            for edge_type in edge_types:\n                edge_type_str = EdgeTypeStr(edge_type)\n                if edge_type_str in self.values:\n                    out[edge_type_str if mapped else edge_type] = (\n                        self.values[edge_type_str])\n                else:\n                    if default is None:\n                        raise ValueError(f\"Missing number of neighbors for \"\n                                         f\"edge type '{edge_type}'\")\n                    out[edge_type_str if mapped else edge_type] = default\n\n        elif isinstance(self.values, dict) and not mapped:\n            out = {key.to_tuple(): value for key, value in self.values.items()}\n\n        else:\n            out = copy.copy(self.values)\n\n        if isinstance(out, dict):\n            num_hops = {len(v) for v in out.values()}\n            if len(num_hops) > 1:\n                raise ValueError(f\"Number of hops must be the same across all \"\n                                 f\"edge types (got {len(num_hops)} different \"\n                                 f\"number of hops)\")\n\n        return out\n\n    def get_values(\n        self,\n        edge_types: Optional[List[EdgeType]] = None,\n    ) -> Union[List[int], Dict[EdgeType, List[int]]]:\n        r\"\"\"Returns the number of neighbors.\n\n        Args:\n            edge_types (List[Tuple[str, str, str]], optional): The edge types\n                to generate the number of neighbors for. (default: :obj:`None`)\n        \"\"\"\n        if '_values' in self.__dict__:\n            return self.__dict__['_values']\n\n        values = self._get_values(edge_types, mapped=False)\n\n        self.__dict__['_values'] = values\n        return values\n\n    def get_mapped_values(\n        self,\n        edge_types: Optional[List[EdgeType]] = None,\n    ) -> Union[List[int], Dict[str, List[int]]]:\n        r\"\"\"Returns the number of neighbors.\n        For heterogeneous graphs, a dictionary is returned in which edge type\n        tuples are converted to strings.\n\n        Args:\n            edge_types (List[Tuple[str, str, str]], optional): The edge types\n                to generate the number of neighbors for. (default: :obj:`None`)\n        \"\"\"\n        if '_mapped_values' in self.__dict__:\n            return self.__dict__['_mapped_values']\n\n        values = self._get_values(edge_types, mapped=True)\n\n        self.__dict__['_mapped_values'] = values\n        return values\n\n    @property\n    def num_hops(self) -> int:\n        r\"\"\"Returns the number of hops.\"\"\"\n        if '_num_hops' in self.__dict__:\n            return self.__dict__['_num_hops']\n\n        if isinstance(self.values, (tuple, list)):\n            num_hops = max(len(self.values), len(self.default or []))\n        else:  # isinstance(self.values, dict):\n            num_hops = max([0] + [len(v) for v in self.values.values()])\n            num_hops = max(num_hops, len(self.default or []))\n\n        self.__dict__['_num_hops'] = num_hops\n        return num_hops\n\n    def __len__(self) -> int:\n        r\"\"\"Returns the number of hops.\"\"\"\n        return self.num_hops\n\n\nclass NegativeSamplingMode(Enum):\n    # 'binary': Randomly sample negative edges in the graph.\n    binary = 'binary'\n    # 'triplet': Randomly sample negative destination nodes for each positive\n    # source node.\n    triplet = 'triplet'\n\n\n@dataclass\nclass NegativeSampling(CastMixin):\n    r\"\"\"The negative sampling configuration of a\n    :class:`~torch_geometric.sampler.BaseSampler` when calling\n    :meth:`~torch_geometric.sampler.BaseSampler.sample_from_edges`.\n\n    Args:\n        mode (str): The negative sampling mode\n            (:obj:`\"binary\"` or :obj:`\"triplet\"`).\n            If set to :obj:`\"binary\"`, will randomly sample negative links\n            from the graph.\n            If set to :obj:`\"triplet\"`, will randomly sample negative\n            destination nodes for each positive source node.\n        amount (int or float, optional): The ratio of sampled negative edges to\n            the number of positive edges. (default: :obj:`1`)\n        src_weight (torch.Tensor, optional): A node-level vector determining\n            the sampling of source nodes. Does not necessarily need to sum up\n            to one. If not given, negative nodes will be sampled uniformly.\n            (default: :obj:`None`)\n        dst_weight (torch.Tensor, optional): A node-level vector determining\n            the sampling of destination nodes. Does not necessarily need to sum\n            up to one. If not given, negative nodes will be sampled uniformly.\n            (default: :obj:`None`)\n    \"\"\"\n    mode: NegativeSamplingMode\n    amount: Union[int, float] = 1\n    src_weight: Optional[Tensor] = None\n    dst_weight: Optional[Tensor] = None\n\n    def __init__(\n        self,\n        mode: Union[NegativeSamplingMode, str],\n        amount: Union[int, float] = 1,\n        src_weight: Optional[Tensor] = None,\n        dst_weight: Optional[Tensor] = None,\n    ):\n        self.mode = NegativeSamplingMode(mode)\n        self.amount = amount\n        self.src_weight = src_weight\n        self.dst_weight = dst_weight\n\n        if self.amount <= 0:\n            raise ValueError(f\"The attribute 'amount' needs to be positive \"\n                             f\"for '{self.__class__.__name__}' \"\n                             f\"(got {self.amount})\")\n\n        if self.is_triplet():\n            if self.amount != math.ceil(self.amount):\n                raise ValueError(f\"The attribute 'amount' needs to be an \"\n                                 f\"integer for '{self.__class__.__name__}' \"\n                                 f\"with 'triplet' negative sampling \"\n                                 f\"(got {self.amount}).\")\n            self.amount = math.ceil(self.amount)\n\n    def is_binary(self) -> bool:\n        return self.mode == NegativeSamplingMode.binary\n\n    def is_triplet(self) -> bool:\n        return self.mode == NegativeSamplingMode.triplet\n\n    def sample(\n        self,\n        num_samples: int,\n        endpoint: Literal['src', 'dst'],\n        num_nodes: Optional[int] = None,\n    ) -> Tensor:\n        r\"\"\"Generates :obj:`num_samples` negative samples.\"\"\"\n        weight = self.src_weight if endpoint == 'src' else self.dst_weight\n\n        if weight is None:\n            if num_nodes is None:\n                raise ValueError(\n                    f\"Cannot sample negatives in '{self.__class__.__name__}' \"\n                    f\"without passing the 'num_nodes' argument\")\n            return torch.randint(num_nodes, (num_samples, ))\n\n        if num_nodes is not None and weight.numel() != num_nodes:\n            raise ValueError(\n                f\"The 'weight' attribute in '{self.__class__.__name__}' \"\n                f\"needs to match the number of nodes {num_nodes} \"\n                f\"(got {self.weight.numel()})\")\n        return torch.multinomial(weight, num_samples, replacement=True)\n\n\nclass BaseSampler(ABC):\n    r\"\"\"An abstract base class that initializes a graph sampler and provides\n    :meth:`sample_from_nodes` and :meth:`sample_from_edges` routines.\n\n    .. note ::\n\n        Any data stored in the sampler will be *replicated* across data loading\n        workers that use the sampler since each data loading worker holds its\n        own instance of a sampler.\n        As such, it is recommended to limit the amount of information stored in\n        the sampler.\n    \"\"\"\n    @abstractmethod\n    def sample_from_nodes(\n        self,\n        index: NodeSamplerInput,\n        **kwargs,\n    ) -> Union[HeteroSamplerOutput, SamplerOutput]:\n        r\"\"\"Performs sampling from the nodes specified in :obj:`index`,\n        returning a sampled subgraph in the specified output format.\n\n        The :obj:`index` is a tuple holding the following information:\n\n        1. The example indices of the seed nodes\n        2. The node indices to start sampling from\n        3. The timestamps of the given seed nodes (optional)\n\n        Args:\n            index (NodeSamplerInput): The node sampler input object.\n            **kwargs (optional): Additional keyword arguments.\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def sample_from_edges(\n        self,\n        index: EdgeSamplerInput,\n        neg_sampling: Optional[NegativeSampling] = None,\n    ) -> Union[HeteroSamplerOutput, SamplerOutput]:\n        r\"\"\"Performs sampling from the edges specified in :obj:`index`,\n        returning a sampled subgraph in the specified output format.\n\n        The :obj:`index` is a tuple holding the following information:\n\n        1. The example indices of the seed links\n        2. The source node indices to start sampling from\n        3. The destination node indices to start sampling from\n        4. The labels of the seed links (optional)\n        5. The timestamps of the given seed nodes (optional)\n\n        Args:\n            index (EdgeSamplerInput): The edge sampler input object.\n            neg_sampling (NegativeSampling, optional): The negative sampling\n                configuration. (default: :obj:`None`)\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:\n        r\"\"\"If the sampler performs any modification of edge ordering in the\n        original graph, this function is expected to return the permutation\n        tensor that defines the permutation from the edges in the original\n        graph and the edges used in the sampler. If no such permutation was\n        applied, :obj:`None` is returned. For heterogeneous graphs, the\n        expected return type is a permutation tensor for each edge type.\n        \"\"\"\n        return None\n"
  },
  {
    "path": "torch_geometric/sampler/hgt_sampler.py",
    "content": "from typing import Dict, List, Optional, Union\n\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.sampler import (\n    BaseSampler,\n    EdgeSamplerInput,\n    HeteroSamplerOutput,\n    NegativeSampling,\n    NodeSamplerInput,\n    SamplerOutput,\n)\nfrom torch_geometric.sampler.utils import remap_keys, to_hetero_csc\nfrom torch_geometric.typing import (\n    WITH_TORCH_SPARSE,\n    EdgeType,\n    NodeType,\n    OptTensor,\n)\n\n\nclass HGTSampler(BaseSampler):\n    r\"\"\"An implementation of an in-memory heterogeneous layer-wise sampler\n    user by :class:`~torch_geometric.loader.HGTLoader`.\n    \"\"\"\n    def __init__(\n        self,\n        data: HeteroData,\n        num_samples: Union[List[int], Dict[NodeType, List[int]]],\n        is_sorted: bool = False,\n        share_memory: bool = False,\n    ):\n        if not WITH_TORCH_SPARSE:\n            raise ImportError(\n                f\"'{self.__class__.__name__}' requires 'torch-sparse'\")\n\n        if isinstance(data, Data) or isinstance(data, tuple):\n            raise NotImplementedError(\n                f'{self.__class__.__name__} does not support a data object of '\n                f'type {type(data)}.')\n\n        if isinstance(num_samples, (list, tuple)):\n            num_samples = {key: num_samples for key in data.node_types}\n\n        self.node_types, self.edge_types = data.metadata()\n        self.num_samples = num_samples\n        self.num_hops = max([len(v) for v in num_samples.values()])\n\n        # Conversion to/from C++ string type (see `NeighborSampler`):\n        self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}\n        self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}\n\n        # Convert the graph data into a suitable format for sampling:\n        colptr_dict, row_dict, self.perm = to_hetero_csc(\n            data, device='cpu', share_memory=share_memory, is_sorted=is_sorted)\n        self.row_dict = remap_keys(row_dict, self.to_rel_type)\n        self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)\n\n    def sample_from_nodes(\n        self,\n        inputs: NodeSamplerInput,\n    ) -> HeteroSamplerOutput:\n\n        node, row, col, edge = torch.ops.torch_sparse.hgt_sample(\n            self.colptr_dict,\n            self.row_dict,\n            {inputs.input_type: inputs.node},\n            self.num_samples,\n            self.num_hops,\n        )\n\n        return HeteroSamplerOutput(\n            node=node,\n            row=remap_keys(row, self.to_edge_type),\n            col=remap_keys(col, self.to_edge_type),\n            edge=remap_keys(edge, self.to_edge_type),\n            batch=None,\n            metadata=(inputs.input_id, inputs.time),\n        )\n\n    def sample_from_edges(\n        self,\n        index: EdgeSamplerInput,\n        neg_sampling: Optional[NegativeSampling] = None,\n    ) -> Union[HeteroSamplerOutput, SamplerOutput]:\n        pass\n\n    @property\n    def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:\n        return self.perm\n"
  },
  {
    "path": "torch_geometric/sampler/neighbor_sampler.py",
    "content": "import copy\nimport math\nimport sys\nimport warnings\nfrom typing import Callable, Dict, List, Literal, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.data import (\n    Data,\n    FeatureStore,\n    GraphStore,\n    HeteroData,\n    remote_backend_utils,\n)\nfrom torch_geometric.data.graph_store import EdgeLayout\nfrom torch_geometric.sampler import (\n    BaseSampler,\n    EdgeSamplerInput,\n    HeteroSamplerOutput,\n    NegativeSampling,\n    NodeSamplerInput,\n    SamplerOutput,\n)\nfrom torch_geometric.sampler.base import DataType, NumNeighbors, SubgraphType\nfrom torch_geometric.sampler.utils import (\n    global_to_local_node_idx,\n    remap_keys,\n    reverse_edge_type,\n    to_csc,\n    to_hetero_csc,\n)\nfrom torch_geometric.typing import EdgeType, NodeType, OptTensor\n\nNumNeighborsType = Union[NumNeighbors, List[int], Dict[EdgeType, List[int]]]\n\n\nclass NeighborSampler(BaseSampler):\n    r\"\"\"An implementation of an in-memory (heterogeneous) neighbor sampler used\n    by :class:`~torch_geometric.loader.NeighborLoader`.\n    \"\"\"\n    def __init__(\n        self,\n        data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],\n        num_neighbors: NumNeighborsType,\n        subgraph_type: Union[SubgraphType, str] = 'directional',\n        replace: bool = False,\n        disjoint: bool = False,\n        temporal_strategy: str = 'uniform',\n        time_attr: Optional[str] = None,\n        weight_attr: Optional[str] = None,\n        is_sorted: bool = False,\n        share_memory: bool = False,\n        directed: bool = True,  # Deprecated\n        sample_direction: Literal['forward', 'backward'] = 'forward',\n    ):\n        if not directed:\n            subgraph_type = SubgraphType.induced\n            warnings.warn(\n                f\"The usage of the 'directed' argument in \"\n                f\"'{self.__class__.__name__}' is deprecated. Use \"\n                f\"`subgraph_type='induced'` instead.\", stacklevel=2)\n\n        if (not torch_geometric.typing.WITH_PYG_LIB and sys.platform == 'linux'\n                and subgraph_type != SubgraphType.induced):\n            warnings.warn(\n                f\"Using '{self.__class__.__name__}' without a \"\n                f\"'pyg-lib' installation is deprecated and will be \"\n                f\"removed soon. Please install 'pyg-lib' for \"\n                f\"accelerated neighborhood sampling\", stacklevel=2)\n\n        self.data_type = DataType.from_data(data)\n        self.sample_direction = sample_direction\n\n        if self.sample_direction == 'backward':\n            # TODO(zaristei)\n            if time_attr is not None:\n                raise NotImplementedError(\n                    \"Temporal Sampling not yet supported for backward sampling\"\n                )\n\n        if self.data_type == DataType.homogeneous:\n            self.num_nodes = data.num_nodes\n\n            self.node_time: Optional[Tensor] = None\n            self.edge_time: Optional[Tensor] = None\n\n            if time_attr is not None:\n                if data.is_node_attr(time_attr):\n                    self.node_time = data[time_attr]\n                elif data.is_edge_attr(time_attr):\n                    self.edge_time = data[time_attr]\n                else:\n                    raise ValueError(\n                        f\"The time attribute '{time_attr}' is neither a \"\n                        f\"node-level or edge-level attribute\")\n\n            # Convert the graph data into CSC format for sampling:\n            self.colptr, self.row, self.perm = to_csc(\n                data, device='cpu', share_memory=share_memory,\n                is_sorted=is_sorted, src_node_time=self.node_time,\n                edge_time=self.edge_time,\n                to_transpose=self.sample_direction == 'backward')\n\n            if self.edge_time is not None and self.perm is not None:\n                self.edge_time = self.edge_time[self.perm]\n\n            self.edge_weight: Optional[Tensor] = None\n            if weight_attr is not None:\n                self.edge_weight = data[weight_attr]\n                if self.perm is not None:\n                    self.edge_weight = self.edge_weight[self.perm]\n\n        elif self.data_type == DataType.heterogeneous:\n            self.node_types, self.edge_types = data.metadata()\n\n            # reverse edge types if sample_direction is backward\n            if self.sample_direction == 'backward':\n                self.edge_types = [\n                    reverse_edge_type(edge_type)\n                    for edge_type in self.edge_types\n                ]\n                self.to_restored_edge_type = {\n                    k: reverse_edge_type(k)\n                    for k in self.edge_types\n                }\n\n            self.num_nodes = {k: data[k].num_nodes for k in self.node_types}\n\n            self.node_time: Optional[Dict[NodeType, Tensor]] = None\n            self.edge_time: Optional[Dict[EdgeType, Tensor]] = None\n\n            if time_attr is not None:\n                is_node_level_time = is_edge_level_time = False\n\n                for store in data.node_stores:\n                    if time_attr in store:\n                        is_node_level_time = True\n                for store in data.edge_stores:\n                    if time_attr in store:\n                        is_edge_level_time = True\n\n                if is_node_level_time and is_edge_level_time:\n                    raise ValueError(\n                        f\"The time attribute '{time_attr}' holds both \"\n                        f\"node-level and edge-level information\")\n\n                if not is_node_level_time and not is_edge_level_time:\n                    raise ValueError(\n                        f\"The time attribute '{time_attr}' is neither a \"\n                        f\"node-level or edge-level attribute\")\n\n                if is_node_level_time:\n                    self.node_time = data.collect(time_attr)\n                else:\n                    self.edge_time = data.collect(time_attr)\n\n            # Conversion to/from C++ string type: Since C++ cannot take\n            # dictionaries with tuples as key as input, edge type triplets need\n            # to be converted into single strings.\n            self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}\n            self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}\n\n            # Convert the graph data into CSC format for sampling:\n            colptr_dict, row_dict, self.perm = to_hetero_csc(\n                data, device='cpu', share_memory=share_memory,\n                is_sorted=is_sorted, node_time_dict=self.node_time,\n                edge_time_dict=self.edge_time,\n                to_transpose=self.sample_direction == 'backward')\n\n            self.row_dict = remap_keys(row_dict, self.to_rel_type)\n            self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)\n\n            if self.edge_time is not None:\n                for edge_type, edge_time in self.edge_time.items():\n                    if self.perm.get(edge_type, None) is not None:\n                        edge_time = edge_time[self.perm[edge_type]]\n                        self.edge_time[edge_type] = edge_time\n                self.edge_time = remap_keys(self.edge_time, self.to_rel_type)\n\n            self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None\n            if weight_attr is not None:\n                self.edge_weight = data.collect(weight_attr)\n                for edge_type, edge_weight in self.edge_weight.items():\n                    if self.perm.get(edge_type, None) is not None:\n                        edge_weight = edge_weight[self.perm[edge_type]]\n                        self.edge_weight[edge_type] = edge_weight\n                self.edge_weight = remap_keys(self.edge_weight,\n                                              self.to_rel_type)\n\n        else:  # self.data_type == DataType.remote\n            feature_store, graph_store = data\n\n            # Obtain graph metadata:\n            attrs = [attr for attr in feature_store.get_all_tensor_attrs()]\n\n            edge_attrs = graph_store.get_all_edge_attrs()\n            self.edge_types = list({attr.edge_type for attr in edge_attrs})\n\n            # reverse edge types if sample_direction is backward\n            if self.sample_direction == 'backward':\n                self.edge_types = [\n                    reverse_edge_type(edge_type)\n                    for edge_type in self.edge_types\n                ]\n                self.to_restored_edge_type = {\n                    k: reverse_edge_type(k)\n                    for k in self.edge_types\n                }\n                self.to_backward_edge_type = {\n                    v: k\n                    for k, v in self.to_restored_edge_type.items()\n                }\n\n            if weight_attr is not None:\n                raise NotImplementedError(\n                    f\"'weight_attr' argument not yet supported within \"\n                    f\"'{self.__class__.__name__}' for \"\n                    f\"'(FeatureStore, GraphStore)' inputs\")\n\n            if time_attr is not None:\n                # If the `time_attr` is present, we expect that `GraphStore`\n                # holds all edges sorted by destination, and within local\n                # neighborhoods, node indices should be sorted by time.\n                # TODO (matthias, manan) Find an alternative way to ensure.\n                for edge_attr in edge_attrs:\n                    if edge_attr.layout == EdgeLayout.CSR:\n                        raise ValueError(\n                            \"Temporal sampling requires that edges are stored \"\n                            \"in either COO or CSC layout\")\n                    if not edge_attr.is_sorted:\n                        raise ValueError(\n                            \"Temporal sampling requires that edges are \"\n                            \"sorted by destination, and by source time \"\n                            \"within local neighborhoods\")\n\n                # We obtain all features with `node_attr.name=time_attr`:\n                time_attrs = [\n                    copy.copy(attr) for attr in attrs\n                    if attr.attr_name == time_attr\n                ]\n\n            if not self.is_hetero:\n                self.node_types = [None]\n                self.num_nodes = max(edge_attrs[0].size)\n                self.edge_weight: Optional[Tensor] = None\n\n                self.node_time: Optional[Tensor] = None\n                self.edge_time: Optional[Tensor] = None\n\n                if time_attr is not None:\n                    if len(time_attrs) != 1:\n                        raise ValueError(\"Temporal sampling specified but did \"\n                                         \"not find any temporal data\")\n                    time_attrs[0].index = None  # Reset index for full data.\n                    time_tensor = feature_store.get_tensor(time_attrs[0])\n                    # Currently, we determine whether to use node-level or\n                    # edge-level temporal sampling based on the attribute name.\n                    if time_attr == 'time':\n                        self.node_time = time_tensor\n                    else:\n                        self.edge_time = time_tensor\n\n                if self.sample_direction == 'forward':\n                    self.row, self.colptr, self.perm = graph_store.csc()\n                elif self.sample_direction == 'backward':\n                    self.colptr, self.row, self.perm = graph_store.csr()\n\n            else:\n                node_types = [\n                    attr.group_name for attr in attrs\n                    if isinstance(attr.group_name, str)\n                ]\n                self.node_types = list(set(node_types))\n                self.num_nodes = {\n                    node_type: remote_backend_utils.size(*data, node_type)\n                    for node_type in self.node_types\n                }\n                self.edge_weight: Optional[Dict[EdgeType, Tensor]] = None\n\n                self.node_time: Optional[Dict[NodeType, Tensor]] = None\n                self.edge_time: Optional[Dict[EdgeType, Tensor]] = None\n\n                if time_attr is not None:\n                    for attr in time_attrs:  # Reset index for full data.\n                        attr.index = None\n\n                    time_tensors = feature_store.multi_get_tensor(time_attrs)\n                    time = {\n                        attr.group_name: time_tensor\n                        for attr, time_tensor in zip(time_attrs, time_tensors)\n                    }\n\n                    group_names = [attr.group_name for attr in time_attrs]\n                    if all([isinstance(g, str) for g in group_names]):\n                        self.node_time = time\n                    elif all([isinstance(g, tuple) for g in group_names]):\n                        self.edge_time = time\n                    else:\n                        raise ValueError(\n                            f\"Found time attribute '{time_attr}' for both \"\n                            f\"node-level and edge-level types\")\n\n                # Conversion to/from C++ string type (see above):\n                self.to_rel_type = {k: '__'.join(k) for k in self.edge_types}\n                self.to_edge_type = {v: k for k, v in self.to_rel_type.items()}\n                if self.sample_direction == 'forward':\n                    row_dict, colptr_dict, self.perm = graph_store.csc()\n                elif self.sample_direction == 'backward':\n                    colptr_dict, row_dict, self.perm = graph_store.csr()\n\n                    colptr_dict = remap_keys(colptr_dict,\n                                             self.to_backward_edge_type)\n                    row_dict = remap_keys(row_dict, self.to_backward_edge_type)\n                    self.perm = remap_keys(self.perm,\n                                           self.to_backward_edge_type)\n\n                self.row_dict = remap_keys(row_dict, self.to_rel_type)\n                self.colptr_dict = remap_keys(colptr_dict, self.to_rel_type)\n\n        if (self.edge_time is not None\n                and not torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE):\n            raise ImportError(\"Edge-level temporal sampling requires a \"\n                              \"more recent 'pyg-lib' installation\")\n\n        if (self.edge_weight is not None\n                and not torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE):\n            raise ImportError(\"Weighted neighbor sampling requires \"\n                              \"'pyg-lib>=0.3.0'\")\n\n        self.num_neighbors = num_neighbors\n        self.replace = replace\n        self.subgraph_type = SubgraphType(subgraph_type)\n        self.disjoint = disjoint\n        self.temporal_strategy = temporal_strategy\n        self.keep_orig_edges = False\n\n    @property\n    def num_neighbors(self) -> NumNeighbors:\n        if self.sample_direction == 'backward':\n            return self._input_num_neighbors \\\n                if self._input_num_neighbors is not None \\\n                else self._num_neighbors\n        return self._num_neighbors\n\n    @num_neighbors.setter\n    def num_neighbors(self, num_neighbors: NumNeighborsType):\n        # only used if sample direction is backward and num_neighbors has edge\n        # keys\n        self._input_num_neighbors = None\n\n        if isinstance(num_neighbors, NumNeighbors):\n            num_neighbors_values = num_neighbors.values\n            if isinstance(num_neighbors_values,\n                          dict) and self.sample_direction == 'backward':\n                # reverse the edge_types if sample_direction is backward\n                self._input_num_neighbors = num_neighbors\n                num_neighbors_values = remap_keys(num_neighbors_values,\n                                                  self.to_backward_edge_type)\n                self._num_neighbors = NumNeighbors(num_neighbors_values)\n            else:\n                self._num_neighbors = num_neighbors\n        else:\n            if isinstance(num_neighbors,\n                          dict) and self.sample_direction == 'backward':\n                # intentionally recursing here to make sure num_neighbors is\n                # set as expected for the user\n                self.num_neighbors = NumNeighbors(\n                    remap_keys(num_neighbors, self.to_backward_edge_type))\n            else:\n                self._num_neighbors = NumNeighbors(num_neighbors)\n\n    @property\n    def is_hetero(self) -> bool:\n        if self.data_type == DataType.homogeneous:\n            return False\n        if self.data_type == DataType.heterogeneous:\n            return True\n\n        # self.data_type == DataType.remote\n        return self.edge_types != [None]\n\n    @property\n    def is_temporal(self) -> bool:\n        return self.node_time is not None or self.edge_time is not None\n\n    @property\n    def disjoint(self) -> bool:\n        return self._disjoint or self.is_temporal\n\n    @disjoint.setter\n    def disjoint(self, disjoint: bool):\n        self._disjoint = disjoint\n\n    # Node-based sampling #####################################################\n\n    def sample_from_nodes(\n        self,\n        inputs: NodeSamplerInput,\n    ) -> Union[SamplerOutput, HeteroSamplerOutput]:\n        out = node_sample(inputs, self._sample)\n        if self.subgraph_type == SubgraphType.bidirectional:\n            out = out.to_bidirectional(keep_orig_edges=self.keep_orig_edges)\n        return out\n\n    # Edge-based sampling #####################################################\n\n    def sample_from_edges(\n        self,\n        inputs: EdgeSamplerInput,\n        neg_sampling: Optional[NegativeSampling] = None,\n    ) -> Union[SamplerOutput, HeteroSamplerOutput]:\n        out = edge_sample(inputs, self._sample, self.num_nodes, self.disjoint,\n                          self.node_time, neg_sampling)\n        if self.subgraph_type == SubgraphType.bidirectional:\n            out = out.to_bidirectional(keep_orig_edges=self.keep_orig_edges)\n        return out\n\n    # Other Utilities #########################################################\n\n    @property\n    def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:\n        return self.perm\n\n    # Helper functions ########################################################\n\n    def _sample(\n        self,\n        seed: Union[Tensor, Dict[NodeType, Tensor]],\n        seed_time: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None,\n        **kwargs,\n    ) -> Union[SamplerOutput, HeteroSamplerOutput]:\n        r\"\"\"Implements neighbor sampling by calling either :obj:`pyg-lib` (if\n        installed) or :obj:`torch-sparse` (if installed) sampling routines.\n        \"\"\"\n        if isinstance(seed, dict):  # Heterogeneous sampling:\n            # TODO Support induced subgraph sampling in `pyg-lib`.\n            if (torch_geometric.typing.WITH_PYG_LIB\n                    and self.subgraph_type != SubgraphType.induced):\n                # TODO (matthias) Ideally, `seed` inherits dtype from `colptr`\n                colptrs = list(self.colptr_dict.values())\n                dtype = colptrs[0].dtype if len(colptrs) > 0 else torch.int64\n                seed = {k: v.to(dtype) for k, v in seed.items()}\n\n                args = (\n                    self.node_types,\n                    self.edge_types,\n                    self.colptr_dict,\n                    self.row_dict,\n                    seed,\n                    self.num_neighbors.get_mapped_values(self.edge_types),\n                    self.node_time,\n                )\n                if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE:\n                    args += (self.edge_time, )\n                args += (seed_time, )\n                if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:\n                    args += (self.edge_weight, )\n                args += (\n                    True,  # csc\n                    self.replace,\n                    self.subgraph_type != SubgraphType.induced,\n                    self.disjoint,\n                    self.temporal_strategy,\n                    # TODO (matthias) `return_edge_id` if edge features present\n                    True,  # return_edge_id\n                )\n\n                out = torch.ops.pyg.hetero_neighbor_sample(*args)\n                row, col, node, edge, batch = out[:4] + (None, )\n\n                # `pyg-lib>0.1.0` returns sampled number of nodes/edges:\n                num_sampled_nodes = num_sampled_edges = None\n                if len(out) >= 6:\n                    num_sampled_nodes, num_sampled_edges = out[4:6]\n\n                if self.disjoint:\n                    node = {k: v.t().contiguous() for k, v in node.items()}\n                    batch = {k: v[0] for k, v in node.items()}\n                    node = {k: v[1] for k, v in node.items()}\n\n            elif torch_geometric.typing.WITH_TORCH_SPARSE:\n                if self.disjoint:\n                    if self.subgraph_type == SubgraphType.induced:\n                        raise ValueError(\"'disjoint' sampling not supported \"\n                                         \"for neighbor sampling with \"\n                                         \"`subgraph_type='induced'`\")\n                    else:\n                        raise ValueError(\"'disjoint' sampling not supported \"\n                                         \"for neighbor sampling via \"\n                                         \"'torch-sparse'. Please install \"\n                                         \"'pyg-lib' for improved and \"\n                                         \"optimized sampling routines.\")\n\n                out = torch.ops.torch_sparse.hetero_neighbor_sample(\n                    self.node_types,\n                    self.edge_types,\n                    self.colptr_dict,\n                    self.row_dict,\n                    seed,  # seed_dict\n                    self.num_neighbors.get_mapped_values(self.edge_types),\n                    self.num_neighbors.num_hops,\n                    self.replace,\n                    self.subgraph_type != SubgraphType.induced,\n                )\n                node, row, col, edge, batch = out + (None, )\n                num_sampled_nodes = num_sampled_edges = None\n\n            else:\n                raise ImportError(f\"'{self.__class__.__name__}' requires \"\n                                  f\"either 'pyg-lib' or 'torch-sparse'\")\n\n            if self.sample_direction == 'backward':\n                row, col = col, row\n\n            row = remap_keys(row, self.to_edge_type)\n            col = remap_keys(col, self.to_edge_type)\n            edge = remap_keys(edge, self.to_edge_type)\n\n            # In the case of backward sampling, we need to restore the edges\n            # keys to be forward facing in the HeteroSamplerOutput object.\n            if self.sample_direction == 'backward':\n                row = remap_keys(row, self.to_restored_edge_type)\n                col = remap_keys(col, self.to_restored_edge_type)\n                edge = remap_keys(edge, self.to_restored_edge_type)\n\n            if num_sampled_edges is not None:\n                num_sampled_edges = remap_keys(\n                    num_sampled_edges,\n                    self.to_edge_type,\n                )\n                if self.sample_direction == 'backward':\n                    num_sampled_edges = remap_keys(num_sampled_edges,\n                                                   self.to_restored_edge_type)\n\n            return HeteroSamplerOutput(\n                node=node,\n                row=row,\n                col=col,\n                edge=edge,\n                batch=batch,\n                num_sampled_nodes=num_sampled_nodes,\n                num_sampled_edges=num_sampled_edges,\n            )\n\n        else:  # Homogeneous sampling:\n            # TODO Support induced subgraph sampling in `pyg-lib`.\n            if (torch_geometric.typing.WITH_PYG_LIB\n                    and self.subgraph_type != SubgraphType.induced):\n\n                args = (\n                    self.colptr,\n                    self.row,\n                    # TODO (matthias) `seed` should inherit dtype from `colptr`\n                    seed.to(self.colptr.dtype),\n                    self.num_neighbors.get_mapped_values(),\n                    self.node_time,\n                )\n                if torch_geometric.typing.WITH_EDGE_TIME_NEIGHBOR_SAMPLE:\n                    args += (self.edge_time, )\n                args += (seed_time, )\n                if torch_geometric.typing.WITH_WEIGHTED_NEIGHBOR_SAMPLE:\n                    args += (self.edge_weight, )\n                args += (\n                    True,  # csc\n                    self.replace,\n                    self.subgraph_type != SubgraphType.induced,\n                    self.disjoint,\n                    self.temporal_strategy,\n                    # TODO (matthias) `return_edge_id` if edge features present\n                    True,  # return_edge_id\n                )\n\n                out = torch.ops.pyg.neighbor_sample(*args)\n                row, col, node, edge, batch = out[:4] + (None, )\n\n                # `pyg-lib>0.1.0` returns sampled number of nodes/edges:\n                num_sampled_nodes = num_sampled_edges = None\n                if len(out) >= 6:\n                    num_sampled_nodes, num_sampled_edges = out[4:6]\n\n                if self.disjoint:\n                    batch, node = node.t().contiguous()\n\n            elif torch_geometric.typing.WITH_TORCH_SPARSE:\n                if self.disjoint:\n                    raise ValueError(\"'disjoint' sampling not supported for \"\n                                     \"neighbor sampling via 'torch-sparse'. \"\n                                     \"Please install 'pyg-lib' for improved \"\n                                     \"and optimized sampling routines.\")\n\n                out = torch.ops.torch_sparse.neighbor_sample(\n                    self.colptr,\n                    self.row,\n                    seed,  # seed\n                    self.num_neighbors.get_mapped_values(),\n                    self.replace,\n                    self.subgraph_type != SubgraphType.induced,\n                )\n                node, row, col, edge, batch = out + (None, )\n                num_sampled_nodes = num_sampled_edges = None\n\n            else:\n                raise ImportError(f\"'{self.__class__.__name__}' requires \"\n                                  f\"either 'pyg-lib' or 'torch-sparse'\")\n\n            if self.sample_direction == 'backward':\n                row, col = col, row\n\n            return SamplerOutput(\n                node=node,\n                row=row,\n                col=col,\n                edge=edge,\n                batch=batch,\n                num_sampled_nodes=num_sampled_nodes,\n                num_sampled_edges=num_sampled_edges,\n            )\n\n\nclass BidirectionalNeighborSampler(NeighborSampler):\n    \"\"\"A sampler that allows for both upstream and downstream sampling.\"\"\"\n    def __init__(\n        self,\n        data: Union[Data, HeteroData, Tuple[FeatureStore, GraphStore]],\n        num_neighbors: NumNeighborsType,\n        subgraph_type: Union[SubgraphType, str] = 'directional',\n        replace: bool = False,\n        disjoint: bool = False,\n        temporal_strategy: str = 'uniform',\n        time_attr: Optional[str] = None,\n        weight_attr: Optional[str] = None,\n        is_sorted: bool = False,\n        share_memory: bool = False,\n        # Deprecated:\n        directed: bool = True,\n    ):\n\n        # TODO(zaristei)\n        if isinstance(num_neighbors, NumNeighbors) and isinstance(\n                num_neighbors.values, dict) or isinstance(num_neighbors, dict):\n            raise RuntimeError(\n                \"BidirectionalNeighborSampler does not yet support edge \"\n                \"delimited sampling.\")\n\n        self.forward_sampler = NeighborSampler(\n            data, num_neighbors, subgraph_type, replace, disjoint,\n            temporal_strategy, time_attr, weight_attr, is_sorted, share_memory,\n            sample_direction='forward', directed=directed)\n        self.backward_sampler = NeighborSampler(\n            data, num_neighbors, subgraph_type, replace, disjoint,\n            temporal_strategy, time_attr, weight_attr, is_sorted, share_memory,\n            sample_direction='backward', directed=directed)\n\n        # Trigger warnings on init if number of hops is greater than 1\n        self.num_neighbors = num_neighbors\n        self.subgraph_type = subgraph_type\n\n    @property\n    def num_neighbors(self) -> NumNeighbors:\n        return self._num_neighbors\n\n    @num_neighbors.setter\n    def num_neighbors(self, num_neighbors: NumNeighborsType):\n        if not isinstance(num_neighbors, NumNeighbors):\n            num_neighbors = NumNeighbors(num_neighbors)\n        if num_neighbors.num_hops > 1:\n            print(\"Warning: Number of hops is greater than 1, resulting in \"\n                  \"memory-expensive recursive calls.\")\n        self._num_neighbors = num_neighbors\n\n    @property\n    def is_hetero(self) -> bool:\n        return self.forward_sampler.is_hetero\n\n    @property\n    def is_temporal(self) -> bool:\n        return self.forward_sampler.is_temporal\n\n    @property\n    def disjoint(self) -> bool:\n        return self.forward_sampler.disjoint\n\n    @disjoint.setter\n    def disjoint(self, disjoint: bool):\n        self.forward_sampler.disjoint = disjoint\n        self.backward_sampler.disjoint = disjoint\n\n    def sample_from_nodes(\n        self,\n        inputs: NodeSamplerInput,\n    ) -> Union[SamplerOutput, HeteroSamplerOutput]:\n        return super().sample_from_nodes(inputs)\n\n    def sample_from_edges(\n        self,\n        inputs: EdgeSamplerInput,\n        neg_sampling: Optional[NegativeSampling] = None,\n    ) -> Union[SamplerOutput, HeteroSamplerOutput]:\n        # TODO(zaristei) Figure out what exactly regular and negative sampling\n        # imply for bidirectional sampling case\n        if neg_sampling is not None:\n            raise RuntimeError(\n                \"BidirectionalNeighborSampler does not yet support \"\n                \"negative sampling.\")\n        # Not thoroughly tested yet!\n        return super().sample_from_edges(inputs)\n\n    @property\n    def edge_permutation(self) -> Union[OptTensor, Dict[EdgeType, OptTensor]]:\n        return self.forward_sampler.edge_permutation\n\n    def _sample(\n        self,\n        seed: Union[Tensor, Dict[NodeType, Tensor]],\n        seed_time: Optional[Union[Tensor, Dict[NodeType, Tensor]]] = None,\n        **kwargs,\n    ) -> Union[SamplerOutput, HeteroSamplerOutput]:\n\n        if seed_time is not None:\n            raise NotImplementedError(\n                \"BidirectionalNeighborSampler does not yet support \"\n                \"temporal sampling.\")\n\n        if self.is_hetero:\n            raise NotImplementedError(\n                \"BidirectionalNeighborSampler does not yet support \"\n                \"heterogeneous sampling.\")\n        else:\n            current_seed = seed\n            current_seed_batch = None\n            current_seed_time = seed_time\n            seen_seed_set = {int(node) for node in current_seed}\n            if self.disjoint:\n                current_seed_batch = torch.arange(len(current_seed))\n                seen_seed_set = {\n                    (int(node), int(batch))\n                    for node, batch in zip(current_seed, current_seed_batch)\n                }\n\n            iter_results = []\n\n            for n_neighbors in self.num_neighbors.values:\n                current_n_neighbors = [n_neighbors]\n                self.forward_sampler.num_neighbors = current_n_neighbors\n                self.backward_sampler.num_neighbors = current_n_neighbors\n\n                fwd_result = self.forward_sampler._sample(\n                    current_seed, current_seed_time, **kwargs)\n                bwd_result = self.backward_sampler._sample(\n                    current_seed, current_seed_time, **kwargs)\n                # The seeds for the next iteration will be the new nodes in\n                # this iteration\n                iter_result = fwd_result.merge_with(bwd_result)\n                iter_results.append(iter_result)\n\n                # Find the nodes not yet seen to set a seed for next iteration\n                if self.disjoint:\n                    iter_seed_global_batch = global_to_local_node_idx(\n                        current_seed_batch, iter_result.batch)\n                    iter_result.seed_node = seed[iter_seed_global_batch]\n\n                    keep_mask = torch.tensor([\n                        (int(node), int(batch)) not in seen_seed_set\n                        for node, batch in zip(iter_result.node,\n                                               iter_seed_global_batch)\n                    ])\n                    next_seed = [(int(node), int(batch))\n                                 for node, batch in zip(\n                                     iter_result.node[keep_mask],\n                                     iter_seed_global_batch[keep_mask])\n                                 ] if keep_mask.any() else []\n                    current_seed, current_seed_batch = torch.tensor(\n                        next_seed).reshape(-1, 2).transpose(0, 1).contiguous()\n                else:\n                    keep_mask = torch.tensor([\n                        int(node) not in seen_seed_set\n                        for node in iter_result.node\n                    ])\n                    next_seed = [\n                        int(node) for node in iter_result.node[keep_mask]\n                    ] if keep_mask.any() else []\n                    current_seed = torch.tensor(next_seed)\n\n                seen_seed_set |= set(next_seed)\n\n                # TODO(zaristei) figure out how to update seed times for\n                # temporal sampling\n\n            return SamplerOutput.collate(iter_results)\n\n\n# Sampling Utilities ##########################################################\n\n\ndef node_sample(\n    inputs: NodeSamplerInput,\n    sample_fn: Callable,\n) -> Union[SamplerOutput, HeteroSamplerOutput]:\n    r\"\"\"Performs sampling from a :class:`NodeSamplerInput`, leveraging a\n    sampling function that accepts a seed and (optionally) a seed time as\n    input. Returns the output of this sampling procedure.\n    \"\"\"\n    if inputs.input_type is not None:  # Heterogeneous sampling:\n        seed = {inputs.input_type: inputs.node}\n        seed_time = None\n        if inputs.time is not None:\n            seed_time = {inputs.input_type: inputs.time}\n    else:  # Homogeneous sampling:\n        seed = inputs.node\n        seed_time = inputs.time\n\n    out = sample_fn(seed, seed_time)\n    out.metadata = (inputs.input_id, inputs.time)\n\n    return out\n\n\ndef edge_sample(\n    inputs: EdgeSamplerInput,\n    sample_fn: Callable,\n    num_nodes: Union[int, Dict[NodeType, int]],\n    disjoint: bool,\n    node_time: Optional[Union[Tensor, Dict[str, Tensor]]] = None,\n    neg_sampling: Optional[NegativeSampling] = None,\n) -> Union[SamplerOutput, HeteroSamplerOutput]:\n    r\"\"\"Performs sampling from an edge sampler input, leveraging a sampling\n    function of the same signature as `node_sample`.\n    \"\"\"\n    input_id = inputs.input_id\n    src = inputs.row\n    dst = inputs.col\n    edge_label = inputs.label\n    edge_label_time = inputs.time\n    input_type = inputs.input_type\n\n    src_time = dst_time = edge_label_time\n    assert edge_label_time is None or disjoint\n\n    assert isinstance(num_nodes, (dict, int))\n    if not isinstance(num_nodes, dict):\n        num_src_nodes = num_dst_nodes = num_nodes\n    else:\n        num_src_nodes = num_nodes[input_type[0]]\n        num_dst_nodes = num_nodes[input_type[-1]]\n\n    num_pos = src.numel()\n    num_neg = 0\n\n    # Negative Sampling #######################################################\n\n    if neg_sampling is not None:\n        # When we are doing negative sampling, we append negative information\n        # of nodes/edges to `src`, `dst`, `src_time`, `dst_time`.\n        # Later on, we can easily reconstruct what belongs to positive and\n        # negative examples by slicing via `num_pos`.\n        num_neg = math.ceil(num_pos * neg_sampling.amount)\n\n        if neg_sampling.is_binary():\n            # In the \"binary\" case, we randomly sample negative pairs of nodes.\n            if isinstance(node_time, dict):\n                src_node_time = node_time.get(input_type[0])\n            else:\n                src_node_time = node_time\n\n            src_neg = neg_sample(src, neg_sampling, num_src_nodes, src_time,\n                                 src_node_time, endpoint='src')\n            src = torch.cat([src, src_neg], dim=0)\n\n            if isinstance(node_time, dict):\n                dst_node_time = node_time.get(input_type[-1])\n            else:\n                dst_node_time = node_time\n\n            dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time,\n                                 dst_node_time, endpoint='dst')\n            dst = torch.cat([dst, dst_neg], dim=0)\n\n            if edge_label is None:\n                edge_label = torch.ones(num_pos)\n            size = (num_neg, ) + edge_label.size()[1:]\n            edge_neg_label = edge_label.new_zeros(size)\n            edge_label = torch.cat([edge_label, edge_neg_label])\n\n            if edge_label_time is not None:\n                src_time = dst_time = edge_label_time.repeat(\n                    1 + math.ceil(neg_sampling.amount))[:num_pos + num_neg]\n\n        elif neg_sampling.is_triplet():\n            # In the \"triplet\" case, we randomly sample negative destinations.\n            if isinstance(node_time, dict):\n                dst_node_time = node_time.get(input_type[-1])\n            else:\n                dst_node_time = node_time\n\n            dst_neg = neg_sample(dst, neg_sampling, num_dst_nodes, dst_time,\n                                 dst_node_time, endpoint='dst')\n            dst = torch.cat([dst, dst_neg], dim=0)\n\n            assert edge_label is None\n\n            if edge_label_time is not None:\n                dst_time = edge_label_time.repeat(1 + neg_sampling.amount)\n\n    # Heterogeneous Neighborhood Sampling #####################################\n\n    if input_type is not None:\n        seed_time_dict = None\n        if input_type[0] != input_type[-1]:  # Two distinct node types:\n\n            if not disjoint:\n                src, inverse_src = src.unique(return_inverse=True)\n                dst, inverse_dst = dst.unique(return_inverse=True)\n\n            seed_dict = {input_type[0]: src, input_type[-1]: dst}\n\n            if edge_label_time is not None:  # Always disjoint.\n                seed_time_dict = {\n                    input_type[0]: src_time,\n                    input_type[-1]: dst_time,\n                }\n\n        else:  # Only a single node type: Merge both source and destination.\n\n            seed = torch.cat([src, dst], dim=0)\n\n            if not disjoint:\n                seed, inverse_seed = seed.unique(return_inverse=True)\n\n            seed_dict = {input_type[0]: seed}\n\n            if edge_label_time is not None:  # Always disjoint.\n                seed_time_dict = {\n                    input_type[0]: torch.cat([src_time, dst_time], dim=0),\n                }\n\n        out = sample_fn(seed_dict, seed_time_dict)\n\n        # Enhance `out` by label information ##################################\n        if disjoint:\n            for key, batch in out.batch.items():\n                out.batch[key] = batch % num_pos\n\n        if neg_sampling is None or neg_sampling.is_binary():\n            if disjoint:\n                if input_type[0] != input_type[-1]:\n                    edge_label_index = torch.arange(num_pos + num_neg)\n                    edge_label_index = edge_label_index.repeat(2).view(2, -1)\n                else:\n                    edge_label_index = torch.arange(2 * (num_pos + num_neg))\n                    edge_label_index = edge_label_index.view(2, -1)\n            else:\n                if input_type[0] != input_type[-1]:\n                    edge_label_index = torch.stack([\n                        inverse_src,\n                        inverse_dst,\n                    ], dim=0)\n                else:\n                    edge_label_index = inverse_seed.view(2, -1)\n\n            out.metadata = (input_id, edge_label_index, edge_label, src_time)\n\n        elif neg_sampling.is_triplet():\n            if disjoint:\n                src_index = torch.arange(num_pos)\n                if input_type[0] != input_type[-1]:\n                    dst_pos_index = torch.arange(num_pos)\n                    # `dst_neg_index` needs to be offset such that indices with\n                    # offset `num_pos` belong to the same triplet:\n                    dst_neg_index = torch.arange(\n                        num_pos, seed_dict[input_type[-1]].numel())\n                    dst_neg_index = dst_neg_index.view(-1, num_pos).t()\n                else:\n                    dst_pos_index = torch.arange(num_pos, 2 * num_pos)\n                    dst_neg_index = torch.arange(\n                        2 * num_pos, seed_dict[input_type[-1]].numel())\n                    dst_neg_index = dst_neg_index.view(-1, num_pos).t()\n            else:\n                if input_type[0] != input_type[-1]:\n                    src_index = inverse_src\n                    dst_pos_index = inverse_dst[:num_pos]\n                    dst_neg_index = inverse_dst[num_pos:]\n                else:\n                    src_index = inverse_seed[:num_pos]\n                    dst_pos_index = inverse_seed[num_pos:2 * num_pos]\n                    dst_neg_index = inverse_seed[2 * num_pos:]\n\n            dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)\n\n            out.metadata = (\n                input_id,\n                src_index,\n                dst_pos_index,\n                dst_neg_index,\n                src_time,\n            )\n\n    # Homogeneous Neighborhood Sampling #######################################\n\n    else:\n\n        seed = torch.cat([src, dst], dim=0)\n        seed_time = None\n\n        if not disjoint:\n            seed, inverse_seed = seed.unique(return_inverse=True)\n\n        if edge_label_time is not None:  # Always disjoint.\n            seed_time = torch.cat([src_time, dst_time])\n\n        out = sample_fn(seed, seed_time)\n\n        # Enhance `out` by label information ##################################\n        if neg_sampling is None or neg_sampling.is_binary():\n            if disjoint:\n                out.batch = out.batch % num_pos\n                edge_label_index = torch.arange(seed.numel()).view(2, -1)\n            else:\n                edge_label_index = inverse_seed.view(2, -1)\n\n            out.metadata = (input_id, edge_label_index, edge_label, src_time)\n\n        elif neg_sampling.is_triplet():\n            if disjoint:\n                out.batch = out.batch % num_pos\n                src_index = torch.arange(num_pos)\n                dst_pos_index = torch.arange(num_pos, 2 * num_pos)\n                # `dst_neg_index` needs to be offset such that indices with\n                # offset `num_pos` belong to the same triplet:\n                dst_neg_index = torch.arange(2 * num_pos, seed.numel())\n                dst_neg_index = dst_neg_index.view(-1, num_pos).t()\n            else:\n                src_index = inverse_seed[:num_pos]\n                dst_pos_index = inverse_seed[num_pos:2 * num_pos]\n                dst_neg_index = inverse_seed[2 * num_pos:]\n            dst_neg_index = dst_neg_index.view(num_pos, -1).squeeze(-1)\n\n            out.metadata = (\n                input_id,\n                src_index,\n                dst_pos_index,\n                dst_neg_index,\n                src_time,\n            )\n\n    return out\n\n\ndef neg_sample(\n    seed: Tensor,\n    neg_sampling: NegativeSampling,\n    num_nodes: int,\n    seed_time: Optional[Tensor],\n    node_time: Optional[Tensor],\n    endpoint: Literal['str', 'dst'],\n) -> Tensor:\n    num_neg = math.ceil(seed.numel() * neg_sampling.amount)\n\n    # TODO: Do not sample false negatives.\n    if node_time is None:\n        return neg_sampling.sample(num_neg, endpoint, num_nodes)\n\n    # If we are in a temporal-sampling scenario, we need to respect the\n    # timestamp of the given nodes we can use as negative examples.\n    # That is, we can only sample nodes for which `node_time <= seed_time`.\n    # For now, we use a greedy algorithm which randomly samples negative\n    # nodes and discard any which do not respect the temporal constraint.\n    # We iteratively repeat this process until we have sampled a valid node for\n    # each seed.\n    # TODO See if this greedy algorithm here can be improved.\n    assert seed_time is not None\n    num_samples = math.ceil(neg_sampling.amount)\n    seed_time = seed_time.view(1, -1).expand(num_samples, -1)\n\n    out = neg_sampling.sample(num_samples * seed.numel(), endpoint, num_nodes)\n    out = out.view(num_samples, seed.numel())\n    mask = node_time[out] > seed_time  # holds all invalid samples.\n    neg_sampling_complete = False\n    for _ in range(5):  # pragma: no cover\n        num_invalid = int(mask.sum())\n        if num_invalid == 0:\n            neg_sampling_complete = True\n            break\n\n        # Greedily search for alternative negatives.\n        out[mask] = tmp = neg_sampling.sample(num_invalid, endpoint, num_nodes)\n        mask[mask.clone()] = node_time[tmp] >= seed_time[mask]\n\n    if not neg_sampling_complete:  # pragma: no cover\n        # Not much options left. In that case, we set remaining negatives\n        # to the node with minimum timestamp.\n        out[mask] = node_time.argmin()\n\n    return out.view(-1)[:num_neg]\n"
  },
  {
    "path": "torch_geometric/sampler/utils.py",
    "content": "from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.storage import EdgeStorage\nfrom torch_geometric.index import index2ptr\nfrom torch_geometric.typing import EdgeType, NodeType, OptTensor\nfrom torch_geometric.utils import coalesce, index_sort, lexsort\n\n\ndef reverse_edge_type(edge_type: EdgeType) -> EdgeType:\n    \"\"\"Reverses edge types for heterogeneous graphs. Useful in cases of\n    backward sampling.\n    \"\"\"\n    return (edge_type[2], edge_type[1],\n            edge_type[0]) if edge_type is not None else None\n\n\n# Edge Layout Conversion ######################################################\n\n\ndef sort_csc(\n    row: Tensor,\n    col: Tensor,\n    src_node_time: OptTensor = None,\n    edge_time: OptTensor = None,\n) -> Tuple[Tensor, Tensor, Tensor]:\n\n    if src_node_time is None and edge_time is None:\n        col, perm = index_sort(col)\n        return row[perm], col, perm\n\n    elif edge_time is not None:\n        assert src_node_time is None\n        perm = lexsort([edge_time, col])\n        return row[perm], col[perm], perm\n\n    else:  # src_node_time is not None\n        perm = lexsort([src_node_time[row], col])\n        return row[perm], col[perm], perm\n\n\n# TODO(manan) deprecate when FeatureStore / GraphStore unification is complete\ndef to_csc(\n    data: Union[Data, EdgeStorage],\n    device: Optional[torch.device] = None,\n    share_memory: bool = False,\n    is_sorted: bool = False,\n    src_node_time: Optional[Tensor] = None,\n    edge_time: Optional[Tensor] = None,\n    to_transpose: bool = False,\n) -> Tuple[Tensor, Tensor, OptTensor]:\n    # Convert the graph data into a suitable format for sampling (CSC format).\n    # Returns the `colptr` and `row` indices of the graph, as well as an\n    # `perm` vector that denotes the permutation of edges.\n    # Since no permutation of edges is applied when using `SparseTensor`,\n    # `perm` can be of type `None`.\n    perm: Optional[Tensor] = None\n\n    if hasattr(data, 'adj'):\n        if src_node_time is not None:\n            raise NotImplementedError(\"Temporal sampling via 'SparseTensor' \"\n                                      \"format not yet supported\")\n        if to_transpose:\n            row, colptr, _ = data.adj.csr()\n        else:\n            colptr, row, _ = data.adj.csc()\n\n    elif hasattr(data, 'adj_t'):\n        if src_node_time is not None:\n            # TODO (matthias) This only works when instantiating a\n            # `SparseTensor` with `is_sorted=True`. Otherwise, the\n            # `SparseTensor` will by default re-sort the neighbors according to\n            # column index.\n            # As such, we probably want to consider re-adding error:\n            # raise NotImplementedError(\"Temporal sampling via 'SparseTensor' \"\n            #                           \"format not yet supported\")\n            pass\n        if to_transpose:\n            row, colptr, _ = data.adj_t.csc()\n        else:\n            colptr, row, _ = data.adj_t.csr()\n\n    elif data.edge_index is not None:\n        if to_transpose:\n            col, row = data.edge_index\n        else:\n            row, col = data.edge_index\n\n        if not is_sorted:\n            row, col, perm = sort_csc(row, col, src_node_time, edge_time)\n        colptr = index2ptr(col,\n                           data.size(1) if not to_transpose else data.size(0))\n    else:\n        row = torch.empty(0, dtype=torch.long, device=device)\n        colptr = torch.zeros(data.num_nodes + 1, dtype=torch.long,\n                             device=device)\n\n    colptr = colptr.to(device)\n    row = row.to(device)\n    perm = perm.to(device) if perm is not None else None\n\n    if not colptr.is_cuda and share_memory:\n        colptr.share_memory_()\n        row.share_memory_()\n        if perm is not None:\n            perm.share_memory_()\n\n    return colptr, row, perm\n\n\ndef to_hetero_csc(\n    data: HeteroData,\n    device: Optional[torch.device] = None,\n    share_memory: bool = False,\n    is_sorted: bool = False,\n    node_time_dict: Optional[Dict[NodeType, Tensor]] = None,\n    edge_time_dict: Optional[Dict[EdgeType, Tensor]] = None,\n    to_transpose: bool = False,\n) -> Tuple[Dict[str, Tensor], Dict[str, Tensor], Dict[str, OptTensor]]:\n    # Convert the heterogeneous graph data into a suitable format for sampling\n    # (CSC format).\n    # Returns dictionaries holding `colptr` and `row` indices as well as edge\n    # permutations for each edge type, respectively.\n    colptr_dict, row_dict, perm_dict = {}, {}, {}\n\n    for edge_type, store in data.edge_items():\n        src_node_time = (node_time_dict or {}).get(edge_type[0], None)\n        edge_time = (edge_time_dict or {}).get(edge_type, None)\n        out = to_csc(store, device, share_memory, is_sorted, src_node_time,\n                     edge_time, to_transpose)\n        # Edge types need to be reversed for backward sampling:\n        if to_transpose:\n            edge_type = reverse_edge_type(edge_type)\n\n        colptr_dict[edge_type], row_dict[edge_type], perm_dict[edge_type] = out\n\n    return colptr_dict, row_dict, perm_dict\n\n\ndef to_bidirectional(\n    row: Tensor,\n    col: Tensor,\n    rev_row: Tensor,\n    rev_col: Tensor,\n    edge_id: OptTensor = None,\n    rev_edge_id: OptTensor = None,\n) -> Tuple[Tensor, Tensor, OptTensor]:\n\n    assert row.numel() == col.numel()\n    assert rev_row.numel() == rev_col.numel()\n\n    edge_index = row.new_empty(2, row.numel() + rev_row.numel())\n    edge_index[0, :row.numel()] = row\n    edge_index[1, :row.numel()] = col\n    edge_index[0, row.numel():] = rev_col\n    edge_index[1, row.numel():] = rev_row\n\n    if edge_id is not None:\n        edge_id = torch.cat([edge_id, rev_edge_id], dim=0)\n\n    (row, col), edge_id = coalesce(\n        edge_index,\n        edge_id,\n        sort_by_row=False,\n        reduce='any',\n    )\n\n    return row, col, edge_id\n\n\n###############################################################################\n\nX, Y = TypeVar('X'), TypeVar('Y')\n\n\ndef remap_keys(\n    inputs: Dict[X, Any],\n    mapping: Dict[X, Y],\n    exclude: Optional[List[X]] = None,\n) -> Dict[Union[X, Y], Any]:\n    exclude = exclude or []\n    return {\n        k if k in exclude else mapping.get(k, k): v\n        for k, v in inputs.items()\n    }\n\n\ndef local_to_global_node_idx(node_values: Tensor,\n                             local_indices: Tensor) -> Tensor:\n    \"\"\"Convert a tensor of indices referring to elements in the node_values\n    tensor to their values.\n\n    Args:\n        node_values (Tensor): The node values. (num_nodes, feature_dim)\n        local_indices (Tensor): The local indices. (num_indices)\n\n    Returns:\n        Tensor: The values of the node_values tensor at the local indices.\n        (num_indices, feature_dim)\n    \"\"\"\n    return torch.index_select(node_values, dim=0, index=local_indices)\n\n\ndef global_to_local_node_idx(node_values: Tensor,\n                             local_values: Tensor) -> Tensor:\n    \"\"\"Converts a tensor of values that are contained in the node_values\n    tensor to their indices in that tensor.\n\n    Args:\n        node_values (Tensor): The node values. (num_nodes, feature_dim)\n        local_values (Tensor): The local values. (num_indices, feature_dim)\n\n    Returns:\n        Tensor: The indices of the local values in the node_values tensor.\n        (num_indices)\n    \"\"\"\n    if node_values.dim() == 1:\n        node_values = node_values.unsqueeze(1)\n    if local_values.dim() == 1:\n        local_values = local_values.unsqueeze(1)\n    node_values_expand = node_values.unsqueeze(-1).expand(\n        *node_values.shape,\n        local_values.shape[0])  # (num_nodes, feature_dim, num_indices)\n    local_values_expand = local_values.transpose(0, 1).unsqueeze(0).expand(\n        *node_values_expand.shape)  # (num_nodes, feature_dim, num_indices)\n    idx_match = torch.all(node_values_expand == local_values_expand,\n                          dim=1).nonzero()  # (num_indices, 2)\n    sort_idx = torch.argsort(idx_match[:, 1])\n\n    return idx_match[:, 0][sort_idx]\n\n\ndef unique_unsorted(tensor: Tensor) -> Tensor:\n    \"\"\"Returns the unique elements of a tensor while preserving the original\n    order.\n\n    Necessary because torch.unique() ignores sort parameter.\n    \"\"\"\n    seen = set()\n    output = []\n    for val in tensor:\n        val = tuple(val.tolist())\n        if val not in seen:\n            seen.add(val)\n            output.append(val)\n    return torch.tensor(output, dtype=tensor.dtype,\n                        device=tensor.device).reshape((-1, *tensor.shape[1:]))\n"
  },
  {
    "path": "torch_geometric/seed.py",
    "content": "import random\n\nimport numpy as np\nimport torch\n\n\ndef seed_everything(seed: int) -> None:\n    r\"\"\"Sets the seed for generating random numbers in :pytorch:`PyTorch`,\n    :obj:`numpy` and :python:`Python`.\n\n    Args:\n        seed (int): The desired seed.\n    \"\"\"\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n"
  },
  {
    "path": "torch_geometric/template.py",
    "content": "import importlib\nimport os.path as osp\nimport sys\nimport tempfile\nfrom typing import Any\n\nfrom jinja2 import Environment, FileSystemLoader\n\n\ndef module_from_template(\n    module_name: str,\n    template_path: str,\n    tmp_dirname: str,\n    **kwargs: Any,\n) -> Any:\n\n    if module_name in sys.modules:  # If module is already loaded, return it:\n        return sys.modules[module_name]\n\n    env = Environment(loader=FileSystemLoader(osp.dirname(template_path)))\n    template = env.get_template(osp.basename(template_path))\n    module_repr = template.render(**kwargs)\n\n    with tempfile.NamedTemporaryFile(\n            mode='w',\n            prefix=f'{module_name}_',\n            suffix='.py',\n            delete=False,\n    ) as tmp:\n        tmp.write(module_repr)\n        tmp.flush()\n\n    spec = importlib.util.spec_from_file_location(module_name, tmp.name)\n    assert spec is not None\n    module = importlib.util.module_from_spec(spec)\n    sys.modules[module_name] = module\n    assert spec.loader is not None\n    spec.loader.exec_module(module)\n    return module\n"
  },
  {
    "path": "torch_geometric/testing/__init__.py",
    "content": "r\"\"\"Testing package.\n\nThis package provides helper methods and decorators to ease testing.\n\"\"\"\n\nfrom .decorators import (\n    is_full_test,\n    onlyFullTest,\n    is_distributed_test,\n    onlyDistributedTest,\n    onlyLinux,\n    noWindows,\n    noMac,\n    minPython,\n    onlyCUDA,\n    onlyXPU,\n    onlyOnline,\n    onlyGraphviz,\n    onlyNeighborSampler,\n    has_package,\n    withPackage,\n    withDevice,\n    withCUDA,\n    withMETIS,\n    withHashTensor,\n    disableExtensions,\n    withoutExtensions,\n)\nfrom .asserts import assert_module\nfrom .feature_store import MyFeatureStore\nfrom .graph_store import MyGraphStore\nfrom .data import (\n    get_random_edge_index,\n    get_random_tensor_frame,\n    FakeHeteroDataset,\n)\n\n__all__ = [\n    'is_full_test',\n    'onlyFullTest',\n    'is_distributed_test',\n    'onlyDistributedTest',\n    'onlyLinux',\n    'noWindows',\n    'noMac',\n    'minPython',\n    'onlyCUDA',\n    'onlyXPU',\n    'onlyOnline',\n    'onlyGraphviz',\n    'onlyNeighborSampler',\n    'has_package',\n    'withPackage',\n    'withDevice',\n    'withCUDA',\n    'withMETIS',\n    'withHashTensor',\n    'disableExtensions',\n    'withoutExtensions',\n    'assert_module',\n    'MyFeatureStore',\n    'MyGraphStore',\n    'get_random_edge_index',\n    'get_random_tensor_frame',\n    'FakeHeteroDataset',\n]\n"
  },
  {
    "path": "torch_geometric/testing/asserts.py",
    "content": "import copy\nimport warnings\nfrom typing import Any, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import WITH_TORCH_SPARSE, SparseTensor\nfrom torch_geometric.utils import to_torch_coo_tensor, to_torch_csc_tensor\n\nSPARSE_LAYOUTS: List[Union[str, torch.layout]] = [\n    'torch_sparse', torch.sparse_csc, torch.sparse_coo\n]\n\n\ndef assert_module(\n    module: torch.nn.Module,\n    x: Any,\n    edge_index: Tensor,\n    *,\n    expected_size: Tuple[int, ...],\n    test_edge_permutation: bool = True,\n    test_node_permutation: bool = False,\n    test_sparse_layouts: Optional[List[Union[str, torch.layout]]] = None,\n    sparse_size: Optional[Tuple[int, int]] = None,\n    atol: float = 1e-08,\n    rtol: float = 1e-05,\n    equal_nan: bool = False,\n    **kwargs: Any,\n) -> Any:\n    r\"\"\"Asserts that the output of a :obj:`module` is correct.\n\n    Specifically, this method tests that:\n\n    1. The module output has the correct shape.\n    2. The module is invariant to the permutation of edges.\n    3. The module is invariant to the permutation of nodes.\n    4. The module is invariant to the layout of :obj:`edge_index`.\n\n    Args:\n        module (torch.nn.Module): The module to test.\n        x (Any): The input features to the module.\n        edge_index (torch.Tensor): The input edge indices.\n        expected_size (Tuple[int, ...]): The expected output size.\n        test_edge_permutation (bool, optional): If set to :obj:`False`, will\n            not test the module for edge permutation invariance.\n        test_node_permutation (bool, optional): If set to :obj:`False`, will\n            not test the module for node permutation invariance.\n        test_sparse_layouts (List[str or int], optional): The sparse layouts to\n            test for module invariance. (default: :obj:`[\"torch_sparse\",\n            torch.sparse_csc, torch.sparse_coo]`)\n        sparse_size (Tuple[int, int], optional): The size of the sparse\n            adjacency matrix. If not given, will try to automatically infer it.\n            (default: :obj:`None`)\n        atol (float, optional): Absolute tolerance. (default: :obj:`1e-08`)\n        rtol (float, optional): Relative tolerance. (default: :obj:`1e-05`)\n        equal_nan (bool, optional): If set to :obj:`True`, then two :obj:`NaN`s\n            will be considered equal. (default: :obj:`False`)\n        **kwargs (optional): Additional arguments passed to\n            :meth:`module.forward`.\n    \"\"\"\n    if test_sparse_layouts is None:\n        test_sparse_layouts = SPARSE_LAYOUTS\n\n    if sparse_size is None:\n        if 'size' in kwargs:\n            sparse_size = kwargs['size']\n        elif isinstance(x, Tensor):\n            sparse_size = (x.size(0), x.size(0))\n        elif (isinstance(x, (tuple, list)) and isinstance(x[0], Tensor)\n              and isinstance(x[1], Tensor)):\n            sparse_size = (x[0].size(0), x[1].size(0))\n\n    if len(test_sparse_layouts) > 0 and sparse_size is None:\n        raise ValueError(f\"Got sparse layouts {test_sparse_layouts}, but no \"\n                         f\"'sparse_size' were specified\")\n\n    expected = module(x, edge_index=edge_index, **kwargs)\n    assert expected.size() == expected_size\n\n    if test_edge_permutation:\n        perm = torch.randperm(edge_index.size(1))\n        perm_kwargs = copy.copy(kwargs)\n        for key, value in kwargs.items():\n            if isinstance(value, Tensor) and value.size(0) == perm.numel():\n                perm_kwargs[key] = value[perm]\n        out = module(x, edge_index[:, perm], **perm_kwargs)\n        assert torch.allclose(out, expected, rtol, atol, equal_nan)\n\n    if test_node_permutation:\n        raise NotImplementedError\n\n    for layout in (test_sparse_layouts or []):\n        # TODO Add support for values.\n        if layout == 'torch_sparse':\n            if not WITH_TORCH_SPARSE:\n                continue\n\n            adj = SparseTensor.from_edge_index(\n                edge_index,\n                sparse_sizes=sparse_size,\n            )\n            adj_t = adj.t()\n\n        elif layout == torch.sparse_csc:\n            adj = to_torch_csc_tensor(edge_index, size=sparse_size)\n            adj_t = adj.t()\n\n        elif layout == torch.sparse_coo:\n            warnings.filterwarnings('ignore', \".*to CSR format.*\")\n            adj = to_torch_coo_tensor(edge_index, size=sparse_size)\n            adj_t = adj.t().coalesce()\n\n        else:\n            raise ValueError(f\"Got invalid sparse layout '{layout}'\")\n\n        out = module(x, adj_t, **kwargs)\n        assert torch.allclose(out, expected, rtol, atol, equal_nan)\n\n    return expected\n"
  },
  {
    "path": "torch_geometric/testing/data.py",
    "content": "from typing import Callable, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import HeteroData, InMemoryDataset\nfrom torch_geometric.typing import TensorFrame, torch_frame\nfrom torch_geometric.utils import coalesce as coalesce_fn\n\n\ndef get_random_edge_index(\n    num_src_nodes: int,\n    num_dst_nodes: int,\n    num_edges: int,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n    coalesce: bool = False,\n) -> Tensor:\n    row = torch.randint(num_src_nodes, (num_edges, ), dtype=dtype,\n                        device=device)\n    col = torch.randint(num_dst_nodes, (num_edges, ), dtype=dtype,\n                        device=device)\n    edge_index = torch.stack([row, col], dim=0)\n\n    if coalesce:\n        edge_index = coalesce_fn(edge_index)\n\n    return edge_index\n\n\ndef get_random_tensor_frame(\n    num_rows: int,\n    device: Optional[torch.device] = None,\n) -> TensorFrame:\n\n    feat_dict = {\n        torch_frame.categorical:\n        torch.randint(0, 3, size=(num_rows, 3), device=device),\n        torch_frame.numerical:\n        torch.randn(size=(num_rows, 2), device=device),\n    }\n    col_names_dict = {\n        torch_frame.categorical: ['a', 'b', 'c'],\n        torch_frame.numerical: ['x', 'y'],\n    }\n    y = torch.randn(num_rows, device=device)\n\n    return torch_frame.TensorFrame(\n        feat_dict=feat_dict,\n        col_names_dict=col_names_dict,\n        y=y,\n    )\n\n\nclass FakeHeteroDataset(InMemoryDataset):\n    def __init__(self, transform: Optional[Callable] = None):\n        super().__init__(transform=transform)\n\n        data = HeteroData()\n\n        num_papers = 100\n        num_authors = 10\n\n        data['paper'].x = torch.randn(num_papers, 16)\n        data['author'].x = torch.randn(num_authors, 8)\n\n        edge_index = get_random_edge_index(\n            num_src_nodes=num_papers,\n            num_dst_nodes=num_authors,\n            num_edges=300,\n        )\n        data['paper', 'author'].edge_index = edge_index\n        data['author', 'paper'].edge_index = edge_index.flip([0])\n\n        data['paper'].y = torch.randint(0, 4, (num_papers, ))\n\n        perm = torch.randperm(num_papers)\n        data['paper'].train_mask = torch.zeros(num_papers, dtype=torch.bool)\n        data['paper'].train_mask[perm[0:60]] = True\n        data['paper'].val_mask = torch.zeros(num_papers, dtype=torch.bool)\n        data['paper'].val_mask[perm[60:80]] = True\n        data['paper'].test_mask = torch.zeros(num_papers, dtype=torch.bool)\n        data['paper'].test_mask[perm[80:100]] = True\n\n        self.data, self.slices = self.collate([data])\n"
  },
  {
    "path": "torch_geometric/testing/decorators.py",
    "content": "import os\nimport sys\nimport warnings\nfrom importlib import import_module\nfrom importlib.util import find_spec\nfrom typing import Callable\n\nimport torch\nfrom packaging.requirements import Requirement\nfrom packaging.version import Version\n\nimport torch_geometric\nimport torch_geometric.typing\nfrom torch_geometric.typing import WITH_METIS, WITH_PYG_LIB, WITH_TORCH_SPARSE\nfrom torch_geometric.visualization.graph import has_graphviz\n\n\ndef is_full_test() -> bool:\n    r\"\"\"Whether to run the full but time-consuming test suite.\"\"\"\n    return os.getenv('FULL_TEST', '0') == '1'\n\n\ndef onlyFullTest(func: Callable) -> Callable:\n    r\"\"\"A decorator to specify that this function belongs to the full test\n    suite.\n    \"\"\"\n    import pytest\n    return pytest.mark.skipif(\n        not is_full_test(),\n        reason=\"Fast test run\",\n    )(func)\n\n\ndef is_distributed_test() -> bool:\n    r\"\"\"Whether to run the distributed test suite.\"\"\"\n    return (os.getenv('DIST_TEST', '0') == '1' and sys.platform == 'linux'\n            and has_package('pyg_lib'))\n\n\ndef onlyDistributedTest(func: Callable) -> Callable:\n    r\"\"\"A decorator to specify that this function belongs to the distributed\n    test suite.\n    \"\"\"\n    import pytest\n    return pytest.mark.skipif(\n        not is_distributed_test(),\n        reason=\"Fast test run\",\n    )(func)\n\n\ndef onlyLinux(func: Callable) -> Callable:\n    r\"\"\"A decorator to specify that this function should only execute on\n    Linux systems.\n    \"\"\"\n    import pytest\n    return pytest.mark.skipif(\n        sys.platform != 'linux',\n        reason=\"No Linux system\",\n    )(func)\n\n\ndef noWindows(func: Callable) -> Callable:\n    r\"\"\"A decorator to specify that this function should not execute on\n    Windows systems.\n    \"\"\"\n    import pytest\n    return pytest.mark.skipif(\n        os.name == 'nt',\n        reason=\"Windows system\",\n    )(func)\n\n\ndef noMac(func: Callable) -> Callable:\n    r\"\"\"A decorator to specify that this function should not execute on\n    macOS systems.\n    \"\"\"\n    import pytest\n    return pytest.mark.skipif(\n        sys.platform == 'darwin',\n        reason=\"macOS system\",\n    )(func)\n\n\ndef minPython(version: str) -> Callable:\n    r\"\"\"A decorator to run tests on specific :python:`Python` versions only.\"\"\"\n    def decorator(func: Callable) -> Callable:\n        import pytest\n\n        major, minor = version.split('.')\n\n        skip = False\n        if sys.version_info.major < int(major):\n            skip = True\n        if (sys.version_info.major == int(major)\n                and sys.version_info.minor < int(minor)):\n            skip = True\n\n        return pytest.mark.skipif(\n            skip,\n            reason=f\"Python {version} required\",\n        )(func)\n\n    return decorator\n\n\ndef onlyCUDA(func: Callable) -> Callable:\n    r\"\"\"A decorator to skip tests if CUDA is not found.\"\"\"\n    import pytest\n    return pytest.mark.skipif(\n        not torch.cuda.is_available(),\n        reason=\"CUDA not available\",\n    )(func)\n\n\ndef onlyXPU(func: Callable) -> Callable:\n    r\"\"\"A decorator to skip tests if XPU is not found.\"\"\"\n    import pytest\n    return pytest.mark.skipif(\n        not torch_geometric.is_xpu_available(),\n        reason=\"XPU not available\",\n    )(func)\n\n\ndef onlyOnline(func: Callable) -> Callable:\n    r\"\"\"A decorator to skip tests if there exists no connection to the\n    internet.\n    \"\"\"\n    import http.client as httplib\n\n    import pytest\n\n    has_connection = True\n    connection = httplib.HTTPSConnection('8.8.8.8', timeout=5)\n    try:\n        connection.request('HEAD', '/')\n    except Exception:\n        has_connection = False\n    finally:\n        connection.close()\n\n    return pytest.mark.skipif(\n        not has_connection,\n        reason=\"No internet connection\",\n    )(func)\n\n\ndef onlyGraphviz(func: Callable) -> Callable:\n    r\"\"\"A decorator to specify that this function should only execute in case\n    :obj:`graphviz` is installed.\n    \"\"\"\n    import pytest\n    return pytest.mark.skipif(\n        not has_graphviz(),\n        reason=\"Graphviz not installed\",\n    )(func)\n\n\ndef onlyNeighborSampler(func: Callable) -> Callable:\n    r\"\"\"A decorator to skip tests if no neighborhood sampler package is\n    installed.\n    \"\"\"\n    import pytest\n    return pytest.mark.skipif(\n        not WITH_PYG_LIB and not WITH_TORCH_SPARSE,\n        reason=\"No neighbor sampler installed\",\n    )(func)\n\n\ndef has_package(package: str) -> bool:\n    r\"\"\"Returns :obj:`True` in case :obj:`package` is installed.\"\"\"\n    if '|' in package:\n        return any(has_package(p) for p in package.split('|'))\n\n    req = Requirement(package)\n    if find_spec(req.name) is None:\n        return False\n\n    try:\n        module = import_module(req.name)\n        if not hasattr(module, '__version__'):\n            return True\n\n        version = Version(module.__version__).base_version\n        return version in req.specifier\n    except Exception:\n        return False\n\n\ndef withPackage(*args: str) -> Callable:\n    r\"\"\"A decorator to skip tests if certain packages are not installed.\n    Also supports version specification.\n    \"\"\"\n    na_packages = {package for package in args if not has_package(package)}\n\n    if len(na_packages) == 1:\n        reason = f\"Package {list(na_packages)[0]} not found\"\n    else:\n        reason = f\"Packages {na_packages} not found\"\n\n    def decorator(func: Callable) -> Callable:\n        import pytest\n        return pytest.mark.skipif(len(na_packages) > 0, reason=reason)(func)\n\n    return decorator\n\n\ndef withCUDA(func: Callable) -> Callable:\n    r\"\"\"A decorator to test both on CPU and CUDA (if available).\"\"\"\n    import pytest\n\n    devices = [pytest.param(torch.device('cpu'), id='cpu')]\n    if torch.cuda.is_available():\n        devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))\n\n    return pytest.mark.parametrize('device', devices)(func)\n\n\ndef withDevice(func: Callable) -> Callable:\n    r\"\"\"A decorator to test on all available tensor processing devices.\"\"\"\n    import pytest\n\n    devices = [pytest.param(torch.device('cpu'), id='cpu')]\n\n    if torch.cuda.is_available():\n        devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0'))\n\n    if torch_geometric.is_mps_available():\n        devices.append(pytest.param(torch.device('mps:0'), id='mps'))\n\n    if torch_geometric.is_xpu_available():\n        devices.append(pytest.param(torch.device('xpu:0'), id='xpu'))\n\n    # Additional devices can be registered through environment variables:\n    device = os.getenv('TORCH_DEVICE')\n    if device:\n        backend = os.getenv('TORCH_BACKEND')\n        if backend is None:\n            warnings.warn(\n                f\"Please specify the backend via 'TORCH_BACKEND' in\"\n                f\"order to test against '{device}'\", stacklevel=2)\n        else:\n            import_module(backend)\n            devices.append(pytest.param(torch.device(device), id=device))\n\n    return pytest.mark.parametrize('device', devices)(func)\n\n\ndef withMETIS(func: Callable) -> Callable:\n    r\"\"\"A decorator to only test in case a valid METIS method is available.\"\"\"\n    import pytest\n\n    with_metis = WITH_METIS\n\n    if with_metis:\n        try:  # Test that METIS can successfully execute:\n            # TODO Using `pyg-lib` metis partitioning leads to some weird bugs\n            # in the # CI. As such, we require `torch-sparse` for now.\n            rowptr = torch.tensor([0, 2, 4, 6])\n            col = torch.tensor([1, 2, 0, 2, 1, 0])\n            torch.ops.torch_sparse.partition(rowptr, col, None, 2, True)\n        except Exception:\n            with_metis = False\n\n    return pytest.mark.skipif(\n        not with_metis,\n        reason=\"METIS not enabled\",\n    )(func)\n\n\ndef withHashTensor(func: Callable) -> Callable:\n    r\"\"\"A decorator to only test in case :class:`HashTensor` is available.\"\"\"\n    import pytest\n\n    return pytest.mark.skipif(\n        not torch_geometric.typing.WITH_CPU_HASH_MAP\n        and not has_package('pandas'),\n        reason=\"HashTensor dependencies not available\",\n    )(func)\n\n\ndef disableExtensions(func: Callable) -> Callable:\n    r\"\"\"A decorator to temporarily disable the usage of the\n    :obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` extension\n    packages.\n    \"\"\"\n    import pytest\n\n    return pytest.mark.usefixtures('disable_extensions')(func)\n\n\ndef withoutExtensions(func: Callable) -> Callable:\n    r\"\"\"A decorator to test both with and without the usage of extension\n    packages such as :obj:`torch_scatter`, :obj:`torch_sparse` and\n    :obj:`pyg_lib`.\n    \"\"\"\n    import pytest\n\n    return pytest.mark.parametrize(\n        'without_extensions',\n        ['enable_extensions', 'disable_extensions'],\n        indirect=True,\n    )(func)\n"
  },
  {
    "path": "torch_geometric/testing/distributed.py",
    "content": "import sys\nimport traceback\nfrom dataclasses import dataclass\nfrom io import StringIO\nfrom typing import Any, Callable, List, Tuple\n\nimport pytest\nfrom torch.multiprocessing import Manager, Queue\nfrom typing_extensions import Self\n\n\n@dataclass\nclass ProcArgs:\n    target: Callable\n    args: Tuple[Any, ...]\n\n\nclass MPCaptOutput:\n    def __enter__(self) -> Self:\n        self.stdout = StringIO()\n        self.stderr = StringIO()\n        self.old_stdout = sys.stdout\n        self.old_stderr = sys.stderr\n\n        sys.stdout = self.stdout\n        sys.stderr = self.stderr\n\n        return self\n\n    def __exit__(self, *args: Any) -> None:\n        sys.stdout = self.old_stdout\n        sys.stderr = self.old_stderr\n\n    @property\n    def stdout_str(self) -> str:\n        return self.stdout.getvalue()\n\n    @property\n    def stderr_str(self) -> str:\n        return self.stderr.getvalue()\n\n\ndef ps_std_capture(\n    func: Callable,\n    queue: Queue,\n    *args: Any,\n    **kwargs: Any,\n) -> None:\n    with MPCaptOutput() as capt:\n        try:\n            func(*args, **kwargs)\n        except Exception as e:\n            traceback.print_exc(file=sys.stderr)\n            raise e\n        finally:\n            queue.put((capt.stdout_str, capt.stderr_str))\n\n\ndef assert_run_mproc(\n    mp_context: Any,\n    pargs: List[ProcArgs],\n    full_trace: bool = False,\n    timeout: int = 5,\n) -> None:\n    manager = Manager()\n    world_size = len(pargs)\n    queues = [manager.Queue() for _ in pargs]\n    procs = [\n        mp_context.Process(\n            target=ps_std_capture,\n            args=[p.target, q, world_size] + list(p.args),\n        ) for p, q in zip(pargs, queues)\n    ]\n    results = []\n\n    for p, _ in zip(procs, queues):\n        p.start()\n\n    for p, q in zip(procs, queues):\n        p.join()\n        stdout, stderr = q.get(timeout=timeout)\n        results.append((p, stdout, stderr))\n\n    for p, stdout, stderr in results:\n        if stdout:\n            print(stdout)\n        if stderr:  # can be a warning as well => exitcode == 0\n            print(stderr)\n        if p.exitcode != 0:\n            pytest.fail(\n                pytrace=full_trace, reason=stderr.splitlines()[-1]\n                if stderr else f\"exitcode {p.exitcode}\")\n"
  },
  {
    "path": "torch_geometric/testing/feature_store.py",
    "content": "from typing import Dict, List, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import FeatureStore, TensorAttr\nfrom torch_geometric.typing import FeatureTensorType\n\nKeyType = Tuple[Optional[str], Optional[str]]\n\n\nclass MyFeatureStore(FeatureStore):\n    def __init__(self) -> None:\n        super().__init__()\n        self.store: Dict[KeyType, Tuple[Tensor, Tensor]] = {}\n\n    @staticmethod\n    def key(attr: TensorAttr) -> KeyType:\n        return (attr.group_name, attr.attr_name)\n\n    def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:\n        index = attr.index\n\n        # None indices define the obvious index:\n        if index is None:\n            index = torch.arange(0, tensor.shape[0])\n\n        # Store the index:\n        assert isinstance(index, Tensor)\n        assert isinstance(tensor, Tensor)\n        self.store[self.key(attr)] = (index, tensor)\n\n        return True\n\n    def _get_tensor(self, attr: TensorAttr) -> Optional[Tensor]:\n        index, tensor = self.store.get(self.key(attr), (None, None))\n\n        if tensor is None:\n            raise KeyError(f\"Could not find tensor for '{attr}'\")\n\n        assert isinstance(tensor, Tensor)\n\n        # None indices return the whole tensor:\n        if attr.index is None:\n            return tensor\n\n        # Empty slices return the whole tensor:\n        if (isinstance(attr.index, slice)\n                and attr.index == slice(None, None, None)):\n            return tensor\n\n        assert isinstance(attr.index, Tensor)\n\n        if attr.index.numel() == 0:\n            return tensor[attr.index]\n\n        idx = torch.cat([(index == v).nonzero() for v in attr.index]).view(-1)\n        return tensor[idx]\n\n    def _remove_tensor(self, attr: TensorAttr) -> bool:\n        return self.store.pop(self.key(attr), None) is not None\n\n    def _get_tensor_size(self, attr: TensorAttr) -> Optional[Tuple[int, ...]]:\n        tensor = self._get_tensor(attr)\n        return tensor.size() if tensor is not None else None\n\n    def get_all_tensor_attrs(self) -> List[TensorAttr]:\n        return [self._tensor_attr_cls.cast(*key) for key in self.store.keys()]\n"
  },
  {
    "path": "torch_geometric/testing/graph_store.py",
    "content": "from typing import Dict, List, Optional, Tuple\n\nfrom torch import Tensor\n\nfrom torch_geometric.data import EdgeAttr, GraphStore\nfrom torch_geometric.typing import EdgeTensorType\n\n\nclass MyGraphStore(GraphStore):\n    def __init__(self) -> None:\n        super().__init__()\n        self.store: Dict[Tuple, Tuple[Tensor, Tensor]] = {}\n\n    @staticmethod\n    def key(attr: EdgeAttr) -> Tuple:\n        return (attr.edge_type, attr.layout.value, attr.is_sorted, attr.size)\n\n    def _put_edge_index(\n        self,\n        edge_index: EdgeTensorType,\n        edge_attr: EdgeAttr,\n    ) -> bool:\n        self.store[self.key(edge_attr)] = edge_index\n        return True\n\n    def _get_edge_index(self, edge_attr: EdgeAttr) -> Optional[EdgeTensorType]:\n        return self.store.get(self.key(edge_attr), None)\n\n    def _remove_edge_index(self, edge_attr: EdgeAttr) -> bool:\n        return self.store.pop(self.key(edge_attr), None) is not None\n\n    def get_all_edge_attrs(self) -> List[EdgeAttr]:\n        return [EdgeAttr(*key) for key in self.store.keys()]\n"
  },
  {
    "path": "torch_geometric/transforms/__init__.py",
    "content": "# flake8: noqa\n\nfrom .base_transform import BaseTransform\nfrom .compose import Compose, ComposeFilters\nfrom .to_device import ToDevice\nfrom .to_sparse_tensor import ToSparseTensor\nfrom .constant import Constant\nfrom .normalize_features import NormalizeFeatures\nfrom .svd_feature_reduction import SVDFeatureReduction\nfrom .remove_training_classes import RemoveTrainingClasses\nfrom .random_node_split import RandomNodeSplit\nfrom .random_link_split import RandomLinkSplit\nfrom .node_property_split import NodePropertySplit\nfrom .mask import IndexToMask, MaskToIndex\nfrom .pad import Pad\n\nfrom .to_undirected import ToUndirected\nfrom .one_hot_degree import OneHotDegree\nfrom .target_indegree import TargetIndegree\nfrom .local_degree_profile import LocalDegreeProfile\nfrom .add_self_loops import AddSelfLoops\nfrom .add_remaining_self_loops import AddRemainingSelfLoops\nfrom .remove_self_loops import RemoveSelfLoops\nfrom .remove_isolated_nodes import RemoveIsolatedNodes\nfrom .remove_duplicated_edges import RemoveDuplicatedEdges\nfrom .knn_graph import KNNGraph\nfrom .radius_graph import RadiusGraph\nfrom .to_dense import ToDense\nfrom .two_hop import TwoHop\nfrom .line_graph import LineGraph\nfrom .laplacian_lambda_max import LaplacianLambdaMax\nfrom .gdc import GDC\nfrom .sign import SIGN\nfrom .gcn_norm import GCNNorm\nfrom .add_metapaths import AddMetaPaths, AddRandomMetaPaths\nfrom .rooted_subgraph import RootedEgoNets, RootedRWSubgraph\nfrom .largest_connected_components import LargestConnectedComponents\nfrom .virtual_node import VirtualNode\nfrom .add_positional_encoding import AddLaplacianEigenvectorPE, AddRandomWalkPE\nfrom .add_gpse import AddGPSE\nfrom .feature_propagation import FeaturePropagation\nfrom .half_hop import HalfHop\n\nfrom .distance import Distance\nfrom .cartesian import Cartesian\nfrom .local_cartesian import LocalCartesian\nfrom .polar import Polar\nfrom .spherical import Spherical\nfrom .point_pair_features import PointPairFeatures\nfrom .center import Center\nfrom .normalize_rotation import NormalizeRotation\nfrom .normalize_scale import NormalizeScale\nfrom .random_jitter import RandomJitter\nfrom .random_flip import RandomFlip\nfrom .linear_transformation import LinearTransformation\nfrom .random_scale import RandomScale\nfrom .random_rotate import RandomRotate\nfrom .random_shear import RandomShear\nfrom .face_to_edge import FaceToEdge\nfrom .sample_points import SamplePoints\nfrom .fixed_points import FixedPoints\nfrom .generate_mesh_normals import GenerateMeshNormals\nfrom .delaunay import Delaunay\nfrom .to_superpixels import ToSLIC\nfrom .grid_sampling import GridSampling\n\ngeneral_transforms = [\n    'BaseTransform',\n    'Compose',\n    'ComposeFilters',\n    'ToDevice',\n    'ToSparseTensor',\n    'Constant',\n    'NormalizeFeatures',\n    'SVDFeatureReduction',\n    'RemoveTrainingClasses',\n    'RandomNodeSplit',\n    'RandomLinkSplit',\n    'NodePropertySplit',\n    'IndexToMask',\n    'MaskToIndex',\n    'Pad',\n]\n\ngraph_transforms = [\n    'ToUndirected',\n    'OneHotDegree',\n    'TargetIndegree',\n    'LocalDegreeProfile',\n    'AddSelfLoops',\n    'AddRemainingSelfLoops',\n    'RemoveSelfLoops',\n    'RemoveIsolatedNodes',\n    'RemoveDuplicatedEdges',\n    'KNNGraph',\n    'RadiusGraph',\n    'ToDense',\n    'TwoHop',\n    'LineGraph',\n    'LaplacianLambdaMax',\n    'GDC',\n    'SIGN',\n    'GCNNorm',\n    'AddMetaPaths',\n    'AddRandomMetaPaths',\n    'RootedEgoNets',\n    'RootedRWSubgraph',\n    'LargestConnectedComponents',\n    'VirtualNode',\n    'AddLaplacianEigenvectorPE',\n    'AddRandomWalkPE',\n    'AddGPSE',\n    'FeaturePropagation',\n    'HalfHop',\n]\n\nvision_transforms = [\n    'Distance',\n    'Cartesian',\n    'LocalCartesian',\n    'Polar',\n    'Spherical',\n    'PointPairFeatures',\n    'Center',\n    'NormalizeRotation',\n    'NormalizeScale',\n    'RandomJitter',\n    'RandomFlip',\n    'LinearTransformation',\n    'RandomScale',\n    'RandomRotate',\n    'RandomShear',\n    'FaceToEdge',\n    'SamplePoints',\n    'FixedPoints',\n    'GenerateMeshNormals',\n    'Delaunay',\n    'ToSLIC',\n    'GridSampling',\n]\n\n__all__ = general_transforms + graph_transforms + vision_transforms\n\nfrom torch_geometric.deprecation import deprecated\n\nRandomTranslate = deprecated(\"use 'transforms.RandomJitter' instead\",\n                             'transforms.RandomTranslate')(RandomJitter)\n"
  },
  {
    "path": "torch_geometric/transforms/add_gpse.py",
    "content": "from typing import Any\n\nfrom torch.nn import Module\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform, VirtualNode\n\n\n@functional_transform('add_gpse')\nclass AddGPSE(BaseTransform):\n    r\"\"\"Adds the GPSE encoding from the `\"Graph Positional and Structural\n    Encoder\" <https://arxiv.org/abs/2307.07107>`_ paper to the given graph\n    (functional name: :obj:`add_gpse`).\n    To be used with a :class:`~torch_geometric.nn.GPSE` model, which generates\n    the actual encodings.\n\n    Args:\n        model (Module): The pre-trained GPSE model.\n        use_vn (bool, optional): Whether to use virtual nodes.\n            (default: :obj:`True`)\n        rand_type (str, optional): Type of random features to use. Options are\n            :obj:`NormalSE`, :obj:`UniformSE`, :obj:`BernoulliSE`.\n            (default: :obj:`NormalSE`)\n\n    \"\"\"\n    def __init__(\n        self,\n        model: Module,\n        use_vn: bool = True,\n        rand_type: str = 'NormalSE',\n    ):\n        self.model = model\n        self.use_vn = use_vn\n        self.vn = VirtualNode()\n        self.rand_type = rand_type\n\n    def forward(self, data: Data) -> Any:\n        pass\n\n    def __call__(self, data: Data) -> Data:\n        from torch_geometric.nn.models.gpse import gpse_process\n\n        data_vn = self.vn(data.clone()) if self.use_vn else data.clone()\n        batch_out = gpse_process(self.model, data_vn, 'NormalSE', self.use_vn)\n        batch_out = batch_out.to('cpu', non_blocking=True)\n        data.pestat_GPSE = batch_out[:-1] if self.use_vn else batch_out\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/add_metapaths.py",
    "content": "import warnings\nfrom typing import List, Optional, Tuple, Union, cast\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.data import HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.typing import EdgeType\nfrom torch_geometric.utils import coalesce, degree\n\n\n@functional_transform('add_metapaths')\nclass AddMetaPaths(BaseTransform):\n    r\"\"\"Adds additional edge types to a\n    :class:`~torch_geometric.data.HeteroData` object between the source node\n    type and the destination node type of a given :obj:`metapath`, as described\n    in the `\"Heterogenous Graph Attention Networks\"\n    <https://arxiv.org/abs/1903.07293>`_ paper\n    (functional name: :obj:`add_metapaths`).\n\n    Meta-path based neighbors can exploit different aspects of structure\n    information in heterogeneous graphs.\n    Formally, a metapath is a path of the form\n\n    .. math::\n\n        \\mathcal{V}_1 \\xrightarrow{R_1} \\mathcal{V}_2 \\xrightarrow{R_2} \\ldots\n        \\xrightarrow{R_{\\ell-1}} \\mathcal{V}_{\\ell}\n\n    in which :math:`\\mathcal{V}_i` represents node types, and :math:`R_j`\n    represents the edge type connecting two node types.\n    The added edge type is given by the sequential multiplication  of\n    adjacency matrices along the metapath, and is added to the\n    :class:`~torch_geometric.data.HeteroData` object as edge type\n    :obj:`(src_node_type, \"metapath_*\", dst_node_type)`, where\n    :obj:`src_node_type` and :obj:`dst_node_type` denote :math:`\\mathcal{V}_1`\n    and :math:`\\mathcal{V}_{\\ell}`, respectively.\n\n    In addition, a :obj:`metapath_dict` object is added to the\n    :class:`~torch_geometric.data.HeteroData` object which maps the\n    metapath-based edge type to its original metapath.\n\n    .. code-block:: python\n\n        from torch_geometric.datasets import DBLP\n        from torch_geometric.data import HeteroData\n        from torch_geometric.transforms import AddMetaPaths\n\n        data = DBLP(root)[0]\n        # 4 node types: \"paper\", \"author\", \"conference\", and \"term\"\n        # 6 edge types: (\"paper\",\"author\"), (\"author\", \"paper\"),\n        #               (\"paper, \"term\"), (\"paper\", \"conference\"),\n        #               (\"term, \"paper\"), (\"conference\", \"paper\")\n\n        # Add two metapaths:\n        # 1. From \"paper\" to \"paper\" through \"conference\"\n        # 2. From \"author\" to \"conference\" through \"paper\"\n        metapaths = [[(\"paper\", \"conference\"), (\"conference\", \"paper\")],\n                     [(\"author\", \"paper\"), (\"paper\", \"conference\")]]\n        data = AddMetaPaths(metapaths)(data)\n\n        print(data.edge_types)\n        >>> [(\"author\", \"to\", \"paper\"), (\"paper\", \"to\", \"author\"),\n             (\"paper\", \"to\", \"term\"), (\"paper\", \"to\", \"conference\"),\n             (\"term\", \"to\", \"paper\"), (\"conference\", \"to\", \"paper\"),\n             (\"paper\", \"metapath_0\", \"paper\"),\n             (\"author\", \"metapath_1\", \"conference\")]\n\n        print(data.metapath_dict)\n        >>> {(\"paper\", \"metapath_0\", \"paper\"): [(\"paper\", \"conference\"),\n                                                (\"conference\", \"paper\")],\n             (\"author\", \"metapath_1\", \"conference\"): [(\"author\", \"paper\"),\n                                                      (\"paper\", \"conference\")]}\n\n    Args:\n        metapaths (List[List[Tuple[str, str, str]]]): The metapaths described\n            by a list of lists of\n            :obj:`(src_node_type, rel_type, dst_node_type)` tuples.\n        drop_orig_edge_types (bool, optional): If set to :obj:`True`, existing\n            edge types will be dropped. (default: :obj:`False`)\n        keep_same_node_type (bool, optional): If set to :obj:`True`, existing\n            edge types between the same node type are not dropped even in case\n            :obj:`drop_orig_edge_types` is set to :obj:`True`.\n            (default: :obj:`False`)\n        drop_unconnected_node_types (bool, optional): If set to :obj:`True`,\n            will drop node types not connected by any edge type.\n            (default: :obj:`False`)\n        max_sample (int, optional): If set, will sample at maximum\n            :obj:`max_sample` neighbors within metapaths. Useful in order to\n            tackle very dense metapath edges. (default: :obj:`None`)\n        weighted (bool, optional): If set to :obj:`True`, computes weights for\n            each metapath edge and stores them in :obj:`edge_weight`. The\n            weight of each metapath edge is computed as the number of metapaths\n            from the start to the end of the metapath edge.\n            (default :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        metapaths: List[List[EdgeType]],\n        drop_orig_edge_types: bool = False,\n        keep_same_node_type: bool = False,\n        drop_unconnected_node_types: bool = False,\n        max_sample: Optional[int] = None,\n        weighted: bool = False,\n        **kwargs: bool,\n    ) -> None:\n        if 'drop_orig_edges' in kwargs:\n            warnings.warn(\n                \"'drop_orig_edges' is deprecated. Use \"\n                \"'drop_orig_edge_types' instead\", stacklevel=2)\n            drop_orig_edge_types = kwargs['drop_orig_edges']\n\n        if 'drop_unconnected_nodes' in kwargs:\n            warnings.warn(\n                \"'drop_unconnected_nodes' is deprecated. Use \"\n                \"'drop_unconnected_node_types' instead\", stacklevel=2)\n            drop_unconnected_node_types = kwargs['drop_unconnected_nodes']\n\n        for path in metapaths:\n            assert len(path) >= 2, f\"Invalid metapath '{path}'\"\n            assert all([\n                j[-1] == path[i + 1][0] for i, j in enumerate(path[:-1])\n            ]), f\"Invalid sequence of node types in '{path}'\"\n\n        self.metapaths = metapaths\n        self.drop_orig_edge_types = drop_orig_edge_types\n        self.keep_same_node_type = keep_same_node_type\n        self.drop_unconnected_node_types = drop_unconnected_node_types\n        self.max_sample = max_sample\n        self.weighted = weighted\n\n    def forward(self, data: HeteroData) -> HeteroData:\n        edge_types = data.edge_types  # Save original edge types.\n        data.metapath_dict = {}\n\n        for j, metapath in enumerate(self.metapaths):\n            for edge_type in metapath:\n                assert data._to_canonical(edge_type) in edge_types\n\n            edge_type = metapath[0]\n            edge_index, edge_weight = self._edge_index(data, edge_type)\n\n            if self.max_sample is not None:\n                edge_index, edge_weight = self._sample(edge_index, edge_weight)\n\n            for edge_type in metapath[1:]:\n                edge_index2, edge_weight2 = self._edge_index(data, edge_type)\n\n                edge_index, edge_weight = edge_index.matmul(\n                    edge_index2, edge_weight, edge_weight2)\n\n                if not self.weighted:\n                    edge_weight = None\n\n                if self.max_sample is not None:\n                    edge_index, edge_weight = self._sample(\n                        edge_index, edge_weight)\n\n            new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])\n            data[new_edge_type].edge_index = edge_index.as_tensor()\n            if self.weighted:\n                data[new_edge_type].edge_weight = edge_weight\n            data.metapath_dict[new_edge_type] = metapath\n\n        postprocess(data, edge_types, self.drop_orig_edge_types,\n                    self.keep_same_node_type, self.drop_unconnected_node_types)\n\n        return data\n\n    def _edge_index(\n        self,\n        data: HeteroData,\n        edge_type: EdgeType,\n    ) -> Tuple[EdgeIndex, Optional[Tensor]]:\n\n        edge_index = EdgeIndex(\n            data[edge_type].edge_index,\n            sparse_size=data[edge_type].size(),\n        )\n        edge_index, perm = edge_index.sort_by('row')\n\n        if not self.weighted:\n            return edge_index, None\n\n        edge_weight = data[edge_type].get('edge_weight')\n        if edge_weight is not None:\n            assert edge_weight.dim() == 1\n            edge_weight = edge_weight[perm]\n\n        return edge_index, edge_weight\n\n    def _sample(\n        self,\n        edge_index: EdgeIndex,\n        edge_weight: Optional[Tensor],\n    ) -> Tuple[EdgeIndex, Optional[Tensor]]:\n        assert self.max_sample is not None\n\n        deg = degree(edge_index[0], num_nodes=edge_index.get_sparse_size(0))\n        prob = (self.max_sample * (1. / deg))[edge_index[0]]\n        mask = torch.rand_like(prob) < prob\n\n        edge_index = cast(EdgeIndex, edge_index[:, mask])\n        assert isinstance(edge_index, EdgeIndex)\n        if edge_weight is not None:\n            edge_weight = edge_weight[mask]\n\n        return edge_index, edge_weight\n\n\n@functional_transform('add_random_metapaths')\nclass AddRandomMetaPaths(BaseTransform):\n    r\"\"\"Adds additional edge types similar to :class:`AddMetaPaths`.\n    The key difference is that the added edge type is given by\n    multiple random walks along the metapath.\n    One might want to increase the number of random walks\n    via :obj:`walks_per_node` to achieve competitive performance with\n    :class:`AddMetaPaths`.\n\n    Args:\n        metapaths (List[List[Tuple[str, str, str]]]): The metapaths described\n            by a list of lists of\n            :obj:`(src_node_type, rel_type, dst_node_type)` tuples.\n        drop_orig_edge_types (bool, optional): If set to :obj:`True`, existing\n            edge types will be dropped. (default: :obj:`False`)\n        keep_same_node_type (bool, optional): If set to :obj:`True`, existing\n            edge types between the same node type are not dropped even in case\n            :obj:`drop_orig_edge_types` is set to :obj:`True`.\n            (default: :obj:`False`)\n        drop_unconnected_node_types (bool, optional): If set to :obj:`True`,\n            will drop node types not connected by any edge type.\n            (default: :obj:`False`)\n        walks_per_node (int, List[int], optional): The number of random walks\n            for each starting node in a metapath. (default: :obj:`1`)\n        sample_ratio (float, optional): The ratio of source nodes to start\n            random walks from. (default: :obj:`1.0`)\n    \"\"\"\n    def __init__(\n        self,\n        metapaths: List[List[EdgeType]],\n        drop_orig_edge_types: bool = False,\n        keep_same_node_type: bool = False,\n        drop_unconnected_node_types: bool = False,\n        walks_per_node: Union[int, List[int]] = 1,\n        sample_ratio: float = 1.0,\n    ):\n\n        for path in metapaths:\n            assert len(path) >= 2, f\"Invalid metapath '{path}'\"\n            assert all([\n                j[-1] == path[i + 1][0] for i, j in enumerate(path[:-1])\n            ]), f\"Invalid sequence of node types in '{path}'\"\n\n        self.metapaths = metapaths\n        self.drop_orig_edge_types = drop_orig_edge_types\n        self.keep_same_node_type = keep_same_node_type\n        self.drop_unconnected_node_types = drop_unconnected_node_types\n        self.sample_ratio = sample_ratio\n        if isinstance(walks_per_node, int):\n            walks_per_node = [walks_per_node] * len(metapaths)\n        assert len(walks_per_node) == len(metapaths)\n        self.walks_per_node = walks_per_node\n\n    def forward(self, data: HeteroData) -> HeteroData:\n        edge_types = data.edge_types  # save original edge types\n        data.metapath_dict = {}\n\n        for j, metapath in enumerate(self.metapaths):\n            for edge_type in metapath:\n                assert data._to_canonical(\n                    edge_type) in edge_types, f\"'{edge_type}' not present\"\n\n            src_node = metapath[0][0]\n            num_nodes = data[src_node].num_nodes\n            num_starts = round(num_nodes * self.sample_ratio)\n            row = start = torch.randperm(num_nodes)[:num_starts].repeat(\n                self.walks_per_node[j])\n\n            for edge_type in metapath:\n                edge_index = EdgeIndex(\n                    data[edge_type].edge_index,\n                    sparse_size=data[edge_type].size(),\n                )\n                col, mask = self.sample(edge_index, start)\n                row, col = row[mask], col[mask]\n                start = col\n\n            new_edge_type = (metapath[0][0], f'metapath_{j}', metapath[-1][-1])\n            data[new_edge_type].edge_index = coalesce(torch.vstack([row, col]))\n            data.metapath_dict[new_edge_type] = metapath\n\n        postprocess(data, edge_types, self.drop_orig_edge_types,\n                    self.keep_same_node_type, self.drop_unconnected_node_types)\n\n        return data\n\n    @staticmethod\n    def sample(edge_index: EdgeIndex, subset: Tensor) -> Tuple[Tensor, Tensor]:\n        \"\"\"Sample neighbors from :obj:`edge_index` for each node in\n        :obj:`subset`.\n        \"\"\"\n        edge_index, _ = edge_index.sort_by('row')\n        rowptr = edge_index.get_indptr()\n        rowcount = rowptr.diff()[subset]\n\n        mask = rowcount > 0\n        offset = torch.zeros_like(subset)\n        offset[mask] = rowptr[subset[mask]]\n\n        rand = torch.rand((rowcount.size(0), 1), device=subset.device)\n        rand.mul_(rowcount.to(rand.dtype).view(-1, 1))\n        rand = rand.to(torch.long)\n        rand.add_(offset.view(-1, 1))\n        col = edge_index[1][rand].squeeze()\n        return col, mask\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'sample_ratio={self.sample_ratio}, '\n                f'walks_per_node={self.walks_per_node})')\n\n\ndef postprocess(\n    data: HeteroData,\n    edge_types: List[EdgeType],\n    drop_orig_edge_types: bool,\n    keep_same_node_type: bool,\n    drop_unconnected_node_types: bool,\n) -> None:\n\n    if drop_orig_edge_types:\n        for i in edge_types:\n            if keep_same_node_type and i[0] == i[-1]:\n                continue\n            else:\n                del data[i]\n\n    # Remove nodes not connected by any edge type:\n    if drop_unconnected_node_types:\n        new_edge_types = data.edge_types\n        node_types = data.node_types\n        connected_nodes = set()\n        for i in new_edge_types:\n            connected_nodes.add(i[0])\n            connected_nodes.add(i[-1])\n        for node in node_types:\n            if node not in connected_nodes:\n                del data[node]\n"
  },
  {
    "path": "torch_geometric/transforms/add_positional_encoding.py",
    "content": "from typing import Any, Optional\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import (\n    get_laplacian,\n    get_self_loop_attr,\n    is_torch_sparse_tensor,\n    scatter,\n    to_edge_index,\n    to_scipy_sparse_matrix,\n    to_torch_coo_tensor,\n    to_torch_csr_tensor,\n)\n\n\ndef add_node_attr(\n    data: Data,\n    value: Any,\n    attr_name: Optional[str] = None,\n) -> Data:\n    # TODO Move to `BaseTransform`.\n    if attr_name is None:\n        if data.x is not None:\n            x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x\n            data.x = torch.cat([x, value.to(x.device, x.dtype)], dim=-1)\n        else:\n            data.x = value\n    else:\n        data[attr_name] = value\n\n    return data\n\n\n@functional_transform('add_laplacian_eigenvector_pe')\nclass AddLaplacianEigenvectorPE(BaseTransform):\n    r\"\"\"Adds the Laplacian eigenvector positional encoding from the\n    `\"Benchmarking Graph Neural Networks\" <https://arxiv.org/abs/2003.00982>`_\n    paper to the given graph\n    (functional name: :obj:`add_laplacian_eigenvector_pe`).\n\n    Args:\n        k (int): The number of non-trivial eigenvectors to consider.\n        attr_name (str, optional): The attribute name of the data object to add\n            positional encodings to. If set to :obj:`None`, will be\n            concatenated to :obj:`data.x`.\n            (default: :obj:`\"laplacian_eigenvector_pe\"`)\n        is_undirected (bool, optional): If set to :obj:`True`, this transform\n            expects undirected graphs as input, and can hence speed up the\n            computation of eigenvectors. (default: :obj:`False`)\n        **kwargs (optional): Additional arguments of\n            :meth:`scipy.sparse.linalg.eigs` (when :attr:`is_undirected` is\n            :obj:`False`) or :meth:`scipy.sparse.linalg.eigsh` (when\n            :attr:`is_undirected` is :obj:`True`).\n    \"\"\"\n    # Number of nodes from which to use sparse eigenvector computation:\n    SPARSE_THRESHOLD: int = 100\n\n    def __init__(\n        self,\n        k: int,\n        attr_name: Optional[str] = 'laplacian_eigenvector_pe',\n        is_undirected: bool = False,\n        **kwargs: Any,\n    ) -> None:\n        self.k = k\n        self.attr_name = attr_name\n        self.is_undirected = is_undirected\n        self.kwargs = kwargs\n\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n        num_nodes = data.num_nodes\n        assert num_nodes is not None\n\n        edge_index, edge_weight = get_laplacian(\n            data.edge_index,\n            data.edge_weight,\n            normalization='sym',\n            num_nodes=num_nodes,\n        )\n\n        L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes)\n\n        if num_nodes < self.SPARSE_THRESHOLD:\n            from numpy.linalg import eig, eigh\n            eig_fn = eig if not self.is_undirected else eigh\n\n            eig_vals, eig_vecs = eig_fn(L.todense())\n        else:\n            from scipy.sparse.linalg import eigs, eigsh\n            eig_fn = eigs if not self.is_undirected else eigsh\n\n            eig_vals, eig_vecs = eig_fn(\n                L,\n                k=self.k + 1,\n                which='SR' if not self.is_undirected else 'SA',\n                return_eigenvectors=True,\n                **self.kwargs,\n            )\n\n        eig_vecs = np.real(eig_vecs[:, eig_vals.argsort()])\n        pe = torch.from_numpy(eig_vecs[:, 1:self.k + 1])\n        sign = -1 + 2 * torch.randint(0, 2, (self.k, ))\n        pe *= sign\n\n        data = add_node_attr(data, pe, attr_name=self.attr_name)\n        return data\n\n\n@functional_transform('add_random_walk_pe')\nclass AddRandomWalkPE(BaseTransform):\n    r\"\"\"Adds the random walk positional encoding from the `\"Graph Neural\n    Networks with Learnable Structural and Positional Representations\"\n    <https://arxiv.org/abs/2110.07875>`_ paper to the given graph\n    (functional name: :obj:`add_random_walk_pe`).\n\n    Args:\n        walk_length (int): The number of random walk steps.\n        attr_name (str, optional): The attribute name of the data object to add\n            positional encodings to. If set to :obj:`None`, will be\n            concatenated to :obj:`data.x`.\n            (default: :obj:`\"random_walk_pe\"`)\n    \"\"\"\n    def __init__(\n        self,\n        walk_length: int,\n        attr_name: Optional[str] = 'random_walk_pe',\n    ) -> None:\n        self.walk_length = walk_length\n        self.attr_name = attr_name\n\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n        row, col = data.edge_index\n        N = data.num_nodes\n        assert N is not None\n\n        if data.edge_weight is None:\n            value = torch.ones(data.num_edges, device=row.device)\n        else:\n            value = data.edge_weight\n        value = scatter(value, row, dim_size=N, reduce='sum').clamp(min=1)[row]\n        value = 1.0 / value\n\n        if N <= 2_000:  # Dense code path for faster computation:\n            adj = torch.zeros((N, N), device=row.device)\n            adj[row, col] = value\n            loop_index = torch.arange(N, device=row.device)\n        elif torch_geometric.typing.NO_MKL:  # pragma: no cover\n            adj = to_torch_coo_tensor(data.edge_index, value, size=data.size())\n        else:\n            adj = to_torch_csr_tensor(data.edge_index, value, size=data.size())\n\n        def get_pe(out: Tensor) -> Tensor:\n            if is_torch_sparse_tensor(out):\n                return get_self_loop_attr(*to_edge_index(out), num_nodes=N)\n            return out[loop_index, loop_index]\n\n        out = adj\n        pe_list = [get_pe(out)]\n        for _ in range(self.walk_length - 1):\n            out = out @ adj\n            pe_list.append(get_pe(out))\n\n        pe = torch.stack(pe_list, dim=-1)\n        data = add_node_attr(data, pe, attr_name=self.attr_name)\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/add_remaining_self_loops.py",
    "content": "from typing import Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import add_remaining_self_loops\n\n\n@functional_transform('add_remaining_self_loops')\nclass AddRemainingSelfLoops(BaseTransform):\n    r\"\"\"Adds remaining self-loops to the given homogeneous or heterogeneous\n    graph (functional name: :obj:`add_remaining_self_loops`).\n\n    Args:\n        attr (str, optional): The name of the attribute of edge weights\n            or multi-dimensional edge features to pass to\n            :meth:`torch_geometric.utils.add_remaining_self_loops`.\n            (default: :obj:`\"edge_weight\"`)\n        fill_value (float or Tensor or str, optional): The way to generate\n            edge features of self-loops (in case :obj:`attr != None`).\n            If given as :obj:`float` or :class:`torch.Tensor`, edge features of\n            self-loops will be directly given by :obj:`fill_value`.\n            If given as :obj:`str`, edge features of self-loops are computed by\n            aggregating all features of edges that point to the specific node,\n            according to a reduce operation. (:obj:`\"add\"`, :obj:`\"mean\"`,\n            :obj:`\"min\"`, :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`1.`)\n    \"\"\"\n    def __init__(\n        self,\n        attr: str = 'edge_weight',\n        fill_value: Union[float, Tensor, str] = 1.0,\n    ):\n        self.attr = attr\n        self.fill_value = fill_value\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for store in data.edge_stores:\n            if store.is_bipartite() or 'edge_index' not in store:\n                continue\n\n            store.edge_index, store[self.attr] = add_remaining_self_loops(\n                store.edge_index,\n                edge_attr=store.get(self.attr, None),\n                fill_value=self.fill_value,\n                num_nodes=store.size(0),\n            )\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/add_self_loops.py",
    "content": "from typing import Union\n\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import add_self_loops\n\n\n@functional_transform('add_self_loops')\nclass AddSelfLoops(BaseTransform):\n    r\"\"\"Adds self-loops to the given homogeneous or heterogeneous graph\n    (functional name: :obj:`add_self_loops`).\n\n    Args:\n        attr (str, optional): The name of the attribute of edge weights\n            or multi-dimensional edge features to pass to\n            :meth:`torch_geometric.utils.add_self_loops`.\n            (default: :obj:`\"edge_weight\"`)\n        fill_value (float or Tensor or str, optional): The way to generate\n            edge features of self-loops (in case :obj:`attr != None`).\n            If given as :obj:`float` or :class:`torch.Tensor`, edge features of\n            self-loops will be directly given by :obj:`fill_value`.\n            If given as :obj:`str`, edge features of self-loops are computed by\n            aggregating all features of edges that point to the specific node,\n            according to a reduce operation. (:obj:`\"add\"`, :obj:`\"mean\"`,\n            :obj:`\"min\"`, :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`1.`)\n    \"\"\"\n    def __init__(\n        self,\n        attr: str = 'edge_weight',\n        fill_value: Union[float, Tensor, str] = 1.0,\n    ) -> None:\n        self.attr = attr\n        self.fill_value = fill_value\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for store in data.edge_stores:\n            if store.is_bipartite() or 'edge_index' not in store:\n                continue\n\n            store.edge_index, store[self.attr] = add_self_loops(\n                store.edge_index,\n                edge_attr=store.get(self.attr, None),\n                fill_value=self.fill_value,\n                num_nodes=store.size(0),\n            )\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/base_transform.py",
    "content": "import copy\nfrom abc import ABC, abstractmethod\nfrom typing import Any\n\n\nclass BaseTransform(ABC):\n    r\"\"\"An abstract base class for writing transforms.\n\n    Transforms are a general way to modify and customize\n    :class:`~torch_geometric.data.Data` or\n    :class:`~torch_geometric.data.HeteroData` objects, either by implicitly\n    passing them as an argument to a :class:`~torch_geometric.data.Dataset`, or\n    by applying them explicitly to individual\n    :class:`~torch_geometric.data.Data` or\n    :class:`~torch_geometric.data.HeteroData` objects:\n\n    .. code-block:: python\n\n        import torch_geometric.transforms as T\n        from torch_geometric.datasets import TUDataset\n\n        transform = T.Compose([T.ToUndirected(), T.AddSelfLoops()])\n\n        dataset = TUDataset(path, name='MUTAG', transform=transform)\n        data = dataset[0]  # Implicitly transform data on every access.\n\n        data = TUDataset(path, name='MUTAG')[0]\n        data = transform(data)  # Explicitly transform data.\n    \"\"\"\n    def __call__(self, data: Any) -> Any:\n        # Shallow-copy the data so that we prevent in-place data modification.\n        return self.forward(copy.copy(data))\n\n    @abstractmethod\n    def forward(self, data: Any) -> Any:\n        pass\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}()'\n"
  },
  {
    "path": "torch_geometric/transforms/cartesian.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('cartesian')\nclass Cartesian(BaseTransform):\n    r\"\"\"Saves the relative Cartesian coordinates of linked nodes in its edge\n    attributes (functional name: :obj:`cartesian`). Each coordinate gets\n    globally normalized to a specified interval (:math:`[0, 1]` by default).\n\n    Args:\n        norm (bool, optional): If set to :obj:`False`, the output will not be\n            normalized. (default: :obj:`True`)\n        max_value (float, optional): If set and :obj:`norm=True`, normalization\n            will be performed based on this value instead of the maximum value\n            found in the data. (default: :obj:`None`)\n        cat (bool, optional): If set to :obj:`False`, all existing edge\n            attributes will be replaced. (default: :obj:`True`)\n        interval ((float, float), optional): A tuple specifying the lower and\n            upper bound for normalization. (default: :obj:`(0.0, 1.0)`)\n    \"\"\"\n    def __init__(\n            self,\n            norm: bool = True,\n            max_value: Optional[float] = None,\n            cat: bool = True,\n            interval: Tuple[float, float] = (0.0, 1.0),\n    ):\n        self.norm = norm\n        self.max = max_value\n        self.cat = cat\n        self.interval = interval\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        assert data.edge_index is not None\n        (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr\n\n        cart = pos[row] - pos[col]\n        cart = cart.view(-1, 1) if cart.dim() == 1 else cart\n\n        if self.norm and cart.numel() > 0:\n            max_val = float(cart.abs().max()) if self.max is None else self.max\n\n            length = self.interval[1] - self.interval[0]\n            center = (self.interval[0] + self.interval[1]) / 2\n            cart = length * cart / (2 * max_val) + center\n\n        if pseudo is not None and self.cat:\n            pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo\n            data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1)\n        else:\n            data.edge_attr = cart\n\n        return data\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(norm={self.norm}, '\n                f'max_value={self.max})')\n"
  },
  {
    "path": "torch_geometric/transforms/center.py",
    "content": "from typing import Union\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('center')\nclass Center(BaseTransform):\n    r\"\"\"Centers node positions :obj:`data.pos` around the origin\n    (functional name: :obj:`center`).\n    \"\"\"\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for store in data.node_stores:\n            if hasattr(store, 'pos'):\n                store.pos = store.pos - store.pos.mean(dim=-2, keepdim=True)\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/compose.py",
    "content": "from typing import Callable, List, Union\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.transforms import BaseTransform\n\n\nclass Compose(BaseTransform):\n    r\"\"\"Composes several transforms together.\n\n    Args:\n        transforms (List[Callable]): List of transforms to compose.\n    \"\"\"\n    def __init__(self, transforms: List[Callable]):\n        self.transforms = transforms\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for transform in self.transforms:\n            if isinstance(data, (list, tuple)):\n                data = [transform(d) for d in data]\n            else:\n                data = transform(data)\n        return data\n\n    def __repr__(self) -> str:\n        args = [f'  {transform}' for transform in self.transforms]\n        return '{}([\\n{}\\n])'.format(self.__class__.__name__, ',\\n'.join(args))\n\n\nclass ComposeFilters:\n    r\"\"\"Composes several filters together.\n\n    Args:\n        filters (List[Callable]): List of filters to compose.\n    \"\"\"\n    def __init__(self, filters: List[Callable]):\n        self.filters = filters\n\n    def __call__(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> bool:\n        for filter_fn in self.filters:\n            if isinstance(data, (list, tuple)):\n                if not all([filter_fn(d) for d in data]):\n                    return False\n            elif not filter_fn(data):\n                return False\n        return True\n\n    def __repr__(self) -> str:\n        args = [f'  {filter_fn}' for filter_fn in self.filters]\n        return '{}([\\n{}\\n])'.format(self.__class__.__name__, ',\\n'.join(args))\n"
  },
  {
    "path": "torch_geometric/transforms/constant.py",
    "content": "from typing import List, Optional, Union\n\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('constant')\nclass Constant(BaseTransform):\n    r\"\"\"Appends a constant value to each node feature :obj:`x`\n    (functional name: :obj:`constant`).\n\n    Args:\n        value (float, optional): The value to add. (default: :obj:`1.0`)\n        cat (bool, optional): If set to :obj:`False`, existing node features\n            will be replaced. (default: :obj:`True`)\n        node_types (str or List[str], optional): The specified node type(s) to\n            append constant values for if used on heterogeneous graphs.\n            If set to :obj:`None`, constants will be added to each node feature\n            :obj:`x` for all existing node types. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        value: float = 1.0,\n        cat: bool = True,\n        node_types: Optional[Union[str, List[str]]] = None,\n    ):\n        if isinstance(node_types, str):\n            node_types = [node_types]\n\n        self.value = value\n        self.cat = cat\n        self.node_types = node_types\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n\n        for store in data.node_stores:\n            if self.node_types is None or store._key in self.node_types:\n                num_nodes = store.num_nodes\n                assert num_nodes is not None\n                c = torch.full((num_nodes, 1), self.value, dtype=torch.float)\n\n                if hasattr(store, 'x') and self.cat:\n                    x = store.x.view(-1, 1) if store.x.dim() == 1 else store.x\n                    store.x = torch.cat([x, c.to(x.device, x.dtype)], dim=-1)\n                else:\n                    store.x = c\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(value={self.value})'\n"
  },
  {
    "path": "torch_geometric/transforms/delaunay.py",
    "content": "from typing import List\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\nclass _QhullTransform(BaseTransform):\n    r\"\"\"Q-hull implementation of delaunay triangulation.\"\"\"\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        import scipy.spatial\n\n        pos = data.pos.cpu().numpy()\n        tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')\n        face = torch.from_numpy(tri.simplices)\n\n        data.face = face.t().contiguous().to(data.pos.device, torch.long)\n        return data\n\n\nclass _ShullTransform(BaseTransform):\n    r\"\"\"Sweep-hull implementation of delaunay triangulation.\"\"\"\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        from torch_delaunay.functional import shull2d\n\n        face = shull2d(data.pos.cpu())\n        data.face = face.t().contiguous().to(data.pos.device)\n        return data\n\n\nclass _SequentialTransform(BaseTransform):\n    r\"\"\"Runs the first successful transformation.\n\n    All intermediate exceptions are suppressed except the last.\n    \"\"\"\n    def __init__(self, transforms: List[BaseTransform]) -> None:\n        assert len(transforms) > 0\n        self.transforms = transforms\n\n    def forward(self, data: Data) -> Data:\n        for i, transform in enumerate(self.transforms):\n            try:\n                return transform.forward(data)\n            except ImportError as e:\n                if i == len(self.transforms) - 1:\n                    raise e\n        return data\n\n\n@functional_transform('delaunay')\nclass Delaunay(BaseTransform):\n    r\"\"\"Computes the delaunay triangulation of a set of points\n    (functional name: :obj:`delaunay`).\n\n    .. hint::\n        Consider installing the\n        `torch_delaunay <https://github.com/ybubnov/torch_delaunay>`_ package\n        to speed up computation.\n    \"\"\"\n    def __init__(self) -> None:\n        self._transform = _SequentialTransform([\n            _ShullTransform(),\n            _QhullTransform(),\n        ])\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        device = data.pos.device\n\n        if data.pos.size(0) < 2:\n            data.edge_index = torch.empty(2, 0, dtype=torch.long,\n                                          device=device)\n        elif data.pos.size(0) == 2:\n            data.edge_index = torch.tensor([[0, 1], [1, 0]], device=device)\n        elif data.pos.size(0) == 3:\n            data.face = torch.tensor([[0], [1], [2]], device=device)\n        else:\n            data = self._transform.forward(data)\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/distance.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('distance')\nclass Distance(BaseTransform):\n    r\"\"\"Saves the Euclidean distance of linked nodes in its edge attributes\n    (functional name: :obj:`distance`). Each distance gets globally normalized\n    to a specified interval (:math:`[0, 1]` by default).\n\n    Args:\n        norm (bool, optional): If set to :obj:`False`, the output will not be\n            normalized. (default: :obj:`True`)\n        max_value (float, optional): If set and :obj:`norm=True`, normalization\n            will be performed based on this value instead of the maximum value\n            found in the data. (default: :obj:`None`)\n        cat (bool, optional): If set to :obj:`False`, all existing edge\n            attributes will be replaced. (default: :obj:`True`)\n        interval ((float, float), optional): A tuple specifying the lower and\n            upper bound for normalization. (default: :obj:`(0.0, 1.0)`)\n    \"\"\"\n    def __init__(\n            self,\n            norm: bool = True,\n            max_value: Optional[float] = None,\n            cat: bool = True,\n            interval: Tuple[float, float] = (0.0, 1.0),\n    ):\n        self.norm = norm\n        self.max = max_value\n        self.cat = cat\n        self.interval = interval\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        assert data.edge_index is not None\n        (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr\n\n        dist = torch.norm(pos[col] - pos[row], p=2, dim=-1).view(-1, 1)\n\n        if self.norm and dist.numel() > 0:\n            max_val = float(dist.max()) if self.max is None else self.max\n\n            length = self.interval[1] - self.interval[0]\n            dist = length * (dist / max_val) + self.interval[0]\n\n        if pseudo is not None and self.cat:\n            pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo\n            data.edge_attr = torch.cat([pseudo, dist.type_as(pseudo)], dim=-1)\n        else:\n            data.edge_attr = dist\n\n        return data\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(norm={self.norm}, '\n                f'max_value={self.max})')\n"
  },
  {
    "path": "torch_geometric/transforms/face_to_edge.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import to_undirected\n\n\n@functional_transform('face_to_edge')\nclass FaceToEdge(BaseTransform):\n    r\"\"\"Converts mesh faces of shape :obj:`[3, num_faces]` or\n    :obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]`\n    (functional name: :obj:`face_to_edge`).\n\n    This transform supports both 2D triangular faces, represented by a\n    tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces,\n    represented by a tensor of shape :obj:`[4, num_faces]`. It will convert\n    these faces into edge indices, where each edge is defined by the indices\n    of its two endpoints.\n\n    Args:\n        remove_faces (bool, optional): If set to :obj:`False`, the face tensor\n            will not be removed.\n    \"\"\"\n    def __init__(self, remove_faces: bool = True) -> None:\n        self.remove_faces = remove_faces\n\n    def forward(self, data: Data) -> Data:\n        if hasattr(data, 'face'):\n            assert data.face is not None\n            face = data.face\n\n            if face.size(0) not in [3, 4]:\n                raise RuntimeError(f\"Expected 'face' tensor with shape \"\n                                   f\"[3, num_faces] or [4, num_faces] \"\n                                   f\"(got {list(face.size())})\")\n\n            if face.size()[0] == 3:\n                edge_index = torch.cat([\n                    face[:2],\n                    face[1:],\n                    face[::2],\n                ], dim=1)\n            else:\n                assert face.size()[0] == 4\n                edge_index = torch.cat([\n                    face[:2],\n                    face[1:3],\n                    face[2:4],\n                    face[::2],\n                    face[1::2],\n                    face[::3],\n                ], dim=1)\n\n            edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)\n\n            data.edge_index = edge_index\n            if self.remove_faces:\n                data.face = None\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/feature_propagation.py",
    "content": "from torch import Tensor\n\nimport torch_geometric\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import is_torch_sparse_tensor, to_torch_csc_tensor\n\n\n@functional_transform('feature_propagation')\nclass FeaturePropagation(BaseTransform):\n    r\"\"\"The feature propagation operator from the `\"On the Unreasonable\n    Effectiveness of Feature propagation in Learning on Graphs with Missing\n    Node Features\" <https://arxiv.org/abs/2111.12128>`_ paper\n    (functional name: :obj:`feature_propagation`).\n\n    .. math::\n        \\mathbf{X}^{(0)} &= (1 - \\mathbf{M}) \\cdot \\mathbf{X}\n\n        \\mathbf{X}^{(\\ell + 1)} &= \\mathbf{X}^{(0)} + \\mathbf{M} \\cdot\n        (\\mathbf{D}^{-1/2} \\mathbf{A} \\mathbf{D}^{-1/2} \\mathbf{X}^{(\\ell)})\n\n    where missing node features are inferred by known features via propagation.\n\n    .. code-block:: python\n\n        from torch_geometric.transforms import FeaturePropagation\n\n        transform = FeaturePropagation(missing_mask=torch.isnan(data.x))\n        data = transform(data)\n\n    Args:\n        missing_mask (torch.Tensor): Mask matrix\n            :math:`\\mathbf{M} \\in {\\{ 0, 1 \\}}^{N\\times F}` indicating missing\n            node features.\n        num_iterations (int, optional): The number of propagations.\n            (default: :obj:`40`)\n    \"\"\"\n    def __init__(self, missing_mask: Tensor, num_iterations: int = 40) -> None:\n        self.missing_mask = missing_mask\n        self.num_iterations = num_iterations\n\n    def forward(self, data: Data) -> Data:\n        assert data.x is not None\n        assert data.edge_index is not None or data.adj_t is not None\n\n        assert data.x.size() == self.missing_mask.size()\n        gcn_norm = torch_geometric.nn.conv.gcn_conv.gcn_norm\n\n        missing_mask = self.missing_mask.to(data.x.device)\n        known_mask = ~missing_mask\n\n        if data.edge_index is not None:\n            edge_weight = data.edge_attr\n            if 'edge_weight' in data:\n                edge_weight = data.edge_weight\n            adj_t = to_torch_csc_tensor(\n                edge_index=data.edge_index,\n                edge_attr=edge_weight,\n                size=data.size(0),\n            ).t()\n            adj_t, _ = gcn_norm(adj_t, add_self_loops=False)\n        elif is_torch_sparse_tensor(data.adj_t):\n            adj_t, _ = gcn_norm(data.adj_t, add_self_loops=False)\n        else:\n            adj_t = gcn_norm(data.adj_t, add_self_loops=False)\n\n        x = data.x.clone()\n        x[missing_mask] = 0.\n\n        out = x\n        for _ in range(self.num_iterations):\n            out = adj_t @ out\n            out[known_mask] = x[known_mask]  # Reset.\n        data.x = out\n\n        return data\n\n    def __repr__(self) -> str:\n        na_values = int(self.missing_mask.sum()) / self.missing_mask.numel()\n        return (f'{self.__class__.__name__}('\n                f'missing_features={100 * na_values:.1f}%, '\n                f'num_iterations={self.num_iterations})')\n"
  },
  {
    "path": "torch_geometric/transforms/fixed_points.py",
    "content": "import math\nimport re\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('fixed_points')\nclass FixedPoints(BaseTransform):\n    r\"\"\"Samples a fixed number of points and features from a point cloud\n    (functional name: :obj:`fixed_points`).\n\n    Args:\n        num (int): The number of points to sample.\n        replace (bool, optional): If set to :obj:`False`, samples points\n            without replacement. (default: :obj:`True`)\n        allow_duplicates (bool, optional): In case :obj:`replace` is\n            :obj`False` and :obj:`num` is greater than the number of points,\n            this option determines whether to add duplicated nodes to the\n            output points or not.\n            In case :obj:`allow_duplicates` is :obj:`False`, the number of\n            output points might be smaller than :obj:`num`.\n            In case :obj:`allow_duplicates` is :obj:`True`, the number of\n            duplicated points are kept to a minimum. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        num: int,\n        replace: bool = True,\n        allow_duplicates: bool = False,\n    ):\n        self.num = num\n        self.replace = replace\n        self.allow_duplicates = allow_duplicates\n\n    def forward(self, data: Data) -> Data:\n        num_nodes = data.num_nodes\n        assert num_nodes is not None\n\n        if self.replace:\n            choice = torch.from_numpy(\n                np.random.choice(num_nodes, self.num, replace=True)).long()\n        elif not self.allow_duplicates:\n            choice = torch.randperm(num_nodes)[:self.num]\n        else:\n            choice = torch.cat([\n                torch.randperm(num_nodes)\n                for _ in range(math.ceil(self.num / num_nodes))\n            ], dim=0)[:self.num]\n\n        for key, value in data.items():\n            if key == 'num_nodes':\n                data.num_nodes = choice.size(0)\n            elif bool(re.search('edge', key)):\n                continue\n            elif (isinstance(value, Tensor) and value.size(0) == num_nodes\n                  and value.size(0) != 1):\n                data[key] = value[choice]\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.num}, replace={self.replace})'\n"
  },
  {
    "path": "torch_geometric/transforms/gcn_norm.py",
    "content": "import torch_geometric\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('gcn_norm')\nclass GCNNorm(BaseTransform):\n    r\"\"\"Applies the GCN normalization from the `\"Semi-supervised Classification\n    with Graph Convolutional Networks\" <https://arxiv.org/abs/1609.02907>`_\n    paper (functional name: :obj:`gcn_norm`).\n\n    .. math::\n        \\mathbf{\\hat{A}} = \\mathbf{\\hat{D}}^{-1/2} (\\mathbf{A} + \\mathbf{I})\n        \\mathbf{\\hat{D}}^{-1/2}\n\n    where :math:`\\hat{D}_{ii} = \\sum_{j=0} \\hat{A}_{ij} + 1`.\n    \"\"\"\n    def __init__(self, add_self_loops: bool = True):\n        self.add_self_loops = add_self_loops\n\n    def forward(self, data: Data) -> Data:\n        gcn_norm = torch_geometric.nn.conv.gcn_conv.gcn_norm\n        assert 'edge_index' in data or 'adj_t' in data\n\n        if 'edge_index' in data:\n            data.edge_index, data.edge_weight = gcn_norm(\n                data.edge_index, data.edge_weight, data.num_nodes,\n                add_self_loops=self.add_self_loops)\n        else:\n            data.adj_t = gcn_norm(data.adj_t,\n                                  add_self_loops=self.add_self_loops)\n\n        return data\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}('\n                f'add_self_loops={self.add_self_loops})')\n"
  },
  {
    "path": "torch_geometric/transforms/gdc.py",
    "content": "from typing import Any, Dict, Optional, Tuple\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import (\n    add_self_loops,\n    coalesce,\n    get_ppr,\n    is_undirected,\n    scatter,\n    sort_edge_index,\n    to_dense_adj,\n)\n\n\n@functional_transform('gdc')\nclass GDC(BaseTransform):\n    r\"\"\"Processes the graph via Graph Diffusion Convolution (GDC) from the\n    `\"Diffusion Improves Graph Learning\" <https://arxiv.org/abs/1911.05485>`_\n    paper (functional name: :obj:`gdc`).\n\n    .. note::\n\n        The paper offers additional advice on how to choose the\n        hyperparameters.\n        For an example of using GCN with GDC, see `examples/gcn.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        gcn.py>`_.\n\n    Args:\n        self_loop_weight (float, optional): Weight of the added self-loop.\n            Set to :obj:`None` to add no self-loops. (default: :obj:`1`)\n        normalization_in (str, optional): Normalization of the transition\n            matrix on the original (input) graph. Possible values:\n            :obj:`\"sym\"`, :obj:`\"col\"`, and :obj:`\"row\"`.\n            See :func:`GDC.transition_matrix` for details.\n            (default: :obj:`\"sym\"`)\n        normalization_out (str, optional): Normalization of the transition\n            matrix on the transformed GDC (output) graph. Possible values:\n            :obj:`\"sym\"`, :obj:`\"col\"`, :obj:`\"row\"`, and :obj:`None`.\n            See :func:`GDC.transition_matrix` for details.\n            (default: :obj:`\"col\"`)\n        diffusion_kwargs (dict, optional): Dictionary containing the parameters\n            for diffusion.\n            `method` specifies the diffusion method (:obj:`\"ppr\"`,\n            :obj:`\"heat\"` or :obj:`\"coeff\"`).\n            Each diffusion method requires different additional parameters.\n            See :func:`GDC.diffusion_matrix_exact` or\n            :func:`GDC.diffusion_matrix_approx` for details.\n            (default: :obj:`dict(method='ppr', alpha=0.15)`)\n        sparsification_kwargs (dict, optional): Dictionary containing the\n            parameters for sparsification.\n            `method` specifies the sparsification method (:obj:`\"threshold\"` or\n            :obj:`\"topk\"`).\n            Each sparsification method requires different additional\n            parameters.\n            See :func:`GDC.sparsify_dense` for details.\n            (default: :obj:`dict(method='threshold', avg_degree=64)`)\n        exact (bool, optional): Whether to exactly calculate the diffusion\n            matrix.\n            Note that the exact variants are not scalable.\n            They densify the adjacency matrix and calculate either its inverse\n            or its matrix exponential.\n            However, the approximate variants do not support edge weights and\n            currently only personalized PageRank and sparsification by\n            threshold are implemented as fast, approximate versions.\n            (default: :obj:`True`)\n\n    :rtype: :class:`torch_geometric.data.Data`\n    \"\"\"\n    def __init__(\n        self,\n        self_loop_weight: float = 1.,\n        normalization_in: str = 'sym',\n        normalization_out: str = 'col',\n        diffusion_kwargs: Optional[Dict[str, Any]] = None,\n        sparsification_kwargs: Optional[Dict[str, Any]] = None,\n        exact: bool = True,\n    ) -> None:\n        self.self_loop_weight = self_loop_weight\n        self.normalization_in = normalization_in\n        self.normalization_out = normalization_out\n        self.diffusion_kwargs = diffusion_kwargs or dict(\n            method='ppr', alpha=0.15)\n        self.sparsification_kwargs = sparsification_kwargs or dict(\n            method='threshold', avg_degree=64)\n        self.exact = exact\n\n        if self_loop_weight:\n            assert exact or self_loop_weight == 1\n\n    @torch.no_grad()\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n        edge_index = data.edge_index\n        N = data.num_nodes\n        assert N is not None\n\n        if data.edge_attr is None:\n            edge_weight = torch.ones(edge_index.size(1),\n                                     device=edge_index.device)\n        else:\n            edge_weight = data.edge_attr\n            assert self.exact\n            assert edge_weight.dim() == 1\n\n        if self.self_loop_weight:\n            edge_index, edge_weight = add_self_loops(\n                edge_index, edge_weight, fill_value=self.self_loop_weight,\n                num_nodes=N)\n\n        edge_index, edge_weight = coalesce(edge_index, edge_weight, N)\n\n        if self.exact:\n            edge_index, edge_weight = self.transition_matrix(\n                edge_index, edge_weight, N, self.normalization_in)\n            diff_mat = self.diffusion_matrix_exact(edge_index, edge_weight, N,\n                                                   **self.diffusion_kwargs)\n            edge_index, edge_weight = self.sparsify_dense(\n                diff_mat, **self.sparsification_kwargs)\n        else:\n            edge_index, edge_weight = self.diffusion_matrix_approx(\n                edge_index, edge_weight, N, self.normalization_in,\n                **self.diffusion_kwargs)\n            edge_index, edge_weight = self.sparsify_sparse(\n                edge_index, edge_weight, N, **self.sparsification_kwargs)\n\n        edge_index, edge_weight = coalesce(edge_index, edge_weight, N)\n        edge_index, edge_weight = self.transition_matrix(\n            edge_index, edge_weight, N, self.normalization_out)\n\n        data.edge_index = edge_index\n        data.edge_attr = edge_weight\n\n        return data\n\n    def transition_matrix(\n        self,\n        edge_index: Tensor,\n        edge_weight: Tensor,\n        num_nodes: int,\n        normalization: str,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Calculate the approximate, sparse diffusion on a given sparse\n        matrix.\n\n        Args:\n            edge_index (LongTensor): The edge indices.\n            edge_weight (Tensor): One-dimensional edge weights.\n            num_nodes (int): Number of nodes.\n            normalization (str): Normalization scheme:\n\n                1. :obj:`\"sym\"`: Symmetric normalization\n                   :math:`\\mathbf{T} = \\mathbf{D}^{-1/2} \\mathbf{A}\n                   \\mathbf{D}^{-1/2}`.\n                2. :obj:`\"col\"`: Column-wise normalization\n                   :math:`\\mathbf{T} = \\mathbf{A} \\mathbf{D}^{-1}`.\n                3. :obj:`\"row\"`: Row-wise normalization\n                   :math:`\\mathbf{T} = \\mathbf{D}^{-1} \\mathbf{A}`.\n                4. :obj:`None`: No normalization.\n\n        :rtype: (:class:`LongTensor`, :class:`Tensor`)\n        \"\"\"\n        if normalization == 'sym':\n            row, col = edge_index\n            deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum')\n            deg_inv_sqrt = deg.pow(-0.5)\n            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0\n            edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]\n        elif normalization == 'col':\n            _, col = edge_index\n            deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum')\n            deg_inv = 1. / deg\n            deg_inv[deg_inv == float('inf')] = 0\n            edge_weight = edge_weight * deg_inv[col]\n        elif normalization == 'row':\n            row, _ = edge_index\n            deg = scatter(edge_weight, row, 0, num_nodes, reduce='sum')\n            deg_inv = 1. / deg\n            deg_inv[deg_inv == float('inf')] = 0\n            edge_weight = edge_weight * deg_inv[row]\n        elif normalization is None:\n            pass\n        else:\n            raise ValueError(\n                f\"Transition matrix normalization '{normalization}' unknown\")\n\n        return edge_index, edge_weight\n\n    def diffusion_matrix_exact(  # noqa: D417\n        self,\n        edge_index: Tensor,\n        edge_weight: Tensor,\n        num_nodes: int,\n        method: str,\n        **kwargs: Any,\n    ) -> Tensor:\n        r\"\"\"Calculate the (dense) diffusion on a given sparse graph.\n        Note that these exact variants are not scalable. They densify the\n        adjacency matrix and calculate either its inverse or its matrix\n        exponential.\n\n        Args:\n            edge_index (LongTensor): The edge indices.\n            edge_weight (Tensor): One-dimensional edge weights.\n            num_nodes (int): Number of nodes.\n            method (str): Diffusion method:\n\n                1. :obj:`\"ppr\"`: Use personalized PageRank as diffusion.\n                   Additionally expects the parameter:\n\n                   - **alpha** (*float*) - Return probability in PPR.\n                     Commonly lies in :obj:`[0.05, 0.2]`.\n\n                2. :obj:`\"heat\"`: Use heat kernel diffusion.\n                   Additionally expects the parameter:\n\n                   - **t** (*float*) - Time of diffusion. Commonly lies in\n                     :obj:`[2, 10]`.\n\n                3. :obj:`\"coeff\"`: Freely choose diffusion coefficients.\n                   Additionally expects the parameter:\n\n                   - **coeffs** (*List[float]*) - List of coefficients\n                     :obj:`theta_k` for each power of the transition matrix\n                     (starting at :obj:`0`).\n\n        :rtype: (:class:`Tensor`)\n        \"\"\"\n        if method == 'ppr':\n            # α (I_n + (α - 1) A)^-1\n            edge_weight = (kwargs['alpha'] - 1) * edge_weight\n            edge_index, edge_weight = add_self_loops(edge_index, edge_weight,\n                                                     fill_value=1,\n                                                     num_nodes=num_nodes)\n            mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()\n            diff_matrix = kwargs['alpha'] * torch.inverse(mat)\n\n        elif method == 'heat':\n            # exp(t (A - I_n))\n            edge_index, edge_weight = add_self_loops(edge_index, edge_weight,\n                                                     fill_value=-1,\n                                                     num_nodes=num_nodes)\n            edge_weight = kwargs['t'] * edge_weight\n            mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()\n            undirected = is_undirected(edge_index, edge_weight, num_nodes)\n            diff_matrix = self.__expm__(mat, undirected)\n\n        elif method == 'coeff':\n            adj_matrix = to_dense_adj(edge_index,\n                                      edge_attr=edge_weight).squeeze()\n            mat = torch.eye(num_nodes, device=edge_index.device)\n\n            diff_matrix = kwargs['coeffs'][0] * mat\n            for coeff in kwargs['coeffs'][1:]:\n                mat = mat @ adj_matrix\n                diff_matrix += coeff * mat\n        else:\n            raise ValueError(f\"Exact GDC diffusion '{method}' unknown\")\n\n        return diff_matrix\n\n    def diffusion_matrix_approx(  # noqa: D417\n        self,\n        edge_index: Tensor,\n        edge_weight: Tensor,\n        num_nodes: int,\n        normalization: str,\n        method: str,\n        **kwargs: Any,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Calculate the approximate, sparse diffusion on a given sparse\n        graph.\n\n        Args:\n            edge_index (LongTensor): The edge indices.\n            edge_weight (Tensor): One-dimensional edge weights.\n            num_nodes (int): Number of nodes.\n            normalization (str): Transition matrix normalization scheme\n                (:obj:`\"sym\"`, :obj:`\"row\"`, or :obj:`\"col\"`).\n                See :func:`GDC.transition_matrix` for details.\n            method (str): Diffusion method:\n\n                1. :obj:`\"ppr\"`: Use personalized PageRank as diffusion.\n                   Additionally expects the parameters:\n\n                   - **alpha** (*float*) - Return probability in PPR.\n                     Commonly lies in :obj:`[0.05, 0.2]`.\n\n                   - **eps** (*float*) - Threshold for PPR calculation stopping\n                     criterion (:obj:`edge_weight >= eps * out_degree`).\n                     Recommended default: :obj:`1e-4`.\n\n        :rtype: (:class:`LongTensor`, :class:`Tensor`)\n        \"\"\"\n        if method == 'ppr':\n            if normalization == 'sym':\n                # Calculate original degrees.\n                _, col = edge_index\n                deg = scatter(edge_weight, col, 0, num_nodes, reduce='sum')\n\n            edge_index, edge_weight = get_ppr(\n                edge_index,\n                alpha=kwargs['alpha'],\n                eps=kwargs['eps'],\n                num_nodes=num_nodes,\n            )\n\n            if normalization == 'col':\n                edge_index, edge_weight = sort_edge_index(\n                    edge_index.flip([0]), edge_weight, num_nodes)\n\n            if normalization == 'sym':\n                # We can change the normalization from row-normalized to\n                # symmetric by multiplying the resulting matrix with D^{1/2}\n                # from the left and D^{-1/2} from the right.\n                # Since we use the original degrees for this it will be like\n                # we had used symmetric normalization from the beginning\n                # (except for errors due to approximation).\n                row, col = edge_index\n                deg_inv = deg.sqrt()\n                deg_inv_sqrt = deg.pow(-0.5)\n                deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0\n                edge_weight = deg_inv[row] * edge_weight * deg_inv_sqrt[col]\n            elif normalization in ['col', 'row']:\n                pass\n            else:\n                raise ValueError(\n                    f\"Transition matrix normalization '{normalization}' not \"\n                    f\"implemented for non-exact GDC computation\")\n\n        elif method == 'heat':\n            raise NotImplementedError(\n                'Currently no fast heat kernel is implemented. You are '\n                'welcome to create one yourself, e.g., based on '\n                '\"Kloster and Gleich: Heat kernel based community detection '\n                '(KDD 2014).\"')\n        else:\n            raise ValueError(f\"Approximate GDC diffusion '{method}' unknown\")\n\n        return edge_index, edge_weight\n\n    def sparsify_dense(  # noqa: D417\n        self,\n        matrix: Tensor,\n        method: str,\n        **kwargs: Any,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Sparsifies the given dense matrix.\n\n        Args:\n            matrix (Tensor): Matrix to sparsify.\n            method (str): Method of sparsification. Options:\n\n                1. :obj:`\"threshold\"`: Remove all edges with weights smaller\n                   than :obj:`eps`.\n                   Additionally expects one of these parameters:\n\n                   - **eps** (*float*) - Threshold to bound edges at.\n\n                   - **avg_degree** (*int*) - If :obj:`eps` is not given,\n                     it can optionally be calculated by calculating the\n                     :obj:`eps` required to achieve a given :obj:`avg_degree`.\n\n                2. :obj:`\"topk\"`: Keep edges with top :obj:`k` edge weights per\n                   node (column).\n                   Additionally expects the following parameters:\n\n                   - **k** (*int*) - Specifies the number of edges to keep.\n\n                   - **dim** (*int*) - The axis along which to take the top\n                     :obj:`k`.\n\n        :rtype: (:class:`LongTensor`, :class:`Tensor`)\n        \"\"\"\n        assert matrix.shape[0] == matrix.shape[1]\n        N = matrix.shape[1]\n\n        if method == 'threshold':\n            if 'eps' not in kwargs.keys():\n                kwargs['eps'] = self.__calculate_eps__(matrix, N,\n                                                       kwargs['avg_degree'])\n\n            edge_index = (matrix >= kwargs['eps']).nonzero(as_tuple=False).t()\n            edge_index_flat = edge_index[0] * N + edge_index[1]\n            edge_weight = matrix.flatten()[edge_index_flat]\n\n        elif method == 'topk':\n            k, dim = min(N, kwargs['k']), kwargs['dim']\n            assert dim in [0, 1]\n            sort_idx = torch.argsort(matrix, dim=dim, descending=True)\n            if dim == 0:\n                top_idx = sort_idx[:k]\n                edge_weight = torch.gather(matrix, dim=dim,\n                                           index=top_idx).flatten()\n\n                row_idx = torch.arange(0, N, device=matrix.device).repeat(k)\n                edge_index = torch.stack([top_idx.flatten(), row_idx], dim=0)\n            else:\n                top_idx = sort_idx[:, :k]\n                edge_weight = torch.gather(matrix, dim=dim,\n                                           index=top_idx).flatten()\n\n                col_idx = torch.arange(\n                    0, N, device=matrix.device).repeat_interleave(k)\n                edge_index = torch.stack([col_idx, top_idx.flatten()], dim=0)\n        else:\n            raise ValueError(f\"GDC sparsification '{method}' unknown\")\n\n        return edge_index, edge_weight\n\n    def sparsify_sparse(  # noqa: D417\n        self,\n        edge_index: Tensor,\n        edge_weight: Tensor,\n        num_nodes: int,\n        method: str,\n        **kwargs: Any,\n    ) -> Tuple[Tensor, Tensor]:\n        r\"\"\"Sparsifies a given sparse graph further.\n\n        Args:\n            edge_index (torch.Tensor): The edge indices.\n            edge_weight (torch.Tensor): One-dimensional edge weights.\n            num_nodes (int): Number of nodes.\n            method (str): Method of sparsification:\n\n                1. :obj:`\"threshold\"`: Remove all edges with weights smaller\n                   than :obj:`eps`.\n                   Additionally expects one of these parameters:\n\n                   - **eps** (*float*) - Threshold to bound edges at.\n\n                   - **avg_degree** (*int*) - If :obj:`eps` is not given,\n                     it can optionally be calculated by calculating the\n                     :obj:`eps` required to achieve a given :obj:`avg_degree`.\n\n        :rtype: (:class:`LongTensor`, :class:`Tensor`)\n        \"\"\"\n        if method == 'threshold':\n            if 'eps' not in kwargs.keys():\n                kwargs['eps'] = self.__calculate_eps__(\n                    edge_weight,\n                    num_nodes,\n                    kwargs['avg_degree'],\n                )\n\n            remaining_edge_idx = (edge_weight >= kwargs['eps']).nonzero(\n                as_tuple=False).flatten()\n            edge_index = edge_index[:, remaining_edge_idx]\n            edge_weight = edge_weight[remaining_edge_idx]\n        elif method == 'topk':\n            raise NotImplementedError(\n                'Sparse topk sparsification not implemented')\n        else:\n            raise ValueError(f\"GDC sparsification '{method}' unknown\")\n\n        return edge_index, edge_weight\n\n    def __expm__(self, matrix: Tensor, symmetric: bool) -> Tensor:\n        r\"\"\"Calculates matrix exponential.\n\n        Args:\n            matrix (Tensor): Matrix to take exponential of.\n            symmetric (bool): Specifies whether the matrix is symmetric.\n\n        :rtype: (:class:`Tensor`)\n        \"\"\"\n        from scipy.linalg import expm\n\n        if symmetric:\n            e, V = torch.linalg.eigh(matrix, UPLO='U')\n            diff_mat = V @ torch.diag(e.exp()) @ V.t()\n        else:\n            diff_mat = torch.from_numpy(expm(matrix.cpu().numpy()))\n            diff_mat = diff_mat.to(matrix.device, matrix.dtype)\n        return diff_mat\n\n    def __calculate_eps__(\n        self,\n        matrix: Tensor,\n        num_nodes: int,\n        avg_degree: int,\n    ) -> float:\n        r\"\"\"Calculates threshold necessary to achieve a given average degree.\n\n        Args:\n            matrix (Tensor): Adjacency matrix or edge weights.\n            num_nodes (int): Number of nodes.\n            avg_degree (int): Target average degree.\n\n        :rtype: (:class:`float`)\n        \"\"\"\n        sorted_edges = torch.sort(matrix.flatten(), descending=True).values\n        if avg_degree * num_nodes > len(sorted_edges):\n            return -np.inf\n\n        left = sorted_edges[avg_degree * num_nodes - 1]\n        right = sorted_edges[avg_degree * num_nodes]\n        return float(left + right) / 2.0\n"
  },
  {
    "path": "torch_geometric/transforms/generate_mesh_normals.py",
    "content": "import torch.nn.functional as F\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import scatter\n\n\n@functional_transform('generate_mesh_normals')\nclass GenerateMeshNormals(BaseTransform):\n    r\"\"\"Generate normal vectors for each mesh node based on neighboring\n    faces (functional name: :obj:`generate_mesh_normals`).\n    \"\"\"\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        assert data.face is not None\n        pos, face = data.pos, data.face\n\n        vec1 = pos[face[1]] - pos[face[0]]\n        vec2 = pos[face[2]] - pos[face[0]]\n        face_norm = F.normalize(vec1.cross(vec2, dim=1), p=2, dim=-1)  # [F, 3]\n\n        face_norm = face_norm.repeat(3, 1)\n        idx = face.view(-1)\n\n        norm = scatter(face_norm, idx, 0, pos.size(0), reduce='sum')\n        norm = F.normalize(norm, p=2, dim=-1)  # [N, 3]\n\n        data.norm = norm\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/grid_sampling.py",
    "content": "import re\nfrom typing import List, Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import one_hot, scatter\n\n\n@functional_transform('grid_sampling')\nclass GridSampling(BaseTransform):\n    r\"\"\"Clusters points into fixed-sized voxels\n    (functional name: :obj:`grid_sampling`).\n    Each cluster returned is a new point based on the mean of all points\n    inside the given cluster.\n\n    Args:\n        size (float or [float] or Tensor): Size of a voxel (in each dimension).\n        start (float or [float] or Tensor, optional): Start coordinates of the\n            grid (in each dimension). If set to :obj:`None`, will be set to the\n            minimum coordinates found in :obj:`data.pos`.\n            (default: :obj:`None`)\n        end (float or [float] or Tensor, optional): End coordinates of the grid\n            (in each dimension). If set to :obj:`None`, will be set to the\n            maximum coordinates found in :obj:`data.pos`.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        size: Union[float, List[float], Tensor],\n        start: Optional[Union[float, List[float], Tensor]] = None,\n        end: Optional[Union[float, List[float], Tensor]] = None,\n    ) -> None:\n        self.size = size\n        self.start = start\n        self.end = end\n\n    def forward(self, data: Data) -> Data:\n        num_nodes = data.num_nodes\n\n        assert data.pos is not None\n        c = torch_geometric.nn.voxel_grid(data.pos, self.size, data.batch,\n                                          self.start, self.end)\n        c, perm = torch_geometric.nn.pool.consecutive.consecutive_cluster(c)\n\n        for key, item in data.items():\n            if bool(re.search('edge', key)):\n                raise ValueError(f\"'{self.__class__.__name__}' does not \"\n                                 f\"support coarsening of edges\")\n\n            if torch.is_tensor(item) and item.size(0) == num_nodes:\n                if key == 'y':\n                    item = scatter(one_hot(item), c, dim=0, reduce='sum')\n                    data[key] = item.argmax(dim=-1)\n                elif key == 'batch':\n                    data[key] = item[perm]\n                else:\n                    data[key] = scatter(item, c, dim=0, reduce='mean')\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(size={self.size})'\n"
  },
  {
    "path": "torch_geometric/transforms/half_hop.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('half_hop')\nclass HalfHop(BaseTransform):\n    r\"\"\"The graph upsampling augmentation from the\n    `\"Half-Hop: A Graph Upsampling Approach for Slowing Down Message Passing\"\n    <https://openreview.net/forum?id=lXczFIwQkv>`_ paper.\n    The graph is augmented by adding artificial slow nodes between neighbors\n    to slow down message propagation. (functional name: :obj:`half_hop`).\n\n    .. note::\n        :class:`HalfHop` augmentation is not supported if :obj:`data` has\n        :attr:`edge_weight` or :attr:`edge_attr`.\n\n    Args:\n        alpha (float, optional): The interpolation factor\n            used to compute slow node features\n            :math:`x = \\alpha*x_src + (1-\\alpha)*x_dst` (default: :obj:`0.5`)\n        p (float, optional): The probability of half-hopping\n            an edge. (default: :obj:`1.0`)\n\n    .. code-block:: python\n\n        import torch_geometric.transforms as T\n\n        transform = T.HalfHop(alpha=0.5)\n        data = transform(data)  # Apply transformation.\n        out = model(data.x, data.edge_index)  # Feed-forward.\n        out = out[~data.slow_node_mask]  # Get rid of slow nodes.\n    \"\"\"\n    def __init__(self, alpha: float = 0.5, p: float = 1.0) -> None:\n        if alpha < 0. or alpha > 1.:\n            raise ValueError(f\"Interpolation factor has to be between 0 and 1 \"\n                             f\"(got '{alpha}'\")\n        if p < 0. or p > 1.:\n            raise ValueError(f\"Ratio of half-hopped edges has to be between \"\n                             f\"0 and 1 (got '{p}'\")\n\n        self.p = p\n        self.alpha = alpha\n\n    def forward(self, data: Data) -> Data:\n        if data.edge_weight is not None or data.edge_attr is not None:\n            raise ValueError(\"'HalfHop' augmentation is not supported if \"\n                             \"'data' contains 'edge_weight' or 'edge_attr'\")\n\n        assert data.x is not None\n        assert data.edge_index is not None\n        x, edge_index = data.x, data.edge_index\n        num_nodes = data.num_nodes\n        assert num_nodes is not None\n\n        # isolate self loops which are not half-hopped\n        self_loop_mask = edge_index[0] == edge_index[1]\n        edge_index_self_loop = edge_index[:, self_loop_mask]\n        edge_index = edge_index[:, ~self_loop_mask]\n\n        # randomly sample nodes and half-hop their edges\n        node_mask = torch.rand(num_nodes, device=x.device) < self.p\n        edge_mask = node_mask[edge_index[1]]\n        edge_index_to_halfhop = edge_index[:, edge_mask]\n        edge_index_to_keep = edge_index[:, ~edge_mask]\n\n        # add new slow nodes of which features are initialized\n        # by linear interpolation\n        num_halfhop_edges = edge_index_to_halfhop.size(1)\n        slow_node_ids = torch.arange(num_halfhop_edges,\n                                     device=x.device) + num_nodes\n        x_src = x[edge_index_to_halfhop[0]]\n        x_dst = x[edge_index_to_halfhop[1]]\n        x_slow_node = self.alpha * x_src + (1 - self.alpha) * x_dst\n        new_x = torch.cat([x, x_slow_node], dim=0)\n\n        # add new edges between slow nodes and the original nodes\n        edge_index_slow = [\n            torch.stack([edge_index_to_halfhop[0], slow_node_ids]),\n            torch.stack([slow_node_ids, edge_index_to_halfhop[1]]),\n            torch.stack([edge_index_to_halfhop[1], slow_node_ids])\n        ]\n        new_edge_index = torch.cat(\n            [edge_index_to_keep, edge_index_self_loop, *edge_index_slow],\n            dim=1)\n\n        # prepare a mask that distinguishes between original nodes & slow nodes\n        slow_node_mask = torch.cat(\n            [x.new_zeros(x.size(0)),\n             x.new_ones(slow_node_ids.size(0))], dim=0).bool()\n\n        data.x, data.edge_index = new_x, new_edge_index\n        data.slow_node_mask = slow_node_mask\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(alpha={self.alpha}, p={self.p})'\n"
  },
  {
    "path": "torch_geometric/transforms/knn_graph.py",
    "content": "import torch_geometric\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import to_undirected\n\n\n@functional_transform('knn_graph')\nclass KNNGraph(BaseTransform):\n    r\"\"\"Creates a k-NN graph based on node positions :obj:`data.pos`\n    (functional name: :obj:`knn_graph`).\n\n    Args:\n        k (int, optional): The number of neighbors. (default: :obj:`6`)\n        loop (bool, optional): If :obj:`True`, the graph will contain\n            self-loops. (default: :obj:`False`)\n        force_undirected (bool, optional): If set to :obj:`True`, new edges\n            will be undirected. (default: :obj:`False`)\n        flow (str, optional): The flow direction when used in combination with\n            message passing (:obj:`\"source_to_target\"` or\n            :obj:`\"target_to_source\"`).\n            If set to :obj:`\"source_to_target\"`, every target node will have\n            exactly :math:`k` source nodes pointing to it.\n            (default: :obj:`\"source_to_target\"`)\n        cosine (bool, optional): If :obj:`True`, will use the cosine\n            distance instead of euclidean distance to find nearest neighbors.\n            (default: :obj:`False`)\n        num_workers (int): Number of workers to use for computation. Has no\n            effect in case :obj:`batch` is not :obj:`None`, or the input lies\n            on the GPU. (default: :obj:`1`)\n    \"\"\"\n    def __init__(\n        self,\n        k: int = 6,\n        loop: bool = False,\n        force_undirected: bool = False,\n        flow: str = 'source_to_target',\n        cosine: bool = False,\n        num_workers: int = 1,\n    ) -> None:\n        self.k = k\n        self.loop = loop\n        self.force_undirected = force_undirected\n        self.flow = flow\n        self.cosine = cosine\n        self.num_workers = num_workers\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n\n        edge_index = torch_geometric.nn.knn_graph(\n            data.pos,\n            self.k,\n            data.batch,\n            loop=self.loop,\n            flow=self.flow,\n            cosine=self.cosine,\n            num_workers=self.num_workers,\n        )\n\n        if self.force_undirected:\n            edge_index = to_undirected(edge_index, num_nodes=data.num_nodes)\n\n        data.edge_index = edge_index\n        data.edge_attr = None\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(k={self.k})'\n"
  },
  {
    "path": "torch_geometric/transforms/laplacian_lambda_max.py",
    "content": "from typing import Optional\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import get_laplacian, to_scipy_sparse_matrix\n\n\n@functional_transform('laplacian_lambda_max')\nclass LaplacianLambdaMax(BaseTransform):\n    r\"\"\"Computes the highest eigenvalue of the graph Laplacian given by\n    :meth:`torch_geometric.utils.get_laplacian`\n    (functional name: :obj:`laplacian_lambda_max`).\n\n    Args:\n        normalization (str, optional): The normalization scheme for the graph\n            Laplacian (default: :obj:`None`):\n\n            1. :obj:`None`: No normalization\n            :math:`\\mathbf{L} = \\mathbf{D} - \\mathbf{A}`\n\n            2. :obj:`\"sym\"`: Symmetric normalization\n            :math:`\\mathbf{L} = \\mathbf{I} - \\mathbf{D}^{-1/2} \\mathbf{A}\n            \\mathbf{D}^{-1/2}`\n\n            3. :obj:`\"rw\"`: Random-walk normalization\n            :math:`\\mathbf{L} = \\mathbf{I} - \\mathbf{D}^{-1} \\mathbf{A}`\n        is_undirected (bool, optional): If set to :obj:`True`, this transform\n            expects undirected graphs as input, and can hence speed up the\n            computation of the largest eigenvalue. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        normalization: Optional[str] = None,\n        is_undirected: bool = False,\n    ):\n        assert normalization in [None, 'sym', 'rw'], 'Invalid normalization'\n        self.normalization = normalization\n        self.is_undirected = is_undirected\n\n    def forward(self, data: Data) -> Data:\n        from scipy.sparse.linalg import eigs, eigsh\n\n        assert data.edge_index is not None\n        num_nodes = data.num_nodes\n\n        edge_weight = data.edge_attr\n        if edge_weight is not None and edge_weight.numel() != data.num_edges:\n            edge_weight = None\n\n        edge_index, edge_weight = get_laplacian(\n            data.edge_index,\n            edge_weight,\n            self.normalization,\n            num_nodes=num_nodes,\n        )\n\n        L = to_scipy_sparse_matrix(edge_index, edge_weight, num_nodes)\n\n        eig_fn = eigs\n        if self.is_undirected and self.normalization != 'rw':\n            eig_fn = eigsh\n\n        lambda_max = eig_fn(L, k=1, which='LM', return_eigenvectors=False)\n        data.lambda_max = lambda_max.real.item()\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(normalization={self.normalization})'\n"
  },
  {
    "path": "torch_geometric/transforms/largest_connected_components.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import to_scipy_sparse_matrix\n\n\n@functional_transform('largest_connected_components')\nclass LargestConnectedComponents(BaseTransform):\n    r\"\"\"Selects the subgraph that corresponds to the\n    largest connected components in the graph\n    (functional name: :obj:`largest_connected_components`).\n\n    Args:\n        num_components (int, optional): Number of largest components to keep\n            (default: :obj:`1`)\n        connection (str, optional): Type of connection to use for directed\n            graphs, can be either :obj:`'strong'` or :obj:`'weak'`.\n            Nodes `i` and `j` are strongly connected if a path\n            exists both from `i` to `j` and from `j` to `i`. A directed graph\n            is weakly connected if replacing all of its directed edges with\n            undirected edges produces a connected (undirected) graph.\n            (default: :obj:`'weak'`)\n    \"\"\"\n    def __init__(\n        self,\n        num_components: int = 1,\n        connection: str = 'weak',\n    ) -> None:\n        assert connection in ['strong', 'weak'], 'Unknown connection type'\n        self.num_components = num_components\n        self.connection = connection\n\n    def forward(self, data: Data) -> Data:\n        import numpy as np\n        import scipy.sparse as sp\n\n        assert data.edge_index is not None\n\n        adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes)\n\n        num_components, component = sp.csgraph.connected_components(\n            adj, connection=self.connection)\n\n        if num_components <= self.num_components:\n            return data\n\n        _, count = np.unique(component, return_counts=True)\n        subset_np = np.isin(component, count.argsort()[-self.num_components:])\n        subset = torch.from_numpy(subset_np)\n        subset = subset.to(data.edge_index.device, torch.bool)\n\n        return data.subgraph(subset)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.num_components})'\n"
  },
  {
    "path": "torch_geometric/transforms/line_graph.py",
    "content": "import torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import coalesce, cumsum, remove_self_loops, scatter\n\n\n@functional_transform('line_graph')\nclass LineGraph(BaseTransform):\n    r\"\"\"Converts a graph to its corresponding line-graph\n    (functional name: :obj:`line_graph`).\n\n    .. math::\n        L(\\mathcal{G}) &= (\\mathcal{V}^{\\prime}, \\mathcal{E}^{\\prime})\n\n        \\mathcal{V}^{\\prime} &= \\mathcal{E}\n\n        \\mathcal{E}^{\\prime} &= \\{ (e_1, e_2) : e_1 \\cap e_2 \\neq \\emptyset \\}\n\n    Line-graph node indices are equal to indices in the original graph's\n    coalesced :obj:`edge_index`.\n    For undirected graphs, the maximum line-graph node index is\n    :obj:`(data.edge_index.size(1) // 2) - 1`.\n\n    New node features are given by old edge attributes.\n    For undirected graphs, edge attributes for reciprocal edges\n    :obj:`(row, col)` and :obj:`(col, row)` get summed together.\n\n    Args:\n        force_directed (bool, optional): If set to :obj:`True`, the graph will\n            be always treated as a directed graph. (default: :obj:`False`)\n    \"\"\"\n    def __init__(self, force_directed: bool = False) -> None:\n        self.force_directed = force_directed\n\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n        edge_index, edge_attr = data.edge_index, data.edge_attr\n        N = data.num_nodes\n\n        edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes=N)\n        row, col = edge_index\n\n        if self.force_directed or data.is_directed():\n            i = torch.arange(row.size(0), dtype=torch.long, device=row.device)\n\n            count = scatter(torch.ones_like(row), row, dim=0,\n                            dim_size=data.num_nodes, reduce='sum')\n            ptr = cumsum(count)\n\n            cols = [i[ptr[col[j]]:ptr[col[j] + 1]] for j in range(col.size(0))]\n            rows = [row.new_full((c.numel(), ), j) for j, c in enumerate(cols)]\n\n            row, col = torch.cat(rows, dim=0), torch.cat(cols, dim=0)\n\n            data.edge_index = torch.stack([row, col], dim=0)\n            data.x = data.edge_attr\n            data.num_nodes = edge_index.size(1)\n\n        else:\n            # Compute node indices.\n            mask = row < col\n            row, col = row[mask], col[mask]\n            i = torch.arange(row.size(0), dtype=torch.long, device=row.device)\n\n            (row, col), i = coalesce(\n                torch.stack([\n                    torch.cat([row, col], dim=0),\n                    torch.cat([col, row], dim=0)\n                ], dim=0),\n                torch.cat([i, i], dim=0),\n                N,\n            )\n\n            # Compute new edge indices according to `i`.\n            count = scatter(torch.ones_like(row), row, dim=0,\n                            dim_size=data.num_nodes, reduce='sum')\n            joints = list(torch.split(i, count.tolist()))\n\n            def generate_grid(x: Tensor) -> Tensor:\n                row = x.view(-1, 1).repeat(1, x.numel()).view(-1)\n                col = x.repeat(x.numel())\n                return torch.stack([row, col], dim=0)\n\n            joints = [generate_grid(joint) for joint in joints]\n            joint = torch.cat(joints, dim=1)\n            joint, _ = remove_self_loops(joint)\n            N = row.size(0) // 2\n            joint = coalesce(joint, num_nodes=N)\n\n            if edge_attr is not None:\n                data.x = scatter(edge_attr, i, dim=0, dim_size=N, reduce='sum')\n            data.edge_index = joint\n            data.num_nodes = edge_index.size(1) // 2\n\n        data.edge_attr = None\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/linear_transformation.py",
    "content": "from typing import Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('linear_transformation')\nclass LinearTransformation(BaseTransform):\n    r\"\"\"Transforms node positions :obj:`data.pos` with a square transformation\n    matrix computed offline (functional name: :obj:`linear_transformation`).\n\n    Args:\n        matrix (Tensor): Tensor with shape :obj:`[D, D]` where :obj:`D`\n            corresponds to the dimensionality of node positions.\n    \"\"\"\n    def __init__(self, matrix: Tensor):\n        if not isinstance(matrix, Tensor):\n            matrix = torch.tensor(matrix)\n        assert matrix.dim() == 2, (\n            'Transformation matrix should be two-dimensional.')\n        assert matrix.size(0) == matrix.size(1), (\n            f'Transformation matrix should be square (got {matrix.size()})')\n\n        # Store the matrix as its transpose.\n        # We do this to enable post-multiplication in `forward`.\n        self.matrix = matrix.t()\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for store in data.node_stores:\n            if not hasattr(store, 'pos'):\n                continue\n\n            pos = store.pos.view(-1, 1) if store.pos.dim() == 1 else store.pos\n            assert pos.size(-1) == self.matrix.size(-2), (\n                'Node position matrix and transformation matrix have '\n                'incompatible shape')\n            # We post-multiply the points by the transformation matrix instead\n            # of pre-multiplying, because `pos` attribute has shape `[N, D]`,\n            # and we want to preserve this shape.\n            store.pos = pos @ self.matrix.to(pos.device, pos.dtype)\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(\\n{self.matrix.cpu().numpy()}\\n)'\n"
  },
  {
    "path": "torch_geometric/transforms/local_cartesian.py",
    "content": "from typing import Tuple\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import scatter\n\n\n@functional_transform('local_cartesian')\nclass LocalCartesian(BaseTransform):\n    r\"\"\"Saves the relative Cartesian coordinates of linked nodes in its edge\n    attributes (functional name: :obj:`local_cartesian`). Each coordinate gets\n    *neighborhood-normalized* to a specified interval\n    (:math:`[0, 1]` by default).\n\n    Args:\n        norm (bool, optional): If set to :obj:`False`, the output will not be\n            normalized. (default: :obj:`True`)\n        cat (bool, optional): If set to :obj:`False`, all existing edge\n            attributes will be replaced. (default: :obj:`True`)\n        interval ((float, float), optional): A tuple specifying the lower and\n            upper bound for normalization. (default: :obj:`(0.0, 1.0)`)\n    \"\"\"\n    def __init__(\n            self,\n            norm: bool = True,\n            cat: bool = True,\n            interval: Tuple[float, float] = (0.0, 1.0),\n    ):\n        self.norm = norm\n        self.cat = cat\n        self.interval = interval\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        assert data.edge_index is not None\n        (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr\n\n        cart = pos[row] - pos[col]\n        cart = cart.view(-1, 1) if cart.dim() == 1 else cart\n\n        if self.norm:\n            max_value = scatter(cart.abs(), col, 0, pos.size(0), reduce='max')\n            max_value = max_value.max(dim=-1, keepdim=True)[0]\n\n            length = self.interval[1] - self.interval[0]\n            center = (self.interval[0] + self.interval[1]) / 2\n            cart = length * cart / (2 * max_value[col]) + center\n\n        if pseudo is not None and self.cat:\n            pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo\n            data.edge_attr = torch.cat([pseudo, cart.type_as(pseudo)], dim=-1)\n        else:\n            data.edge_attr = cart\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/local_degree_profile.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import degree\n\n\n@functional_transform('local_degree_profile')\nclass LocalDegreeProfile(BaseTransform):\n    r\"\"\"Appends the Local Degree Profile (LDP) from the `\"A Simple yet\n    Effective Baseline for Non-attribute Graph Classification\"\n    <https://arxiv.org/abs/1811.03508>`_ paper\n    (functional name: :obj:`local_degree_profile`).\n\n    .. math::\n        \\mathbf{x}_i = \\mathbf{x}_i \\, \\Vert \\, (\\deg(i), \\min(DN(i)),\n        \\max(DN(i)), \\textrm{mean}(DN(i)), \\textrm{std}(DN(i)))\n\n    to the node features, where :math:`DN(i) = \\{ \\deg(j) \\mid j \\in\n    \\mathcal{N}(i) \\}`.\n    \"\"\"\n    def __init__(self) -> None:\n        from torch_geometric.nn.aggr.fused import FusedAggregation\n        self.aggr = FusedAggregation(['min', 'max', 'mean', 'std'])\n\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n        row, col = data.edge_index\n        num_nodes = data.num_nodes\n\n        deg = degree(row, num_nodes, dtype=torch.float).view(-1, 1)\n        xs = [deg] + self.aggr(deg[col], row, dim_size=num_nodes)\n\n        if data.x is not None:\n            data.x = data.x.view(-1, 1) if data.x.dim() == 1 else data.x\n            data.x = torch.cat([data.x] + xs, dim=-1)\n        else:\n            data.x = torch.cat(xs, dim=-1)\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/mask.py",
    "content": "from typing import List, Optional, Sequence, Union\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.data.storage import BaseStorage\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import index_to_mask, mask_to_index\n\nAnyData = Union[Data, HeteroData]\n\n\ndef get_attrs_with_suffix(\n    attrs: Optional[List[str]],\n    store: BaseStorage,\n    suffix: str,\n) -> List[str]:\n    if attrs is not None:\n        return attrs\n    return [key for key in store.keys() if key.endswith(suffix)]\n\n\ndef get_mask_size(\n    attr: str,\n    store: BaseStorage,\n    size: Optional[int],\n) -> Optional[int]:\n    if size is not None:\n        return size\n    return store.num_edges if store.is_edge_attr(attr) else store.num_nodes\n\n\n@functional_transform('index_to_mask')\nclass IndexToMask(BaseTransform):\n    r\"\"\"Converts indices to a mask representation\n    (functional name: :obj:`index_to_mask`).\n\n    Args:\n        attrs (str, [str], optional): If given, will only perform index to mask\n            conversion for the given attributes. If omitted, will infer the\n            attributes from the suffix :obj:`_index`. (default: :obj:`None`)\n        sizes (int, [int], optional): The size of the mask. If set to\n            :obj:`None`, an automatically sized tensor is returned. The number\n            of nodes will be used by default, except for edge attributes which\n            will use the number of edges as the mask size.\n            (default: :obj:`None`)\n        replace (bool, optional): if set to :obj:`True` replaces the index\n            attributes with mask tensors. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        attrs: Optional[Union[str, List[str]]] = None,\n        sizes: Optional[Union[int, List[int]]] = None,\n        replace: bool = False,\n    ) -> None:\n        self.attrs = [attrs] if isinstance(attrs, str) else attrs\n        self.sizes = sizes\n        self.replace = replace\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for store in data.stores:\n            attrs = get_attrs_with_suffix(self.attrs, store, '_index')\n\n            sizes: Sequence[Optional[int]]\n            if isinstance(self.sizes, int):\n                sizes = [self.sizes] * len(attrs)\n            elif isinstance(self.sizes, (list, tuple)):\n                if len(attrs) != len(self.sizes):\n                    raise ValueError(\n                        f\"The number of attributes (got {len(attrs)}) must \"\n                        f\"match the number of sizes provided \"\n                        f\"(got {len(self.sizes)})\")\n                sizes = self.sizes\n            else:\n                sizes = [None] * len(attrs)\n\n            for attr, size in zip(attrs, sizes):\n                if 'edge_index' in attr:\n                    continue\n                if attr not in store:\n                    continue\n                size = get_mask_size(attr, store, size)\n                mask = index_to_mask(store[attr], size=size)\n                store[f'{attr[:-6]}_mask'] = mask\n                if self.replace:\n                    del store[attr]\n\n        return data\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(attrs={self.attrs}, '\n                f'sizes={self.sizes}, replace={self.replace})')\n\n\n@functional_transform('mask_to_index')\nclass MaskToIndex(BaseTransform):\n    r\"\"\"Converts a mask to an index representation\n    (functional name: :obj:`mask_to_index`).\n\n    Args:\n        attrs (str, [str], optional): If given, will only perform mask to index\n            conversion for the given attributes.  If omitted, will infer the\n            attributes from the suffix :obj:`_mask` (default: :obj:`None`)\n        replace (bool, optional): if set to :obj:`True` replaces the mask\n            attributes with index tensors. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        attrs: Optional[Union[str, List[str]]] = None,\n        replace: bool = False,\n    ):\n        self.attrs = [attrs] if isinstance(attrs, str) else attrs\n        self.replace = replace\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for store in data.stores:\n            attrs = get_attrs_with_suffix(self.attrs, store, '_mask')\n\n            for attr in attrs:\n                if attr not in store:\n                    continue\n                index = mask_to_index(store[attr])\n                store[f'{attr[:-5]}_index'] = index\n                if self.replace:\n                    del store[attr]\n\n        return data\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(attrs={self.attrs}, '\n                f'replace={self.replace})')\n"
  },
  {
    "path": "torch_geometric/transforms/node_property_split.py",
    "content": "from typing import Any, Dict, List\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import to_networkx\n\n\n@functional_transform('node_property_split')\nclass NodePropertySplit(BaseTransform):\n    r\"\"\"Creates a node-level split with distributional shift based on a given\n    node property, as proposed in the `\"Evaluating Robustness and Uncertainty\n    of Graph Models Under Structural Distributional Shifts\"\n    <https://arxiv.org/abs/2302.13875>`__ paper\n    (functional name: :obj:`node_property_split`).\n\n    It splits the nodes in a given graph into five non-intersecting parts\n    based on their structural properties.\n    This can be used for transductive node prediction tasks with distributional\n    shifts.\n    It considers the in-distribution (ID) and out-of-distribution (OOD) subsets\n    of nodes.\n    The ID subset includes training, validation and testing parts, while\n    the OOD subset includes validation and testing parts.\n    As a result, it creates five associated node mask vectors for each graph,\n    three which are for the ID nodes (:obj:`id_train_mask`,\n    :obj:`id_val_mask`, :obj:`id_test_mask`), and two which are for the OOD\n    nodes (:obj:`ood_val_mask`, :obj:`ood_test_mask`).\n\n    This class implements three particular strategies for inducing\n    distributional shifts in a graph — based on **popularity**, **locality**\n    or **density**.\n\n    Args:\n        property_name (str): The name of the node property to be used\n            (:obj:`\"popularity\"`, :obj:`\"locality\"`, :obj:`\"density\"`).\n        ratios ([float]): A list of five ratio values for ID training,\n            ID validation, ID test, OOD validation and OOD test parts.\n            The values must sum to :obj:`1.0`.\n        ascending (bool, optional): Whether to sort nodes in ascending order\n            of the node property, so that nodes with greater values of the\n            property are considered to be OOD (default: :obj:`True`)\n\n    .. code-block:: python\n\n        from torch_geometric.transforms import NodePropertySplit\n        from torch_geometric.datasets.graph_generator import ERGraph\n\n        data = ERGraph(num_nodes=1000, edge_prob=0.4)()\n\n        property_name = 'popularity'\n        ratios = [0.3, 0.1, 0.1, 0.3, 0.2]\n        transform = NodePropertySplit(property_name, ratios)\n\n        data = transform(data)\n    \"\"\"\n    def __init__(\n        self,\n        property_name: str,\n        ratios: List[float],\n        ascending: bool = True,\n    ):\n        if property_name not in {'popularity', 'locality', 'density'}:\n            raise ValueError(f\"Unexpected 'property_name' \"\n                             f\"(got '{property_name}')\")\n\n        if len(ratios) != 5:\n            raise ValueError(f\"'ratios' must contain 5 values \"\n                             f\"(got {len(ratios)})\")\n\n        if sum(ratios) != 1.0:\n            raise ValueError(f\"'ratios' must sum to 1.0 (got {sum(ratios)})\")\n\n        self.property_name = property_name\n        self.compute_fn = _property_name_to_compute_fn[property_name]\n        self.ratios = ratios\n        self.ascending = ascending\n\n    def forward(self, data: Data) -> Data:\n        G = to_networkx(data, to_undirected=True, remove_self_loops=True)\n        property_values = self.compute_fn(G, self.ascending)\n        mask_dict = self._mask_nodes_by_property(property_values, self.ratios)\n\n        for key, mask in mask_dict.items():\n            data[key] = mask\n\n        return data\n\n    @staticmethod\n    def _compute_popularity_property(G: Any, ascending: bool = True) -> Tensor:\n        import networkx.algorithms as A\n\n        property_values = torch.tensor(list(A.pagerank(G).values()))\n        property_values *= -1 if ascending else 1\n        return property_values\n\n    @staticmethod\n    def _compute_locality_property(G: Any, ascending: bool = True) -> Tensor:\n        import networkx.algorithms as A\n\n        pagerank_values = torch.tensor(list(A.pagerank(G).values()))\n\n        num_nodes = G.number_of_nodes()\n        personalization = dict(zip(range(num_nodes), [0.0] * num_nodes))\n        personalization[int(pagerank_values.argmax())] = 1.0\n\n        property_values = torch.tensor(\n            list(A.pagerank(G, personalization=personalization).values()))\n        property_values *= -1 if ascending else 1\n        return property_values\n\n    @staticmethod\n    def _compute_density_property(G: Any, ascending: bool = True) -> Tensor:\n        import networkx.algorithms as A\n\n        property_values = torch.tensor(list(A.clustering(G).values()))\n        property_values *= -1 if ascending else 1\n        return property_values\n\n    @staticmethod\n    def _mask_nodes_by_property(\n        property_values: Tensor,\n        ratios: List[float],\n    ) -> Dict[str, Tensor]:\n\n        num_nodes = property_values.size(0)\n        sizes = (num_nodes * torch.tensor(ratios)).round().long()\n        sizes[-1] -= sizes.sum() - num_nodes\n\n        perm = torch.randperm(num_nodes)\n        id_size = int(sizes[:3].sum())\n        perm = perm[property_values[perm].argsort()]\n        perm[:id_size] = perm[:id_size][torch.randperm(id_size)]\n\n        node_splits = perm.split(sizes.tolist())\n        names = [\n            'id_train_mask',\n            'id_val_mask',\n            'id_test_mask',\n            'ood_val_mask',\n            'ood_test_mask',\n        ]\n\n        split_masks = {}\n        for name, node_split in zip(names, node_splits):\n            split_mask = torch.zeros(num_nodes, dtype=torch.bool)\n            split_mask[node_split] = True\n            split_masks[name] = split_mask\n        return split_masks\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.property_name})'\n\n\n_property_name_to_compute_fn = {\n    'popularity': NodePropertySplit._compute_popularity_property,\n    'locality': NodePropertySplit._compute_locality_property,\n    'density': NodePropertySplit._compute_density_property,\n}\n"
  },
  {
    "path": "torch_geometric/transforms/normalize_features.py",
    "content": "from typing import List, Optional, Union\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('normalize_features')\nclass NormalizeFeatures(BaseTransform):\n    r\"\"\"Row-normalizes the attributes given in :obj:`attrs` to sum-up to one\n    (functional name: :obj:`normalize_features`).\n\n    Args:\n        attrs (List[str]): The names of attributes to normalize.\n            (default: :obj:`[\"x\"]`)\n    \"\"\"\n    def __init__(self, attrs: Optional[List[str]] = None) -> None:\n        self.attrs = attrs or [\"x\"]\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for store in data.stores:\n            for key, value in store.items(*self.attrs):\n                if value.numel() > 0:\n                    value = value - value.min()\n                    value.div_(value.sum(dim=-1, keepdim=True).clamp_(min=1.))\n                    store[key] = value\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/normalize_rotation.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('normalize_rotation')\nclass NormalizeRotation(BaseTransform):\n    r\"\"\"Rotates all points according to the eigenvectors of the point cloud\n    (functional name: :obj:`normalize_rotation`).\n    If the data additionally holds normals saved in :obj:`data.normal`, these\n    will be rotated accordingly.\n\n    Args:\n        max_points (int, optional): If set to a value greater than :obj:`0`,\n            only a random number of :obj:`max_points` points are sampled and\n            used to compute eigenvectors. (default: :obj:`-1`)\n        sort (bool, optional): If set to :obj:`True`, will sort eigenvectors\n            according to their eigenvalues. (default: :obj:`False`)\n    \"\"\"\n    def __init__(self, max_points: int = -1, sort: bool = False) -> None:\n        self.max_points = max_points\n        self.sort = sort\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        pos = data.pos\n\n        if self.max_points > 0 and pos.size(0) > self.max_points:\n            perm = torch.randperm(pos.size(0))\n            pos = pos[perm[:self.max_points]]\n\n        pos = pos - pos.mean(dim=0, keepdim=True)\n        C = torch.matmul(pos.t(), pos)\n        e, v = torch.linalg.eig(C)  # v[:,j] is j-th eigenvector\n        e, v = torch.view_as_real(e), v.real\n\n        if self.sort:\n            indices = e[:, 0].argsort(descending=True)\n            v = v.t()[indices].t()\n\n        data.pos = torch.matmul(data.pos, v)\n\n        if 'normal' in data:\n            data.normal = F.normalize(torch.matmul(data.normal, v))\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/normalize_scale.py",
    "content": "from torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform, Center\n\n\n@functional_transform('normalize_scale')\nclass NormalizeScale(BaseTransform):\n    r\"\"\"Centers and normalizes node positions to the interval :math:`(-1, 1)`\n    (functional name: :obj:`normalize_scale`).\n    \"\"\"\n    def __init__(self) -> None:\n        self.center = Center()\n\n    def forward(self, data: Data) -> Data:\n        data = self.center(data)\n\n        assert data.pos is not None\n        scale = (1.0 / data.pos.abs().max()) * 0.999999\n        data.pos = data.pos * scale\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/one_hot_degree.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import degree, one_hot\n\n\n@functional_transform('one_hot_degree')\nclass OneHotDegree(BaseTransform):\n    r\"\"\"Adds the node degree as one hot encodings to the node features\n    (functional name: :obj:`one_hot_degree`).\n\n    Args:\n        max_degree (int): Maximum degree.\n        in_degree (bool, optional): If set to :obj:`True`, will compute the\n            in-degree of nodes instead of the out-degree.\n            (default: :obj:`False`)\n        cat (bool, optional): Concat node degrees to node features instead\n            of replacing them. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        max_degree: int,\n        in_degree: bool = False,\n        cat: bool = True,\n    ) -> None:\n        self.max_degree = max_degree\n        self.in_degree = in_degree\n        self.cat = cat\n\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n        idx, x = data.edge_index[1 if self.in_degree else 0], data.x\n        deg = degree(idx, data.num_nodes, dtype=torch.long)\n        deg = one_hot(deg, num_classes=self.max_degree + 1)\n\n        if x is not None and self.cat:\n            x = x.view(-1, 1) if x.dim() == 1 else x\n            data.x = torch.cat([x, deg.to(x.dtype)], dim=-1)\n        else:\n            data.x = deg\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.max_degree})'\n"
  },
  {
    "path": "torch_geometric/transforms/pad.py",
    "content": "from abc import ABC, abstractmethod\nfrom dataclasses import dataclass\nfrom typing import Any, Callable, Dict, List, Optional, Tuple, Union\n\nimport torch\nimport torch.nn.functional as F\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.data.storage import EdgeStorage, NodeStorage\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.typing import EdgeType, NodeType\n\n\nclass Padding(ABC):\n    r\"\"\"An abstract class for specifying padding values.\"\"\"\n    @abstractmethod\n    def get_value(\n        self,\n        store_type: Optional[Union[NodeType, EdgeType]] = None,\n        attr_name: Optional[str] = None,\n    ) -> Union[int, float]:\n        pass\n\n\n@dataclass(init=False)\nclass UniformPadding(Padding):\n    r\"\"\"Uniform padding independent of attribute name or node/edge type.\n\n    Args:\n        value (int or float, optional): The value to be used for padding.\n            (default: :obj:`0.0`)\n    \"\"\"\n    value: Union[int, float] = 0.0\n\n    def __init__(self, value: Union[int, float] = 0.0):\n        self.value = value\n\n        if not isinstance(self.value, (int, float)):\n            raise ValueError(f\"Expected 'value' to be an integer or float \"\n                             f\"(got '{type(value)}'\")\n\n    def get_value(\n        self,\n        store_type: Optional[Union[NodeType, EdgeType]] = None,\n        attr_name: Optional[str] = None,\n    ) -> Union[int, float]:\n        return self.value\n\n\n@dataclass(init=False)\nclass MappingPadding(Padding):\n    r\"\"\"An abstract class for specifying different padding values.\"\"\"\n    values: Dict[Any, Padding]\n    default: UniformPadding\n\n    def __init__(\n        self,\n        values: Dict[Any, Union[int, float, Padding]],\n        default: Union[int, float] = 0.0,\n    ):\n        if not isinstance(values, dict):\n            raise ValueError(f\"Expected 'values' to be a dictionary \"\n                             f\"(got '{type(values)}'\")\n\n        self.values = {\n            key: UniformPadding(val) if isinstance(val, (int, float)) else val\n            for key, val in values.items()\n        }\n        self.default = UniformPadding(default)\n\n        for key, value in self.values.items():\n            self.validate_key_value(key, value)\n\n    def validate_key_value(self, key: Any, value: Any) -> None:\n        pass\n\n\nclass AttrNamePadding(MappingPadding):\n    r\"\"\"Padding dependent on attribute names.\n\n    Args:\n        values (dict): The mapping from attribute names to padding values.\n        default (int or float, optional): The padding value to use for\n            attribute names not specified in :obj:`values`.\n            (default: :obj:`0.0`)\n    \"\"\"\n    def validate_key_value(self, key: Any, value: Any) -> None:\n        if not isinstance(key, str):\n            raise ValueError(f\"Expected the attribute name '{key}' to be a \"\n                             f\"string (got '{type(key)}')\")\n\n        if not isinstance(value, UniformPadding):\n            raise ValueError(f\"Expected the value of '{key}' to be of \"\n                             f\"type 'UniformPadding' (got '{type(value)}')\")\n\n    def get_value(\n        self,\n        store_type: Optional[Union[NodeType, EdgeType]] = None,\n        attr_name: Optional[str] = None,\n    ) -> Union[int, float]:\n        padding = self.values.get(attr_name, self.default)\n        return padding.get_value()\n\n\nclass NodeTypePadding(MappingPadding):\n    r\"\"\"Padding dependent on node types.\n\n    Args:\n        values (dict): The mapping from node types to padding values.\n        default (int or float, optional): The padding value to use for node\n            types not specified in :obj:`values`. (default: :obj:`0.0`)\n    \"\"\"\n    def validate_key_value(self, key: Any, value: Any) -> None:\n        if not isinstance(key, str):\n            raise ValueError(f\"Expected the node type '{key}' to be a string \"\n                             f\"(got '{type(key)}')\")\n\n        if not isinstance(value, (UniformPadding, AttrNamePadding)):\n            raise ValueError(f\"Expected the value of '{key}' to be of \"\n                             f\"type 'UniformPadding' or 'AttrNamePadding' \"\n                             f\"(got '{type(value)}')\")\n\n    def get_value(\n        self,\n        store_type: Optional[Union[NodeType, EdgeType]] = None,\n        attr_name: Optional[str] = None,\n    ) -> Union[int, float]:\n        padding = self.values.get(store_type, self.default)\n        return padding.get_value(attr_name=attr_name)\n\n\nclass EdgeTypePadding(MappingPadding):\n    r\"\"\"Padding dependent on node types.\n\n    Args:\n        values (dict): The mapping from edge types to padding values.\n        default (int or float, optional): The padding value to use for edge\n            types not specified in :obj:`values`. (default: :obj:`0.0`)\n    \"\"\"\n    def validate_key_value(self, key: Any, value: Any) -> None:\n        if not isinstance(key, tuple):\n            raise ValueError(f\"Expected the edge type '{key}' to be a tuple \"\n                             f\"(got '{type(key)}')\")\n\n        if len(key) != 3:\n            raise ValueError(f\"Expected the edge type '{key}' to hold exactly \"\n                             f\"three elements (got {len(key)})\")\n\n        if not isinstance(value, (UniformPadding, AttrNamePadding)):\n            raise ValueError(f\"Expected the value of '{key}' to be of \"\n                             f\"type 'UniformPadding' or 'AttrNamePadding' \"\n                             f\"(got '{type(value)}')\")\n\n    def get_value(\n        self,\n        store_type: Optional[Union[NodeType, EdgeType]] = None,\n        attr_name: Optional[str] = None,\n    ) -> Union[int, float]:\n        padding = self.values.get(store_type, self.default)\n        return padding.get_value(attr_name=attr_name)\n\n\nclass _NumNodes:\n    def __init__(\n        self,\n        value: Union[int, Dict[NodeType, int], None],\n    ) -> None:\n        self.value = value\n\n    def get_value(self, key: Optional[NodeType] = None) -> Optional[int]:\n        if self.value is None or isinstance(self.value, int):\n            return self.value\n        assert isinstance(key, str)\n        return self.value[key]\n\n\nclass _NumEdges:\n    def __init__(\n        self,\n        value: Union[int, Dict[EdgeType, int], None],\n        num_nodes: _NumNodes,\n    ) -> None:\n\n        if value is None:\n            if isinstance(num_nodes.value, int):\n                value = num_nodes.value * num_nodes.value\n            else:\n                value = {}\n\n        self.value = value\n        self.num_nodes = num_nodes\n\n    def get_value(self, key: Optional[EdgeType] = None) -> Optional[int]:\n        if self.value is None or isinstance(self.value, int):\n            return self.value\n\n        assert isinstance(key, tuple) and len(key) == 3\n        if key not in self.value:\n            num_src_nodes = self.num_nodes.get_value(key[0])\n            num_dst_nodes = self.num_nodes.get_value(key[-1])\n            assert num_src_nodes is not None and num_dst_nodes is not None\n            self.value[key] = num_src_nodes * num_dst_nodes\n\n        return self.value[key]\n\n\n@functional_transform('pad')\nclass Pad(BaseTransform):\n    r\"\"\"Applies padding to enforce consistent tensor shapes\n    (functional name: :obj:`pad`).\n\n    This transform will pad node and edge features up to a maximum allowed size\n    in the node or edge feature dimension. By default :obj:`0.0` is used as the\n    padding value and can be configured by setting :obj:`node_pad_value` and\n    :obj:`edge_pad_value`.\n\n    In case of applying :class:`Pad` to a :class:`~torch_geometric.data.Data`\n    object, the :obj:`node_pad_value` value (or :obj:`edge_pad_value`) can be\n    either:\n\n    * an int, float or object of :class:`UniformPadding` class for cases when\n      all attributes are going to be padded with the same value;\n    * an object of :class:`AttrNamePadding` class for cases when padding is\n      going to differ based on attribute names.\n\n    In case of applying :class:`Pad` to a\n    :class:`~torch_geometric.data.HeteroData` object, the :obj:`node_pad_value`\n    value (or :obj:`edge_pad_value`) can be either:\n\n    * an int, float or object of :class:`UniformPadding` class for cases when\n      all attributes of all node (or edge) stores are going to be padded with\n      the same value;\n    * an object of :class:`AttrNamePadding` class for cases when padding is\n      going to differ based on attribute names (but not based on node or edge\n      types);\n    * an object of class :class:`NodeTypePadding` or :class:`EdgeTypePadding`\n      for cases when padding values are going to differ based on node or edge\n      types. Padding values can also differ based on attribute names for a\n      given node or edge type by using :class:`AttrNamePadding` objects as\n      values of its `values` argument.\n\n    Note that in order to allow for consistent padding across all graphs in a\n    dataset, below conditions must be met:\n\n    * if :obj:`max_num_nodes` is a single value, it must be greater than or\n      equal to the maximum number of nodes of any graph in the dataset;\n    * if :obj:`max_num_nodes` is a dictionary, value for every node type must\n      be greater than or equal to the maximum number of this type nodes of any\n      graph in the dataset.\n\n    Example below shows how to create a :class:`Pad` transform for an\n    :class:`~torch_geometric.data.HeteroData` object. The object is padded to\n    have :obj:`10` nodes of type :obj:`v0`, :obj:`20` nodes of type :obj:`v1`\n    and :obj:`30` nodes of type :obj:`v2`.\n    It is padded to have :obj:`80` edges of type :obj:`('v0', 'e0', 'v1')`.\n    All the attributes of the :obj:`v0` nodes are padded using a value of\n    :obj:`3.0`.\n    The :obj:`x` attribute of the :obj:`v1` node type is padded using a value\n    of :obj:`-1.0`, and the other attributes of this node type are padded using\n    a value of :obj:`0.5`.\n    All the attributes of node types other than :obj:`v0` and :obj:`v1` are\n    padded using a value of :obj:`1.0`.\n    All the attributes of the :obj:`('v0', 'e0', 'v1')` edge type are padded\n    using a value of :obj:`3.5`.\n    The :obj:`edge_attr` attributes of the :obj:`('v1', 'e0', 'v0')` edge type\n    are padded using a value of :obj:`-1.5`, and any other attributes of this\n    edge type are padded using a value of :obj:`5.5`.\n    All the attributes of edge types other than these two are padded using a\n    value of :obj:`1.5`.\n\n    .. code-block:: python\n\n        num_nodes = {'v0': 10, 'v1': 20, 'v2':30}\n        num_edges = {('v0', 'e0', 'v1'): 80}\n\n        node_padding = NodeTypePadding({\n            'v0': 3.0,\n            'v1': AttrNamePadding({'x': -1.0}, default=0.5),\n        }, default=1.0)\n\n        edge_padding = EdgeTypePadding({\n            ('v0', 'e0', 'v1'): 3.5,\n            ('v1', 'e0', 'v0'): AttrNamePadding({'edge_attr': -1.5},\n                                                default=5.5),\n        }, default=1.5)\n\n        transform = Pad(num_nodes, num_edges, node_padding, edge_padding)\n\n    Args:\n        max_num_nodes (int or dict): The number of nodes after padding.\n            In heterogeneous graphs, may also take in a dictionary denoting the\n            number of nodes for specific node types.\n        max_num_edges (int or dict, optional): The number of edges after\n            padding.\n            In heterogeneous graphs, may also take in a dictionary denoting the\n            number of edges for specific edge types. (default: :obj:`None`)\n        node_pad_value (int or float or Padding, optional): The fill value to\n            use for node features. (default: :obj:`0.0`)\n        edge_pad_value (int or float or Padding, optional): The fill value to\n            use for edge features. (default: :obj:`0.0`)\n            The :obj:`edge_index` tensor is padded with with the index of the\n            first padded node (which represents a set of self-loops on the\n            padded node). (default: :obj:`0.0`)\n        mask_pad_value (bool, optional): The fill value to use for\n            :obj:`train_mask`, :obj:`val_mask` and :obj:`test_mask` attributes\n            (default: :obj:`False`).\n        add_pad_mask (bool, optional): If set to :obj:`True`, will attach\n            node-level :obj:`pad_node_mask` and edge-level :obj:`pad_edge_mask`\n            attributes to the output which indicates which elements in the data\n            are real (represented by :obj:`True`) and which were added as a\n            result of padding (represented by :obj:`False`).\n            (default: :obj:`False`)\n        exclude_keys ([str], optional): Keys to be removed\n            from the input data object. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        max_num_nodes: Union[int, Dict[NodeType, int]],\n        max_num_edges: Optional[Union[int, Dict[EdgeType, int]]] = None,\n        node_pad_value: Union[int, float, Padding] = 0.0,\n        edge_pad_value: Union[int, float, Padding] = 0.0,\n        mask_pad_value: bool = False,\n        add_pad_mask: bool = False,\n        exclude_keys: Optional[List[str]] = None,\n    ):\n        self.max_num_nodes = _NumNodes(max_num_nodes)\n        self.max_num_edges = _NumEdges(max_num_edges, self.max_num_nodes)\n\n        self.node_pad: Padding\n        if not isinstance(node_pad_value, Padding):\n            self.node_pad = UniformPadding(node_pad_value)\n        else:\n            self.node_pad = node_pad_value\n\n        self.edge_pad: Padding\n        if not isinstance(edge_pad_value, Padding):\n            self.edge_pad = UniformPadding(edge_pad_value)\n        else:\n            self.edge_pad = edge_pad_value\n\n        self.node_additional_attrs_pad = {\n            key: mask_pad_value\n            for key in ['train_mask', 'val_mask', 'test_mask']\n        }\n\n        self.add_pad_mask = add_pad_mask\n        self.exclude_keys = set(exclude_keys or [])\n\n    def __should_pad_node_attr(self, attr_name: str) -> bool:\n        if attr_name in self.node_additional_attrs_pad:\n            return True\n        if self.exclude_keys is None or attr_name not in self.exclude_keys:\n            return True\n        return False\n\n    def __should_pad_edge_attr(self, attr_name: str) -> bool:\n        if self.max_num_edges.value is None:\n            return False\n        if attr_name == 'edge_index':\n            return True\n        if self.exclude_keys is None or attr_name not in self.exclude_keys:\n            return True\n        return False\n\n    def __get_node_padding(\n        self,\n        attr_name: str,\n        node_type: Optional[NodeType] = None,\n    ) -> Union[int, float]:\n        if attr_name in self.node_additional_attrs_pad:\n            return self.node_additional_attrs_pad[attr_name]\n        return self.node_pad.get_value(node_type, attr_name)\n\n    def __get_edge_padding(\n        self,\n        attr_name: str,\n        edge_type: Optional[EdgeType] = None,\n    ) -> Union[int, float]:\n        return self.edge_pad.get_value(edge_type, attr_name)\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n\n        if isinstance(data, Data):\n            assert isinstance(self.node_pad, (UniformPadding, AttrNamePadding))\n            assert isinstance(self.edge_pad, (UniformPadding, AttrNamePadding))\n\n            for key in self.exclude_keys:\n                del data[key]\n\n            num_nodes = data.num_nodes\n            assert num_nodes is not None\n            self.__pad_edge_store(data._store, data.__cat_dim__, num_nodes)\n            self.__pad_node_store(data._store, data.__cat_dim__)\n            data.num_nodes = self.max_num_nodes.get_value()\n        else:\n            assert isinstance(\n                self.node_pad,\n                (UniformPadding, AttrNamePadding, NodeTypePadding))\n            assert isinstance(\n                self.edge_pad,\n                (UniformPadding, AttrNamePadding, EdgeTypePadding))\n\n            for edge_type, edge_store in data.edge_items():\n                for key in self.exclude_keys:\n                    del edge_store[key]\n\n                src_node_type, _, dst_node_type = edge_type\n                num_src_nodes = data[src_node_type].num_nodes\n                num_dst_nodes = data[dst_node_type].num_nodes\n                assert num_src_nodes is not None and num_dst_nodes is not None\n                self.__pad_edge_store(edge_store, data.__cat_dim__,\n                                      (num_src_nodes, num_dst_nodes),\n                                      edge_type)\n\n            for node_type, node_store in data.node_items():\n                for key in self.exclude_keys:\n                    del node_store[key]\n                self.__pad_node_store(node_store, data.__cat_dim__, node_type)\n                data[node_type].num_nodes = self.max_num_nodes.get_value(\n                    node_type)\n\n        return data\n\n    def __pad_node_store(\n        self,\n        store: NodeStorage,\n        get_dim_fn: Callable,\n        node_type: Optional[NodeType] = None,\n    ) -> None:\n\n        attrs_to_pad = [key for key in store.keys() if store.is_node_attr(key)]\n\n        if len(attrs_to_pad) == 0:\n            return\n\n        num_target_nodes = self.max_num_nodes.get_value(node_type)\n        assert num_target_nodes is not None\n        assert store.num_nodes is not None\n        assert num_target_nodes >= store.num_nodes, \\\n            f'The number of nodes after padding ({num_target_nodes}) cannot ' \\\n            f'be lower than the number of nodes in the data object ' \\\n            f'({store.num_nodes}).'\n        num_pad_nodes = num_target_nodes - store.num_nodes\n\n        if self.add_pad_mask:\n            pad_node_mask = torch.ones(num_target_nodes, dtype=torch.bool)\n            pad_node_mask[store.num_nodes:] = False\n            store.pad_node_mask = pad_node_mask\n\n        for attr_name in attrs_to_pad:\n            attr = store[attr_name]\n            pad_value = self.__get_node_padding(attr_name, node_type)\n            dim = get_dim_fn(attr_name, attr)\n            store[attr_name] = self._pad_tensor_dim(attr, dim, num_pad_nodes,\n                                                    pad_value)\n\n    def __pad_edge_store(\n        self,\n        store: EdgeStorage,\n        get_dim_fn: Callable,\n        num_nodes: Union[int, Tuple[int, int]],\n        edge_type: Optional[EdgeType] = None,\n    ) -> None:\n\n        attrs_to_pad = {\n            attr\n            for attr in store.keys()\n            if store.is_edge_attr(attr) and self.__should_pad_edge_attr(attr)\n        }\n        if not attrs_to_pad:\n            return\n        num_target_edges = self.max_num_edges.get_value(edge_type)\n        assert num_target_edges is not None\n        assert num_target_edges >= store.num_edges, \\\n            f'The number of edges after padding ({num_target_edges}) cannot ' \\\n            f'be lower than the number of edges in the data object ' \\\n            f'({store.num_edges}).'\n        num_pad_edges = num_target_edges - store.num_edges\n\n        if self.add_pad_mask:\n            pad_edge_mask = torch.ones(num_target_edges, dtype=torch.bool)\n            pad_edge_mask[store.num_edges:] = False\n            store.pad_edge_mask = pad_edge_mask\n\n        if isinstance(num_nodes, tuple):\n            src_pad_value, dst_pad_value = num_nodes\n        else:\n            src_pad_value = dst_pad_value = num_nodes\n\n        for attr_name in attrs_to_pad:\n            attr = store[attr_name]\n            dim = get_dim_fn(attr_name, attr)\n            if attr_name == 'edge_index':\n                store[attr_name] = self._pad_edge_index(\n                    attr, num_pad_edges, src_pad_value, dst_pad_value)\n            else:\n                pad_value = self.__get_edge_padding(attr_name, edge_type)\n                store[attr_name] = self._pad_tensor_dim(\n                    attr, dim, num_pad_edges, pad_value)\n\n    @staticmethod\n    def _pad_tensor_dim(input: torch.Tensor, dim: int, length: int,\n                        pad_value: float) -> torch.Tensor:\n        r\"\"\"Pads the input tensor in the specified dim with a constant value of\n        the given length.\n        \"\"\"\n        pads = [0] * (2 * input.ndim)\n        pads[-2 * dim - 1] = length\n        return F.pad(input, pads, 'constant', pad_value)\n\n    @staticmethod\n    def _pad_edge_index(input: torch.Tensor, length: int, src_pad_value: float,\n                        dst_pad_value: float) -> torch.Tensor:\n        r\"\"\"Pads the edges :obj:`edge_index` feature with values specified\n        separately for src and dst nodes.\n        \"\"\"\n        pads = [0, length, 0, 0]\n        padded = F.pad(input, pads, 'constant', src_pad_value)\n        if src_pad_value != dst_pad_value:\n            padded[1, input.shape[1]:] = dst_pad_value\n        return padded\n\n    def __repr__(self) -> str:\n        s = f'{self.__class__.__name__}('\n        s += f'max_num_nodes={self.max_num_nodes.value}, '\n        s += f'max_num_edges={self.max_num_edges.value}, '\n        s += f'node_pad_value={self.node_pad}, '\n        s += f'edge_pad_value={self.edge_pad})'\n        return s\n"
  },
  {
    "path": "torch_geometric/transforms/point_pair_features.py",
    "content": "import torch\n\nimport torch_geometric\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('point_pair_features')\nclass PointPairFeatures(BaseTransform):\n    r\"\"\"Computes the rotation-invariant Point Pair Features\n    (functional name: :obj:`point_pair_features`).\n\n    .. math::\n        \\left( \\| \\mathbf{d_{j,i}} \\|, \\angle(\\mathbf{n}_i, \\mathbf{d_{j,i}}),\n        \\angle(\\mathbf{n}_j, \\mathbf{d_{j,i}}), \\angle(\\mathbf{n}_i,\n        \\mathbf{n}_j) \\right)\n\n    of linked nodes in its edge attributes, where :math:`\\mathbf{d}_{j,i}`\n    denotes the difference vector between, and :math:`\\mathbf{n}_i` and\n    :math:`\\mathbf{n}_j` denote the surface normals of node :math:`i` and\n    :math:`j` respectively.\n\n    Args:\n        cat (bool, optional): If set to :obj:`False`, all existing edge\n            attributes will be replaced. (default: :obj:`True`)\n    \"\"\"\n    def __init__(self, cat: bool = True):\n        self.cat = cat\n\n    def forward(self, data: Data) -> Data:\n        ppf_func = torch_geometric.nn.conv.ppf_conv.point_pair_features\n\n        assert data.edge_index is not None\n        assert data.pos is not None and data.norm is not None\n        assert data.pos.size(-1) == 3\n        assert data.pos.size() == data.norm.size()\n\n        row, col = data.edge_index\n        pos, norm, pseudo = data.pos, data.norm, data.edge_attr\n\n        ppf = ppf_func(pos[row], pos[col], norm[row], norm[col])\n\n        if pseudo is not None and self.cat:\n            pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo\n            data.edge_attr = torch.cat([pseudo, ppf.type_as(pseudo)], dim=-1)\n        else:\n            data.edge_attr = ppf\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/polar.py",
    "content": "from math import pi as PI\nfrom typing import Optional\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('polar')\nclass Polar(BaseTransform):\n    r\"\"\"Saves the polar coordinates of linked nodes in its edge attributes\n    (functional name: :obj:`polar`).\n\n    Args:\n        norm (bool, optional): If set to :obj:`False`, the output will not be\n            normalized to the interval :math:`{[0, 1]}^2`.\n            (default: :obj:`True`)\n        max_value (float, optional): If set and :obj:`norm=True`, normalization\n            will be performed based on this value instead of the maximum value\n            found in the data. (default: :obj:`None`)\n        cat (bool, optional): If set to :obj:`False`, all existing edge\n            attributes will be replaced. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        norm: bool = True,\n        max_value: Optional[float] = None,\n        cat: bool = True,\n    ) -> None:\n        self.norm = norm\n        self.max = max_value\n        self.cat = cat\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        assert data.edge_index is not None\n        (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr\n        assert pos.dim() == 2 and pos.size(1) == 2\n\n        cart = pos[col] - pos[row]\n\n        rho = torch.norm(cart, p=2, dim=-1).view(-1, 1)\n\n        theta = torch.atan2(cart[..., 1], cart[..., 0]).view(-1, 1)\n        theta = theta + (theta < 0).type_as(theta) * (2 * PI)\n\n        if self.norm:\n            rho = rho / (rho.max() if self.max is None else self.max)\n            theta = theta / (2 * PI)\n\n        polar = torch.cat([rho, theta], dim=-1)\n\n        if pseudo is not None and self.cat:\n            pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo\n            data.edge_attr = torch.cat([pseudo, polar.type_as(pos)], dim=-1)\n        else:\n            data.edge_attr = polar\n\n        return data\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(norm={self.norm}, '\n                f'max_value={self.max})')\n"
  },
  {
    "path": "torch_geometric/transforms/radius_graph.py",
    "content": "import torch_geometric\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('radius_graph')\nclass RadiusGraph(BaseTransform):\n    r\"\"\"Creates edges based on node positions :obj:`data.pos` to all points\n    within a given distance (functional name: :obj:`radius_graph`).\n\n    Args:\n        r (float): The distance.\n        loop (bool, optional): If :obj:`True`, the graph will contain\n            self-loops. (default: :obj:`False`)\n        max_num_neighbors (int, optional): The maximum number of neighbors to\n            return for each element in :obj:`y`.\n            This flag is only needed for CUDA tensors. (default: :obj:`32`)\n        flow (str, optional): The flow direction when using in combination with\n            message passing (:obj:`\"source_to_target\"` or\n            :obj:`\"target_to_source\"`). (default: :obj:`\"source_to_target\"`)\n        num_workers (int): Number of workers to use for computation. Has no\n            effect in case :obj:`batch` is not :obj:`None`, or the input lies\n            on the GPU. (default: :obj:`1`)\n    \"\"\"\n    def __init__(\n        self,\n        r: float,\n        loop: bool = False,\n        max_num_neighbors: int = 32,\n        flow: str = 'source_to_target',\n        num_workers: int = 1,\n    ) -> None:\n        self.r = r\n        self.loop = loop\n        self.max_num_neighbors = max_num_neighbors\n        self.flow = flow\n        self.num_workers = num_workers\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n\n        data.edge_index = torch_geometric.nn.radius_graph(\n            data.pos,\n            self.r,\n            data.batch,\n            self.loop,\n            max_num_neighbors=self.max_num_neighbors,\n            flow=self.flow,\n            num_workers=self.num_workers,\n        )\n        data.edge_attr = None\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(r={self.r})'\n"
  },
  {
    "path": "torch_geometric/transforms/random_flip.py",
    "content": "import random\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('random_flip')\nclass RandomFlip(BaseTransform):\n    \"\"\"Flips node positions along a given axis randomly with a given\n    probability (functional name: :obj:`random_flip`).\n\n    Args:\n        axis (int): The axis along the position of nodes being flipped.\n        p (float, optional): Probability that node positions will be flipped.\n            (default: :obj:`0.5`)\n    \"\"\"\n    def __init__(self, axis: int, p: float = 0.5) -> None:\n        self.axis = axis\n        self.p = p\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n\n        if random.random() < self.p:\n            pos = data.pos.clone()\n            pos[..., self.axis] = -pos[..., self.axis]\n            data.pos = pos\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(axis={self.axis}, p={self.p})'\n"
  },
  {
    "path": "torch_geometric/transforms/random_jitter.py",
    "content": "from itertools import repeat\nfrom typing import Sequence, Union\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('random_jitter')\nclass RandomJitter(BaseTransform):\n    r\"\"\"Translates node positions by randomly sampled translation values\n    within a given interval (functional name: :obj:`random_jitter`).\n    In contrast to other random transformations,\n    translation is applied separately at each position.\n\n    Args:\n        translate (sequence or float or int): Maximum translation in each\n            dimension, defining the range\n            :math:`(-\\mathrm{translate}, +\\mathrm{translate})` to sample from.\n            If :obj:`translate` is a number instead of a sequence, the same\n            range is used for each dimension.\n    \"\"\"\n    def __init__(\n        self,\n        translate: Union[float, int, Sequence[Union[float, int]]],\n    ) -> None:\n        self.translate = translate\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        num_nodes, dim = data.pos.size()\n\n        translate: Sequence[Union[float, int]]\n        if isinstance(self.translate, (int, float)):\n            translate = list(repeat(self.translate, times=dim))\n        else:\n            assert len(self.translate) == dim\n            translate = self.translate\n\n        jitter = data.pos.new_empty(num_nodes, dim)\n        for d in range(dim):\n            jitter[:, d].uniform_(-abs(translate[d]), abs(translate[d]))\n\n        data.pos = data.pos + jitter\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.translate})'\n"
  },
  {
    "path": "torch_geometric/transforms/random_link_split.py",
    "content": "import copy\nimport warnings\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.data.storage import EdgeStorage\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.typing import EdgeType\nfrom torch_geometric.utils import negative_sampling\n\n\n@functional_transform('random_link_split')\nclass RandomLinkSplit(BaseTransform):\n    r\"\"\"Performs an edge-level random split into training, validation and test\n    sets of a :class:`~torch_geometric.data.Data` or a\n    :class:`~torch_geometric.data.HeteroData` object\n    (functional name: :obj:`random_link_split`).\n    The split is performed such that the training split does not include edges\n    in validation and test splits; and the validation split does not include\n    edges in the test split.\n\n    .. code-block:: python\n\n        from torch_geometric.transforms import RandomLinkSplit\n\n        transform = RandomLinkSplit(is_undirected=True)\n        train_data, val_data, test_data = transform(data)\n\n    Args:\n        num_val (int or float, optional): The number of validation edges.\n            If set to a floating-point value in :math:`[0, 1]`, it represents\n            the ratio of edges to include in the validation set.\n            (default: :obj:`0.1`)\n        num_test (int or float, optional): The number of test edges.\n            If set to a floating-point value in :math:`[0, 1]`, it represents\n            the ratio of edges to include in the test set.\n            (default: :obj:`0.2`)\n        is_undirected (bool): If set to :obj:`True`, the graph is assumed to be\n            undirected, and positive and negative samples will not leak\n            (reverse) edge connectivity across different splits. This only\n            affects the graph split, label data will not be returned\n            undirected. This option is ignored for bipartite edge types or\n            whenever :obj:`edge_type != rev_edge_type`. (default: :obj:`False`)\n        key (str, optional): The name of the attribute holding\n            ground-truth labels.\n            If :obj:`data[key]` does not exist, it will be automatically\n            created and represents a binary classification task\n            (:obj:`1` = edge, :obj:`0` = no edge).\n            If :obj:`data[key]` exists, it has to be a categorical label from\n            :obj:`0` to :obj:`num_classes - 1`.\n            After negative sampling, label :obj:`0` represents negative edges,\n            and labels :obj:`1` to :obj:`num_classes` represent the labels of\n            positive edges. (default: :obj:`\"edge_label\"`)\n        split_labels (bool, optional): If set to :obj:`True`, will split\n            positive and negative labels and save them in distinct attributes\n            :obj:`\"pos_edge_label\"` and :obj:`\"neg_edge_label\"`, respectively.\n            (default: :obj:`False`)\n        add_negative_train_samples (bool, optional): Whether to add negative\n            training samples for link prediction.\n            If the model already performs negative sampling, then the option\n            should be set to :obj:`False`.\n            Otherwise, the added negative samples will be the same across\n            training iterations unless negative sampling is performed again.\n            (default: :obj:`True`)\n        neg_sampling_ratio (float, optional): The ratio of sampled negative\n            edges to the number of positive edges. (default: :obj:`1.0`)\n        disjoint_train_ratio (int or float, optional): If set to a value\n            greater than :obj:`0.0`, training edges will not be shared for\n            message passing and supervision. Instead,\n            :obj:`disjoint_train_ratio` edges are used as ground-truth labels\n            for supervision during training. (default: :obj:`0.0`)\n        edge_types (Tuple[EdgeType] or List[EdgeType], optional): The edge\n            types used for performing edge-level splitting in case of\n            operating on :class:`~torch_geometric.data.HeteroData` objects.\n            (default: :obj:`None`)\n        rev_edge_types (Tuple[EdgeType] or List[Tuple[EdgeType]], optional):\n            The reverse edge types of :obj:`edge_types` in case of operating\n            on :class:`~torch_geometric.data.HeteroData` objects.\n            This will ensure that edges of the reverse direction will be\n            split accordingly to prevent any data leakage.\n            Can be :obj:`None` in case no reverse connection exists.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        num_val: Union[int, float] = 0.1,\n        num_test: Union[int, float] = 0.2,\n        is_undirected: bool = False,\n        key: str = 'edge_label',\n        split_labels: bool = False,\n        add_negative_train_samples: bool = True,\n        neg_sampling_ratio: float = 1.0,\n        disjoint_train_ratio: Union[int, float] = 0.0,\n        edge_types: Optional[Union[EdgeType, List[EdgeType]]] = None,\n        rev_edge_types: Optional[Union[\n            EdgeType,\n            List[Optional[EdgeType]],\n        ]] = None,\n    ) -> None:\n        if isinstance(edge_types, list):\n            if rev_edge_types is None:\n                rev_edge_types = [None] * len(edge_types)\n\n            assert isinstance(rev_edge_types, list)\n            assert len(edge_types) == len(rev_edge_types)\n\n        self.num_val = num_val\n        self.num_test = num_test\n        self.is_undirected = is_undirected\n        self.key = key\n        self.split_labels = split_labels\n        self.add_negative_train_samples = add_negative_train_samples\n        self.neg_sampling_ratio = neg_sampling_ratio\n        self.disjoint_train_ratio = disjoint_train_ratio\n        self.edge_types = edge_types\n        self.rev_edge_types = rev_edge_types\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Tuple[\n            Union[Data, HeteroData],\n            Union[Data, HeteroData],\n            Union[Data, HeteroData],\n    ]:\n        edge_types = self.edge_types\n        rev_edge_types = self.rev_edge_types\n\n        train_data = copy.copy(data)\n        val_data = copy.copy(data)\n        test_data = copy.copy(data)\n\n        if isinstance(data, HeteroData):\n            assert isinstance(train_data, HeteroData)\n            assert isinstance(val_data, HeteroData)\n            assert isinstance(test_data, HeteroData)\n\n            if edge_types is None:\n                raise ValueError(\n                    \"The 'RandomLinkSplit' transform expects 'edge_types' to \"\n                    \"be specified when operating on 'HeteroData' objects\")\n\n            if not isinstance(edge_types, list):\n                assert not isinstance(rev_edge_types, list)\n                edge_types = [edge_types]\n                rev_edge_types = [rev_edge_types]\n\n            stores = [data[edge_type] for edge_type in edge_types]\n            train_stores = [train_data[edge_type] for edge_type in edge_types]\n            val_stores = [val_data[edge_type] for edge_type in edge_types]\n            test_stores = [test_data[edge_type] for edge_type in edge_types]\n        else:\n            assert isinstance(train_data, Data)\n            assert isinstance(val_data, Data)\n            assert isinstance(test_data, Data)\n\n            rev_edge_types = [None]\n\n            train_data = copy.copy(data)\n            val_data = copy.copy(data)\n            test_data = copy.copy(data)\n\n            stores = [data._store]\n            train_stores = [train_data._store]\n            val_stores = [val_data._store]\n            test_stores = [test_data._store]\n\n        assert isinstance(rev_edge_types, list)\n        for item in zip(stores, train_stores, val_stores, test_stores,\n                        rev_edge_types):\n            store, train_store, val_store, test_store, rev_edge_type = item\n\n            is_undirected = self.is_undirected\n            is_undirected &= not store.is_bipartite()\n            is_undirected &= (rev_edge_type is None\n                              or (isinstance(data, HeteroData)\n                                  and store._key == data[rev_edge_type]._key))\n\n            edge_index = store.edge_index\n            if is_undirected:\n                mask = edge_index[0] <= edge_index[1]\n                perm = mask.nonzero(as_tuple=False).view(-1)\n                perm = perm[torch.randperm(perm.size(0), device=perm.device)]\n            else:\n                device = edge_index.device\n                perm = torch.randperm(edge_index.size(1), device=device)\n\n            num_val = self.num_val\n            if isinstance(num_val, float):\n                num_val = int(num_val * perm.numel())\n            num_test = self.num_test\n            if isinstance(num_test, float):\n                num_test = int(num_test * perm.numel())\n\n            num_train = perm.numel() - num_val - num_test\n\n            if num_train <= 0:\n                raise ValueError(\"Insufficient number of edges for training\")\n\n            train_edges = perm[:num_train]\n            val_edges = perm[num_train:num_train + num_val]\n            test_edges = perm[num_train + num_val:]\n            train_val_edges = perm[:num_train + num_val]\n\n            num_disjoint = self.disjoint_train_ratio\n            if isinstance(num_disjoint, float):\n                num_disjoint = int(num_disjoint * train_edges.numel())\n            if num_train - num_disjoint <= 0:\n                raise ValueError(\"Insufficient number of edges for training\")\n\n            # Create data splits:\n            self._split(train_store, train_edges[num_disjoint:], is_undirected,\n                        rev_edge_type)\n            self._split(val_store, train_edges, is_undirected, rev_edge_type)\n            self._split(test_store, train_val_edges, is_undirected,\n                        rev_edge_type)\n\n            # Create negative samples:\n            num_neg_train = 0\n            if self.add_negative_train_samples:\n                if num_disjoint > 0:\n                    num_neg_train = int(num_disjoint * self.neg_sampling_ratio)\n                else:\n                    num_neg_train = int(num_train * self.neg_sampling_ratio)\n            num_neg_val = int(num_val * self.neg_sampling_ratio)\n            num_neg_test = int(num_test * self.neg_sampling_ratio)\n\n            num_neg = num_neg_train + num_neg_val + num_neg_test\n\n            size = store.size()\n            if store._key is None or store._key[0] == store._key[-1]:\n                size = size[0]\n            neg_edge_index = negative_sampling(edge_index, size,\n                                               num_neg_samples=num_neg,\n                                               method='sparse')\n\n            # Adjust ratio if not enough negative edges exist\n            if neg_edge_index.size(1) < num_neg:\n                num_neg_found = neg_edge_index.size(1)\n                ratio = num_neg_found / num_neg\n                warnings.warn(\n                    f\"There are not enough negative edges to satisfy \"\n                    \"the provided sampling ratio. The ratio will be \"\n                    f\"adjusted to {ratio:.2f}.\", stacklevel=2)\n                num_neg_train = int((num_neg_train / num_neg) * num_neg_found)\n                num_neg_val = int((num_neg_val / num_neg) * num_neg_found)\n                num_neg_test = num_neg_found - num_neg_train - num_neg_val\n\n            # Create labels:\n            if num_disjoint > 0:\n                train_edges = train_edges[:num_disjoint]\n            self._create_label(\n                store,\n                train_edges,\n                neg_edge_index[:, num_neg_val + num_neg_test:],\n                out=train_store,\n            )\n            self._create_label(\n                store,\n                val_edges,\n                neg_edge_index[:, :num_neg_val],\n                out=val_store,\n            )\n            self._create_label(\n                store,\n                test_edges,\n                neg_edge_index[:, num_neg_val:num_neg_val + num_neg_test],\n                out=test_store,\n            )\n\n        return train_data, val_data, test_data\n\n    def _split(\n        self,\n        store: EdgeStorage,\n        index: Tensor,\n        is_undirected: bool,\n        rev_edge_type: Optional[EdgeType],\n    ) -> EdgeStorage:\n\n        edge_attrs = {key for key in store.keys() if store.is_edge_attr(key)}\n        for key, value in store.items():\n            if key == 'edge_index':\n                continue\n\n            if key in edge_attrs:\n                value = value[index]\n                if is_undirected:\n                    value = torch.cat([value, value], dim=0)\n                store[key] = value\n\n        edge_index = store.edge_index[:, index]\n        if is_undirected:\n            edge_index = torch.cat([edge_index, edge_index.flip([0])], dim=-1)\n        store.edge_index = edge_index\n\n        if rev_edge_type is not None:\n            rev_store = store._parent()[rev_edge_type]\n            for key in rev_store.keys():\n                if key not in store:\n                    del rev_store[key]  # We delete all outdated attributes.\n                elif key == 'edge_index':\n                    rev_store.edge_index = store.edge_index.flip([0])\n                else:\n                    rev_store[key] = store[key]\n\n        return store\n\n    def _create_label(\n        self,\n        store: EdgeStorage,\n        index: Tensor,\n        neg_edge_index: Tensor,\n        out: EdgeStorage,\n    ) -> EdgeStorage:\n\n        edge_index = store.edge_index[:, index]\n\n        if hasattr(store, self.key):\n            edge_label = store[self.key]\n            edge_label = edge_label[index]\n            # Increment labels by one. Note that there is no need to increment\n            # in case no negative edges are added.\n            if neg_edge_index.numel() > 0:\n                assert edge_label.dtype == torch.long\n                assert edge_label.size(0) == edge_index.size(1)\n                edge_label.add_(1)\n            if hasattr(out, self.key):\n                delattr(out, self.key)\n        else:\n            edge_label = torch.ones(index.numel(), device=index.device)\n\n        if neg_edge_index.numel() > 0:\n            neg_edge_label = edge_label.new_zeros((neg_edge_index.size(1), ) +\n                                                  edge_label.size()[1:])\n\n        if self.split_labels:\n            out[f'pos_{self.key}'] = edge_label\n            out[f'pos_{self.key}_index'] = edge_index\n            if neg_edge_index.numel() > 0:\n                out[f'neg_{self.key}'] = neg_edge_label\n                out[f'neg_{self.key}_index'] = neg_edge_index\n\n        else:\n            if neg_edge_index.numel() > 0:\n                edge_label = torch.cat([edge_label, neg_edge_label], dim=0)\n                edge_index = torch.cat([edge_index, neg_edge_index], dim=-1)\n            out[self.key] = edge_label\n            out[f'{self.key}_index'] = edge_index\n\n        return out\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(num_val={self.num_val}, '\n                f'num_test={self.num_test})')\n"
  },
  {
    "path": "torch_geometric/transforms/random_node_split.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.data.storage import NodeStorage\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('random_node_split')\nclass RandomNodeSplit(BaseTransform):\n    r\"\"\"Performs a node-level random split by adding :obj:`train_mask`,\n    :obj:`val_mask` and :obj:`test_mask` attributes to the\n    :class:`~torch_geometric.data.Data` or\n    :class:`~torch_geometric.data.HeteroData` object\n    (functional name: :obj:`random_node_split`).\n\n    Args:\n        split (str, optional): The type of dataset split (:obj:`\"train_rest\"`,\n            :obj:`\"test_rest\"`, :obj:`\"random\"`).\n            If set to :obj:`\"train_rest\"`, all nodes except those in the\n            validation and test sets will be used for training (as in the\n            `\"FastGCN: Fast Learning with Graph Convolutional Networks via\n            Importance Sampling\" <https://arxiv.org/abs/1801.10247>`_ paper).\n            If set to :obj:`\"test_rest\"`, all nodes except those in the\n            training and validation sets will be used for test (as in the\n            `\"Pitfalls of Graph Neural Network Evaluation\"\n            <https://arxiv.org/abs/1811.05868>`_ paper).\n            If set to :obj:`\"random\"`, train, validation, and test sets will be\n            randomly generated, according to :obj:`num_train_per_class`,\n            :obj:`num_val` and :obj:`num_test` (as in the `\"Semi-supervised\n            Classification with Graph Convolutional Networks\"\n            <https://arxiv.org/abs/1609.02907>`_ paper).\n            (default: :obj:`\"train_rest\"`)\n        num_splits (int, optional): The number of splits to add. If bigger\n            than :obj:`1`, the shape of masks will be\n            :obj:`[num_nodes, num_splits]`, and :obj:`[num_nodes]` otherwise.\n            (default: :obj:`1`)\n        num_train_per_class (int, optional): The number of training samples\n            per class in case of :obj:`\"test_rest\"` and :obj:`\"random\"` split.\n            (default: :obj:`20`)\n        num_val (int or float, optional): The number of validation samples.\n            If float, it represents the ratio of samples to include in the\n            validation set. (default: :obj:`500`)\n        num_test (int or float, optional): The number of test samples in case\n            of :obj:`\"train_rest\"` and :obj:`\"random\"` split. If float, it\n            represents the ratio of samples to include in the test set.\n            (default: :obj:`1000`)\n        key (str, optional): The name of the attribute holding ground-truth\n            labels. By default, will only add node-level splits for node-level\n            storages in which :obj:`key` is present. (default: :obj:`\"y\"`).\n    \"\"\"\n    def __init__(\n        self,\n        split: str = \"train_rest\",\n        num_splits: int = 1,\n        num_train_per_class: int = 20,\n        num_val: Union[int, float] = 500,\n        num_test: Union[int, float] = 1000,\n        key: Optional[str] = \"y\",\n    ) -> None:\n        assert split in ['train_rest', 'test_rest', 'random']\n        self.split = split\n        self.num_splits = num_splits\n        self.num_train_per_class = num_train_per_class\n        self.num_val = num_val\n        self.num_test = num_test\n        self.key = key\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for store in data.node_stores:\n            if self.key is not None and not hasattr(store, self.key):\n                continue\n\n            train_masks, val_masks, test_masks = zip(\n                *[self._split(store) for _ in range(self.num_splits)])\n\n            store.train_mask = torch.stack(train_masks, dim=-1).squeeze(-1)\n            store.val_mask = torch.stack(val_masks, dim=-1).squeeze(-1)\n            store.test_mask = torch.stack(test_masks, dim=-1).squeeze(-1)\n\n        return data\n\n    def _split(self, store: NodeStorage) -> Tuple[Tensor, Tensor, Tensor]:\n        num_nodes = store.num_nodes\n        assert num_nodes is not None\n\n        train_mask = torch.zeros(num_nodes, dtype=torch.bool)\n        val_mask = torch.zeros(num_nodes, dtype=torch.bool)\n        test_mask = torch.zeros(num_nodes, dtype=torch.bool)\n\n        if isinstance(self.num_val, float):\n            num_val = round(num_nodes * self.num_val)\n        else:\n            num_val = self.num_val\n\n        if isinstance(self.num_test, float):\n            num_test = round(num_nodes * self.num_test)\n        else:\n            num_test = self.num_test\n\n        if self.split == 'train_rest':\n            perm = torch.randperm(num_nodes)\n            val_mask[perm[:num_val]] = True\n            test_mask[perm[num_val:num_val + num_test]] = True\n            train_mask[perm[num_val + num_test:]] = True\n        else:\n            assert self.key is not None\n            y = getattr(store, self.key)\n            num_classes = int(y.max().item()) + 1\n            for c in range(num_classes):\n                idx = (y == c).nonzero(as_tuple=False).view(-1)\n                idx = idx[torch.randperm(idx.size(0))]\n                idx = idx[:self.num_train_per_class]\n                train_mask[idx] = True\n\n            remaining = (~train_mask).nonzero(as_tuple=False).view(-1)\n            remaining = remaining[torch.randperm(remaining.size(0))]\n\n            val_mask[remaining[:num_val]] = True\n\n            if self.split == 'test_rest':\n                test_mask[remaining[num_val:]] = True\n            elif self.split == 'random':\n                test_mask[remaining[num_val:num_val + num_test]] = True\n\n        return train_mask, val_mask, test_mask\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(split={self.split})'\n"
  },
  {
    "path": "torch_geometric/transforms/random_rotate.py",
    "content": "import math\nimport random\nfrom typing import Tuple, Union\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform, LinearTransformation\n\n\n@functional_transform('random_rotate')\nclass RandomRotate(BaseTransform):\n    r\"\"\"Rotates node positions around a specific axis by a randomly sampled\n    factor within a given interval (functional name: :obj:`random_rotate`).\n\n    Args:\n        degrees (tuple or float): Rotation interval from which the rotation\n            angle is sampled. If :obj:`degrees` is a number instead of a\n            tuple, the interval is given by :math:`[-\\mathrm{degrees},\n            \\mathrm{degrees}]`.\n        axis (int, optional): The rotation axis. (default: :obj:`0`)\n    \"\"\"\n    def __init__(\n        self,\n        degrees: Union[Tuple[float, float], float],\n        axis: int = 0,\n    ) -> None:\n        if isinstance(degrees, (int, float)):\n            degrees = (-abs(degrees), abs(degrees))\n        assert isinstance(degrees, (tuple, list)) and len(degrees) == 2\n        self.degrees = degrees\n        self.axis = axis\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n\n        degree = math.pi * random.uniform(*self.degrees) / 180.0\n        sin, cos = math.sin(degree), math.cos(degree)\n\n        if data.pos.size(-1) == 2:\n            matrix = [[cos, sin], [-sin, cos]]\n        else:\n            if self.axis == 0:\n                matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]]\n            elif self.axis == 1:\n                matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]]\n            else:\n                matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]]\n\n        return LinearTransformation(torch.tensor(matrix))(data)\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}({self.degrees}, '\n                f'axis={self.axis})')\n"
  },
  {
    "path": "torch_geometric/transforms/random_scale.py",
    "content": "import random\nfrom typing import Tuple\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('random_scale')\nclass RandomScale(BaseTransform):\n    r\"\"\"Scales node positions by a randomly sampled factor :math:`s` within a\n    given interval, *e.g.*, resulting in the transformation matrix\n    (functional name: :obj:`random_scale`).\n\n    .. math::\n        \\begin{bmatrix}\n            s & 0 & 0 \\\\\n            0 & s & 0 \\\\\n            0 & 0 & s \\\\\n        \\end{bmatrix}\n\n    for three-dimensional positions.\n\n    Args:\n        scales (tuple): scaling factor interval, e.g. :obj:`(a, b)`, then scale\n            is randomly sampled from the range\n            :math:`a \\leq \\mathrm{scale} \\leq b`.\n    \"\"\"\n    def __init__(self, scales: Tuple[float, float]) -> None:\n        assert isinstance(scales, (tuple, list)) and len(scales) == 2\n        self.scales = scales\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n\n        scale = random.uniform(*self.scales)\n        data.pos = data.pos * scale\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.scales})'\n"
  },
  {
    "path": "torch_geometric/transforms/random_shear.py",
    "content": "from typing import Union\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform, LinearTransformation\n\n\n@functional_transform('random_shear')\nclass RandomShear(BaseTransform):\n    r\"\"\"Shears node positions by randomly sampled factors :math:`s` within a\n    given interval, *e.g.*, resulting in the transformation matrix\n    (functional name: :obj:`random_shear`).\n\n    .. math::\n        \\begin{bmatrix}\n            1      & s_{xy} & s_{xz} \\\\\n            s_{yx} & 1      & s_{yz} \\\\\n            s_{zx} & z_{zy} & 1      \\\\\n        \\end{bmatrix}\n\n    for three-dimensional positions.\n\n    Args:\n        shear (float or int): maximum shearing factor defining the range\n            :math:`(-\\mathrm{shear}, +\\mathrm{shear})` to sample from.\n    \"\"\"\n    def __init__(self, shear: Union[float, int]) -> None:\n        self.shear = abs(shear)\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n\n        dim = data.pos.size(-1)\n\n        matrix = data.pos.new_empty(dim, dim).uniform_(-self.shear, self.shear)\n        eye = torch.arange(dim, dtype=torch.long)\n        matrix[eye, eye] = 1\n\n        return LinearTransformation(matrix)(data)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.shear})'\n"
  },
  {
    "path": "torch_geometric/transforms/remove_duplicated_edges.py",
    "content": "from typing import List, Optional, Union\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import coalesce\n\n\n@functional_transform('remove_duplicated_edges')\nclass RemoveDuplicatedEdges(BaseTransform):\n    r\"\"\"Removes duplicated edges from a given homogeneous or heterogeneous\n    graph. Useful to clean-up known repeated edges/self-loops in common\n    benchmark datasets, *e.g.*, in :obj:`ogbn-products`.\n    (functional name: :obj:`remove_duplicated_edges`).\n\n    Args:\n        key (str or [str], optional): The name of edge attribute(s) to merge in\n            case of duplication. (default: :obj:`[\"edge_weight\", \"edge_attr\"]`)\n        reduce (str, optional): The reduce operation to use for merging edge\n            attributes (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"min\"`,\n            :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`\"add\"`)\n    \"\"\"\n    def __init__(\n        self,\n        key: Optional[Union[str, List[str]]] = None,\n        reduce: str = \"add\",\n    ) -> None:\n        key = key or ['edge_attr', 'edge_weight']\n\n        if isinstance(key, str):\n            key = [key]\n\n        self.keys = key\n        self.reduce = reduce\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n\n        for store in data.edge_stores:\n            keys = [key for key in self.keys if key in store]\n\n            size = [s for s in store.size() if s is not None]\n            num_nodes = max(size) if len(size) > 0 else None\n\n            store.edge_index, edge_attrs = coalesce(\n                edge_index=store.edge_index,\n                edge_attr=[store[key] for key in keys],\n                num_nodes=num_nodes,\n                reduce=self.reduce,\n            )\n\n            for key, edge_attr in zip(keys, edge_attrs):\n                store[key] = edge_attr\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/remove_isolated_nodes.py",
    "content": "import copy\nfrom collections import defaultdict\nfrom typing import Union\n\nimport torch\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('remove_isolated_nodes')\nclass RemoveIsolatedNodes(BaseTransform):\n    r\"\"\"Removes isolated nodes from the graph\n    (functional name: :obj:`remove_isolated_nodes`).\n    \"\"\"\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        # Gather all nodes that occur in at least one edge (across all types):\n        n_ids_dict = defaultdict(list)\n        for edge_store in data.edge_stores:\n            if 'edge_index' not in edge_store:\n                continue\n\n            if edge_store._key is None:\n                src = dst = None\n            else:\n                src, _, dst = edge_store._key\n\n            n_ids_dict[src].append(edge_store.edge_index[0])\n            n_ids_dict[dst].append(edge_store.edge_index[1])\n\n        n_id_dict = {k: torch.cat(v).unique() for k, v in n_ids_dict.items()}\n\n        n_map_dict = {}\n        for node_store in data.node_stores:\n            if node_store._key not in n_id_dict:\n                n_id_dict[node_store._key] = torch.empty(0, dtype=torch.long)\n\n            idx = n_id_dict[node_store._key]\n            assert data.num_nodes is not None\n            mapping = idx.new_zeros(data.num_nodes)\n            mapping[idx] = torch.arange(idx.numel(), device=mapping.device)\n            n_map_dict[node_store._key] = mapping\n\n        for edge_store in data.edge_stores:\n            if 'edge_index' not in edge_store:\n                continue\n\n            if edge_store._key is None:\n                src = dst = None\n            else:\n                src, _, dst = edge_store._key\n\n            row = n_map_dict[src][edge_store.edge_index[0]]\n            col = n_map_dict[dst][edge_store.edge_index[1]]\n            edge_store.edge_index = torch.stack([row, col], dim=0)\n\n        old_data = copy.copy(data)\n        for out, node_store in zip(data.node_stores, old_data.node_stores):\n            for key, value in node_store.items():\n                if key == 'num_nodes':\n                    out.num_nodes = n_id_dict[node_store._key].numel()\n                elif node_store.is_node_attr(key):\n                    out[key] = value[n_id_dict[node_store._key]]\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/remove_self_loops.py",
    "content": "from typing import Union\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import remove_self_loops\n\n\n@functional_transform('remove_self_loops')\nclass RemoveSelfLoops(BaseTransform):\n    r\"\"\"Removes all self-loops in the given homogeneous or heterogeneous\n    graph (functional name: :obj:`remove_self_loops`).\n\n    Args:\n        attr (str, optional): The name of the attribute of edge weights\n            or multi-dimensional edge features to pass to\n            :meth:`torch_geometric.utils.remove_self_loops`.\n            (default: :obj:`\"edge_weight\"`)\n    \"\"\"\n    def __init__(self, attr: str = 'edge_weight') -> None:\n        self.attr = attr\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for store in data.edge_stores:\n            if store.is_bipartite() or 'edge_index' not in store:\n                continue\n\n            store.edge_index, store[self.attr] = remove_self_loops(\n                store.edge_index,\n                edge_attr=store.get(self.attr, None),\n            )\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/remove_training_classes.py",
    "content": "from typing import List\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('remove_training_classes')\nclass RemoveTrainingClasses(BaseTransform):\n    r\"\"\"Removes classes from the node-level training set as given by\n    :obj:`data.train_mask`, *e.g.*, in order to get a zero-shot label scenario\n    (functional name: :obj:`remove_training_classes`).\n\n    Args:\n        classes (List[int]): The classes to remove from the training set.\n    \"\"\"\n    def __init__(self, classes: List[int]):\n        self.classes = classes\n\n    def forward(self, data: Data) -> Data:\n        data.train_mask = data.train_mask.clone()\n        for i in self.classes:\n            data.train_mask[data.y == i] = False\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.classes})'\n"
  },
  {
    "path": "torch_geometric/transforms/rooted_subgraph.py",
    "content": "import copy\nfrom abc import ABC, abstractmethod\nfrom typing import Any, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import to_torch_csc_tensor\n\n\nclass RootedSubgraphData(Data):\n    r\"\"\"A data object describing a homogeneous graph together with each node's\n    rooted subgraph.\n\n    It contains several additional properties that hold the information to map\n    to batch of every node's rooted subgraph:\n\n    * :obj:`sub_edge_index` (Tensor): The edge indices of all combined rooted\n      subgraphs.\n    * :obj:`n_id` (Tensor): The indices of nodes in all combined rooted\n      subgraphs.\n    * :obj:`e_id` (Tensor): The indices of edges in all combined rooted\n      subgraphs.\n    * :obj:`n_sub_batch` (Tensor): The batch vector to distinguish nodes across\n      different subgraphs.\n    * :obj:`e_sub_batch` (Tensor): The batch vector to distinguish edges across\n      different subgraphs.\n    \"\"\"\n    def __inc__(self, key: str, value: Any, *args: Any, **kwargs: Any) -> Any:\n        if key == 'sub_edge_index':\n            return self.n_id.size(0)\n        if key in ['n_sub_batch', 'e_sub_batch']:\n            return 1 + int(self.n_sub_batch[-1])\n        elif key == 'n_id':\n            return self.num_nodes\n        elif key == 'e_id':\n            assert self.edge_index is not None\n            return self.edge_index.size(1)\n        return super().__inc__(key, value, *args, **kwargs)\n\n    def map_data(self) -> Data:\n        # Maps all feature information of the :class:`Data` object to each\n        # rooted subgraph.\n        data = copy.copy(self)\n\n        for key, value in self.items():\n            if key in ['sub_edge_index', 'n_id', 'e_id', 'e_sub_batch']:\n                del data[key]\n            elif key == 'n_sub_batch':\n                continue\n            elif key == 'num_nodes':\n                data.num_nodes = self.n_id.size(0)\n            elif key == 'edge_index':\n                data.edge_index = self.sub_edge_index\n            elif self.is_node_attr(key):\n                dim = self.__cat_dim__(key, value)\n                data[key] = value.index_select(dim, self.n_id)\n            elif self.is_edge_attr(key):\n                dim = self.__cat_dim__(key, value)\n                data[key] = value.index_select(dim, self.e_id)\n\n        return data\n\n\nclass RootedSubgraph(BaseTransform, ABC):\n    r\"\"\"Base class for implementing rooted subgraph transformations.\"\"\"\n    @abstractmethod\n    def extract(\n        self,\n        data: Data,\n    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n        # Returns the tuple:\n        # :obj:`(sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch)`\n        # of the :class:`RootedSubgraphData` object.\n        pass\n\n    def map(\n        self,\n        data: Data,\n        n_mask: Tensor,\n    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n\n        assert data.edge_index is not None\n        num_nodes = data.num_nodes\n        assert num_nodes is not None\n\n        n_sub_batch, n_id = n_mask.nonzero().t()\n        e_mask = n_mask[:, data.edge_index[0]] & n_mask[:, data.edge_index[1]]\n        e_sub_batch, e_id = e_mask.nonzero().t()\n\n        sub_edge_index = data.edge_index[:, e_id]\n        arange = torch.arange(n_id.size(0), device=data.edge_index.device)\n        node_map = data.edge_index.new_ones(num_nodes, num_nodes)\n        node_map[n_sub_batch, n_id] = arange\n        sub_edge_index += (arange * num_nodes)[e_sub_batch]\n        sub_edge_index = node_map.view(-1)[sub_edge_index]\n\n        return sub_edge_index, n_id, e_id, n_sub_batch, e_sub_batch\n\n    def forward(self, data: Data) -> RootedSubgraphData:\n        out = self.extract(data)\n        d = RootedSubgraphData.from_dict(data.to_dict())\n        d.sub_edge_index, d.n_id, d.e_id, d.n_sub_batch, d.e_sub_batch = out\n        return d\n\n\nclass RootedEgoNets(RootedSubgraph):\n    r\"\"\"Collects rooted :math:`k`-hop EgoNets for each node in the graph, as\n    described in the `\"From Stars to Subgraphs: Uplifting Any GNN with Local\n    Structure Awareness\" <https://arxiv.org/abs/2110.03753>`_ paper.\n\n    Args:\n        num_hops (int): the number of hops :math:`k`.\n    \"\"\"\n    def __init__(self, num_hops: int) -> None:\n        super().__init__()\n        self.num_hops = num_hops\n\n    def extract(\n        self,\n        data: Data,\n    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n\n        assert data.edge_index is not None\n        num_nodes = data.num_nodes\n        assert num_nodes is not None\n\n        adj_t = to_torch_csc_tensor(data.edge_index, size=data.size()).t()\n        n_mask = torch.eye(num_nodes, device=data.edge_index.device)\n        for _ in range(self.num_hops):\n            n_mask += adj_t @ n_mask\n\n        return self.map(data, n_mask > 0)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(num_hops={self.num_hops})'\n\n\nclass RootedRWSubgraph(RootedSubgraph):\n    \"\"\"Collects rooted random-walk based subgraphs for each node in the graph,\n    as described in the `\"From Stars to Subgraphs: Uplifting Any GNN with Local\n    Structure Awareness\" <https://arxiv.org/abs/2110.03753>`_ paper.\n\n    Args:\n        walk_length (int): the length of the random walk.\n        repeat (int, optional): The number of times of repeating the random\n            walk to reduce randomness. (default: :obj:`1`)\n    \"\"\"\n    def __init__(self, walk_length: int, repeat: int = 1):\n        super().__init__()\n        self.walk_length = walk_length\n        self.repeat = repeat\n\n    def extract(\n        self,\n        data: Data,\n    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:\n        from torch_cluster import random_walk\n\n        assert data.edge_index is not None\n        num_nodes = data.num_nodes\n        assert num_nodes is not None\n\n        start = torch.arange(num_nodes, device=data.edge_index.device)\n        start = start.view(-1, 1).repeat(1, self.repeat).view(-1)\n        walk = random_walk(data.edge_index[0], data.edge_index[1], start,\n                           self.walk_length, num_nodes=data.num_nodes)\n\n        n_mask = torch.zeros((num_nodes, num_nodes), dtype=torch.bool,\n                             device=walk.device)\n        start = start.view(-1, 1).repeat(1, (self.walk_length + 1)).view(-1)\n        n_mask[start, walk.view(-1)] = True\n\n        return self.map(data, n_mask)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(walk_length={self.walk_length})'\n"
  },
  {
    "path": "torch_geometric/transforms/sample_points.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('sample_points')\nclass SamplePoints(BaseTransform):\n    r\"\"\"Uniformly samples a fixed number of points on the mesh faces according\n    to their face area (functional name: :obj:`sample_points`).\n\n    Args:\n        num (int): The number of points to sample.\n        remove_faces (bool, optional): If set to :obj:`False`, the face tensor\n            will not be removed. (default: :obj:`True`)\n        include_normals (bool, optional): If set to :obj:`True`, then compute\n            normals for each sampled point. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        num: int,\n        remove_faces: bool = True,\n        include_normals: bool = False,\n    ):\n        self.num = num\n        self.remove_faces = remove_faces\n        self.include_normals = include_normals\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        assert data.face is not None\n\n        pos, face = data.pos, data.face\n        assert pos.size(1) == 3 and face.size(0) == 3\n\n        pos_max = pos.abs().max()\n        pos = pos / pos_max\n\n        area = (pos[face[1]] - pos[face[0]]).cross(\n            pos[face[2]] - pos[face[0]],\n            dim=1,\n        )\n        area = area.norm(p=2, dim=1).abs() / 2\n\n        prob = area / area.sum()\n        sample = torch.multinomial(prob, self.num, replacement=True)\n        face = face[:, sample]\n\n        frac = torch.rand(self.num, 2, device=pos.device)\n        mask = frac.sum(dim=-1) > 1\n        frac[mask] = 1 - frac[mask]\n\n        vec1 = pos[face[1]] - pos[face[0]]\n        vec2 = pos[face[2]] - pos[face[0]]\n\n        if self.include_normals:\n            data.normal = torch.nn.functional.normalize(\n                vec1.cross(vec2, dim=1), p=2)\n\n        pos_sampled = pos[face[0]]\n        pos_sampled += frac[:, :1] * vec1\n        pos_sampled += frac[:, 1:] * vec2\n\n        pos_sampled = pos_sampled * pos_max\n        data.pos = pos_sampled\n\n        if self.remove_faces:\n            data.face = None\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.num})'\n"
  },
  {
    "path": "torch_geometric/transforms/sign.py",
    "content": "import torch\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import scatter\n\n\n@functional_transform('sign')\nclass SIGN(BaseTransform):\n    r\"\"\"The Scalable Inception Graph Neural Network module (SIGN) from the\n    `\"SIGN: Scalable Inception Graph Neural Networks\"\n    <https://arxiv.org/abs/2004.11198>`_ paper (functional name: :obj:`sign`),\n    which precomputes the fixed representations.\n\n    .. math::\n        \\mathbf{X}^{(i)} = {\\left( \\mathbf{D}^{-1/2} \\mathbf{A}\n        \\mathbf{D}^{-1/2} \\right)}^i \\mathbf{X}\n\n    for :math:`i \\in \\{ 1, \\ldots, K \\}` and saves them in\n    :obj:`data.x1`, :obj:`data.x2`, ...\n\n    .. note::\n\n        Since intermediate node representations are pre-computed, this operator\n        is able to scale well to large graphs via classic mini-batching.\n        For an example of using SIGN, see `examples/sign.py\n        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/\n        sign.py>`_.\n\n    Args:\n        K (int): The number of hops/layer.\n    \"\"\"\n    def __init__(self, K: int) -> None:\n        self.K = K\n\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n        edge_index = data.edge_index\n        row, col = data.edge_index\n        num_nodes = data.num_nodes\n\n        edge_weight = data.edge_weight\n        if edge_weight is None:\n            edge_weight = torch.ones(data.num_edges, device=edge_index.device)\n\n        deg = scatter(edge_weight, col, dim_size=num_nodes, reduce='sum')\n        deg_inv_sqrt = deg.pow_(-0.5)\n        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)\n        edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]\n\n        edge_index = EdgeIndex(edge_index, sparse_size=(num_nodes, num_nodes))\n        edge_index, perm = edge_index.sort_by('col')\n        edge_weight = edge_weight[perm]\n\n        assert data.x is not None\n        xs = [data.x]\n        for i in range(1, self.K + 1):\n            xs.append(edge_index.matmul(xs[-1], edge_weight, transpose=True))\n            data[f'x{i}'] = xs[-1]\n\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}(K={self.K})'\n"
  },
  {
    "path": "torch_geometric/transforms/spherical.py",
    "content": "from math import pi as PI\nfrom typing import Optional\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('spherical')\nclass Spherical(BaseTransform):\n    r\"\"\"Saves the spherical coordinates of linked nodes in its edge attributes\n    (functional name: :obj:`spherical`).\n\n    Args:\n        norm (bool, optional): If set to :obj:`False`, the output will not be\n            normalized to the interval :math:`{[0, 1]}^3`.\n            (default: :obj:`True`)\n        max_value (float, optional): If set and :obj:`norm=True`, normalization\n            will be performed based on this value instead of the maximum value\n            found in the data. (default: :obj:`None`)\n        cat (bool, optional): If set to :obj:`False`, all existing edge\n            attributes will be replaced. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        norm: bool = True,\n        max_value: Optional[float] = None,\n        cat: bool = True,\n    ):\n        self.norm = norm\n        self.max = max_value\n        self.cat = cat\n\n    def forward(self, data: Data) -> Data:\n        assert data.pos is not None\n        assert data.edge_index is not None\n        (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr\n        assert pos.dim() == 2 and pos.size(1) == 3\n\n        cart = pos[col] - pos[row]\n\n        rho = torch.norm(cart, p=2, dim=-1).view(-1, 1)\n\n        theta = torch.atan2(cart[..., 1], cart[..., 0]).view(-1, 1)\n        theta = theta + (theta < 0).type_as(theta) * (2 * PI)\n\n        phi = torch.acos(cart[..., 2] / rho.view(-1)).view(-1, 1)\n\n        if self.norm:\n            rho = rho / (rho.max() if self.max is None else self.max)\n            theta = theta / (2 * PI)\n            phi = phi / PI\n\n        spher = torch.cat([rho, theta, phi], dim=-1)\n\n        if pseudo is not None and self.cat:\n            pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo\n            data.edge_attr = torch.cat([pseudo, spher.type_as(pos)], dim=-1)\n        else:\n            data.edge_attr = spher\n\n        return data\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(norm={self.norm}, '\n                f'max_value={self.max})')\n"
  },
  {
    "path": "torch_geometric/transforms/svd_feature_reduction.py",
    "content": "import torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('svd_feature_reduction')\nclass SVDFeatureReduction(BaseTransform):\n    r\"\"\"Dimensionality reduction of node features via Singular Value\n    Decomposition (SVD) (functional name: :obj:`svd_feature_reduction`).\n\n    Args:\n        out_channels (int): The dimensionality of node features after\n            reduction.\n    \"\"\"\n    def __init__(self, out_channels: int):\n        self.out_channels = out_channels\n\n    def forward(self, data: Data) -> Data:\n        assert data.x is not None\n\n        if data.x.size(-1) > self.out_channels:\n            U, S, _ = torch.linalg.svd(data.x)\n            data.x = torch.mm(U[:, :self.out_channels],\n                              torch.diag(S[:self.out_channels]))\n        return data\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.out_channels})'\n"
  },
  {
    "path": "torch_geometric/transforms/target_indegree.py",
    "content": "from typing import Optional\n\nimport torch\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import degree\n\n\n@functional_transform('target_indegree')\nclass TargetIndegree(BaseTransform):\n    r\"\"\"Saves the globally normalized degree of target nodes\n    (functional name: :obj:`target_indegree`).\n\n    .. math::\n\n        \\mathbf{u}(i,j) = \\frac{\\deg(j)}{\\max_{v \\in \\mathcal{V}} \\deg(v)}\n\n    in its edge attributes.\n\n    Args:\n        cat (bool, optional): Concat pseudo-coordinates to edge attributes\n            instead of replacing them. (default: :obj:`True`)\n    \"\"\"\n    def __init__(\n        self,\n        norm: bool = True,\n        max_value: Optional[float] = None,\n        cat: bool = True,\n    ) -> None:\n        self.norm = norm\n        self.max = max_value\n        self.cat = cat\n\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n        col, pseudo = data.edge_index[1], data.edge_attr\n\n        deg = degree(col, data.num_nodes)\n\n        if self.norm:\n            deg = deg / (deg.max() if self.max is None else self.max)\n\n        deg = deg[col]\n        deg = deg.view(-1, 1)\n\n        if pseudo is not None and self.cat:\n            pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo\n            data.edge_attr = torch.cat([pseudo, deg.type_as(pseudo)], dim=-1)\n        else:\n            data.edge_attr = deg\n\n        return data\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(norm={self.norm}, '\n                f'max_value={self.max})')\n"
  },
  {
    "path": "torch_geometric/transforms/to_dense.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('to_dense')\nclass ToDense(BaseTransform):\n    r\"\"\"Converts a sparse adjacency matrix to a dense adjacency matrix with\n    shape :obj:`[num_nodes, num_nodes, *]` (functional name: :obj:`to_dense`).\n\n    Args:\n        num_nodes (int, optional): The number of nodes. If set to :obj:`None`,\n            the number of nodes will get automatically inferred.\n            (default: :obj:`None`)\n    \"\"\"\n    def __init__(self, num_nodes: Optional[int] = None) -> None:\n        self.num_nodes = num_nodes\n\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n\n        orig_num_nodes = data.num_nodes\n        assert orig_num_nodes is not None\n\n        if self.num_nodes is None:\n            num_nodes = orig_num_nodes\n        else:\n            assert orig_num_nodes <= self.num_nodes\n            num_nodes = self.num_nodes\n\n        if data.edge_attr is None:\n            edge_attr = torch.ones(data.edge_index.size(1), dtype=torch.float)\n        else:\n            edge_attr = data.edge_attr\n\n        size = torch.Size([num_nodes, num_nodes] + list(edge_attr.size())[1:])\n        adj = torch.sparse_coo_tensor(data.edge_index, edge_attr, size)\n        data.adj = adj.to_dense()\n        data.edge_index = None\n        data.edge_attr = None\n\n        data.mask = torch.zeros(num_nodes, dtype=torch.bool)\n        data.mask[:orig_num_nodes] = 1\n\n        if data.x is not None:\n            _size = [num_nodes - data.x.size(0)] + list(data.x.size())[1:]\n            data.x = torch.cat([data.x, data.x.new_zeros(_size)], dim=0)\n\n        if data.pos is not None:\n            _size = [num_nodes - data.pos.size(0)] + list(data.pos.size())[1:]\n            data.pos = torch.cat([data.pos, data.pos.new_zeros(_size)], dim=0)\n\n        if (data.y is not None and isinstance(data.y, Tensor)\n                and data.y.size(0) == orig_num_nodes):\n            _size = [num_nodes - data.y.size(0)] + list(data.y.size())[1:]\n            data.y = torch.cat([data.y, data.y.new_zeros(_size)], dim=0)\n\n        return data\n\n    def __repr__(self) -> str:\n        if self.num_nodes is None:\n            return f'{self.__class__.__name__}()'\n        return f'{self.__class__.__name__}(num_nodes={self.num_nodes})'\n"
  },
  {
    "path": "torch_geometric/transforms/to_device.py",
    "content": "from typing import List, Optional, Union\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('to_device')\nclass ToDevice(BaseTransform):\n    r\"\"\"Performs tensor device conversion, either for all attributes of the\n    :obj:`~torch_geometric.data.Data` object or only the ones given by\n    :obj:`attrs` (functional name: :obj:`to_device`).\n\n    Args:\n        device (torch.device): The destination device.\n        attrs (List[str], optional): If given, will only perform tensor device\n            conversion for the given attributes. (default: :obj:`None`)\n        non_blocking (bool, optional): If set to :obj:`True` and tensor\n            values are in pinned memory, the copy will be asynchronous with\n            respect to the host. (default: :obj:`False`)\n    \"\"\"\n    def __init__(\n        self,\n        device: Union[int, str],\n        attrs: Optional[List[str]] = None,\n        non_blocking: bool = False,\n    ) -> None:\n        self.device = device\n        self.attrs = attrs or []\n        self.non_blocking = non_blocking\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        return data.to(self.device, *self.attrs,\n                       non_blocking=self.non_blocking)\n\n    def __repr__(self) -> str:\n        return f'{self.__class__.__name__}({self.device})'\n"
  },
  {
    "path": "torch_geometric/transforms/to_sparse_tensor.py",
    "content": "from typing import Optional, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import (\n    sort_edge_index,\n    to_torch_coo_tensor,\n    to_torch_csr_tensor,\n)\n\n\n@functional_transform('to_sparse_tensor')\nclass ToSparseTensor(BaseTransform):\n    r\"\"\"Converts the :obj:`edge_index` attributes of a homogeneous or\n    heterogeneous data object into a **transposed**\n    :class:`torch_sparse.SparseTensor` or :pytorch:`PyTorch`\n    :class:`torch.sparse.Tensor` object with key :obj:`adj_t`\n    (functional name: :obj:`to_sparse_tensor`).\n\n    .. note::\n\n        In case of composing multiple transforms, it is best to convert the\n        :obj:`data` object via :class:`ToSparseTensor` as late as possible,\n        since there exist some transforms that are only able to operate on\n        :obj:`data.edge_index` for now.\n\n    Args:\n        attr (str, optional): The name of the attribute to add as a value to\n            the :class:`~torch_sparse.SparseTensor` or\n            :class:`torch.sparse.Tensor` object (if present).\n            (default: :obj:`edge_weight`)\n        remove_edge_index (bool, optional): If set to :obj:`False`, the\n            :obj:`edge_index` tensor will not be removed.\n            (default: :obj:`True`)\n        fill_cache (bool, optional): If set to :obj:`True`, will fill the\n            underlying :class:`torch_sparse.SparseTensor` cache (if used).\n            (default: :obj:`True`)\n        layout (torch.layout, optional): Specifies the layout of the returned\n            sparse tensor (:obj:`None`, :obj:`torch.sparse_coo` or\n            :obj:`torch.sparse_csr`).\n            If set to :obj:`None` and the :obj:`torch_sparse` dependency is\n            installed, will convert :obj:`edge_index` into a\n            :class:`torch_sparse.SparseTensor` object.\n            If set to :obj:`None` and the :obj:`torch_sparse` dependency is\n            not installed, will convert :obj:`edge_index` into a\n            :class:`torch.sparse.Tensor` object with layout\n            :obj:`torch.sparse_csr`. (default: :obj:`None`)\n    \"\"\"\n    def __init__(\n        self,\n        attr: Optional[str] = 'edge_weight',\n        remove_edge_index: bool = True,\n        fill_cache: bool = True,\n        layout: Optional[int] = None,\n    ) -> None:\n        if layout not in {None, torch.sparse_coo, torch.sparse_csr}:\n            raise ValueError(f\"Unexpected sparse tensor layout \"\n                             f\"(got '{layout}')\")\n\n        self.attr = attr\n        self.remove_edge_index = remove_edge_index\n        self.fill_cache = fill_cache\n        self.layout = layout\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n\n        for store in data.edge_stores:\n            if 'edge_index' not in store:\n                continue\n\n            keys, values = [], []\n            for key, value in store.items():\n                if key in {'edge_index', 'edge_label', 'edge_label_index'}:\n                    continue\n\n                if store.is_edge_attr(key):\n                    keys.append(key)\n                    values.append(value)\n\n            store.edge_index, values = sort_edge_index(\n                store.edge_index,\n                values,\n                sort_by_row=False,\n            )\n\n            for key, value in zip(keys, values):\n                store[key] = value\n\n            layout = self.layout\n            size = store.size()[::-1]\n            edge_weight: Optional[Tensor] = None\n            if self.attr is not None and self.attr in store:\n                edge_weight = store[self.attr]\n\n            if layout is None and torch_geometric.typing.WITH_TORCH_SPARSE:\n                store.adj_t = SparseTensor(\n                    row=store.edge_index[1],\n                    col=store.edge_index[0],\n                    value=edge_weight,\n                    sparse_sizes=size,\n                    is_sorted=True,\n                    trust_data=True,\n                )\n\n            # TODO Multi-dimensional edge attributes only supported for COO.\n            elif ((edge_weight is not None and edge_weight.dim() > 1)\n                  or layout == torch.sparse_coo):\n                assert size[0] is not None and size[1] is not None\n                store.adj_t = to_torch_coo_tensor(\n                    store.edge_index.flip([0]),\n                    edge_attr=edge_weight,\n                    size=size,\n                )\n\n            elif layout is None or layout == torch.sparse_csr:\n                assert size[0] is not None and size[1] is not None\n                store.adj_t = to_torch_csr_tensor(\n                    store.edge_index.flip([0]),\n                    edge_attr=edge_weight,\n                    size=size,\n                )\n\n            if self.remove_edge_index:\n                del store['edge_index']\n                if self.attr is not None and self.attr in store:\n                    del store[self.attr]\n\n            if self.fill_cache and isinstance(store.adj_t, SparseTensor):\n                # Pre-process some important attributes.\n                store.adj_t.storage.rowptr()\n                store.adj_t.storage.csr2csc()\n\n        return data\n\n    def __repr__(self) -> str:\n        return (f'{self.__class__.__name__}(attr={self.attr}, '\n                f'layout={self.layout})')\n"
  },
  {
    "path": "torch_geometric/transforms/to_superpixels.py",
    "content": "from typing import Any\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import scatter\n\n\n@functional_transform('to_slic')\nclass ToSLIC(BaseTransform):\n    r\"\"\"Converts an image to a superpixel representation using the\n    :meth:`skimage.segmentation.slic` algorithm, resulting in a\n    :obj:`torch_geometric.data.Data` object holding the centroids of\n    superpixels in :obj:`data.pos` and their mean color in :obj:`data.x`\n    (functional name: :obj:`to_slic`).\n\n    This transform can be used with any :obj:`torchvision` dataset.\n\n    .. code-block:: python\n\n        from torchvision.datasets import MNIST\n        import torchvision.transforms as T\n        from torch_geometric.transforms import ToSLIC\n\n        transform = T.Compose([T.ToTensor(), ToSLIC(n_segments=75)])\n        dataset = MNIST('/tmp/MNIST', download=True, transform=transform)\n\n    Args:\n        add_seg (bool, optional): If set to `True`, will add the segmentation\n            result to the data object. (default: :obj:`False`)\n        add_img (bool, optional): If set to `True`, will add the input image\n            to the data object. (default: :obj:`False`)\n        **kwargs (optional): Arguments to adjust the output of the SLIC\n            algorithm. See the `SLIC documentation\n            <https://scikit-image.org/docs/dev/api/skimage.segmentation.html\n            #skimage.segmentation.slic>`_ for an overview.\n    \"\"\"\n    def __init__(\n        self,\n        add_seg: bool = False,\n        add_img: bool = False,\n        **kwargs: Any,\n    ) -> None:\n        self.add_seg = add_seg\n        self.add_img = add_img\n        self.kwargs = kwargs\n\n    def forward(self, img: Tensor) -> Data:\n        from skimage.segmentation import slic\n\n        img = img.permute(1, 2, 0)\n        h, w, c = img.size()\n\n        seg = slic(img.to(torch.double).numpy(), start_label=0, **self.kwargs)\n        seg = torch.from_numpy(seg)\n\n        x = scatter(img.view(h * w, c), seg.view(h * w), dim=0, reduce='mean')\n\n        pos_y = torch.arange(h, dtype=torch.float)\n        pos_y = pos_y.view(-1, 1).repeat(1, w).view(h * w)\n        pos_x = torch.arange(w, dtype=torch.float)\n        pos_x = pos_x.view(1, -1).repeat(h, 1).view(h * w)\n\n        pos = torch.stack([pos_x, pos_y], dim=-1)\n        pos = scatter(pos, seg.view(h * w), dim=0, reduce='mean')\n\n        data = Data(x=x, pos=pos)\n\n        if self.add_seg:\n            data.seg = seg.view(1, h, w)\n\n        if self.add_img:\n            data.img = img.permute(2, 0, 1).view(1, c, h, w)\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/to_undirected.py",
    "content": "from typing import Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data, HeteroData\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import to_undirected\n\n\n@functional_transform('to_undirected')\nclass ToUndirected(BaseTransform):\n    r\"\"\"Converts a homogeneous or heterogeneous graph to an undirected graph\n    such that :math:`(j,i) \\in \\mathcal{E}` for every edge\n    :math:`(i,j) \\in \\mathcal{E}` (functional name: :obj:`to_undirected`).\n    In heterogeneous graphs, will add \"reverse\" connections for *all* existing\n    edge types.\n\n    Args:\n        reduce (str, optional): The reduce operation to use for merging edge\n            features (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"min\"`, :obj:`\"max\"`,\n            :obj:`\"mul\"`). (default: :obj:`\"add\"`)\n        merge (bool, optional): If set to :obj:`False`, will create reverse\n            edge types for connections pointing to the same source and target\n            node type.\n            If set to :obj:`True`, reverse edges will be merged into the\n            original relation.\n            This option only has effects in\n            :class:`~torch_geometric.data.HeteroData` graph data.\n            (default: :obj:`True`)\n    \"\"\"\n    def __init__(self, reduce: str = \"add\", merge: bool = True):\n        self.reduce = reduce\n        self.merge = merge\n\n    def forward(\n        self,\n        data: Union[Data, HeteroData],\n    ) -> Union[Data, HeteroData]:\n        for store in data.edge_stores:\n            if 'edge_index' not in store:\n                continue\n\n            nnz = store.edge_index.size(1)\n\n            if isinstance(data, HeteroData) and (store.is_bipartite()\n                                                 or not self.merge):\n                src, rel, dst = store._key\n\n                # Just reverse the connectivity and add edge attributes:\n                row, col = store.edge_index\n                rev_edge_index = torch.stack([col, row], dim=0)\n\n                inv_store = data[dst, f'rev_{rel}', src]\n                inv_store.edge_index = rev_edge_index\n                for key, value in store.items():\n                    if key == 'edge_index':\n                        continue\n                    if isinstance(value, Tensor) and value.size(0) == nnz:\n                        inv_store[key] = value\n\n            else:\n                keys, values = [], []\n                for key, value in store.items():\n                    if key == 'edge_index':\n                        continue\n\n                    if store.is_edge_attr(key):\n                        keys.append(key)\n                        values.append(value)\n\n                store.edge_index, values = to_undirected(\n                    store.edge_index, values, reduce=self.reduce)\n\n                for key, value in zip(keys, values):\n                    store[key] = value\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/two_hop.py",
    "content": "import torch\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\nfrom torch_geometric.utils import coalesce, remove_self_loops\n\n\n@functional_transform('two_hop')\nclass TwoHop(BaseTransform):\n    r\"\"\"Adds the two hop edges to the edge indices\n    (functional name: :obj:`two_hop`).\n    \"\"\"\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n        edge_index, edge_attr = data.edge_index, data.edge_attr\n        N = data.num_nodes\n\n        edge_index = EdgeIndex(edge_index, sparse_size=(N, N))\n        edge_index = edge_index.sort_by('row')[0]\n        edge_index2 = edge_index.matmul(edge_index)[0].as_tensor()\n        edge_index2, _ = remove_self_loops(edge_index2)\n        edge_index = torch.cat([edge_index, edge_index2], dim=1)\n\n        if edge_attr is not None:\n            # We treat newly added edge features as \"zero-features\":\n            edge_attr2 = edge_attr.new_zeros(edge_index2.size(1),\n                                             *edge_attr.size()[1:])\n            edge_attr = torch.cat([edge_attr, edge_attr2], dim=0)\n\n        data.edge_index, data.edge_attr = coalesce(edge_index, edge_attr, N)\n\n        return data\n"
  },
  {
    "path": "torch_geometric/transforms/virtual_node.py",
    "content": "import copy\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.data.datapipes import functional_transform\nfrom torch_geometric.transforms import BaseTransform\n\n\n@functional_transform('virtual_node')\nclass VirtualNode(BaseTransform):\n    r\"\"\"Appends a virtual node to the given homogeneous graph that is connected\n    to all other nodes, as described in the `\"Neural Message Passing for\n    Quantum Chemistry\" <https://arxiv.org/abs/1704.01212>`_ paper\n    (functional name: :obj:`virtual_node`).\n    The virtual node serves as a global scratch space that each node both reads\n    from and writes to in every step of message passing.\n    This allows information to travel long distances during the propagation\n    phase.\n\n    Node and edge features of the virtual node are added as zero-filled input\n    features.\n    Furthermore, special edge types will be added both for in-coming and\n    out-going information to and from the virtual node.\n    \"\"\"\n    def forward(self, data: Data) -> Data:\n        assert data.edge_index is not None\n        row, col = data.edge_index\n        edge_type = data.get('edge_type', torch.zeros_like(row))\n        num_nodes = data.num_nodes\n        assert num_nodes is not None\n\n        arange = torch.arange(num_nodes, device=row.device)\n        full = row.new_full((num_nodes, ), num_nodes)\n        row = torch.cat([row, arange, full], dim=0)\n        col = torch.cat([col, full, arange], dim=0)\n        edge_index = torch.stack([row, col], dim=0)\n\n        num_edge_types = int(edge_type.max()) if edge_type.numel() > 0 else 0\n        new_type = edge_type.new_full((num_nodes, ), num_edge_types + 1)\n        edge_type = torch.cat([edge_type, new_type, new_type + 1], dim=0)\n\n        old_data = copy.copy(data)\n        for key, value in old_data.items():\n            if key == 'edge_index' or key == 'edge_type':\n                continue\n\n            if isinstance(value, Tensor):\n                dim = old_data.__cat_dim__(key, value)\n                size = list(value.size())\n\n                fill_value = None\n                if key == 'edge_weight':\n                    size[dim] = 2 * num_nodes\n                    fill_value = 1.\n                elif key == 'batch':\n                    size[dim] = 1\n                    fill_value = int(value[0])\n                elif old_data.is_edge_attr(key):\n                    size[dim] = 2 * num_nodes\n                    fill_value = 0.\n                elif old_data.is_node_attr(key):\n                    size[dim] = 1\n                    fill_value = 0.\n\n                if fill_value is not None:\n                    new_value = value.new_full(size, fill_value)\n                    data[key] = torch.cat([value, new_value], dim=dim)\n\n        data.edge_index = edge_index\n        data.edge_type = edge_type\n\n        if 'num_nodes' in data:\n            data.num_nodes = num_nodes + 1\n\n        return data\n"
  },
  {
    "path": "torch_geometric/typing.py",
    "content": "import importlib.util\nimport inspect\nimport os\nimport typing\nimport warnings\nfrom typing import Any, Dict, List, Optional, Set, Tuple, TypeAlias, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nWITH_PT20 = int(torch.__version__.split('.')[0]) >= 2\nWITH_PT21 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 1\nWITH_PT22 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 2\nWITH_PT23 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 3\nWITH_PT24 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 4\nWITH_PT25 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 5\nWITH_PT26 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 6\nWITH_PT27 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 7\nWITH_PT28 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 8\nWITH_PT29 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 9\nWITH_PT210 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 10\nWITH_PT211 = WITH_PT20 and int(torch.__version__.split('.')[1]) >= 11\nWITH_PT113 = WITH_PT20 or int(torch.__version__.split('.')[1]) >= 13\n\nWITH_WINDOWS = os.name == 'nt'\nNO_MKL = 'USE_MKL=OFF' in torch.__config__.show() or WITH_WINDOWS\n\nMAX_INT64 = torch.iinfo(torch.int64).max\n\nif WITH_PT20:\n    INDEX_DTYPES: Set[torch.dtype] = {\n        torch.int32,\n        torch.int64,\n    }\nelif not typing.TYPE_CHECKING:  # pragma: no cover\n    INDEX_DTYPES: Set[torch.dtype] = {\n        torch.int64,\n    }\n\nif not hasattr(torch, 'sparse_csc'):\n    torch.sparse_csc = torch.sparse_coo\n\ntry:\n    import pyg_lib  # noqa\n    WITH_PYG_LIB = True\n    WITH_GMM = WITH_PT20 and hasattr(pyg_lib.ops, 'grouped_matmul')\n    WITH_SEGMM = hasattr(pyg_lib.ops, 'segment_matmul')\n    if (WITH_SEGMM and 'PYTEST_CURRENT_TEST' in os.environ\n            and torch.cuda.is_available()):\n        # NOTE `segment_matmul` is currently bugged on older NVIDIA cards which\n        # let our GPU tests on CI crash. Try if this error is present on the\n        # current GPU and disable `WITH_SEGMM`/`WITH_GMM` if necessary.\n        # TODO Drop this code block once `segment_matmul` is fixed.\n        try:\n            x = torch.randn(3, 4, device='cuda')\n            ptr = torch.tensor([0, 2, 3], device='cuda')\n            weight = torch.randn(2, 4, 4, device='cuda')\n            out = pyg_lib.ops.segment_matmul(x, ptr, weight)\n        except RuntimeError:\n            WITH_GMM = False\n            WITH_SEGMM = False\n    WITH_SAMPLED_OP = hasattr(pyg_lib.ops, 'sampled_add')\n    WITH_SPLINE = hasattr(pyg_lib.ops, 'spline_basis')\n    WITH_SOFTMAX = hasattr(pyg_lib.ops, 'softmax_csr')\n    WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort')\n    WITH_METIS = hasattr(pyg_lib, 'partition')\n    WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ('edge_time' in inspect.signature(\n        pyg_lib.sampler.neighbor_sample).parameters)\n    WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(\n        pyg_lib.sampler.neighbor_sample).parameters)\n    try:\n        torch.classes.pyg.CPUHashMap  # noqa: B018\n        WITH_CPU_HASH_MAP = True\n    except Exception:\n        WITH_CPU_HASH_MAP = False\n    try:\n        torch.classes.pyg.CUDAHashMap  # noqa: B018\n        WITH_CUDA_HASH_MAP = True\n    except Exception:\n        WITH_CUDA_HASH_MAP = False\nexcept Exception as e:\n    if not isinstance(e, ImportError):  # pragma: no cover\n        warnings.warn(\n            f\"An issue occurred while importing 'pyg-lib'. \"\n            f\"Disabling its usage. Stacktrace: {e}\", stacklevel=2)\n    pyg_lib = object\n    WITH_PYG_LIB = False\n    WITH_GMM = False\n    WITH_SEGMM = False\n    WITH_SAMPLED_OP = False\n    WITH_SPLINE = False\n    WITH_SOFTMAX = False\n    WITH_INDEX_SORT = False\n    WITH_METIS = False\n    WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False\n    WITH_WEIGHTED_NEIGHBOR_SAMPLE = False\n    WITH_CPU_HASH_MAP = False\n    WITH_CUDA_HASH_MAP = False\n\nif WITH_CPU_HASH_MAP:\n    CPUHashMap: TypeAlias = torch.classes.pyg.CPUHashMap  # type: ignore[name-defined]  # noqa: E501\nelse:\n\n    class CPUHashMap:  # type: ignore\n        def __init__(self, key: Tensor) -> None:\n            raise ImportError(\"'CPUHashMap' requires 'pyg-lib'\")\n\n        def get(self, query: Tensor) -> Tensor:\n            raise ImportError(\"'CPUHashMap' requires 'pyg-lib'\")\n\n\nif WITH_CUDA_HASH_MAP:\n    CUDAHashMap: TypeAlias = torch.classes.pyg.CUDAHashMap  # type: ignore[name-defined]  # noqa: E501\nelse:\n\n    class CUDAHashMap:  # type: ignore\n        def __init__(self, key: Tensor) -> None:\n            raise ImportError(\"'CUDAHashMap' requires 'pyg-lib'\")\n\n        def get(self, query: Tensor) -> Tensor:\n            raise ImportError(\"'CUDAHashMap' requires 'pyg-lib'\")\n\n\ntry:\n    import torch_scatter  # noqa\n    WITH_TORCH_SCATTER = True\nexcept Exception as e:\n    if not isinstance(e, ImportError):  # pragma: no cover\n        warnings.warn(\n            f\"An issue occurred while importing 'torch-scatter'. \"\n            f\"Disabling its usage. Stacktrace: {e}\", stacklevel=2)\n    torch_scatter = object\n    WITH_TORCH_SCATTER = False\n\ntry:\n    import torch_cluster  # noqa\n    WITH_TORCH_CLUSTER = True\n    WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__\nexcept Exception as e:\n    if not isinstance(e, ImportError):  # pragma: no cover\n        warnings.warn(\n            f\"An issue occurred while importing 'torch-cluster'. \"\n            f\"Disabling its usage. Stacktrace: {e}\", stacklevel=2)\n    WITH_TORCH_CLUSTER = False\n    WITH_TORCH_CLUSTER_BATCH_SIZE = False\n\n    class TorchCluster:\n        def __getattr__(self, key: str) -> Any:\n            raise ImportError(f\"'{key}' requires 'torch-cluster'\")\n\n    torch_cluster = TorchCluster()\n\nif importlib.util.find_spec('torch_spline_conv') is not None:\n    warnings.warn(\n        \"'torch-spline-conv' is no longer necessary and is being ignored. \"\n        \"Its functionality has been migrated to 'pyg-lib>=0.6.0'.\",\n        DeprecationWarning,\n        stacklevel=2,\n    )\n\ntry:\n    import torch_sparse  # noqa\n    from torch_sparse import SparseStorage, SparseTensor\n    WITH_TORCH_SPARSE = True\nexcept Exception as e:\n    if not isinstance(e, ImportError):  # pragma: no cover\n        warnings.warn(\n            f\"An issue occurred while importing 'torch-sparse'. \"\n            f\"Disabling its usage. Stacktrace: {e}\", stacklevel=2)\n    WITH_TORCH_SPARSE = False\n\n    class SparseStorage:  # type: ignore\n        def __init__(\n            self,\n            row: Optional[Tensor] = None,\n            rowptr: Optional[Tensor] = None,\n            col: Optional[Tensor] = None,\n            value: Optional[Tensor] = None,\n            sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,\n            rowcount: Optional[Tensor] = None,\n            colptr: Optional[Tensor] = None,\n            colcount: Optional[Tensor] = None,\n            csr2csc: Optional[Tensor] = None,\n            csc2csr: Optional[Tensor] = None,\n            is_sorted: bool = False,\n            trust_data: bool = False,\n        ):\n            raise ImportError(\"'SparseStorage' requires 'torch-sparse'\")\n\n        def value(self) -> Optional[Tensor]:\n            raise ImportError(\"'SparseStorage' requires 'torch-sparse'\")\n\n        def rowcount(self) -> Tensor:\n            raise ImportError(\"'SparseStorage' requires 'torch-sparse'\")\n\n    class SparseTensor:  # type: ignore\n        def __init__(\n            self,\n            row: Optional[Tensor] = None,\n            rowptr: Optional[Tensor] = None,\n            col: Optional[Tensor] = None,\n            value: Optional[Tensor] = None,\n            sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,\n            is_sorted: bool = False,\n            trust_data: bool = False,\n        ):\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        @classmethod\n        def from_edge_index(\n            self,\n            edge_index: Tensor,\n            edge_attr: Optional[Tensor] = None,\n            sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,\n            is_sorted: bool = False,\n            trust_data: bool = False,\n        ) -> 'SparseTensor':\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        @property\n        def storage(self) -> SparseStorage:\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        @classmethod\n        def from_dense(self, mat: Tensor,\n                       has_value: bool = True) -> 'SparseTensor':\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        def size(self, dim: int) -> int:\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        def nnz(self) -> int:\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        def is_cuda(self) -> bool:\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        def has_value(self) -> bool:\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        def set_value(self, value: Optional[Tensor],\n                      layout: Optional[str] = None) -> 'SparseTensor':\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        def fill_value(self, fill_value: float,\n                       dtype: Optional[torch.dtype] = None) -> 'SparseTensor':\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        def coo(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]:\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        def csr(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]:\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        def requires_grad(self) -> bool:\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n        def to_torch_sparse_csr_tensor(\n            self,\n            dtype: Optional[torch.dtype] = None,\n        ) -> Tensor:\n            raise ImportError(\"'SparseTensor' requires 'torch-sparse'\")\n\n    class torch_sparse:  # type: ignore\n        @staticmethod\n        def matmul(src: SparseTensor, other: Tensor,\n                   reduce: str = \"sum\") -> Tensor:\n            raise ImportError(\"'matmul' requires 'torch-sparse'\")\n\n        @staticmethod\n        def sum(src: SparseTensor, dim: Optional[int] = None) -> Tensor:\n            raise ImportError(\"'sum' requires 'torch-sparse'\")\n\n        @staticmethod\n        def mul(src: SparseTensor, other: Tensor) -> SparseTensor:\n            raise ImportError(\"'mul' requires 'torch-sparse'\")\n\n        @staticmethod\n        def set_diag(src: SparseTensor, values: Optional[Tensor] = None,\n                     k: int = 0) -> SparseTensor:\n            raise ImportError(\"'set_diag' requires 'torch-sparse'\")\n\n        @staticmethod\n        def fill_diag(src: SparseTensor, fill_value: float,\n                      k: int = 0) -> SparseTensor:\n            raise ImportError(\"'fill_diag' requires 'torch-sparse'\")\n\n        @staticmethod\n        def masked_select_nnz(src: SparseTensor, mask: Tensor,\n                              layout: Optional[str] = None) -> SparseTensor:\n            raise ImportError(\"'masked_select_nnz' requires 'torch-sparse'\")\n\n\ntry:\n    import torch_frame  # noqa\n    WITH_TORCH_FRAME = True\n    from torch_frame import TensorFrame\nexcept Exception:\n    torch_frame = object\n    WITH_TORCH_FRAME = False\n\n    class TensorFrame:  # type: ignore\n        pass\n\n\ntry:\n    import intel_extension_for_pytorch  # noqa\n    WITH_IPEX = True\nexcept Exception:\n    WITH_IPEX = False\n\n\nclass MockTorchCSCTensor:\n    def __init__(\n        self,\n        edge_index: Tensor,\n        edge_attr: Optional[Tensor] = None,\n        size: Optional[Union[int, Tuple[int, int]]] = None,\n    ):\n        self.edge_index = edge_index\n        self.edge_attr = edge_attr\n        self.size = size\n\n    def t(self) -> Tensor:  # Only support accessing its transpose:\n        from torch_geometric.utils import to_torch_csr_tensor\n        size = self.size\n        return to_torch_csr_tensor(\n            self.edge_index.flip([0]),\n            self.edge_attr,\n            size[::-1] if isinstance(size, (tuple, list)) else size,\n        )\n\n\n# Types for accessing data ####################################################\n\n# Node-types are denoted by a single string, e.g.: `data['paper']`:\nNodeType = str\n\n# Edge-types are denotes by a triplet of strings, e.g.:\n# `data[('author', 'writes', 'paper')]\nEdgeType = Tuple[str, str, str]\n\nNodeOrEdgeType = Union[NodeType, EdgeType]\n\nDEFAULT_REL = 'to'\nEDGE_TYPE_STR_SPLIT = '__'\n\n\nclass EdgeTypeStr(str):\n    r\"\"\"A helper class to construct serializable edge types by merging an edge\n    type tuple into a single string.\n    \"\"\"\n    edge_type: tuple[str, str, str]\n\n    def __new__(cls, *args: Any) -> 'EdgeTypeStr':\n        if isinstance(args[0], (list, tuple)):\n            # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:\n            args = tuple(args[0])\n\n        if len(args) == 1 and isinstance(args[0], str):\n            arg = args[0]  # An edge type string was passed.\n            edge_type = tuple(arg.split(EDGE_TYPE_STR_SPLIT))\n            if len(edge_type) != 3:\n                raise ValueError(f\"Cannot convert the edge type '{arg}' to a \"\n                                 f\"tuple since it holds invalid characters\")\n\n        elif len(args) == 2 and all(isinstance(arg, str) for arg in args):\n            # A `(src, dst)` edge type was passed - add `DEFAULT_REL`:\n            edge_type = (args[0], DEFAULT_REL, args[1])\n            arg = EDGE_TYPE_STR_SPLIT.join(edge_type)\n\n        elif len(args) == 3 and all(isinstance(arg, str) for arg in args):\n            # A `(src, rel, dst)` edge type was passed:\n            edge_type = tuple(args)\n            arg = EDGE_TYPE_STR_SPLIT.join(args)\n\n        else:\n            raise ValueError(f\"Encountered invalid edge type '{args}'\")\n\n        out = str.__new__(cls, arg)\n        out.edge_type = edge_type  # type: ignore\n        return out\n\n    def to_tuple(self) -> EdgeType:\n        r\"\"\"Returns the original edge type.\"\"\"\n        if len(self.edge_type) != 3:\n            raise ValueError(f\"Cannot convert the edge type '{self}' to a \"\n                             f\"tuple since it holds invalid characters\")\n        return self.edge_type\n\n    def __reduce__(self) -> tuple[Any, Any]:\n        return (self.__class__, (self.edge_type, ))\n\n\n# There exist some short-cuts to query edge-types (given that the full triplet\n# can be uniquely reconstructed, e.g.:\n# * via str: `data['writes']`\n# * via Tuple[str, str]: `data[('author', 'paper')]`\nQueryType = Union[NodeType, EdgeType, str, Tuple[str, str]]\n\nMetadata = Tuple[List[NodeType], List[EdgeType]]\n\n# A representation of a feature tensor\nFeatureTensorType = Union[Tensor, np.ndarray]\n\n# A representation of an edge index, following the possible formats:\n#   * COO: (row, col)\n#   * CSC: (row, colptr)\n#   * CSR: (rowptr, col)\nEdgeTensorType = Tuple[Tensor, Tensor]\n\n# Types for message passing ###################################################\n\nAdj = Union[Tensor, SparseTensor]\nOptTensor = Optional[Tensor]\nPairTensor = Tuple[Tensor, Tensor]\nOptPairTensor = Tuple[Tensor, Optional[Tensor]]\nPairOptTensor = Tuple[Optional[Tensor], Optional[Tensor]]\nSize = Optional[Tuple[int, int]]\nNoneType = Optional[Tensor]\n\nMaybeHeteroNodeTensor = Union[Tensor, Dict[NodeType, Tensor]]\nMaybeHeteroAdjTensor = Union[Tensor, Dict[EdgeType, Adj]]\nMaybeHeteroEdgeTensor = Union[Tensor, Dict[EdgeType, Tensor]]\n\n# Types for sampling ##########################################################\n\nInputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]]\nInputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]]\n\n# Serialization ###############################################################\n\nif WITH_PT24:\n    torch.serialization.add_safe_globals([\n        SparseTensor,\n        SparseStorage,\n        TensorFrame,\n        MockTorchCSCTensor,\n        EdgeTypeStr,\n    ])\n"
  },
  {
    "path": "torch_geometric/utils/__init__.py",
    "content": "r\"\"\"Utility package.\"\"\"\n\nimport copy\n\nfrom ._scatter import scatter, group_argsort, group_cat\nfrom ._segment import segment, segment_logsumexp\nfrom ._index_sort import index_sort\nfrom .functions import cumsum\nfrom ._degree import degree\nfrom ._softmax import softmax\nfrom ._lexsort import lexsort\nfrom ._sort_edge_index import sort_edge_index\nfrom ._coalesce import coalesce\nfrom .undirected import is_undirected, to_undirected\nfrom .loop import (contains_self_loops, remove_self_loops,\n                   segregate_self_loops, add_self_loops,\n                   add_remaining_self_loops, get_self_loop_attr)\nfrom .isolated import contains_isolated_nodes, remove_isolated_nodes\nfrom ._subgraph import (get_num_hops, subgraph, k_hop_subgraph,\n                        bipartite_subgraph)\nfrom .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path\nfrom ._homophily import homophily\nfrom ._assortativity import assortativity\nfrom ._normalize_edge_index import normalize_edge_index\nfrom .laplacian import get_laplacian\nfrom .mesh_laplacian import get_mesh_laplacian\nfrom .mask import mask_select, index_to_mask, mask_to_index\nfrom ._select import select, narrow\nfrom ._to_dense_batch import to_dense_batch\nfrom ._to_dense_adj import to_dense_adj\nfrom .nested import to_nested_tensor, from_nested_tensor\nfrom .sparse import (dense_to_sparse, is_sparse, is_torch_sparse_tensor,\n                     to_torch_coo_tensor, to_torch_csr_tensor,\n                     to_torch_csc_tensor, to_torch_sparse_tensor,\n                     to_edge_index)\nfrom ._spmm import spmm\nfrom ._unbatch import unbatch, unbatch_edge_index\nfrom ._one_hot import one_hot\nfrom ._normalized_cut import normalized_cut\nfrom ._grid import grid\nfrom .geodesic import geodesic_distance\nfrom .convert import to_scipy_sparse_matrix, from_scipy_sparse_matrix\nfrom .convert import to_networkx, from_networkx\nfrom .convert import to_networkit, from_networkit\nfrom .convert import to_trimesh, from_trimesh\nfrom .convert import to_cugraph, from_cugraph\nfrom .convert import to_dgl, from_dgl\nfrom .smiles import from_rdmol, to_rdmol, from_smiles, to_smiles\nfrom .random import (erdos_renyi_graph, stochastic_blockmodel_graph,\n                     barabasi_albert_graph)\nfrom ._negative_sampling import (negative_sampling, batched_negative_sampling,\n                                 structured_negative_sampling,\n                                 structured_negative_sampling_feasible)\nfrom .augmentation import shuffle_node, mask_feature, add_random_edge\nfrom ._tree_decomposition import tree_decomposition\nfrom .embedding import get_embeddings, get_embeddings_hetero\nfrom ._trim_to_layer import trim_to_layer\nfrom .ppr import get_ppr\nfrom ._train_test_split_edges import train_test_split_edges\nfrom .influence import total_influence\n\n__all__ = [\n    'scatter',\n    'group_argsort',\n    'group_cat',\n    'segment',\n    'segment_logsumexp',\n    'index_sort',\n    'cumsum',\n    'degree',\n    'softmax',\n    'lexsort',\n    'sort_edge_index',\n    'coalesce',\n    'is_undirected',\n    'to_undirected',\n    'contains_self_loops',\n    'remove_self_loops',\n    'segregate_self_loops',\n    'add_self_loops',\n    'add_remaining_self_loops',\n    'get_self_loop_attr',\n    'contains_isolated_nodes',\n    'remove_isolated_nodes',\n    'get_num_hops',\n    'subgraph',\n    'bipartite_subgraph',\n    'k_hop_subgraph',\n    'dropout_node',\n    'dropout_edge',\n    'dropout_path',\n    'dropout_adj',\n    'homophily',\n    'assortativity',\n    'normalize_edge_index',\n    'get_laplacian',\n    'get_mesh_laplacian',\n    'mask_select',\n    'index_to_mask',\n    'mask_to_index',\n    'select',\n    'narrow',\n    'to_dense_batch',\n    'to_dense_adj',\n    'to_nested_tensor',\n    'from_nested_tensor',\n    'dense_to_sparse',\n    'is_torch_sparse_tensor',\n    'is_sparse',\n    'to_torch_coo_tensor',\n    'to_torch_csr_tensor',\n    'to_torch_csc_tensor',\n    'to_torch_sparse_tensor',\n    'to_edge_index',\n    'spmm',\n    'unbatch',\n    'unbatch_edge_index',\n    'one_hot',\n    'normalized_cut',\n    'grid',\n    'geodesic_distance',\n    'to_scipy_sparse_matrix',\n    'from_scipy_sparse_matrix',\n    'to_networkx',\n    'from_networkx',\n    'to_networkit',\n    'from_networkit',\n    'to_trimesh',\n    'from_trimesh',\n    'to_cugraph',\n    'from_cugraph',\n    'to_dgl',\n    'from_dgl',\n    'from_rdmol',\n    'to_rdmol',\n    'from_smiles',\n    'to_smiles',\n    'erdos_renyi_graph',\n    'stochastic_blockmodel_graph',\n    'barabasi_albert_graph',\n    'negative_sampling',\n    'batched_negative_sampling',\n    'structured_negative_sampling',\n    'structured_negative_sampling_feasible',\n    'shuffle_node',\n    'mask_feature',\n    'add_random_edge',\n    'tree_decomposition',\n    'get_embeddings',\n    'get_embeddings_hetero',\n    'trim_to_layer',\n    'get_ppr',\n    'train_test_split_edges',\n    'total_influence',\n]\n\n# `structured_negative_sampling_feasible` is a long name and thus destroys the\n# documentation rendering. We remove it for now from the documentation:\nclasses = copy.copy(__all__)\nclasses.remove('structured_negative_sampling_feasible')\n"
  },
  {
    "path": "torch_geometric/utils/_assortativity.py",
    "content": "import torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import Adj, SparseTensor\nfrom torch_geometric.utils import coalesce, degree\nfrom torch_geometric.utils._to_dense_adj import to_dense_adj\n\n\ndef assortativity(edge_index: Adj) -> float:\n    r\"\"\"The degree assortativity coefficient from the\n    `\"Mixing patterns in networks\"\n    <https://arxiv.org/abs/cond-mat/0209450>`_ paper.\n    Assortativity in a network refers to the tendency of nodes to\n    connect with other similar nodes over dissimilar nodes.\n    It is computed from Pearson correlation coefficient of the node degrees.\n\n    Args:\n        edge_index (Tensor or SparseTensor): The graph connectivity.\n\n    Returns:\n        The value of the degree assortativity coefficient for the input\n        graph :math:`\\in [-1, 1]`\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 1, 2, 3, 2],\n        ...                            [1, 2, 0, 1, 3]])\n        >>> assortativity(edge_index)\n        -0.666667640209198\n    \"\"\"\n    if isinstance(edge_index, SparseTensor):\n        adj: SparseTensor = edge_index\n        row, col, _ = adj.coo()\n    else:\n        assert isinstance(edge_index, Tensor)\n        row, col = edge_index\n\n    device = row.device\n    out_deg = degree(row, dtype=torch.long)\n    in_deg = degree(col, dtype=torch.long)\n    degrees = torch.unique(torch.cat([out_deg, in_deg]))\n    mapping = row.new_zeros(degrees.max().item() + 1)\n    mapping[degrees] = torch.arange(degrees.size(0), device=device)\n\n    # Compute degree mixing matrix (joint probability distribution) `M`\n    num_degrees = degrees.size(0)\n    src_deg = mapping[out_deg[row]]\n    dst_deg = mapping[in_deg[col]]\n\n    pairs = torch.stack([src_deg, dst_deg], dim=0)\n    occurrence = torch.ones(pairs.size(1), device=device)\n    pairs, occurrence = coalesce(pairs, occurrence)\n    M = to_dense_adj(pairs, edge_attr=occurrence, max_num_nodes=num_degrees)[0]\n    # normalization\n    M /= M.sum()\n\n    # numeric assortativity coefficient, computed by\n    # Pearson correlation coefficient of the node degrees\n    x = y = degrees.float()\n    a, b = M.sum(0), M.sum(1)\n\n    vara = (a * x**2).sum() - ((a * x).sum())**2\n    varb = (b * x**2).sum() - ((b * x).sum())**2\n    xy = torch.outer(x, y)\n    ab = torch.outer(a, b)\n    out = (xy * (M - ab)).sum() / (vara * varb).sqrt()\n    return out.item()\n"
  },
  {
    "path": "torch_geometric/utils/_coalesce.py",
    "content": "import typing\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.edge_index import SortOrder\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import index_sort, scatter\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\nif typing.TYPE_CHECKING:\n    from typing import overload\nelse:\n    from torch.jit import _overload as overload\n\nMISSING = '???'\n\n\n@overload\ndef coalesce(\n    edge_index: Tensor,\n    edge_attr: str = MISSING,\n    num_nodes: Optional[int] = None,\n    reduce: str = 'sum',\n    is_sorted: bool = False,\n    sort_by_row: bool = True,\n) -> Tensor:\n    pass\n\n\n@overload\ndef coalesce(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    num_nodes: Optional[int] = None,\n    reduce: str = 'sum',\n    is_sorted: bool = False,\n    sort_by_row: bool = True,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef coalesce(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: OptTensor,\n    num_nodes: Optional[int] = None,\n    reduce: str = 'sum',\n    is_sorted: bool = False,\n    sort_by_row: bool = True,\n) -> Tuple[Tensor, OptTensor]:\n    pass\n\n\n@overload\ndef coalesce(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: List[Tensor],\n    num_nodes: Optional[int] = None,\n    reduce: str = 'sum',\n    is_sorted: bool = False,\n    sort_by_row: bool = True,\n) -> Tuple[Tensor, List[Tensor]]:\n    pass\n\n\ndef coalesce(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Union[OptTensor, List[Tensor], str] = MISSING,\n    num_nodes: Optional[int] = None,\n    reduce: str = 'sum',\n    is_sorted: bool = False,\n    sort_by_row: bool = True,\n) -> Union[Tensor, Tuple[Tensor, OptTensor], Tuple[Tensor, List[Tensor]]]:\n    \"\"\"Row-wise sorts :obj:`edge_index` and removes its duplicated entries.\n    Duplicate entries in :obj:`edge_attr` are merged by scattering them\n    together according to the given :obj:`reduce` option.\n\n    Args:\n        edge_index (torch.Tensor): The edge indices.\n        edge_attr (torch.Tensor or List[torch.Tensor], optional): Edge weights\n            or multi-dimensional edge features.\n            If given as a list, will re-shuffle and remove duplicates for all\n            its entries. (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n        reduce (str, optional): The reduce operation to use for merging edge\n            features (:obj:`\"sum\"`, :obj:`\"mean\"`, :obj:`\"min\"`, :obj:`\"max\"`,\n            :obj:`\"mul\"`, :obj:`\"any\"`). (default: :obj:`\"sum\"`)\n        is_sorted (bool, optional): If set to :obj:`True`, will expect\n            :obj:`edge_index` to be already sorted row-wise.\n        sort_by_row (bool, optional): If set to :obj:`False`, will sort\n            :obj:`edge_index` column-wise.\n\n    :rtype: :class:`LongTensor` if :attr:`edge_attr` is not passed, else\n        (:class:`LongTensor`, :obj:`Optional[Tensor]` or :obj:`List[Tensor]]`)\n\n    .. warning::\n\n        From :pyg:`PyG >= 2.3.0` onwards, this function will always return a\n        tuple whenever :obj:`edge_attr` is passed as an argument (even in case\n        it is set to :obj:`None`).\n\n    Example:\n        >>> edge_index = torch.tensor([[1, 1, 2, 3],\n        ...                            [3, 3, 1, 2]])\n        >>> edge_attr = torch.tensor([1., 1., 1., 1.])\n        >>> coalesce(edge_index)\n        tensor([[1, 2, 3],\n                [3, 1, 2]])\n\n        >>> # Sort `edge_index` column-wise\n        >>> coalesce(edge_index, sort_by_row=False)\n        tensor([[2, 3, 1],\n                [1, 2, 3]])\n\n        >>> coalesce(edge_index, edge_attr)\n        (tensor([[1, 2, 3],\n                [3, 1, 2]]),\n        tensor([2., 1., 1.]))\n\n        >>> # Use 'mean' operation to merge edge features\n        >>> coalesce(edge_index, edge_attr, reduce='mean')\n        (tensor([[1, 2, 3],\n                [3, 1, 2]]),\n        tensor([1., 1., 1.]))\n    \"\"\"\n    num_edges = edge_index[0].size(0)\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    if num_nodes * num_nodes > torch_geometric.typing.MAX_INT64:\n        raise ValueError(\"'coalesce' will result in an overflow\")\n\n    idx = edge_index[0].new_empty(num_edges + 1)\n    idx[0] = -1\n    idx[1:] = edge_index[1 - int(sort_by_row)]\n    idx[1:].mul_(num_nodes).add_(edge_index[int(sort_by_row)])\n\n    is_undirected = False\n    if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        is_undirected = edge_index.is_undirected\n\n    if not is_sorted:\n        idx[1:], perm = index_sort(idx[1:], max_value=num_nodes * num_nodes)\n        if isinstance(edge_index, Tensor):\n            edge_index = edge_index[:, perm]\n        elif isinstance(edge_index, tuple):\n            edge_index = (edge_index[0][perm], edge_index[1][perm])\n        else:\n            raise NotImplementedError\n        if isinstance(edge_attr, Tensor):\n            edge_attr = edge_attr[perm]\n        elif isinstance(edge_attr, (list, tuple)):\n            edge_attr = [e[perm] for e in edge_attr]\n\n    if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        edge_index._sort_order = SortOrder('row' if sort_by_row else 'col')\n        edge_index._is_undirected = is_undirected\n\n    mask = idx[1:] > idx[:-1]\n\n    # Only perform expensive merging in case there exists duplicates:\n    if mask.all():\n        if edge_attr is None or isinstance(edge_attr, Tensor):\n            return edge_index, edge_attr\n        if isinstance(edge_attr, (list, tuple)):\n            return edge_index, edge_attr\n        return edge_index\n\n    if isinstance(edge_index, Tensor):\n        edge_index = edge_index[:, mask]\n        if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n            edge_index._is_undirected = is_undirected\n    elif isinstance(edge_index, tuple):\n        edge_index = (edge_index[0][mask], edge_index[1][mask])\n    else:\n        raise NotImplementedError\n\n    dim_size: Optional[int] = None\n    if isinstance(edge_attr, (Tensor, list, tuple)) and len(edge_attr) > 0:\n        dim_size = edge_index.size(1)\n        idx = torch.arange(0, num_edges, device=edge_index.device)\n        idx.sub_(mask.logical_not_().cumsum(dim=0))\n\n    if edge_attr is None:\n        return edge_index, None\n    if isinstance(edge_attr, Tensor):\n        edge_attr = scatter(edge_attr, idx, 0, dim_size, reduce)\n        return edge_index, edge_attr\n    if isinstance(edge_attr, (list, tuple)):\n        if len(edge_attr) == 0:\n            return edge_index, edge_attr\n        edge_attr = [scatter(e, idx, 0, dim_size, reduce) for e in edge_attr]\n        return edge_index, edge_attr\n\n    return edge_index\n"
  },
  {
    "path": "torch_geometric/utils/_degree.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef degree(index: Tensor, num_nodes: Optional[int] = None,\n           dtype: Optional[torch.dtype] = None) -> Tensor:\n    r\"\"\"Computes the (unweighted) degree of a given one-dimensional index\n    tensor.\n\n    Args:\n        index (LongTensor): Index tensor.\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)\n        dtype (:obj:`torch.dtype`, optional): The desired data type of the\n            returned tensor.\n\n    :rtype: :class:`Tensor`\n\n    Example:\n        >>> row = torch.tensor([0, 1, 0, 2, 0])\n        >>> degree(row, dtype=torch.long)\n        tensor([3, 1, 1])\n    \"\"\"\n    N = maybe_num_nodes(index, num_nodes)\n    out = torch.zeros((N, ), dtype=dtype, device=index.device)\n    one = torch.ones((index.size(0), ), dtype=out.dtype, device=out.device)\n    return out.scatter_add_(0, index, one)\n"
  },
  {
    "path": "torch_geometric/utils/_grid.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import coalesce\n\n\ndef grid(\n    height: int,\n    width: int,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Returns the edge indices of a two-dimensional grid graph with height\n    :attr:`height` and width :attr:`width` and its node positions.\n\n    Args:\n        height (int): The height of the grid.\n        width (int): The width of the grid.\n        dtype (torch.dtype, optional): The desired data type of the returned\n            position tensor. (default: :obj:`None`)\n        device (torch.device, optional): The desired device of the returned\n            tensors. (default: :obj:`None`)\n\n    :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n    Example:\n        >>> (row, col), pos = grid(height=2, width=2)\n        >>> row\n        tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3])\n        >>> col\n        tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3])\n        >>> pos\n        tensor([[0., 1.],\n                [1., 1.],\n                [0., 0.],\n                [1., 0.]])\n    \"\"\"\n    edge_index = grid_index(height, width, device)\n    pos = grid_pos(height, width, dtype, device)\n    return edge_index, pos\n\n\ndef grid_index(\n    height: int,\n    width: int,\n    device: Optional[torch.device] = None,\n) -> Tensor:\n\n    w = width\n    kernel = torch.tensor(\n        [-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1],\n        device=device,\n    )\n\n    row = torch.arange(height * width, dtype=torch.long, device=device)\n    row = row.view(-1, 1).repeat(1, kernel.size(0))\n    col = row + kernel.view(1, -1)\n    row, col = row.view(height, -1), col.view(height, -1)\n    index = torch.arange(3, row.size(1) - 3, dtype=torch.long, device=device)\n    row, col = row[:, index].view(-1), col[:, index].view(-1)\n\n    mask = (col >= 0) & (col < height * width)\n    row, col = row[mask], col[mask]\n\n    edge_index = torch.stack([row, col], dim=0)\n    edge_index = coalesce(edge_index, num_nodes=height * width)\n    return edge_index\n\n\ndef grid_pos(\n    height: int,\n    width: int,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n) -> Tensor:\n\n    dtype = torch.float if dtype is None else dtype\n    x = torch.arange(width, dtype=dtype, device=device)\n    y = (height - 1) - torch.arange(height, dtype=dtype, device=device)\n\n    x = x.repeat(height)\n    y = y.unsqueeze(-1).repeat(1, width).view(-1)\n\n    return torch.stack([x, y], dim=-1)\n"
  },
  {
    "path": "torch_geometric/utils/_homophily.py",
    "content": "from typing import Union, overload\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import Adj, OptTensor, SparseTensor\nfrom torch_geometric.utils import degree, scatter\n\n\n@overload\ndef homophily(\n    edge_index: Adj,\n    y: Tensor,\n    batch: None = ...,\n    method: str = ...,\n) -> float:\n    pass\n\n\n@overload\ndef homophily(\n    edge_index: Adj,\n    y: Tensor,\n    batch: Tensor,\n    method: str = ...,\n) -> Tensor:\n    pass\n\n\ndef homophily(\n    edge_index: Adj,\n    y: Tensor,\n    batch: OptTensor = None,\n    method: str = 'edge',\n) -> Union[float, Tensor]:\n    r\"\"\"The homophily of a graph characterizes how likely nodes with the same\n    label are near each other in a graph.\n\n    There are many measures of homophily that fits this definition.\n    In particular:\n\n    - In the `\"Beyond Homophily in Graph Neural Networks: Current Limitations\n      and Effective Designs\" <https://arxiv.org/abs/2006.11468>`_ paper, the\n      homophily is the fraction of edges in a graph which connects nodes\n      that have the same class label:\n\n      .. math::\n        \\frac{| \\{ (v,w) : (v,w) \\in \\mathcal{E} \\wedge y_v = y_w \\} | }\n        {|\\mathcal{E}|}\n\n      That measure is called the *edge homophily ratio*.\n\n    - In the `\"Geom-GCN: Geometric Graph Convolutional Networks\"\n      <https://arxiv.org/abs/2002.05287>`_ paper, edge homophily is normalized\n      across neighborhoods:\n\n      .. math::\n        \\frac{1}{|\\mathcal{V}|} \\sum_{v \\in \\mathcal{V}} \\frac{ | \\{ (w,v) : w\n        \\in \\mathcal{N}(v) \\wedge y_v = y_w \\} |  } { |\\mathcal{N}(v)| }\n\n      That measure is called the *node homophily ratio*.\n\n    - In the `\"Large-Scale Learning on Non-Homophilous Graphs: New Benchmarks\n      and Strong Simple Methods\" <https://arxiv.org/abs/2110.14446>`_ paper,\n      edge homophily is modified to be insensitive to the number of classes\n      and size of each class:\n\n      .. math::\n        \\frac{1}{C-1} \\sum_{k=1}^{C} \\max \\left(0, h_k - \\frac{|\\mathcal{C}_k|}\n        {|\\mathcal{V}|} \\right),\n\n      where :math:`C` denotes the number of classes, :math:`|\\mathcal{C}_k|`\n      denotes the number of nodes of class :math:`k`, and :math:`h_k` denotes\n      the edge homophily ratio of nodes of class :math:`k`.\n\n      Thus, that measure is called the *class insensitive edge homophily\n      ratio*.\n\n    Args:\n        edge_index (Tensor or SparseTensor): The graph connectivity.\n        y (Tensor): The labels.\n        batch (LongTensor, optional): Batch vector\\\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots,B-1\\}}^N`, which assigns\n            each node to a specific example. (default: :obj:`None`)\n        method (str, optional): The method used to calculate the homophily,\n            either :obj:`\"edge\"` (first formula), :obj:`\"node\"` (second\n            formula) or :obj:`\"edge_insensitive\"` (third formula).\n            (default: :obj:`\"edge\"`)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 2, 3],\n        ...                            [1, 2, 0, 4]])\n        >>> y = torch.tensor([0, 0, 0, 0, 1])\n        >>> # Edge homophily ratio\n        >>> homophily(edge_index, y, method='edge')\n        0.75\n\n        >>> # Node homophily ratio\n        >>> homophily(edge_index, y, method='node')\n        0.6000000238418579\n\n        >>> # Class insensitive edge homophily ratio\n        >>> homophily(edge_index, y, method='edge_insensitive')\n        0.19999998807907104\n    \"\"\"\n    assert method in {'edge', 'node', 'edge_insensitive'}\n    y = y.squeeze(-1) if y.dim() > 1 else y\n\n    if isinstance(edge_index, SparseTensor):\n        row, col, _ = edge_index.coo()\n    else:\n        row, col = edge_index\n\n    if method == 'edge':\n        out = torch.zeros(row.size(0), device=row.device)\n        out[y[row] == y[col]] = 1.\n        if batch is None:\n            return float(out.mean())\n        else:\n            dim_size = int(batch.max()) + 1\n            return scatter(out, batch[col], 0, dim_size, reduce='mean')\n\n    elif method == 'node':\n        out = torch.zeros(row.size(0), device=row.device)\n        out[y[row] == y[col]] = 1.\n        out = scatter(out, col, 0, dim_size=y.size(0), reduce='mean')\n        if batch is None:\n            return float(out.mean())\n        else:\n            return scatter(out, batch, dim=0, reduce='mean')\n\n    elif method == 'edge_insensitive':\n        assert y.dim() == 1\n        num_classes = int(y.max()) + 1\n        assert num_classes >= 2\n        batch = torch.zeros_like(y) if batch is None else batch\n        num_nodes = degree(batch, dtype=torch.int64)\n        num_graphs = num_nodes.numel()\n        batch = num_classes * batch + y\n\n        h = homophily(edge_index, y, batch, method='edge')\n        h = h.view(num_graphs, num_classes)\n\n        counts = batch.bincount(minlength=num_classes * num_graphs)\n        counts = counts.view(num_graphs, num_classes)\n        proportions = counts / num_nodes.view(-1, 1)\n\n        out = (h - proportions).clamp_(min=0).sum(dim=-1)\n        out /= num_classes - 1\n        return out if out.numel() > 1 else float(out)\n\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "torch_geometric/utils/_index_sort.py",
    "content": "from typing import Optional, Tuple\n\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling\nfrom torch_geometric.typing import pyg_lib\n\n\ndef index_sort(\n    inputs: Tensor,\n    max_value: Optional[int] = None,\n    stable: bool = False,\n) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Sorts the elements of the :obj:`inputs` tensor in ascending order.\n    It is expected that :obj:`inputs` is one-dimensional and that it only\n    contains positive integer values. If :obj:`max_value` is given, it can\n    be used by the underlying algorithm for better performance.\n\n    Args:\n        inputs (torch.Tensor): A vector with positive integer values.\n        max_value (int, optional): The maximum value stored inside\n            :obj:`inputs`. This value can be an estimation, but needs to be\n            greater than or equal to the real maximum.\n            (default: :obj:`None`)\n        stable (bool, optional): Makes the sorting routine stable, which\n            guarantees that the order of equivalent elements is preserved.\n            (default: :obj:`False`)\n    \"\"\"\n    if stable or not torch_geometric.typing.WITH_INDEX_SORT or is_compiling():\n        return inputs.sort(stable=stable)\n    return pyg_lib.ops.index_sort(inputs, max_value=max_value)\n"
  },
  {
    "path": "torch_geometric/utils/_lexsort.py",
    "content": "from typing import List\n\nfrom torch import Tensor\n\n\ndef lexsort(\n    keys: List[Tensor],\n    dim: int = -1,\n    descending: bool = False,\n) -> Tensor:\n    r\"\"\"Performs an indirect stable sort using a sequence of keys.\n\n    Given multiple sorting keys, returns an array of integer indices that\n    describe their sort order.\n    The last key in the sequence is used for the primary sort order, the\n    second-to-last key for the secondary sort order, and so on.\n\n    Args:\n        keys ([torch.Tensor]): The :math:`k` different columns to be sorted.\n            The last key is the primary sort key.\n        dim (int, optional): The dimension to sort along. (default: :obj:`-1`)\n        descending (bool, optional): Controls the sorting order (ascending or\n            descending). (default: :obj:`False`)\n    \"\"\"\n    assert len(keys) >= 1\n\n    out = keys[0].argsort(dim=dim, descending=descending, stable=True)\n    for k in keys[1:]:\n        index = k.gather(dim, out)\n        index = index.argsort(dim=dim, descending=descending, stable=True)\n        out = out.gather(dim, index)\n\n    return out\n"
  },
  {
    "path": "torch_geometric/utils/_negative_sampling.py",
    "content": "import random\nfrom typing import Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import coalesce, cumsum, degree, remove_self_loops\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef negative_sampling(\n    edge_index: Tensor,\n    num_nodes: Optional[Union[int, Tuple[int, int]]] = None,\n    num_neg_samples: Optional[Union[int, float]] = None,\n    method: str = \"sparse\",\n    force_undirected: bool = False,\n) -> Tensor:\n    r\"\"\"Samples random negative edges of a graph given by :attr:`edge_index`.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        num_nodes (int or Tuple[int, int], optional): The number of nodes,\n            *i.e.* :obj:`max_val + 1` of :attr:`edge_index`.\n            If given as a tuple, then :obj:`edge_index` is interpreted as a\n            bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`.\n            (default: :obj:`None`)\n        num_neg_samples (int or float, optional): The (approximate) number of\n            negative samples to return. If set to a floating-point value, it\n            represents the ratio of negative samples to generate based on the\n            number of positive edges. If set to :obj:`None`, will try to\n            return a negative edge for every positive edge.\n            (default: :obj:`None`)\n        method (str, optional): The method to use for negative sampling,\n            *i.e.* :obj:`\"sparse\"` or :obj:`\"dense\"`.\n            This is a memory/runtime trade-off.\n            :obj:`\"sparse\"` will work on any graph of any size, while\n            :obj:`\"dense\"` can perform faster true-negative checks.\n            (default: :obj:`\"sparse\"`)\n        force_undirected (bool, optional): If set to :obj:`True`, sampled\n            negative edges will be undirected. (default: :obj:`False`)\n\n    :rtype: LongTensor\n\n    Examples:\n        >>> # Standard usage\n        >>> edge_index = torch.as_tensor([[0, 0, 1, 2],\n        ...                               [0, 1, 2, 3]])\n        >>> negative_sampling(edge_index)\n        tensor([[3, 0, 0, 3],\n                [2, 3, 2, 1]])\n\n        >>> negative_sampling(edge_index, num_nodes=(3, 4),\n        ...                   num_neg_samples=0.5)  # 50% of positive edges\n        tensor([[0, 3],\n                [3, 0]])\n\n        >>> # For bipartite graph\n        >>> negative_sampling(edge_index, num_nodes=(3, 4))\n        tensor([[0, 2, 2, 1],\n                [2, 2, 1, 3]])\n    \"\"\"\n    assert method in ['sparse', 'dense']\n\n    if num_nodes is None:\n        num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    if isinstance(num_nodes, int):\n        size = (num_nodes, num_nodes)\n        bipartite = False\n    else:\n        size = num_nodes\n        bipartite = True\n        force_undirected = False\n\n    idx, population = edge_index_to_vector(edge_index, size, bipartite,\n                                           force_undirected)\n\n    if idx.numel() >= population:\n        return edge_index.new_empty((2, 0))\n\n    if num_neg_samples is None:\n        num_neg_samples = edge_index.size(1)\n    elif isinstance(num_neg_samples, float):\n        num_neg_samples = int(num_neg_samples * edge_index.size(1))\n    if force_undirected:\n        num_neg_samples = num_neg_samples // 2\n\n    prob = 1. - idx.numel() / population  # Probability to sample a negative.\n    sample_size = int(1.1 * num_neg_samples / prob)  # (Over)-sample size.\n\n    neg_idx: Optional[Tensor] = None\n    if method == 'dense':\n        # The dense version creates a mask of shape `population` to check for\n        # invalid samples.\n        mask = idx.new_ones(population, dtype=torch.bool)\n        mask[idx] = False\n        for _ in range(3):  # Number of tries to sample negative indices.\n            rnd = sample(population, sample_size, idx.device)\n            rnd = rnd[mask[rnd]]  # Filter true negatives.\n            neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])\n            if neg_idx.numel() >= num_neg_samples:\n                neg_idx = neg_idx[:num_neg_samples]\n                break\n            mask[neg_idx] = False\n\n    else:  # 'sparse'\n        # The sparse version checks for invalid samples via `np.isin`.\n        idx = idx.to('cpu')\n        for _ in range(3):  # Number of tries to sample negative indices.\n            rnd = sample(population, sample_size, device='cpu')\n            mask = torch.from_numpy(np.isin(rnd.numpy(), idx.numpy())).bool()\n            if neg_idx is not None:\n                mask |= torch.from_numpy(np.isin(rnd, neg_idx.cpu())).bool()\n            rnd = rnd[~mask].to(edge_index.device)\n            neg_idx = rnd if neg_idx is None else torch.cat([neg_idx, rnd])\n            if neg_idx.numel() >= num_neg_samples:\n                neg_idx = neg_idx[:num_neg_samples]\n                break\n\n    assert neg_idx is not None\n    return vector_to_edge_index(neg_idx, size, bipartite, force_undirected)\n\n\ndef batched_negative_sampling(\n    edge_index: Tensor,\n    batch: Union[Tensor, Tuple[Tensor, Tensor]],\n    num_neg_samples: Optional[Union[int, float]] = None,\n    method: str = \"sparse\",\n    force_undirected: bool = False,\n) -> Tensor:\n    r\"\"\"Samples random negative edges of multiple graphs given by\n    :attr:`edge_index` and :attr:`batch`.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        batch (LongTensor or Tuple[LongTensor, LongTensor]): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example.\n            If given as a tuple, then :obj:`edge_index` is interpreted as a\n            bipartite graph connecting two different node types.\n        num_neg_samples (int or float, optional): The number of negative\n            samples to return. If set to :obj:`None`, will try to return a\n            negative edge for every positive edge. If float, it will generate\n            :obj:`num_neg_samples * num_edges` negative samples.\n            (default: :obj:`None`)\n        method (str, optional): The method to use for negative sampling,\n            *i.e.* :obj:`\"sparse\"` or :obj:`\"dense\"`.\n            This is a memory/runtime trade-off.\n            :obj:`\"sparse\"` will work on any graph of any size, while\n            :obj:`\"dense\"` can perform faster true-negative checks.\n            (default: :obj:`\"sparse\"`)\n        force_undirected (bool, optional): If set to :obj:`True`, sampled\n            negative edges will be undirected. (default: :obj:`False`)\n\n    :rtype: LongTensor\n\n    Examples:\n        >>> # Standard usage\n        >>> edge_index = torch.as_tensor([[0, 0, 1, 2], [0, 1, 2, 3]])\n        >>> edge_index = torch.cat([edge_index, edge_index + 4], dim=1)\n        >>> edge_index\n        tensor([[0, 0, 1, 2, 4, 4, 5, 6],\n                [0, 1, 2, 3, 4, 5, 6, 7]])\n        >>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])\n        >>> batched_negative_sampling(edge_index, batch)\n        tensor([[3, 1, 3, 2, 7, 7, 6, 5],\n                [2, 0, 1, 1, 5, 6, 4, 4]])\n\n        >>> # Using float multiplier for negative samples\n        >>> batched_negative_sampling(edge_index, batch, num_neg_samples=1.5)\n        tensor([[3, 1, 3, 2, 7, 7, 6, 5, 2, 0, 1, 1],\n                [2, 0, 1, 1, 5, 6, 4, 4, 3, 2, 3, 0]])\n\n        >>> # For bipartite graph\n        >>> edge_index1 = torch.as_tensor([[0, 0, 1, 1], [0, 1, 2, 3]])\n        >>> edge_index2 = edge_index1 + torch.tensor([[2], [4]])\n        >>> edge_index3 = edge_index2 + torch.tensor([[2], [4]])\n        >>> edge_index = torch.cat([edge_index1, edge_index2,\n        ...                         edge_index3], dim=1)\n        >>> edge_index\n        tensor([[ 0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5],\n                [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11]])\n        >>> src_batch = torch.tensor([0, 0, 1, 1, 2, 2])\n        >>> dst_batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2])\n        >>> batched_negative_sampling(edge_index,\n        ...                           (src_batch, dst_batch))\n        tensor([[ 0,  0,  1,  1,  2,  2,  3,  3,  4,  4,  5,  5],\n                [ 2,  3,  0,  1,  6,  7,  4,  5, 10, 11,  8,  9]])\n    \"\"\"\n    if isinstance(batch, Tensor):\n        src_batch, dst_batch = batch, batch\n    else:\n        src_batch, dst_batch = batch[0], batch[1]\n\n    split = degree(src_batch[edge_index[0]], dtype=torch.long).tolist()\n    edge_indices = torch.split(edge_index, split, dim=1)\n\n    num_src = degree(src_batch, dtype=torch.long)\n    cum_src = cumsum(num_src)[:-1]\n\n    if isinstance(batch, Tensor):\n        num_nodes = num_src.tolist()\n        ptr = cum_src\n    else:\n        num_dst = degree(dst_batch, dtype=torch.long)\n        cum_dst = cumsum(num_dst)[:-1]\n\n        num_nodes = torch.stack([num_src, num_dst], dim=1).tolist()\n        ptr = torch.stack([cum_src, cum_dst], dim=1).unsqueeze(-1)\n\n    neg_edge_indices = []\n    for i, edge_index in enumerate(edge_indices):\n        edge_index = edge_index - ptr[i]\n        neg_edge_index = negative_sampling(edge_index, num_nodes[i],\n                                           num_neg_samples, method,\n                                           force_undirected)\n        neg_edge_index += ptr[i]\n        neg_edge_indices.append(neg_edge_index)\n\n    return torch.cat(neg_edge_indices, dim=1)\n\n\ndef structured_negative_sampling(\n    edge_index: Tensor,\n    num_nodes: Optional[int] = None,\n    contains_neg_self_loops: bool = True,\n) -> Tuple[Tensor, Tensor, Tensor]:\n    r\"\"\"Samples a negative edge :obj:`(i,k)` for every positive edge\n    :obj:`(i,j)` in the graph given by :attr:`edge_index`, and returns it as a\n    tuple of the form :obj:`(i,j,k)`.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n        contains_neg_self_loops (bool, optional): If set to\n            :obj:`False`, sampled negative edges will not contain self loops.\n            (default: :obj:`True`)\n\n    :rtype: (LongTensor, LongTensor, LongTensor)\n\n    Example:\n        >>> edge_index = torch.as_tensor([[0, 0, 1, 2],\n        ...                               [0, 1, 2, 3]])\n        >>> structured_negative_sampling(edge_index)\n        (tensor([0, 0, 1, 2]), tensor([0, 1, 2, 3]), tensor([2, 3, 0, 2]))\n\n    \"\"\"\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    row, col = edge_index.cpu()\n    pos_idx = row * num_nodes + col\n    if not contains_neg_self_loops:\n        loop_idx = torch.arange(num_nodes) * (num_nodes + 1)\n        pos_idx = torch.cat([pos_idx, loop_idx], dim=0)\n\n    rand = torch.randint(num_nodes, (row.size(0), ), dtype=torch.long)\n    neg_idx = row * num_nodes + rand\n\n    mask = torch.from_numpy(np.isin(neg_idx, pos_idx)).to(torch.bool)\n    rest = mask.nonzero(as_tuple=False).view(-1)\n    while rest.numel() > 0:  # pragma: no cover\n        tmp = torch.randint(num_nodes, (rest.size(0), ), dtype=torch.long)\n        rand[rest] = tmp\n        neg_idx = row[rest] * num_nodes + tmp\n\n        mask = torch.from_numpy(np.isin(neg_idx, pos_idx)).to(torch.bool)\n        rest = rest[mask]\n\n    return edge_index[0], edge_index[1], rand.to(edge_index.device)\n\n\ndef structured_negative_sampling_feasible(\n    edge_index: Tensor,\n    num_nodes: Optional[int] = None,\n    contains_neg_self_loops: bool = True,\n) -> bool:\n    r\"\"\"Returns :obj:`True` if\n    :meth:`~torch_geometric.utils.structured_negative_sampling` is feasible\n    on the graph given by :obj:`edge_index`.\n    :meth:`~torch_geometric.utils.structured_negative_sampling` is infeasible\n    if at least one node is connected to all other nodes.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n        contains_neg_self_loops (bool, optional): If set to\n            :obj:`False`, sampled negative edges will not contain self loops.\n            (default: :obj:`True`)\n\n    :rtype: bool\n\n    Examples:\n        >>> edge_index = torch.LongTensor([[0, 0, 1, 1, 2, 2, 2],\n        ...                                [1, 2, 0, 2, 0, 1, 1]])\n        >>> structured_negative_sampling_feasible(edge_index, 3, False)\n        False\n\n        >>> structured_negative_sampling_feasible(edge_index, 3, True)\n        True\n    \"\"\"\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n    max_num_neighbors = num_nodes\n\n    edge_index = coalesce(edge_index, num_nodes=num_nodes)\n\n    if not contains_neg_self_loops:\n        edge_index, _ = remove_self_loops(edge_index)\n        max_num_neighbors -= 1  # Reduce number of valid neighbors\n\n    deg = degree(edge_index[0], num_nodes)\n    # True if there exists no node that is connected to all other nodes.\n    return bool(torch.all(deg < max_num_neighbors))\n\n\n###############################################################################\n\n\ndef sample(\n    population: int,\n    k: int,\n    device: Optional[Union[torch.device, str]] = None,\n) -> Tensor:\n    if population <= k:\n        return torch.arange(population, device=device)\n    else:\n        return torch.tensor(random.sample(range(population), k), device=device)\n\n\ndef edge_index_to_vector(\n    edge_index: Tensor,\n    size: Tuple[int, int],\n    bipartite: bool,\n    force_undirected: bool = False,\n) -> Tuple[Tensor, int]:\n\n    row, col = edge_index\n\n    if bipartite:  # No need to account for self-loops.\n        idx = (row * size[1]).add_(col)\n        population = size[0] * size[1]\n        return idx, population\n\n    elif force_undirected:\n        assert size[0] == size[1]\n        num_nodes = size[0]\n\n        # We only operate on the upper triangular matrix:\n        mask = row < col\n        row, col = row[mask], col[mask]\n        offset = torch.arange(1, num_nodes, device=row.device).cumsum(0)[row]\n        idx = row.mul_(num_nodes).add_(col).sub_(offset)\n        population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes\n        return idx, population\n\n    else:\n        assert size[0] == size[1]\n        num_nodes = size[0]\n\n        # We remove self-loops as we do not want to take them into account\n        # when sampling negative values.\n        mask = row != col\n        row, col = row[mask], col[mask]\n        col[row < col] -= 1\n        idx = row.mul_(num_nodes - 1).add_(col)\n        population = num_nodes * num_nodes - num_nodes\n        return idx, population\n\n\ndef vector_to_edge_index(\n    idx: Tensor,\n    size: Tuple[int, int],\n    bipartite: bool,\n    force_undirected: bool = False,\n) -> Tensor:\n\n    if bipartite:  # No need to account for self-loops.\n        row = idx.div(size[1], rounding_mode='floor')\n        col = idx % size[1]\n        return torch.stack([row, col], dim=0)\n\n    elif force_undirected:\n        assert size[0] == size[1]\n        num_nodes = size[0]\n\n        offset = torch.arange(1, num_nodes, device=idx.device).cumsum(0)\n        end = torch.arange(num_nodes, num_nodes * num_nodes, num_nodes,\n                           device=idx.device)\n        row = torch.bucketize(idx, end.sub_(offset), right=True)\n        col = offset[row].add_(idx) % num_nodes\n        return torch.stack([torch.cat([row, col]), torch.cat([col, row])], 0)\n\n    else:\n        assert size[0] == size[1]\n        num_nodes = size[0]\n\n        row = idx.div(num_nodes - 1, rounding_mode='floor')\n        col = idx % (num_nodes - 1)\n        col[row <= col] += 1\n        return torch.stack([row, col], dim=0)\n"
  },
  {
    "path": "torch_geometric/utils/_normalize_edge_index.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import add_self_loops as add_self_loops_fn\nfrom torch_geometric.utils import degree\n\n\ndef normalize_edge_index(\n    edge_index: Tensor,\n    num_nodes: Optional[int] = None,\n    add_self_loops: bool = True,\n    symmetric: bool = True,\n) -> Tuple[Tensor, Tensor]:\n    \"\"\"Applies normalization to the edges of a graph.\n\n    This function can add self-loops to the graph and apply either symmetric or\n    asymmetric normalization based on the node degrees.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        num_nodes (int, int], optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n        add_self_loops (bool, optional): If set to :obj:`False`, will not add\n            self-loops to the input graph. (default: :obj:`True`)\n        symmetric (bool, optional):  If set to :obj:`True`, symmetric\n            normalization (:math:`D^{-1/2} A D^{-1/2}`) is used, otherwise\n            asymmetric normalization (:math:`D^{-1} A`).\n    \"\"\"\n    if add_self_loops:\n        edge_index, _ = add_self_loops_fn(edge_index, num_nodes=num_nodes)\n\n    row, col = edge_index[0], edge_index[1]\n    deg = degree(row, num_nodes, dtype=torch.get_default_dtype())\n\n    if symmetric:  # D^-1/2 * A * D^-1/2\n        deg_inv_sqrt = deg.pow(-0.5)\n        deg_inv_sqrt[torch.isinf(deg_inv_sqrt)] = 0\n        edge_weight = deg_inv_sqrt[row] * deg_inv_sqrt[col]\n    else:  # D^-1 * A\n        deg_inv = deg.pow(-1)\n        deg_inv[torch.isinf(deg_inv)] = 0\n        edge_weight = deg_inv[row]\n\n    return edge_index, edge_weight\n"
  },
  {
    "path": "torch_geometric/utils/_normalized_cut.py",
    "content": "from typing import Optional\n\nfrom torch import Tensor\n\nfrom torch_geometric.utils import degree\n\n\ndef normalized_cut(\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    num_nodes: Optional[int] = None,\n) -> Tensor:\n    r\"\"\"Computes the normalized cut :math:`\\mathbf{e}_{i,j} \\cdot\n    \\left( \\frac{1}{\\deg(i)} + \\frac{1}{\\deg(j)} \\right)` of a weighted graph\n    given by edge indices and edge attributes.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor): Edge weights or multi-dimensional edge features.\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n    :rtype: :class:`Tensor`\n\n    Example:\n        >>> edge_index = torch.tensor([[1, 1, 2, 3],\n        ...                            [3, 3, 1, 2]])\n        >>> edge_attr = torch.tensor([1., 1., 1., 1.])\n        >>> normalized_cut(edge_index, edge_attr)\n        tensor([1.5000, 1.5000, 2.0000, 1.5000])\n    \"\"\"\n    row, col = edge_index[0], edge_index[1]\n    deg = 1. / degree(col, num_nodes, edge_attr.dtype)\n    deg = deg[row] + deg[col]\n    cut = edge_attr * deg\n    return cut\n"
  },
  {
    "path": "torch_geometric/utils/_one_hot.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\n\ndef one_hot(\n    index: Tensor,\n    num_classes: Optional[int] = None,\n    dtype: Optional[torch.dtype] = None,\n) -> Tensor:\n    r\"\"\"Taskes a one-dimensional :obj:`index` tensor and returns a one-hot\n    encoded representation of it with shape :obj:`[*, num_classes]` that has\n    zeros everywhere except where the index of last dimension matches the\n    corresponding value of the input tensor, in which case it will be :obj:`1`.\n\n    .. note::\n        This is a more memory-efficient version of\n        :meth:`torch.nn.functional.one_hot` as you can customize the output\n        :obj:`dtype`.\n\n    Args:\n        index (torch.Tensor): The one-dimensional input tensor.\n        num_classes (int, optional): The total number of classes. If set to\n            :obj:`None`, the number of classes will be inferred as one greater\n            than the largest class value in the input tensor.\n            (default: :obj:`None`)\n        dtype (torch.dtype, optional): The :obj:`dtype` of the output tensor.\n    \"\"\"\n    if index.dim() != 1:\n        raise ValueError(\"'index' tensor needs to be one-dimensional\")\n\n    if num_classes is None:\n        num_classes = int(index.max()) + 1\n\n    out = torch.zeros((index.size(0), num_classes), dtype=dtype,\n                      device=index.device)\n    return out.scatter_(1, index.unsqueeze(1), 1)\n"
  },
  {
    "path": "torch_geometric/utils/_scatter.py",
    "content": "from typing import List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling, is_in_onnx_export, warnings\nfrom torch_geometric.typing import torch_scatter\nfrom torch_geometric.utils.functions import cumsum\n\nwarnings.filterwarnings('ignore', '.*is in beta and the API may change.*')\n\n\ndef scatter(\n    src: Tensor,\n    index: Tensor,\n    dim: int = 0,\n    dim_size: Optional[int] = None,\n    reduce: str = 'sum',\n) -> Tensor:\n    r\"\"\"Reduces all values from the :obj:`src` tensor at the indices specified\n    in the :obj:`index` tensor along a given dimension ``dim``. See the\n    `documentation <https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html>`__  # noqa: E501\n    of the ``torch_scatter`` package for more information.\n\n    Args:\n        src (torch.Tensor): The source tensor.\n        index (torch.Tensor): The index tensor.\n        dim (int, optional): The dimension along which to index.\n            (default: ``0``)\n        dim_size (int, optional): The size of the output tensor at dimension\n            ``dim``. If set to :obj:`None`, will create a minimal-sized output\n            tensor according to ``index.max() + 1``. (default: :obj:`None`)\n        reduce (str, optional): The reduce operation (``\"sum\"``, ``\"mean\"``,\n            ``\"mul\"``, ``\"min\"``, ``\"max\"`` or ``\"any\"``). (default: ``\"sum\"``)\n    \"\"\"\n    if isinstance(index, Tensor) and index.dim() != 1:\n        raise ValueError(f\"The `index` argument must be one-dimensional \"\n                         f\"(got {index.dim()} dimensions)\")\n\n    dim = src.dim() + dim if dim < 0 else dim\n\n    if isinstance(src, Tensor) and (dim < 0 or dim >= src.dim()):\n        raise ValueError(f\"The `dim` argument must lay between 0 and \"\n                         f\"{src.dim() - 1} (got {dim})\")\n\n    if dim_size is None:\n        dim_size = int(index.max()) + 1 if index.numel() > 0 else 0\n\n    # For now, we maintain various different code paths, based on whether\n    # the input requires gradients and whether it lays on the CPU/GPU.\n    # For example, `torch_scatter` is usually faster than\n    # `torch.scatter_reduce` on GPU, while `torch.scatter_reduce` is faster\n    # on CPU.\n    # `torch.scatter_reduce` has a faster forward implementation for\n    # \"min\"/\"max\" reductions since it does not compute additional arg\n    # indices, but is therefore way slower in its backward implementation.\n    # More insights can be found in `test/utils/test_scatter.py`.\n\n    size = src.size()[:dim] + (dim_size, ) + src.size()[dim + 1:]\n\n    # For \"any\" reduction, we use regular `scatter_`:\n    if reduce == 'any':\n        index = broadcast(index, src, dim)\n        return src.new_zeros(size).scatter_(dim, index, src)\n\n    # For \"sum\" and \"mean\" reduction, we make use of `scatter_add_`:\n    if reduce == 'sum' or reduce == 'add':\n        index = broadcast(index, src, dim)\n        return src.new_zeros(size).scatter_add_(dim, index, src)\n\n    if reduce == 'mean':\n        count = src.new_zeros(dim_size)\n        count.scatter_add_(0, index, src.new_ones(src.size(dim)))\n        count = count.clamp(min=1)\n\n        index = broadcast(index, src, dim)\n        out = src.new_zeros(size).scatter_add_(dim, index, src)\n\n        return out / broadcast(count, out, dim)\n\n    # For \"min\" and \"max\" reduction, we prefer `scatter_reduce_` on CPU or\n    # in case the input does not require gradients:\n    if reduce in ['min', 'max', 'amin', 'amax']:\n        if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()\n                or is_in_onnx_export() or not src.is_cuda\n                or not src.requires_grad):\n\n            if (src.is_cuda and src.requires_grad and not is_compiling()\n                    and not is_in_onnx_export()):\n                warnings.warn(\n                    f\"The usage of `scatter(reduce='{reduce}')` \"\n                    f\"can be accelerated via the 'torch-scatter'\"\n                    f\" package, but it was not found\", stacklevel=2)\n\n            index = broadcast(index, src, dim)\n            if not is_in_onnx_export():\n                return src.new_zeros(size).scatter_reduce_(\n                    dim, index, src, reduce=f'a{reduce[-3:]}',\n                    include_self=False)\n\n            fill = torch.full(  # type: ignore\n                size=(1, ),\n                fill_value=src.min() if 'max' in reduce else src.max(),\n                dtype=src.dtype,\n                device=src.device,\n            ).expand_as(src)\n            out = src.new_zeros(size).scatter_reduce_(dim, index, fill,\n                                                      reduce=f'a{reduce[-3:]}',\n                                                      include_self=True)\n            return out.scatter_reduce_(dim, index, src,\n                                       reduce=f'a{reduce[-3:]}',\n                                       include_self=True)\n\n        return torch_scatter.scatter(src, index, dim, dim_size=dim_size,\n                                     reduce=reduce[-3:])\n\n    # For \"mul\" reduction, we prefer `scatter_reduce_` on CPU:\n    if reduce == 'mul':\n        if (not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling()\n                or not src.is_cuda):\n\n            if src.is_cuda and not is_compiling():\n                warnings.warn(\n                    f\"The usage of `scatter(reduce='{reduce}')` \"\n                    f\"can be accelerated via the 'torch-scatter'\"\n                    f\" package, but it was not found\", stacklevel=2)\n\n            index = broadcast(index, src, dim)\n            # We initialize with `one` here to match `scatter_mul` output:\n            return src.new_ones(size).scatter_reduce_(dim, index, src,\n                                                      reduce='prod',\n                                                      include_self=True)\n\n        return torch_scatter.scatter(src, index, dim, dim_size=dim_size,\n                                     reduce='mul')\n\n    raise ValueError(f\"Encountered invalid `reduce` argument '{reduce}'\")\n\n\ndef broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor:\n    dim = ref.dim() + dim if dim < 0 else dim\n    size = ((1, ) * dim) + (-1, ) + ((1, ) * (ref.dim() - dim - 1))\n    return src.view(size).expand_as(ref)\n\n\ndef scatter_argmax(\n    src: Tensor,\n    index: Tensor,\n    dim: int = 0,\n    dim_size: Optional[int] = None,\n) -> Tensor:\n\n    if (torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()\n            and not is_in_onnx_export()):\n        out = torch_scatter.scatter_max(src, index, dim=dim, dim_size=dim_size)\n        return out[1]\n\n    # Only implemented under certain conditions for now :(\n    assert src.dim() == 1 and index.dim() == 1\n    assert dim == 0 or dim == -1\n    assert src.numel() == index.numel()\n\n    if dim_size is None:\n        dim_size = int(index.max()) + 1 if index.numel() > 0 else 0\n\n    if not is_in_onnx_export():\n        res = src.new_empty(dim_size)\n        res.scatter_reduce_(0, index, src.detach(), reduce='amax',\n                            include_self=False)\n    else:\n        # `include_self=False` is currently not supported by ONNX:\n        res = src.new_full(\n            size=(dim_size, ),\n            fill_value=src.min(),  # type: ignore\n        )\n        res.scatter_reduce_(0, index, src.detach(), reduce=\"amax\",\n                            include_self=True)\n\n    out = index.new_full((dim_size, ), fill_value=dim_size - 1)\n    nonzero = (src == res[index]).nonzero().view(-1)\n    out[index[nonzero]] = nonzero\n\n    return out\n\n\ndef group_argsort(\n    src: Tensor,\n    index: Tensor,\n    dim: int = 0,\n    num_groups: Optional[int] = None,\n    descending: bool = False,\n    return_consecutive: bool = False,\n    stable: bool = False,\n) -> Tensor:\n    r\"\"\"Returns the indices that sort the tensor :obj:`src` along a given\n    dimension in ascending order by value.\n    In contrast to :meth:`torch.argsort`, sorting is performed in groups\n    according to the values in :obj:`index`.\n\n    Args:\n        src (torch.Tensor): The source tensor.\n        index (torch.Tensor): The index tensor.\n        dim (int, optional): The dimension along which to index.\n            (default: :obj:`0`)\n        num_groups (int, optional): The number of groups.\n            (default: :obj:`None`)\n        descending (bool, optional): Controls the sorting order (ascending or\n            descending). (default: :obj:`False`)\n        return_consecutive (bool, optional): If set to :obj:`True`, will not\n            offset the output to start from :obj:`0` for each group.\n            (default: :obj:`False`)\n        stable (bool, optional): Controls the relative order of equivalent\n            elements. (default: :obj:`False`)\n\n    Example:\n        >>> src = torch.tensor([0, 1, 5, 4, 3, 2, 6, 7, 8])\n        >>> index = torch.tensor([0, 0, 1, 1, 1, 1, 2, 2, 2])\n        >>> group_argsort(src, index)\n        tensor([0, 1, 3, 2, 1, 0, 0, 1, 2])\n    \"\"\"\n    # Only implemented under certain conditions for now :(\n    assert src.dim() == 1 and index.dim() == 1\n    assert dim == 0 or dim == -1\n    assert src.numel() == index.numel()\n\n    if src.numel() == 0:\n        return torch.zeros_like(src)\n\n    # Normalize `src` to range [0, 1]:\n    src = src - src.min()\n    src = src / src.max()\n\n    # Compute `grouped_argsort`:\n    src = src - 2 * index if descending else src + 2 * index\n    perm = src.argsort(descending=descending, stable=stable)\n    out = torch.empty_like(index)\n    out[perm] = torch.arange(index.numel(), device=index.device)\n\n    if return_consecutive:\n        return out\n\n    # Compute cumulative sum of number of entries with the same index:\n    count = scatter(torch.ones_like(index), index, dim=dim,\n                    dim_size=num_groups, reduce='sum')\n    ptr = cumsum(count)\n\n    return out - ptr[index]\n\n\ndef group_cat(\n    tensors: Union[List[Tensor], Tuple[Tensor, ...]],\n    indices: Union[List[Tensor], Tuple[Tensor, ...]],\n    dim: int = 0,\n    return_index: bool = False,\n) -> Union[Tensor, Tuple[Tensor, Tensor]]:\n    r\"\"\"Concatenates the given sequence of tensors :obj:`tensors` in the given\n    dimension :obj:`dim`.\n    Different from :meth:`torch.cat`, values along the concatenating dimension\n    are grouped according to the indices defined in the :obj:`index` tensors.\n    All tensors must have the same shape (except in the concatenating\n    dimension).\n\n    Args:\n        tensors ([Tensor]): Sequence of tensors.\n        indices ([Tensor]): Sequence of index tensors.\n        dim (int, optional): The dimension along which the tensors are\n            concatenated. (default: :obj:`0`)\n        return_index (bool, optional): If set to :obj:`True`, will return the\n            new index tensor. (default: :obj:`False`)\n\n    Example:\n        >>> x1 = torch.tensor([[0.2716, 0.4233],\n        ...                    [0.3166, 0.0142],\n        ...                    [0.2371, 0.3839],\n        ...                    [0.4100, 0.0012]])\n        >>> x2 = torch.tensor([[0.3752, 0.5782],\n        ...                    [0.7757, 0.5999]])\n        >>> index1 = torch.tensor([0, 0, 1, 2])\n        >>> index2 = torch.tensor([0, 2])\n        >>> scatter_concat([x1,x2], [index1, index2], dim=0)\n        tensor([[0.2716, 0.4233],\n                [0.3166, 0.0142],\n                [0.3752, 0.5782],\n                [0.2371, 0.3839],\n                [0.4100, 0.0012],\n                [0.7757, 0.5999]])\n    \"\"\"\n    assert len(tensors) == len(indices)\n    index, perm = torch.cat(indices).sort(stable=True)\n    out = torch.cat(tensors, dim=dim).index_select(dim, perm)\n    return (out, index) if return_index else out\n"
  },
  {
    "path": "torch_geometric/utils/_segment.py",
    "content": "import torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling\nfrom torch_geometric.index import ptr2index\nfrom torch_geometric.typing import torch_scatter\nfrom torch_geometric.utils import scatter\n\n\ndef segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') -> Tensor:\n    r\"\"\"Reduces all values in the first dimension of the :obj:`src` tensor\n    within the ranges specified in the :obj:`ptr`. See the `documentation\n    <https://pytorch-scatter.readthedocs.io/en/latest/functions/\n    segment_csr.html>`__ of the :obj:`torch_scatter` package for more\n    information.\n\n    Args:\n        src (torch.Tensor): The source tensor.\n        ptr (torch.Tensor): A monotonically increasing pointer tensor that\n            refers to the boundaries of segments such that :obj:`ptr[0] = 0`\n            and :obj:`ptr[-1] = src.size(0)`.\n        reduce (str, optional): The reduce operation (:obj:`\"sum\"`,\n            :obj:`\"mean\"`, :obj:`\"min\"` or :obj:`\"max\"`).\n            (default: :obj:`\"sum\"`)\n    \"\"\"\n    if not torch_geometric.typing.WITH_TORCH_SCATTER or is_compiling():\n        return _torch_segment(src, ptr, reduce)\n\n    if (ptr.dim() == 1 and torch_geometric.typing.WITH_PT20 and src.is_cuda\n            and reduce == 'mean'):\n        return _torch_segment(src, ptr, reduce)\n\n    return torch_scatter.segment_csr(src, ptr, reduce=reduce)\n\n\ndef _torch_segment(src: Tensor, ptr: Tensor, reduce: str = 'sum') -> Tensor:\n    if not torch_geometric.typing.WITH_PT20:\n        raise ImportError(\"'segment' requires the 'torch-scatter' package\")\n    if ptr.dim() > 1:\n        raise ImportError(\"'segment' in an arbitrary dimension \"\n                          \"requires the 'torch-scatter' package\")\n\n    if reduce == 'min' or reduce == 'max':\n        reduce = f'a{reduce}'  # `amin` or `amax`\n    initial = 0 if reduce == 'mean' else None\n    out = torch._segment_reduce(src, reduce, offsets=ptr, initial=initial)\n    if reduce == 'amin' or reduce == 'amax':\n        out = torch.where(out.isinf(), 0, out)\n    return out\n\n\ndef segment_logsumexp(\n    src: Tensor,\n    ptr: Tensor,\n    dim: int,\n) -> Tensor:\n    r\"\"\"Returns the log summed exponentials of each row of the :obj:`src`\n    tensor within the ranges specified in the :obj:`ptr`.\n\n    Args:\n        src: The source tensor.\n        ptr (torch.Tensor): A monotonically increasing pointer tensor that\n            refers to the boundaries of segments such that :obj:`ptr[0] = 0`\n            and :obj:`ptr[-1] = src.size(0)`.\n        dim: The dimension to reduce.\n    \"\"\"\n    src = src.transpose(0, dim)  # Move reduction dimension to first dimension.\n\n    index = ptr2index(ptr, output_size=src.size(0))\n    max_src = scatter(src, index, dim_size=ptr.numel() - 1, reduce='max')\n    src = src - max_src[index]\n\n    out = src.exp()\n    out = segment(out, ptr, reduce='sum')\n    out = out.log().nan_to_num(neginf=0.0) + max_src\n\n    out = out.transpose(0, dim)\n\n    return out\n"
  },
  {
    "path": "torch_geometric/utils/_select.py",
    "content": "from typing import Any, List, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import TensorFrame\nfrom torch_geometric.utils.mask import mask_select\nfrom torch_geometric.utils.sparse import is_torch_sparse_tensor\n\n\ndef select(\n    src: Union[Tensor, List[Any], TensorFrame],\n    index_or_mask: Tensor,\n    dim: int,\n) -> Union[Tensor, List[Any]]:\n    r\"\"\"Selects the input tensor or input list according to a given index or\n    mask vector.\n\n    Args:\n        src (torch.Tensor or list): The input tensor or list.\n        index_or_mask (torch.Tensor): The index or mask vector.\n        dim (int): The dimension along which to select.\n    \"\"\"\n    if isinstance(src, Tensor):\n        if index_or_mask.dtype == torch.bool:\n            return mask_select(src, dim, index_or_mask)\n        return src.index_select(dim, index_or_mask)\n\n    if isinstance(src, (tuple, list)):\n        if dim != 0:\n            raise ValueError(\"Cannot select along dimension other than 0\")\n        if index_or_mask.dtype == torch.bool:\n            return [src[i] for i, m in enumerate(index_or_mask) if m]\n        return [src[i] for i in index_or_mask]\n\n    if isinstance(src, TensorFrame):\n        assert dim == 0\n        if index_or_mask.dtype == torch.bool:\n            return mask_select(src, dim, index_or_mask)\n        return src[index_or_mask]\n\n    raise ValueError(f\"Encountered invalid input type (got '{type(src)}')\")\n\n\ndef narrow(src: Union[Tensor, List[Any]], dim: int, start: int,\n           length: int) -> Union[Tensor, List[Any]]:\n    r\"\"\"Narrows the input tensor or input list to the specified range.\n\n    Args:\n        src (torch.Tensor or list): The input tensor or list.\n        dim (int): The dimension along which to narrow.\n        start (int): The starting dimension.\n        length (int): The distance to the ending dimension.\n    \"\"\"\n    if isinstance(src, Tensor) and is_torch_sparse_tensor(src):\n        # TODO Sparse tensors in `torch.sparse` do not yet support `narrow`.\n        index = torch.arange(start, start + length, device=src.device)\n        return src.index_select(dim, index)\n\n    if isinstance(src, Tensor):\n        return src.narrow(dim, start, length)\n\n    if isinstance(src, list):\n        if dim != 0:\n            raise ValueError(\"Cannot narrow along dimension other than 0\")\n        return src[start:start + length]\n\n    raise ValueError(f\"Encountered invalid input type (got '{type(src)}')\")\n"
  },
  {
    "path": "torch_geometric/utils/_softmax.py",
    "content": "from typing import Optional\n\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling\nfrom torch_geometric.typing import pyg_lib\nfrom torch_geometric.utils import scatter, segment\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef softmax(\n    src: Tensor,\n    index: Optional[Tensor] = None,\n    ptr: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n    dim: int = 0,\n) -> Tensor:\n    r\"\"\"Computes a sparsely evaluated softmax.\n    Given a value tensor :attr:`src`, this function first groups the values\n    along the first dimension based on the indices specified in :attr:`index`,\n    and then proceeds to compute the softmax individually for each group.\n\n    Args:\n        src (Tensor): The source tensor.\n        index (LongTensor, optional): The indices of elements for applying the\n            softmax. (default: :obj:`None`)\n        ptr (LongTensor, optional): If given, computes the softmax based on\n            sorted inputs in CSR representation. (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)\n        dim (int, optional): The dimension in which to normalize.\n            (default: :obj:`0`)\n\n    :rtype: :class:`Tensor`\n\n    Examples:\n        >>> src = torch.tensor([1., 1., 1., 1.])\n        >>> index = torch.tensor([0, 0, 1, 2])\n        >>> ptr = torch.tensor([0, 2, 3, 4])\n        >>> softmax(src, index)\n        tensor([0.5000, 0.5000, 1.0000, 1.0000])\n\n        >>> softmax(src, None, ptr)\n        tensor([0.5000, 0.5000, 1.0000, 1.0000])\n\n        >>> src = torch.randn(4, 4)\n        >>> ptr = torch.tensor([0, 4])\n        >>> softmax(src, index, dim=-1)\n        tensor([[0.7404, 0.2596, 1.0000, 1.0000],\n                [0.1702, 0.8298, 1.0000, 1.0000],\n                [0.7607, 0.2393, 1.0000, 1.0000],\n                [0.8062, 0.1938, 1.0000, 1.0000]])\n    \"\"\"\n    if (ptr is not None and src.device.type == 'cpu'\n            and torch_geometric.typing.WITH_SOFTMAX\n            and not is_compiling()):  # pragma: no cover\n        return pyg_lib.ops.softmax_csr(src, ptr, dim)\n\n    if (ptr is not None and\n        (ptr.dim() == 1 or (ptr.dim() > 1 and index is None) or\n         (torch_geometric.typing.WITH_TORCH_SCATTER and not is_compiling()))):\n\n        dim = dim + src.dim() if dim < 0 else dim\n        size = ([1] * dim) + [-1]\n        count = ptr[1:] - ptr[:-1]\n        ptr = ptr.view(size)\n        src_max = segment(src.detach(), ptr, reduce='max')\n        src_max = src_max.repeat_interleave(count, dim=dim)\n        out = (src - src_max).exp()\n        out_sum = segment(out, ptr, reduce='sum') + 1e-16\n        out_sum = out_sum.repeat_interleave(count, dim=dim)\n    elif index is not None:\n        N = maybe_num_nodes(index, num_nodes)\n        src_max = scatter(src.detach(), index, dim, dim_size=N, reduce='max')\n        out = src - src_max.index_select(dim, index)\n        out = out.exp()\n        out_sum = scatter(out, index, dim, dim_size=N, reduce='sum') + 1e-16\n        out_sum = out_sum.index_select(dim, index)\n    else:\n        raise NotImplementedError(\"'softmax' requires 'index' to be specified\")\n\n    return out / out_sum\n"
  },
  {
    "path": "torch_geometric/utils/_sort_edge_index.py",
    "content": "import typing\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.edge_index import SortOrder\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import index_sort, lexsort\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\nif typing.TYPE_CHECKING:\n    from typing import overload\nelse:\n    from torch.jit import _overload as overload\n\nMISSING = '???'\n\n\n@overload\ndef sort_edge_index(\n    edge_index: Tensor,\n    edge_attr: str = MISSING,\n    num_nodes: Optional[int] = None,\n    sort_by_row: bool = True,\n) -> Tensor:\n    pass\n\n\n@overload\ndef sort_edge_index(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    num_nodes: Optional[int] = None,\n    sort_by_row: bool = True,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef sort_edge_index(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: OptTensor,\n    num_nodes: Optional[int] = None,\n    sort_by_row: bool = True,\n) -> Tuple[Tensor, OptTensor]:\n    pass\n\n\n@overload\ndef sort_edge_index(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: List[Tensor],\n    num_nodes: Optional[int] = None,\n    sort_by_row: bool = True,\n) -> Tuple[Tensor, List[Tensor]]:\n    pass\n\n\ndef sort_edge_index(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Union[OptTensor, List[Tensor], str] = MISSING,\n    num_nodes: Optional[int] = None,\n    sort_by_row: bool = True,\n) -> Union[Tensor, Tuple[Tensor, OptTensor], Tuple[Tensor, List[Tensor]]]:\n    \"\"\"Row-wise sorts :obj:`edge_index`.\n\n    Args:\n        edge_index (torch.Tensor): The edge indices.\n        edge_attr (torch.Tensor or List[torch.Tensor], optional): Edge weights\n            or multi-dimensional edge features.\n            If given as a list, will re-shuffle and remove duplicates for all\n            its entries. (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n        sort_by_row (bool, optional): If set to :obj:`False`, will sort\n            :obj:`edge_index` column-wise/by destination node.\n            (default: :obj:`True`)\n\n    :rtype: :class:`LongTensor` if :attr:`edge_attr` is not passed, else\n        (:class:`LongTensor`, :obj:`Optional[Tensor]` or :obj:`List[Tensor]]`)\n\n    .. warning::\n\n        From :pyg:`PyG >= 2.3.0` onwards, this function will always return a\n        tuple whenever :obj:`edge_attr` is passed as an argument (even in case\n        it is set to :obj:`None`).\n\n    Examples:\n        >>> edge_index = torch.tensor([[2, 1, 1, 0],\n                                [1, 2, 0, 1]])\n        >>> edge_attr = torch.tensor([[1], [2], [3], [4]])\n        >>> sort_edge_index(edge_index)\n        tensor([[0, 1, 1, 2],\n                [1, 0, 2, 1]])\n\n        >>> sort_edge_index(edge_index, edge_attr)\n        (tensor([[0, 1, 1, 2],\n                [1, 0, 2, 1]]),\n        tensor([[4],\n                [3],\n                [2],\n                [1]]))\n    \"\"\"\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    if num_nodes * num_nodes > torch_geometric.typing.MAX_INT64:\n        perm = lexsort(keys=[\n            edge_index[int(sort_by_row)],\n            edge_index[1 - int(sort_by_row)],\n        ])\n    else:\n        idx = edge_index[1 - int(sort_by_row)] * num_nodes\n        idx += edge_index[int(sort_by_row)]\n        _, perm = index_sort(idx, max_value=num_nodes * num_nodes)\n\n    if isinstance(edge_index, Tensor):\n        is_undirected = False\n        if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n            is_undirected = edge_index.is_undirected\n        edge_index = edge_index[:, perm]\n        if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n            edge_index._sort_order = SortOrder('row' if sort_by_row else 'col')\n            edge_index._is_undirected = is_undirected\n    elif isinstance(edge_index, tuple):\n        edge_index = (edge_index[0][perm], edge_index[1][perm])\n    else:\n        raise NotImplementedError\n\n    if edge_attr is None:\n        return edge_index, None\n    if isinstance(edge_attr, Tensor):\n        return edge_index, edge_attr[perm]\n    if isinstance(edge_attr, (list, tuple)):\n        return edge_index, [e[perm] for e in edge_attr]\n\n    return edge_index\n"
  },
  {
    "path": "torch_geometric/utils/_spmm.py",
    "content": "import warnings\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.typing import Adj, SparseTensor, torch_sparse\nfrom torch_geometric.utils import is_torch_sparse_tensor, scatter\n\n\ndef spmm(\n    src: Adj,\n    other: Tensor,\n    reduce: str = 'sum',\n) -> Tensor:\n    r\"\"\"Matrix product of sparse matrix with dense matrix.\n\n    Args:\n        src (torch.Tensor or torch_sparse.SparseTensor or EdgeIndex):\n            The input sparse matrix which can be a\n            :pyg:`PyG` :class:`torch_sparse.SparseTensor`,\n            a :pytorch:`PyTorch` :class:`torch.sparse.Tensor` or\n            a :pyg:`PyG` :class:`EdgeIndex`.\n        other (torch.Tensor): The input dense matrix.\n        reduce (str, optional): The reduce operation to use\n            (:obj:`\"sum\"`, :obj:`\"mean\"`, :obj:`\"min\"`, :obj:`\"max\"`).\n            (default: :obj:`\"sum\"`)\n\n    :rtype: :class:`Tensor`\n    \"\"\"\n    reduce = 'sum' if reduce == 'add' else reduce\n\n    if reduce not in ['sum', 'mean', 'min', 'max']:\n        raise ValueError(f\"`reduce` argument '{reduce}' not supported\")\n\n    if not torch.jit.is_scripting() and isinstance(src, EdgeIndex):\n        return src.matmul(other=other, reduce=reduce)  # type: ignore\n\n    if isinstance(src, SparseTensor):\n        if src.nnz() == 0:\n            return other.new_zeros(src.size(0), other.size(1))\n\n        if (torch_geometric.typing.WITH_PT20 and other.dim() == 2\n                and not src.is_cuda() and not src.requires_grad()):\n            # Use optimized PyTorch `torch.sparse.mm` path:\n            csr = src.to_torch_sparse_csr_tensor().to(other.dtype)\n            return torch.sparse.mm(csr, other, reduce)\n        return torch_sparse.matmul(src, other, reduce)\n\n    if not is_torch_sparse_tensor(src):\n        raise ValueError(\"'src' must be a 'torch_sparse.SparseTensor' or a \"\n                         \"'torch.sparse.Tensor'\")\n\n    # `torch.sparse.mm` only supports reductions on CPU for PyTorch>=2.0.\n    # This will currently throw on error for CUDA tensors.\n    if torch_geometric.typing.WITH_PT20:\n\n        if src.is_cuda and (reduce == 'min' or reduce == 'max'):\n            raise NotImplementedError(f\"`{reduce}` reduction is not yet \"\n                                      f\"supported for 'torch.sparse.Tensor' \"\n                                      f\"on device '{src.device}'\")\n\n        # Always convert COO to CSR for more efficient processing:\n        if src.layout == torch.sparse_coo:\n            warnings.warn(\n                f\"Converting sparse tensor to CSR format for more \"\n                f\"efficient processing. Consider converting your \"\n                f\"sparse tensor to CSR format beforehand to avoid \"\n                f\"repeated conversion (got '{src.layout}')\", stacklevel=2)\n            src = src.to_sparse_csr()\n\n        # Warn in case of CSC format without gradient computation:\n        if src.layout == torch.sparse_csc and not other.requires_grad:\n            warnings.warn(\n                f\"Converting sparse tensor to CSR format for more \"\n                f\"efficient processing. Consider converting your \"\n                f\"sparse tensor to CSR format beforehand to avoid \"\n                f\"repeated conversion (got '{src.layout}')\", stacklevel=2)\n\n        # Use the default code path for `sum` reduction (works on CPU/GPU):\n        if reduce == 'sum':\n            return torch.sparse.mm(src, other)\n\n        # Use the default code path with custom reduction (works on CPU):\n        if src.layout == torch.sparse_csr and not src.is_cuda:\n            return torch.sparse.mm(src, other, reduce)\n\n        # Simulate `mean` reduction by dividing by degree:\n        if reduce == 'mean':\n            if src.layout == torch.sparse_csr:\n                ptr = src.crow_indices()\n                deg = ptr[1:] - ptr[:-1]\n            else:\n                assert src.layout == torch.sparse_csc\n                deg = scatter(torch.ones_like(src.values()), src.row_indices(),\n                              dim=0, dim_size=src.size(0), reduce='sum')\n\n            return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1)\n\n        # TODO The `torch.sparse.mm` code path with the `reduce` argument does\n        # not yet support CSC :(\n        if src.layout == torch.sparse_csc:\n            warnings.warn(\n                f\"Converting sparse tensor to CSR format for more \"\n                f\"efficient processing. Consider converting your \"\n                f\"sparse tensor to CSR format beforehand to avoid \"\n                f\"repeated conversion (got '{src.layout}')\", stacklevel=2)\n            src = src.to_sparse_csr()\n\n        return torch.sparse.mm(src, other, reduce)\n\n    # pragma: no cover\n    # PyTorch < 2.0 only supports sparse COO format:\n    if reduce == 'sum':\n        return torch.sparse.mm(src, other)\n    elif reduce == 'mean':\n        if src.layout == torch.sparse_csr:\n            ptr = src.crow_indices()\n            deg = ptr[1:] - ptr[:-1]\n        elif src.layout == torch.sparse_csc:\n            assert src.layout == torch.sparse_csc\n            ones = torch.ones_like(src.values())\n            index = src.row_indices()\n            deg = scatter(ones, index, 0, dim_size=src.size(0), reduce='sum')\n        else:\n            assert src.layout == torch.sparse_coo\n            src = src.coalesce()\n            ones = torch.ones_like(src.values())\n            index = src.indices()[0]\n            deg = scatter(ones, index, 0, dim_size=src.size(0), reduce='sum')\n\n        return torch.sparse.mm(src, other) / deg.view(-1, 1).clamp_(min=1)\n\n    raise ValueError(f\"`{reduce}` reduction is not supported for \"\n                     f\"'torch.sparse.Tensor' on device '{src.device}'\")\n"
  },
  {
    "path": "torch_geometric/utils/_subgraph.py",
    "content": "from typing import List, Literal, Optional, Tuple, Union, overload\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import OptTensor, PairTensor\nfrom torch_geometric.utils import scatter\nfrom torch_geometric.utils.map import map_index\nfrom torch_geometric.utils.mask import index_to_mask\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef get_num_hops(model: torch.nn.Module) -> int:\n    r\"\"\"Returns the number of hops the model is aggregating information\n    from.\n\n    .. note::\n\n        This function counts the number of message passing layers as an\n        approximation of the total number of hops covered by the model.\n        Its output may not necessarily be correct in case message passing\n        layers perform multi-hop aggregation, *e.g.*, as in\n        :class:`~torch_geometric.nn.conv.ChebConv`.\n\n    Example:\n        >>> class GNN(torch.nn.Module):\n        ...     def __init__(self):\n        ...         super().__init__()\n        ...         self.conv1 = GCNConv(3, 16)\n        ...         self.conv2 = GCNConv(16, 16)\n        ...         self.lin = Linear(16, 2)\n        ...\n        ...     def forward(self, x, edge_index):\n        ...         x = self.conv1(x, edge_index).relu()\n        ...         x = self.conv2(x, edge_index).relu()\n        ...         return self.lin(x)\n        >>> get_num_hops(GNN())\n        2\n    \"\"\"\n    from torch_geometric.nn.conv import MessagePassing\n    num_hops = 0\n    for module in model.modules():\n        if isinstance(module, MessagePassing):\n            num_hops += 1\n    return num_hops\n\n\n@overload\ndef subgraph(\n    subset: Union[Tensor, List[int]],\n    edge_index: Tensor,\n    edge_attr: OptTensor = ...,\n    relabel_nodes: bool = ...,\n    num_nodes: Optional[int] = ...,\n) -> Tuple[Tensor, OptTensor]:\n    pass\n\n\n@overload\ndef subgraph(\n    subset: Union[Tensor, List[int]],\n    edge_index: Tensor,\n    edge_attr: OptTensor = ...,\n    relabel_nodes: bool = ...,\n    num_nodes: Optional[int] = ...,\n    *,\n    return_edge_mask: Literal[False],\n) -> Tuple[Tensor, OptTensor]:\n    pass\n\n\n@overload\ndef subgraph(\n    subset: Union[Tensor, List[int]],\n    edge_index: Tensor,\n    edge_attr: OptTensor = ...,\n    relabel_nodes: bool = ...,\n    num_nodes: Optional[int] = ...,\n    *,\n    return_edge_mask: Literal[True],\n) -> Tuple[Tensor, OptTensor, Tensor]:\n    pass\n\n\ndef subgraph(\n    subset: Union[Tensor, List[int]],\n    edge_index: Tensor,\n    edge_attr: OptTensor = None,\n    relabel_nodes: bool = False,\n    num_nodes: Optional[int] = None,\n    *,\n    return_edge_mask: bool = False,\n) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]:\n    r\"\"\"Returns the induced subgraph of :obj:`(edge_index, edge_attr)`\n    containing the nodes in :obj:`subset`.\n\n    Args:\n        subset (LongTensor, BoolTensor or [int]): The nodes to keep.\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional\n            edge features. (default: :obj:`None`)\n        relabel_nodes (bool, optional): If set to :obj:`True`, the resulting\n            :obj:`edge_index` will be relabeled to hold consecutive indices\n            starting from zero. (default: :obj:`False`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max(edge_index) + 1`. (default: :obj:`None`)\n        return_edge_mask (bool, optional): If set to :obj:`True`, will return\n            the edge mask to filter out additional edge features.\n            (default: :obj:`False`)\n\n    :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6],\n        ...                            [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5]])\n        >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])\n        >>> subset = torch.tensor([3, 4, 5])\n        >>> subgraph(subset, edge_index, edge_attr)\n        (tensor([[3, 4, 4, 5],\n                [4, 3, 5, 4]]),\n        tensor([ 7.,  8.,  9., 10.]))\n\n        >>> subgraph(subset, edge_index, edge_attr, return_edge_mask=True)\n        (tensor([[3, 4, 4, 5],\n                [4, 3, 5, 4]]),\n        tensor([ 7.,  8.,  9., 10.]),\n        tensor([False, False, False, False, False, False,  True,\n                True,  True,  True,  False, False]))\n    \"\"\"\n    device = edge_index.device\n\n    if isinstance(subset, (list, tuple)):\n        subset = torch.tensor(subset, dtype=torch.long, device=device)\n\n    if subset.dtype != torch.bool:\n        num_nodes = maybe_num_nodes(edge_index, num_nodes)\n        node_mask = index_to_mask(subset, size=num_nodes)\n    else:\n        num_nodes = subset.size(0)\n        node_mask = subset\n        subset = node_mask.nonzero().view(-1)\n\n    edge_mask = node_mask[edge_index[0]] & node_mask[edge_index[1]]\n    edge_index = edge_index[:, edge_mask]\n    edge_attr = edge_attr[edge_mask] if edge_attr is not None else None\n\n    if relabel_nodes:\n        edge_index, _ = map_index(\n            edge_index.view(-1),\n            subset,\n            max_index=num_nodes,\n            inclusive=True,\n        )\n        edge_index = edge_index.view(2, -1)\n\n    if return_edge_mask:\n        return edge_index, edge_attr, edge_mask\n    else:\n        return edge_index, edge_attr\n\n\ndef bipartite_subgraph(\n    subset: Union[PairTensor, Tuple[List[int], List[int]]],\n    edge_index: Tensor,\n    edge_attr: OptTensor = None,\n    relabel_nodes: bool = False,\n    size: Optional[Tuple[int, int]] = None,\n    return_edge_mask: bool = False,\n) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, OptTensor]]:\n    r\"\"\"Returns the induced subgraph of the bipartite graph\n    :obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`.\n\n    Args:\n        subset (Tuple[Tensor, Tensor] or tuple([int],[int])): The nodes\n            to keep.\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional\n            edge features. (default: :obj:`None`)\n        relabel_nodes (bool, optional): If set to :obj:`True`, the resulting\n            :obj:`edge_index` will be relabeled to hold consecutive indices\n            starting from zero. (default: :obj:`False`)\n        size (tuple, optional): The number of nodes.\n            (default: :obj:`None`)\n        return_edge_mask (bool, optional): If set to :obj:`True`, will return\n            the edge mask to filter out additional edge features.\n            (default: :obj:`False`)\n\n    :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6],\n        ...                            [0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]])\n        >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])\n        >>> subset = (torch.tensor([2, 3, 5]), torch.tensor([2, 3]))\n        >>> bipartite_subgraph(subset, edge_index, edge_attr)\n        (tensor([[2, 3, 5, 5],\n                [3, 2, 2, 3]]),\n        tensor([ 3,  4,  9, 10]))\n\n        >>> bipartite_subgraph(subset, edge_index, edge_attr,\n        ...                    return_edge_mask=True)\n        (tensor([[2, 3, 5, 5],\n                [3, 2, 2, 3]]),\n        tensor([ 3,  4,  9, 10]),\n        tensor([False, False,  True,  True, False, False, False, False,\n                True,  True,  False]))\n    \"\"\"\n    device = edge_index.device\n\n    src_subset, dst_subset = subset\n    if not isinstance(src_subset, Tensor):\n        src_subset = torch.tensor(src_subset, dtype=torch.long, device=device)\n    if not isinstance(dst_subset, Tensor):\n        dst_subset = torch.tensor(dst_subset, dtype=torch.long, device=device)\n\n    if src_subset.dtype != torch.bool:\n        src_size = int(edge_index[0].max()) + 1 if size is None else size[0]\n        src_node_mask = index_to_mask(src_subset, size=src_size)\n    else:\n        src_size = src_subset.size(0)\n        src_node_mask = src_subset\n        src_subset = src_subset.nonzero().view(-1)\n\n    if dst_subset.dtype != torch.bool:\n        dst_size = int(edge_index[1].max()) + 1 if size is None else size[1]\n        dst_node_mask = index_to_mask(dst_subset, size=dst_size)\n    else:\n        dst_size = dst_subset.size(0)\n        dst_node_mask = dst_subset\n        dst_subset = dst_subset.nonzero().view(-1)\n\n    edge_mask = src_node_mask[edge_index[0]] & dst_node_mask[edge_index[1]]\n    edge_index = edge_index[:, edge_mask]\n    edge_attr = edge_attr[edge_mask] if edge_attr is not None else None\n\n    if relabel_nodes:\n        src_index, _ = map_index(edge_index[0], src_subset, max_index=src_size,\n                                 inclusive=True)\n        dst_index, _ = map_index(edge_index[1], dst_subset, max_index=dst_size,\n                                 inclusive=True)\n        edge_index = torch.stack([src_index, dst_index], dim=0)\n\n    if return_edge_mask:\n        return edge_index, edge_attr, edge_mask\n    else:\n        return edge_index, edge_attr\n\n\ndef k_hop_subgraph(\n    node_idx: Union[int, List[int], Tensor],\n    num_hops: int,\n    edge_index: Tensor,\n    relabel_nodes: bool = False,\n    num_nodes: Optional[int] = None,\n    flow: str = 'source_to_target',\n    directed: bool = False,\n) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n    r\"\"\"Computes the induced subgraph of :obj:`edge_index` around all nodes in\n    :attr:`node_idx` reachable within :math:`k` hops.\n\n    The :attr:`flow` argument denotes the direction of edges for finding\n    :math:`k`-hop neighbors. If set to :obj:`\"source_to_target\"`, then the\n    method will find all neighbors that point to the initial set of seed nodes\n    in :attr:`node_idx.`\n    This mimics the natural flow of message passing in Graph Neural Networks.\n\n    The method returns (1) the nodes involved in the subgraph, (2) the filtered\n    :obj:`edge_index` connectivity, (3) the mapping from node indices in\n    :obj:`node_idx` to their new location, and (4) the edge mask indicating\n    which edges were preserved.\n\n    Args:\n        node_idx (int, list, tuple or :obj:`torch.Tensor`): The central seed\n            node(s).\n        num_hops (int): The number of hops :math:`k`.\n        edge_index (LongTensor): The edge indices.\n        relabel_nodes (bool, optional): If set to :obj:`True`, the resulting\n            :obj:`edge_index` will be relabeled to hold consecutive indices\n            starting from zero. (default: :obj:`False`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n        flow (str, optional): The flow direction of :math:`k`-hop aggregation\n            (:obj:`\"source_to_target\"` or :obj:`\"target_to_source\"`).\n            (default: :obj:`\"source_to_target\"`)\n        directed (bool, optional): If set to :obj:`True`, will only include\n            directed edges to the seed nodes :obj:`node_idx`.\n            (default: :obj:`False`)\n\n    :rtype: (:class:`LongTensor`, :class:`LongTensor`, :class:`LongTensor`,\n             :class:`BoolTensor`)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],\n        ...                            [2, 2, 4, 4, 6, 6]])\n\n        >>> # Center node 6, 2-hops\n        >>> subset, edge_index, mapping, edge_mask = k_hop_subgraph(\n        ...     6, 2, edge_index, relabel_nodes=True)\n        >>> subset\n        tensor([2, 3, 4, 5, 6])\n        >>> edge_index\n        tensor([[0, 1, 2, 3],\n                [2, 2, 4, 4]])\n        >>> mapping\n        tensor([4])\n        >>> edge_mask\n        tensor([False, False,  True,  True,  True,  True])\n        >>> subset[mapping]\n        tensor([6])\n\n        >>> edge_index = torch.tensor([[1, 2, 4, 5],\n        ...                            [0, 1, 5, 6]])\n        >>> (subset, edge_index,\n        ...  mapping, edge_mask) = k_hop_subgraph([0, 6], 2,\n        ...                                       edge_index,\n        ...                                       relabel_nodes=True)\n        >>> subset\n        tensor([0, 1, 2, 4, 5, 6])\n        >>> edge_index\n        tensor([[1, 2, 3, 4],\n                [0, 1, 4, 5]])\n        >>> mapping\n        tensor([0, 5])\n        >>> edge_mask\n        tensor([True, True, True, True])\n        >>> subset[mapping]\n        tensor([0, 6])\n    \"\"\"\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    assert flow in ['source_to_target', 'target_to_source']\n    if flow == 'target_to_source':\n        row, col = edge_index\n    else:\n        col, row = edge_index\n\n    node_mask = row.new_empty(num_nodes, dtype=torch.bool)\n    edge_mask = row.new_empty(row.size(0), dtype=torch.bool)\n\n    if isinstance(node_idx, int):\n        node_idx = torch.tensor([node_idx], device=row.device)\n    elif isinstance(node_idx, (list, tuple)):\n        node_idx = torch.tensor(node_idx, device=row.device)\n    else:\n        node_idx = node_idx.to(row.device)\n\n    subsets = [node_idx]\n\n    preserved_edge_mask = torch.zeros_like(edge_mask)\n    for _ in range(num_hops):\n        node_mask.fill_(False)\n        node_mask[subsets[-1]] = True\n        torch.index_select(node_mask, 0, row, out=edge_mask)\n        preserved_edge_mask |= edge_mask\n        subsets.append(col[edge_mask])\n\n    subset, inv = torch.cat(subsets).unique(return_inverse=True)\n    inv = inv[:node_idx.numel()]\n\n    node_mask.fill_(False)\n    node_mask[subset] = True\n\n    if not directed:\n        edge_mask = node_mask[row] & node_mask[col]\n    else:\n        edge_mask = preserved_edge_mask\n\n    edge_index = edge_index[:, edge_mask]\n\n    if relabel_nodes:\n        mapping = row.new_full((num_nodes, ), -1)\n        mapping[subset] = torch.arange(subset.size(0), device=row.device)\n        edge_index = mapping[edge_index]\n\n    return subset, edge_index, inv, edge_mask\n\n\n@overload\ndef hyper_subgraph(\n    subset: Union[Tensor, List[int]],\n    edge_index: Tensor,\n    edge_attr: OptTensor = ...,\n    relabel_nodes: bool = ...,\n    num_nodes: Optional[int] = ...,\n) -> Tuple[Tensor, OptTensor]:\n    pass\n\n\n@overload\ndef hyper_subgraph(\n    subset: Union[Tensor, List[int]],\n    edge_index: Tensor,\n    edge_attr: OptTensor = ...,\n    relabel_nodes: bool = ...,\n    num_nodes: Optional[int] = ...,\n    *,\n    return_edge_mask: Literal[False],\n) -> Tuple[Tensor, OptTensor]:\n    pass\n\n\n@overload\ndef hyper_subgraph(\n    subset: Union[Tensor, List[int]],\n    edge_index: Tensor,\n    edge_attr: OptTensor = ...,\n    relabel_nodes: bool = ...,\n    num_nodes: Optional[int] = ...,\n    *,\n    return_edge_mask: Literal[True],\n) -> Tuple[Tensor, OptTensor, Tensor]:\n    pass\n\n\ndef hyper_subgraph(\n    subset: Union[Tensor, List[int]],\n    edge_index: Tensor,\n    edge_attr: OptTensor = None,\n    relabel_nodes: bool = False,\n    num_nodes: Optional[int] = None,\n    return_edge_mask: bool = False,\n) -> Union[Tuple[Tensor, OptTensor], Tuple[Tensor, OptTensor, Tensor]]:\n    r\"\"\"Returns the induced subgraph of the hyper graph of\n    :obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`.\n\n    Args:\n        subset (torch.Tensor or [int]): The nodes to keep.\n        edge_index (LongTensor): Hyperedge tensor\n            with shape :obj:`[2, num_edges*num_nodes_per_edge]`, where\n            :obj:`edge_index[1]` denotes the hyperedge index and\n            :obj:`edge_index[0]` denotes the node indices that are connected\n            by the hyperedge.\n        edge_attr (torch.Tensor, optional): Edge weights or multi-dimensional\n            edge features of shape :obj:`[num_edges, *]`.\n            (default: :obj:`None`)\n        relabel_nodes (bool, optional): If set to :obj:`True`, the\n            resulting :obj:`edge_index` will be relabeled to hold\n            consecutive indices starting from zero. (default: :obj:`False`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max(edge_index[0]) + 1`. (default: :obj:`None`)\n        return_edge_mask (bool, optional): If set to :obj:`True`, will return\n            the edge mask to filter out additional edge features.\n            (default: :obj:`False`)\n\n    :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 2, 1, 2, 3, 0, 2, 3],\n        ...                            [0, 0, 0, 1, 1, 1, 2, 2, 2]])\n        >>> edge_attr = torch.tensor([3, 2, 6])\n        >>> subset = torch.tensor([0, 3])\n        >>> subgraph(subset, edge_index, edge_attr)\n        (tensor([[0, 3],\n                [0, 0]]),\n        tensor([ 6.]))\n\n        >>> subgraph(subset, edge_index, edge_attr, return_edge_mask=True)\n        (tensor([[0, 3],\n                [0, 0]]),\n        tensor([ 6.]))\n        tensor([False, False, True])\n    \"\"\"\n    device = edge_index.device\n\n    if isinstance(subset, (list, tuple)):\n        subset = torch.tensor(subset, dtype=torch.long, device=device)\n\n    if subset.dtype != torch.bool:\n        num_nodes = maybe_num_nodes(edge_index, num_nodes)\n        node_mask = index_to_mask(subset, size=num_nodes)\n    else:\n        num_nodes = subset.size(0)\n        node_mask = subset\n\n    # Mask all connections that contain a node not in the subset\n    hyper_edge_connection_mask = node_mask[\n        edge_index[0]]  # num_edges*num_nodes_per_edge\n\n    # Mask hyperedges that contain one or less nodes from the subset\n    edge_mask = scatter(hyper_edge_connection_mask.to(torch.long),\n                        edge_index[1], reduce='sum') > 1\n\n    # Mask connections if hyperedge contains one or less nodes from the subset\n    # or is connected to a node not in the subset\n    hyper_edge_connection_mask = hyper_edge_connection_mask & edge_mask[\n        edge_index[1]]\n\n    edge_index = edge_index[:, hyper_edge_connection_mask]\n    edge_attr = edge_attr[edge_mask] if edge_attr is not None else None\n\n    # Relabel edges\n    edge_idx = torch.zeros(edge_mask.size(0), dtype=torch.long, device=device)\n    edge_idx[edge_mask] = torch.arange(edge_mask.sum().item(), device=device)\n    edge_index = torch.cat(\n        [edge_index[0].unsqueeze(0), edge_idx[edge_index[1]].unsqueeze(0)], 0)\n\n    if relabel_nodes:\n        node_idx = torch.zeros(node_mask.size(0), dtype=torch.long,\n                               device=device)\n        node_idx[subset] = torch.arange(node_mask.sum().item(), device=device)\n        edge_index = torch.cat(\n            [node_idx[edge_index[0]].unsqueeze(0), edge_index[1].unsqueeze(0)],\n            0)\n\n    if return_edge_mask:\n        return edge_index, edge_attr, edge_mask\n    else:\n        return edge_index, edge_attr\n"
  },
  {
    "path": "torch_geometric/utils/_to_dense_adj.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import cumsum, scatter\n\n\ndef to_dense_adj(\n    edge_index: Tensor,\n    batch: OptTensor = None,\n    edge_attr: OptTensor = None,\n    max_num_nodes: Optional[int] = None,\n    batch_size: Optional[int] = None,\n) -> Tensor:\n    r\"\"\"Converts batched sparse adjacency matrices given by edge indices and\n    edge attributes to a single dense batched adjacency matrix.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        batch (LongTensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. (default: :obj:`None`)\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional edge\n            features.\n            If :obj:`edge_index` contains duplicated edges, the dense adjacency\n            matrix output holds the summed up entries of :obj:`edge_attr` for\n            duplicated edges. (default: :obj:`None`)\n        max_num_nodes (int, optional): The size of the output node dimension.\n            (default: :obj:`None`)\n        batch_size (int, optional): The batch size. (default: :obj:`None`)\n\n    :rtype: :class:`Tensor`\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 0, 1, 2, 3],\n        ...                            [0, 1, 0, 3, 0]])\n        >>> batch = torch.tensor([0, 0, 1, 1])\n        >>> to_dense_adj(edge_index, batch)\n        tensor([[[1., 1.],\n                [1., 0.]],\n                [[0., 1.],\n                [1., 0.]]])\n\n        >>> to_dense_adj(edge_index, batch, max_num_nodes=4)\n        tensor([[[1., 1., 0., 0.],\n                [1., 0., 0., 0.],\n                [0., 0., 0., 0.],\n                [0., 0., 0., 0.]],\n                [[0., 1., 0., 0.],\n                [1., 0., 0., 0.],\n                [0., 0., 0., 0.],\n                [0., 0., 0., 0.]]])\n\n        >>> edge_attr = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])\n        >>> to_dense_adj(edge_index, batch, edge_attr)\n        tensor([[[1., 2.],\n                [3., 0.]],\n                [[0., 4.],\n                [5., 0.]]])\n    \"\"\"\n    if batch is None:\n        max_index = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0\n        batch = edge_index.new_zeros(max_index)\n\n    if batch_size is None:\n        batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1\n\n    one = batch.new_ones(batch.size(0))\n    num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='sum')\n    cum_nodes = cumsum(num_nodes)\n\n    idx0 = batch[edge_index[0]]\n    idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]\n    idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]\n\n    if max_num_nodes is None:\n        max_num_nodes = int(num_nodes.max())\n\n    elif ((idx1.numel() > 0 and idx1.max() >= max_num_nodes)\n          or (idx2.numel() > 0 and idx2.max() >= max_num_nodes)):\n        mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)\n        idx0 = idx0[mask]\n        idx1 = idx1[mask]\n        idx2 = idx2[mask]\n        edge_attr = None if edge_attr is None else edge_attr[mask]\n\n    if edge_attr is None:\n        edge_attr = torch.ones(idx0.numel(), device=edge_index.device)\n\n    size = [batch_size, max_num_nodes, max_num_nodes]\n    size += list(edge_attr.size())[1:]\n    flattened_size = batch_size * max_num_nodes * max_num_nodes\n\n    idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2\n    adj = scatter(edge_attr, idx, dim=0, dim_size=flattened_size, reduce='sum')\n    adj = adj.view(size)\n\n    return adj\n"
  },
  {
    "path": "torch_geometric/utils/_to_dense_batch.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.experimental import (\n    disable_dynamic_shapes,\n    is_experimental_mode_enabled,\n)\nfrom torch_geometric.utils import cumsum, scatter\n\n\n@disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes'])\ndef to_dense_batch(\n    x: Tensor,\n    batch: Optional[Tensor] = None,\n    fill_value: float = 0.0,\n    max_num_nodes: Optional[int] = None,\n    batch_size: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Given a sparse batch of node features\n    :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times F}` (with\n    :math:`N_i` indicating the number of nodes in graph :math:`i`), creates a\n    dense node feature tensor\n    :math:`\\mathbf{X} \\in \\mathbb{R}^{B \\times N_{\\max} \\times F}` (with\n    :math:`N_{\\max} = \\max_i^B N_i`).\n    In addition, a mask of shape :math:`\\mathbf{M} \\in \\{ 0, 1 \\}^{B \\times\n    N_{\\max}}` is returned, holding information about the existence of\n    fake-nodes in the dense representation.\n\n    Args:\n        x (Tensor): Node feature matrix\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times F}`.\n        batch (LongTensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. Must be ordered. (default: :obj:`None`)\n        fill_value (float, optional): The value for invalid entries in the\n            resulting dense output tensor. (default: :obj:`0`)\n        max_num_nodes (int, optional): The size of the output node dimension.\n            (default: :obj:`None`)\n        batch_size (int, optional): The batch size. (default: :obj:`None`)\n\n    :rtype: (:class:`Tensor`, :class:`BoolTensor`)\n\n    Examples:\n        >>> x = torch.arange(12).view(6, 2)\n        >>> x\n        tensor([[ 0,  1],\n                [ 2,  3],\n                [ 4,  5],\n                [ 6,  7],\n                [ 8,  9],\n                [10, 11]])\n\n        >>> out, mask = to_dense_batch(x)\n        >>> mask\n        tensor([[True, True, True, True, True, True]])\n\n        >>> batch = torch.tensor([0, 0, 1, 2, 2, 2])\n        >>> out, mask = to_dense_batch(x, batch)\n        >>> out\n        tensor([[[ 0,  1],\n                [ 2,  3],\n                [ 0,  0]],\n                [[ 4,  5],\n                [ 0,  0],\n                [ 0,  0]],\n                [[ 6,  7],\n                [ 8,  9],\n                [10, 11]]])\n        >>> mask\n        tensor([[ True,  True, False],\n                [ True, False, False],\n                [ True,  True,  True]])\n\n        >>> out, mask = to_dense_batch(x, batch, max_num_nodes=4)\n        >>> out\n        tensor([[[ 0,  1],\n                [ 2,  3],\n                [ 0,  0],\n                [ 0,  0]],\n                [[ 4,  5],\n                [ 0,  0],\n                [ 0,  0],\n                [ 0,  0]],\n                [[ 6,  7],\n                [ 8,  9],\n                [10, 11],\n                [ 0,  0]]])\n\n        >>> mask\n        tensor([[ True,  True, False, False],\n                [ True, False, False, False],\n                [ True,  True,  True, False]])\n    \"\"\"\n    if batch is None and max_num_nodes is None:\n        mask = torch.ones(1, x.size(0), dtype=torch.bool, device=x.device)\n        return x.unsqueeze(0), mask\n\n    if batch is None:\n        batch = x.new_zeros(x.size(0), dtype=torch.long)\n\n    if batch_size is None:\n        batch_size = int(batch.max()) + 1\n\n    num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0,\n                        dim_size=batch_size, reduce='sum')\n    cum_nodes = cumsum(num_nodes)\n\n    filter_nodes = False\n    dynamic_shapes_disabled = is_experimental_mode_enabled(\n        'disable_dynamic_shapes')\n\n    if max_num_nodes is None:\n        max_num_nodes = int(num_nodes.max())\n    elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes:\n        filter_nodes = True\n\n    tmp = torch.arange(batch.size(0), device=x.device) - cum_nodes[batch]\n    idx = tmp + (batch * max_num_nodes)\n    if filter_nodes:\n        mask = tmp < max_num_nodes\n        x, idx = x[mask], idx[mask]\n\n    size = [batch_size * max_num_nodes] + list(x.size())[1:]\n    out = torch.as_tensor(fill_value, device=x.device, dtype=x.dtype)\n    out = out.repeat(size)\n    out[idx] = x\n    out = out.view([batch_size, max_num_nodes] + list(x.size())[1:])\n\n    mask = torch.zeros(batch_size * max_num_nodes, dtype=torch.bool,\n                       device=x.device)\n    mask[idx] = 1\n    mask = mask.view(batch_size, max_num_nodes)\n\n    return out, mask\n"
  },
  {
    "path": "torch_geometric/utils/_train_test_split_edges.py",
    "content": "import math\n\nimport torch\n\nimport torch_geometric\nfrom torch_geometric.deprecation import deprecated\nfrom torch_geometric.utils import to_undirected\n\n\n@deprecated(\"use 'transforms.RandomLinkSplit' instead\")\ndef train_test_split_edges(\n    data: 'torch_geometric.data.Data',\n    val_ratio: float = 0.05,\n    test_ratio: float = 0.1,\n) -> 'torch_geometric.data.Data':\n    r\"\"\"Splits the edges of a :class:`torch_geometric.data.Data` object\n    into positive and negative train/val/test edges.\n    As such, it will replace the :obj:`edge_index` attribute with\n    :obj:`train_pos_edge_index`, :obj:`train_pos_neg_adj_mask`,\n    :obj:`val_pos_edge_index`, :obj:`val_neg_edge_index` and\n    :obj:`test_pos_edge_index` attributes.\n    If :obj:`data` has edge features named :obj:`edge_attr`, then\n    :obj:`train_pos_edge_attr`, :obj:`val_pos_edge_attr` and\n    :obj:`test_pos_edge_attr` will be added as well.\n\n    .. warning::\n\n        :meth:`~torch_geometric.utils.train_test_split_edges` is deprecated and\n        will be removed in a future release.\n        Use :class:`torch_geometric.transforms.RandomLinkSplit` instead.\n\n    Args:\n        data (Data): The data object.\n        val_ratio (float, optional): The ratio of positive validation edges.\n            (default: :obj:`0.05`)\n        test_ratio (float, optional): The ratio of positive test edges.\n            (default: :obj:`0.1`)\n\n    :rtype: :class:`torch_geometric.data.Data`\n    \"\"\"\n    assert 'batch' not in data  # No batch-mode.\n\n    assert data.num_nodes is not None\n    assert data.edge_index is not None\n\n    num_nodes = data.num_nodes\n    row, col = data.edge_index\n    edge_attr = data.edge_attr\n    del data.edge_index\n    del data.edge_attr\n\n    # Return upper triangular portion.\n    mask = row < col\n    row, col = row[mask], col[mask]\n\n    if edge_attr is not None:\n        edge_attr = edge_attr[mask]\n\n    n_v = int(math.floor(val_ratio * row.size(0)))\n    n_t = int(math.floor(test_ratio * row.size(0)))\n\n    # Positive edges.\n    perm = torch.randperm(row.size(0))\n    row, col = row[perm], col[perm]\n    if edge_attr is not None:\n        edge_attr = edge_attr[perm]\n\n    r, c = row[:n_v], col[:n_v]\n    data.val_pos_edge_index = torch.stack([r, c], dim=0)\n    if edge_attr is not None:\n        data.val_pos_edge_attr = edge_attr[:n_v]\n\n    r, c = row[n_v:n_v + n_t], col[n_v:n_v + n_t]\n    data.test_pos_edge_index = torch.stack([r, c], dim=0)\n    if edge_attr is not None:\n        data.test_pos_edge_attr = edge_attr[n_v:n_v + n_t]\n\n    r, c = row[n_v + n_t:], col[n_v + n_t:]\n    data.train_pos_edge_index = torch.stack([r, c], dim=0)\n    if edge_attr is not None:\n        out = to_undirected(data.train_pos_edge_index, edge_attr[n_v + n_t:])\n        data.train_pos_edge_index, data.train_pos_edge_attr = out\n    else:\n        data.train_pos_edge_index = to_undirected(data.train_pos_edge_index)\n\n    # Negative edges.\n    neg_adj_mask = torch.ones(num_nodes, num_nodes, dtype=torch.uint8)\n    neg_adj_mask = neg_adj_mask.triu(diagonal=1).to(torch.bool)\n    neg_adj_mask[row, col] = 0\n\n    neg_row, neg_col = neg_adj_mask.nonzero(as_tuple=False).t()\n    perm = torch.randperm(neg_row.size(0))[:n_v + n_t]\n    neg_row, neg_col = neg_row[perm], neg_col[perm]\n\n    neg_adj_mask[neg_row, neg_col] = 0\n    data.train_neg_adj_mask = neg_adj_mask\n\n    row, col = neg_row[:n_v], neg_col[:n_v]\n    data.val_neg_edge_index = torch.stack([row, col], dim=0)\n\n    row, col = neg_row[n_v:n_v + n_t], neg_col[n_v:n_v + n_t]\n    data.test_neg_edge_index = torch.stack([row, col], dim=0)\n\n    return data\n"
  },
  {
    "path": "torch_geometric/utils/_tree_decomposition.py",
    "content": "from itertools import chain\nfrom typing import Any, List, Literal, Tuple, Union, overload\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import (\n    from_scipy_sparse_matrix,\n    to_scipy_sparse_matrix,\n    to_undirected,\n)\n\n\n@overload\ndef tree_decomposition(mol: Any) -> Tuple[Tensor, Tensor, int]:\n    pass\n\n\n@overload\ndef tree_decomposition(\n    mol: Any,\n    return_vocab: Literal[False],\n) -> Tuple[Tensor, Tensor, int]:\n    pass\n\n\n@overload\ndef tree_decomposition(\n    mol: Any,\n    return_vocab: Literal[True],\n) -> Tuple[Tensor, Tensor, int, Tensor]:\n    pass\n\n\ndef tree_decomposition(\n    mol: Any,\n    return_vocab: bool = False,\n) -> Union[Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, int, Tensor]]:\n    r\"\"\"The tree decomposition algorithm of molecules from the\n    `\"Junction Tree Variational Autoencoder for Molecular Graph Generation\"\n    <https://arxiv.org/abs/1802.04364>`_ paper.\n    Returns the graph connectivity of the junction tree, the assignment\n    mapping of each atom to the clique in the junction tree, and the number\n    of cliques.\n\n    Args:\n        mol (rdkit.Chem.Mol): An :obj:`rdkit` molecule.\n        return_vocab (bool, optional): If set to :obj:`True`, will return an\n            identifier for each clique (ring, bond, bridged compounds, single).\n            (default: :obj:`False`)\n\n    :rtype: :obj:`(LongTensor, LongTensor, int)` if :obj:`return_vocab` is\n        :obj:`False`, else :obj:`(LongTensor, LongTensor, int, LongTensor)`\n    \"\"\"\n    import rdkit.Chem as Chem\n    from scipy.sparse.csgraph import minimum_spanning_tree\n\n    # Cliques = rings and bonds.\n    cliques: List[List[int]] = [list(x) for x in Chem.GetSymmSSSR(mol)]\n    xs: List[int] = [0] * len(cliques)\n    for bond in mol.GetBonds():\n        if not bond.IsInRing():\n            cliques.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])\n            xs.append(1)\n\n    # Generate `atom2cliques` mappings.\n    atom2cliques: List[List[int]] = [[] for _ in range(mol.GetNumAtoms())]\n    for c in range(len(cliques)):\n        for atom in cliques[c]:\n            atom2cliques[atom].append(c)\n\n    # Merge rings that share more than 2 atoms as they form bridged compounds.\n    for c1 in range(len(cliques)):\n        for atom in cliques[c1]:\n            for c2 in atom2cliques[atom]:\n                if c1 >= c2 or len(cliques[c1]) <= 2 or len(cliques[c2]) <= 2:\n                    continue\n                if len(set(cliques[c1]) & set(cliques[c2])) > 2:\n                    cliques[c1] = list(set(cliques[c1]) | set(cliques[c2]))\n                    xs[c1] = 2\n                    cliques[c2] = []\n                    xs[c2] = -1\n    cliques = [c for c in cliques if len(c) > 0]\n    xs = [x for x in xs if x >= 0]\n\n    # Update `atom2cliques` mappings.\n    atom2cliques = [[] for i in range(mol.GetNumAtoms())]\n    for c in range(len(cliques)):\n        for atom in cliques[c]:\n            atom2cliques[atom].append(c)\n\n    # Add singleton cliques in case there are more than 2 intersecting\n    # cliques. We further compute the \"initial\" clique graph.\n    edges = {}\n    for atom in range(mol.GetNumAtoms()):\n        cs = atom2cliques[atom]\n        if len(cs) <= 1:\n            continue\n\n        # Number of bond clusters that the atom lies in.\n        bonds = [c for c in cs if len(cliques[c]) == 2]\n        # Number of ring clusters that the atom lies in.\n        rings = [c for c in cs if len(cliques[c]) > 4]\n\n        if len(bonds) > 2 or (len(bonds) == 2 and len(cs) > 2):\n            cliques.append([atom])\n            xs.append(3)\n            c2 = len(cliques) - 1\n            for c1 in cs:\n                edges[(c1, c2)] = 1\n\n        elif len(rings) > 2:\n            cliques.append([atom])\n            xs.append(3)\n            c2 = len(cliques) - 1\n            for c1 in cs:\n                edges[(c1, c2)] = 99\n\n        else:\n            for i in range(len(cs)):\n                for j in range(i + 1, len(cs)):\n                    c1, c2 = cs[i], cs[j]\n                    count = len(set(cliques[c1]) & set(cliques[c2]))\n                    edges[(c1, c2)] = min(count, edges.get((c1, c2), 99))\n\n    # Update `atom2cliques` mappings.\n    atom2cliques = [[] for i in range(mol.GetNumAtoms())]\n    for c in range(len(cliques)):\n        for atom in cliques[c]:\n            atom2cliques[atom].append(c)\n\n    if len(edges) > 0:\n        edge_index_T, weight = zip(*edges.items())\n        edge_index = torch.tensor(edge_index_T).t()\n        inv_weight = 100 - torch.tensor(weight)\n        graph = to_scipy_sparse_matrix(edge_index, inv_weight, len(cliques))\n        junc_tree = minimum_spanning_tree(graph)\n        edge_index, _ = from_scipy_sparse_matrix(junc_tree)\n        edge_index = to_undirected(edge_index, num_nodes=len(cliques))\n    else:\n        edge_index = torch.empty((2, 0), dtype=torch.long)\n\n    rows = [[i] * len(atom2cliques[i]) for i in range(mol.GetNumAtoms())]\n    row = torch.tensor(list(chain.from_iterable(rows)))\n    col = torch.tensor(list(chain.from_iterable(atom2cliques)))\n    atom2clique = torch.stack([row, col], dim=0).to(torch.long)\n\n    if return_vocab:\n        vocab = torch.tensor(xs, dtype=torch.long)\n        return edge_index, atom2clique, len(cliques), vocab\n    else:\n        return edge_index, atom2clique, len(cliques)\n"
  },
  {
    "path": "torch_geometric/utils/_trim_to_layer.py",
    "content": "from typing import Dict, List, Optional, Tuple, Union, overload\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.typing import (\n    Adj,\n    EdgeType,\n    MaybeHeteroAdjTensor,\n    MaybeHeteroEdgeTensor,\n    MaybeHeteroNodeTensor,\n    NodeType,\n    SparseStorage,\n    SparseTensor,\n)\n\n\n@overload\ndef trim_to_layer(\n    layer: int,\n    num_sampled_nodes_per_hop: List[int],\n    num_sampled_edges_per_hop: List[int],\n    x: Tensor,\n    edge_index: Adj,\n    edge_attr: Optional[Tensor] = None,\n) -> Tuple[Tensor, Tensor, Optional[Tensor]]:\n    pass\n\n\n@overload\ndef trim_to_layer(\n    layer: int,\n    num_sampled_nodes_per_hop: Dict[NodeType, List[int]],\n    num_sampled_edges_per_hop: Dict[EdgeType, List[int]],\n    x: Dict[NodeType, Tensor],\n    edge_index: Dict[EdgeType, Adj],\n    edge_attr: Optional[Dict[EdgeType, Tensor]] = None,\n) -> Tuple[Dict[NodeType, Tensor], Dict[EdgeType, Adj], Optional[Dict[\n        EdgeType, Tensor]]]:\n    pass\n\n\ndef trim_to_layer(\n    layer: int,\n    num_sampled_nodes_per_hop: Union[List[int], Dict[NodeType, List[int]]],\n    num_sampled_edges_per_hop: Union[List[int], Dict[EdgeType, List[int]]],\n    x: MaybeHeteroNodeTensor,\n    edge_index: MaybeHeteroEdgeTensor,\n    edge_attr: Optional[MaybeHeteroEdgeTensor] = None,\n) -> Tuple[MaybeHeteroNodeTensor, MaybeHeteroAdjTensor,\n           Optional[MaybeHeteroEdgeTensor]]:\n    r\"\"\"Trims the :obj:`edge_index` representation, node features :obj:`x` and\n    edge features :obj:`edge_attr` to a minimal-sized representation for the\n    current GNN layer :obj:`layer` in directed\n    :class:`~torch_geometric.loader.NeighborLoader` scenarios.\n\n    This ensures that no computation is performed for nodes and edges that are\n    not included in the current GNN layer, thus avoiding unnecessary\n    computation within the GNN when performing neighborhood sampling.\n\n    Args:\n        layer (int): The current GNN layer.\n        num_sampled_nodes_per_hop (List[int] or Dict[NodeType, List[int]]): The\n            number of sampled nodes per hop.\n        num_sampled_edges_per_hop (List[int] or Dict[EdgeType, List[int]]): The\n            number of sampled edges per hop.\n        x (torch.Tensor or Dict[NodeType, torch.Tensor]): The homogeneous or\n            heterogeneous (hidden) node features.\n        edge_index (torch.Tensor or Dict[EdgeType, torch.Tensor]): The\n            homogeneous or heterogeneous edge indices.\n        edge_attr (torch.Tensor or Dict[EdgeType, torch.Tensor], optional): The\n            homogeneous or heterogeneous (hidden) edge features.\n    \"\"\"\n    if layer <= 0:\n        return x, edge_index, edge_attr\n\n    if isinstance(num_sampled_edges_per_hop, dict):\n        assert isinstance(num_sampled_nodes_per_hop, dict)\n\n        assert isinstance(x, dict)\n        x = {\n            k: trim_feat(v, layer, num_sampled_nodes_per_hop[k])\n            for k, v in x.items()\n        }\n\n        assert isinstance(edge_index, dict)\n        edge_index = {\n            k:\n            trim_adj(\n                v,\n                layer,\n                num_sampled_nodes_per_hop[k[0]],\n                num_sampled_nodes_per_hop[k[-1]],\n                num_sampled_edges_per_hop[k],\n            )\n            for k, v in edge_index.items()\n        }\n\n        if edge_attr is not None:\n            assert isinstance(edge_attr, dict)\n            edge_attr = {\n                k: trim_feat(v, layer, num_sampled_edges_per_hop[k])\n                for k, v in edge_attr.items()\n            }\n\n        return x, edge_index, edge_attr\n\n    assert isinstance(num_sampled_nodes_per_hop, list)\n\n    assert isinstance(x, Tensor)\n    x = trim_feat(x, layer, num_sampled_nodes_per_hop)\n\n    assert isinstance(edge_index, (Tensor, SparseTensor))\n    edge_index = trim_adj(\n        edge_index,\n        layer,\n        num_sampled_nodes_per_hop,\n        num_sampled_nodes_per_hop,\n        num_sampled_edges_per_hop,\n    )\n\n    if edge_attr is not None:\n        assert isinstance(edge_attr, Tensor)\n        edge_attr = trim_feat(edge_attr, layer, num_sampled_edges_per_hop)\n\n    return x, edge_index, edge_attr\n\n\nclass TrimToLayer(torch.nn.Module):\n    @torch.jit.unused\n    def forward(\n        self,\n        layer: int,\n        num_sampled_nodes_per_hop: Optional[List[int]],\n        num_sampled_edges_per_hop: Optional[List[int]],\n        x: Tensor,\n        edge_index: Adj,\n        edge_attr: Optional[Tensor] = None,\n    ) -> Tuple[Tensor, Adj, Optional[Tensor]]:\n\n        if (not isinstance(num_sampled_nodes_per_hop, list)\n                and isinstance(num_sampled_edges_per_hop, list)):\n            raise ValueError(\"'num_sampled_nodes_per_hop' needs to be given\")\n        if (not isinstance(num_sampled_edges_per_hop, list)\n                and isinstance(num_sampled_nodes_per_hop, list)):\n            raise ValueError(\"'num_sampled_edges_per_hop' needs to be given\")\n\n        if num_sampled_nodes_per_hop is None:\n            return x, edge_index, edge_attr\n        if num_sampled_edges_per_hop is None:\n            return x, edge_index, edge_attr\n\n        return trim_to_layer(\n            layer,\n            num_sampled_nodes_per_hop,\n            num_sampled_edges_per_hop,\n            x,\n            edge_index,\n            edge_attr,\n        )\n\n\n# Helper functions ############################################################\n\n\ndef trim_feat(x: Tensor, layer: int, num_samples_per_hop: List[int]) -> Tensor:\n    if layer <= 0:\n        return x\n\n    return x.narrow(\n        dim=0,\n        start=0,\n        length=x.size(0) - num_samples_per_hop[-layer],\n    )\n\n\ndef trim_adj(\n    edge_index: Adj,\n    layer: int,\n    num_sampled_src_nodes_per_hop: List[int],\n    num_sampled_dst_nodes_per_hop: List[int],\n    num_sampled_edges_per_hop: List[int],\n) -> Adj:\n\n    if layer <= 0:\n        return edge_index\n\n    if isinstance(edge_index, Tensor):\n        edge_index = edge_index.narrow(\n            dim=1,\n            start=0,\n            length=edge_index.size(1) - num_sampled_edges_per_hop[-layer],\n        )\n        if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n            num_rows, num_cols = edge_index.sparse_size()\n            if num_rows is not None:\n                num_rows -= num_sampled_src_nodes_per_hop[-layer]\n            if num_cols is not None:\n                num_cols -= num_sampled_dst_nodes_per_hop[-layer]\n            edge_index.sparse_resize_(num_rows, num_cols)\n        return edge_index\n\n    elif isinstance(edge_index, SparseTensor):\n        size = (\n            edge_index.size(0) - num_sampled_dst_nodes_per_hop[-layer],\n            edge_index.size(1) - num_sampled_src_nodes_per_hop[-layer],\n        )\n\n        num_seed_nodes = size[0] - num_sampled_dst_nodes_per_hop[-(layer + 1)]\n\n        return trim_sparse_tensor(edge_index, size, num_seed_nodes)\n\n    raise ValueError(f\"Unsupported 'edge_index' type '{type(edge_index)}'\")\n\n\ndef trim_sparse_tensor(src: SparseTensor, size: Tuple[int, int],\n                       num_seed_nodes: int) -> SparseTensor:\n    r\"\"\"Trims a :class:`SparseTensor` along both dimensions to only contain\n    the upper :obj:`num_nodes` in both dimensions.\n\n    It is assumed that :class:`SparseTensor` is obtained from BFS traversing,\n    starting from the nodes that have been initially selected.\n\n    Args:\n        src (SparseTensor): The sparse tensor.\n        size (Tuple[int, int]): The number of source and destination nodes to\n            keep.\n        num_seed_nodes (int): The number of seed nodes to compute\n            representations.\n    \"\"\"\n    rowptr, col, value = src.csr()\n\n    rowptr = torch.narrow(rowptr, 0, 0, size[0] + 1).clone()\n    rowptr[num_seed_nodes + 1:] = rowptr[num_seed_nodes]\n\n    col = torch.narrow(col, 0, 0, rowptr[-1])  # type: ignore\n\n    if value is not None:\n        value = torch.narrow(value, 0, 0, rowptr[-1])  # type: ignore\n\n    csr2csc = src.storage._csr2csc\n    if csr2csc is not None:\n        csr2csc = csr2csc[csr2csc < len(col)]\n\n    storage = SparseStorage(\n        row=None,\n        rowptr=rowptr,\n        col=col,\n        value=value,\n        sparse_sizes=size,\n        rowcount=None,\n        colptr=None,\n        colcount=None,\n        csr2csc=csr2csc,\n        csc2csr=None,\n        is_sorted=True,\n        trust_data=True,\n    )\n    return src.from_storage(storage)\n"
  },
  {
    "path": "torch_geometric/utils/_unbatch.py",
    "content": "from typing import List, Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import cumsum, degree\n\n\ndef unbatch(\n    src: Tensor,\n    batch: Tensor,\n    dim: int = 0,\n    batch_size: Optional[int] = None,\n) -> List[Tensor]:\n    r\"\"\"Splits :obj:`src` according to a :obj:`batch` vector along dimension\n    :obj:`dim`.\n\n    Args:\n        src (Tensor): The source tensor.\n        batch (LongTensor): The batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            entry in :obj:`src` to a specific example. Must be ordered.\n        dim (int, optional): The dimension along which to split the :obj:`src`\n            tensor. (default: :obj:`0`)\n        batch_size (int, optional): The batch size. (default: :obj:`None`)\n\n    :rtype: :class:`List[Tensor]`\n\n    Example:\n        >>> src = torch.arange(7)\n        >>> batch = torch.tensor([0, 0, 0, 1, 1, 2, 2])\n        >>> unbatch(src, batch)\n        (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6]))\n    \"\"\"\n    sizes = degree(batch, batch_size, dtype=torch.long).tolist()\n    return src.split(sizes, dim)\n\n\ndef unbatch_edge_index(\n    edge_index: Tensor,\n    batch: Tensor,\n    batch_size: Optional[int] = None,\n) -> List[Tensor]:\n    r\"\"\"Splits the :obj:`edge_index` according to a :obj:`batch` vector.\n\n    Args:\n        edge_index (Tensor): The edge_index tensor. Must be ordered.\n        batch (LongTensor): The batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. Must be ordered.\n        batch_size (int, optional): The batch size. (default: :obj:`None`)\n\n    :rtype: :class:`List[Tensor]`\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6],\n        ...                            [1, 0, 2, 1, 3, 2, 5, 4, 6, 5]])\n        >>> batch = torch.tensor([0, 0, 0, 0, 1, 1, 1])\n        >>> unbatch_edge_index(edge_index, batch)\n        (tensor([[0, 1, 1, 2, 2, 3],\n                [1, 0, 2, 1, 3, 2]]),\n        tensor([[0, 1, 1, 2],\n                [1, 0, 2, 1]]))\n    \"\"\"\n    deg = degree(batch, batch_size, dtype=torch.long)\n    ptr = cumsum(deg)\n\n    edge_batch = batch[edge_index[0]]\n    edge_index = edge_index - ptr[edge_batch]\n    sizes = degree(edge_batch, batch_size, dtype=torch.long).cpu().tolist()\n    return edge_index.split(sizes, dim=1)\n"
  },
  {
    "path": "torch_geometric/utils/augmentation.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import cumsum, negative_sampling, scatter\n\n\ndef shuffle_node(\n    x: Tensor,\n    batch: Optional[Tensor] = None,\n    training: bool = True,\n) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Randomly shuffle the feature matrix :obj:`x` along the\n    first dimension.\n\n    The method returns (1) the shuffled :obj:`x`, (2) the permutation\n    indicating the orders of original nodes after shuffling.\n\n    Args:\n        x (FloatTensor): The feature matrix.\n        batch (LongTensor, optional): Batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            node to a specific example. Must be ordered. (default: :obj:`None`)\n        training (bool, optional): If set to :obj:`False`, this operation is a\n            no-op. (default: :obj:`True`)\n\n    :rtype: (:class:`FloatTensor`, :class:`LongTensor`)\n\n    Example:\n        >>> # Standard case\n        >>> x = torch.tensor([[0, 1, 2],\n        ...                   [3, 4, 5],\n        ...                   [6, 7, 8],\n        ...                   [9, 10, 11]], dtype=torch.float)\n        >>> x, node_perm = shuffle_node(x)\n        >>> x\n        tensor([[ 3.,  4.,  5.],\n                [ 9., 10., 11.],\n                [ 0.,  1.,  2.],\n                [ 6.,  7.,  8.]])\n        >>> node_perm\n        tensor([1, 3, 0, 2])\n\n        >>> # For batched graphs as inputs\n        >>> batch = torch.tensor([0, 0, 1, 1])\n        >>> x, node_perm = shuffle_node(x, batch)\n        >>> x\n        tensor([[ 3.,  4.,  5.],\n                [ 0.,  1.,  2.],\n                [ 9., 10., 11.],\n                [ 6.,  7.,  8.]])\n        >>> node_perm\n        tensor([1, 0, 3, 2])\n    \"\"\"\n    if not training:\n        perm = torch.arange(x.size(0), device=x.device)\n        return x, perm\n    if batch is None:\n        perm = torch.randperm(x.size(0), device=x.device)\n        return x[perm], perm\n    num_nodes = scatter(batch.new_ones(x.size(0)), batch, dim=0, reduce='sum')\n    ptr = cumsum(num_nodes)\n    perm = torch.cat([\n        torch.randperm(n, device=x.device) + offset\n        for offset, n in zip(ptr[:-1], num_nodes)\n    ])\n    return x[perm], perm\n\n\ndef mask_feature(\n    x: Tensor,\n    p: float = 0.5,\n    mode: str = 'col',\n    fill_value: float = 0.,\n    training: bool = True,\n) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Randomly masks feature from the feature matrix\n    :obj:`x` with probability :obj:`p` using samples from\n    a Bernoulli distribution.\n\n    The method returns (1) the retained :obj:`x`, (2) the feature\n    mask broadcastable with :obj:`x` (:obj:`mode='row'` and :obj:`mode='col'`)\n    or with the same shape as :obj:`x` (:obj:`mode='all'`),\n    indicating where features are retained.\n\n    Args:\n        x (FloatTensor): The feature matrix.\n        p (float, optional): The masking ratio. (default: :obj:`0.5`)\n        mode (str, optional): The masked scheme to use for feature masking.\n            (:obj:`\"row\"`, :obj:`\"col\"` or :obj:`\"all\"`).\n            If :obj:`mode='col'`, will mask entire features of all nodes\n            from the feature matrix. If :obj:`mode='row'`, will mask entire\n            nodes from the feature matrix. If :obj:`mode='all'`, will mask\n            individual features across all nodes. (default: :obj:`'col'`)\n        fill_value (float, optional): The value for masked features in the\n            output tensor. (default: :obj:`0`)\n        training (bool, optional): If set to :obj:`False`, this operation is a\n            no-op. (default: :obj:`True`)\n\n    :rtype: (:class:`FloatTensor`, :class:`BoolTensor`)\n\n    Examples:\n        >>> # Masked features are column-wise sampled\n        >>> x = torch.tensor([[1, 2, 3],\n        ...                   [4, 5, 6],\n        ...                   [7, 8, 9]], dtype=torch.float)\n        >>> x, feat_mask = mask_feature(x)\n        >>> x\n        tensor([[1., 0., 3.],\n                [4., 0., 6.],\n                [7., 0., 9.]]),\n        >>> feat_mask\n        tensor([[True, False, True]])\n\n        >>> # Masked features are row-wise sampled\n        >>> x, feat_mask = mask_feature(x, mode='row')\n        >>> x\n        tensor([[1., 2., 3.],\n                [0., 0., 0.],\n                [7., 8., 9.]]),\n        >>> feat_mask\n        tensor([[True], [False], [True]])\n\n        >>> # Masked features are uniformly sampled\n        >>> x, feat_mask = mask_feature(x, mode='all')\n        >>> x\n        tensor([[0., 0., 0.],\n                [4., 0., 6.],\n                [0., 0., 9.]])\n        >>> feat_mask\n        tensor([[False, False, False],\n                [True, False,  True],\n                [False, False,  True]])\n    \"\"\"\n    if p < 0. or p > 1.:\n        raise ValueError(f'Masking ratio has to be between 0 and 1 '\n                         f'(got {p}')\n    if not training or p == 0.0:\n        return x, torch.ones_like(x, dtype=torch.bool)\n    assert mode in ['row', 'col', 'all']\n\n    if mode == 'row':\n        mask = torch.rand(x.size(0), device=x.device) >= p\n        mask = mask.view(-1, 1)\n    elif mode == 'col':\n        mask = torch.rand(x.size(1), device=x.device) >= p\n        mask = mask.view(1, -1)\n    else:\n        mask = torch.rand_like(x) >= p\n\n    x = x.masked_fill(~mask, fill_value)\n    return x, mask\n\n\ndef add_random_edge(\n    edge_index: Tensor,\n    p: float = 0.5,\n    force_undirected: bool = False,\n    num_nodes: Optional[Union[int, Tuple[int, int]]] = None,\n    training: bool = True,\n) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Randomly adds edges to :obj:`edge_index`.\n\n    The method returns (1) the retained :obj:`edge_index`, (2) the added\n    edge indices.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        p (float): Ratio of added edges to the existing edges.\n            (default: :obj:`0.5`)\n        force_undirected (bool, optional): If set to :obj:`True`,\n            added edges will be undirected.\n            (default: :obj:`False`)\n        num_nodes (int, Tuple[int], optional): The overall number of nodes,\n            *i.e.* :obj:`max_val + 1`, or the number of source and\n            destination nodes, *i.e.* :obj:`(max_src_val + 1, max_dst_val + 1)`\n            of :attr:`edge_index`. (default: :obj:`None`)\n        training (bool, optional): If set to :obj:`False`, this operation is a\n            no-op. (default: :obj:`True`)\n\n    :rtype: (:class:`LongTensor`, :class:`LongTensor`)\n\n    Examples:\n        >>> # Standard case\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n        ...                            [1, 0, 2, 1, 3, 2]])\n        >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5)\n        >>> edge_index\n        tensor([[0, 1, 1, 2, 2, 3, 2, 1, 3],\n                [1, 0, 2, 1, 3, 2, 0, 2, 1]])\n        >>> added_edges\n        tensor([[2, 1, 3],\n                [0, 2, 1]])\n\n        >>> # The returned graph is kept undirected\n        >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5,\n        ...                                           force_undirected=True)\n        >>> edge_index\n        tensor([[0, 1, 1, 2, 2, 3, 2, 1, 3, 0, 2, 1],\n                [1, 0, 2, 1, 3, 2, 0, 2, 1, 2, 1, 3]])\n        >>> added_edges\n        tensor([[2, 1, 3, 0, 2, 1],\n                [0, 2, 1, 2, 1, 3]])\n\n        >>> # For bipartite graphs\n        >>> edge_index = torch.tensor([[0, 1, 2, 3, 4, 5],\n        ...                            [2, 3, 1, 4, 2, 1]])\n        >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5,\n        ...                                           num_nodes=(6, 5))\n        >>> edge_index\n        tensor([[0, 1, 2, 3, 4, 5, 3, 4, 1],\n                [2, 3, 1, 4, 2, 1, 1, 3, 2]])\n        >>> added_edges\n        tensor([[3, 4, 1],\n                [1, 3, 2]])\n    \"\"\"\n    if p < 0. or p > 1.:\n        raise ValueError(f\"Ratio of added edges has to be between 0 and 1 \"\n                         f\"(got '{p}')\")\n    if force_undirected and isinstance(num_nodes, (tuple, list)):\n        raise RuntimeError(\"'force_undirected' is not supported for \"\n                           \"bipartite graphs\")\n\n    device = edge_index.device\n    if not training or p == 0.0:\n        edge_index_to_add = torch.tensor([[], []], device=device)\n        return edge_index, edge_index_to_add\n\n    edge_index_to_add = negative_sampling(\n        edge_index=edge_index,\n        num_nodes=num_nodes,\n        num_neg_samples=round(edge_index.size(1) * p),\n        force_undirected=force_undirected,\n    )\n\n    edge_index = torch.cat([edge_index, edge_index_to_add], dim=1)\n\n    return edge_index, edge_index_to_add\n"
  },
  {
    "path": "torch_geometric/utils/convert.py",
    "content": "from collections import defaultdict\nfrom typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.utils.dlpack import from_dlpack, to_dlpack\n\nimport torch_geometric\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef to_scipy_sparse_matrix(\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Any:\n    r\"\"\"Converts a graph given by edge indices and edge attributes to a scipy\n    sparse matrix.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional\n            edge features. (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`)\n\n    Examples:\n        >>> edge_index = torch.tensor([\n        ...     [0, 1, 1, 2, 2, 3],\n        ...     [1, 0, 2, 1, 3, 2],\n        ... ])\n        >>> to_scipy_sparse_matrix(edge_index)\n        <4x4 sparse matrix of type '<class 'numpy.float32'>'\n            with 6 stored elements in COOrdinate format>\n    \"\"\"\n    import scipy.sparse as sp\n\n    row, col = edge_index.cpu()\n\n    if edge_attr is None:\n        edge_attr = torch.ones(row.size(0), device=\"cpu\")\n    else:\n        edge_attr = edge_attr.view(-1).cpu()\n        assert edge_attr.size(0) == row.size(0)\n\n    N = maybe_num_nodes(edge_index, num_nodes)\n    out = sp.coo_matrix(  #\n        (edge_attr.numpy(), (row.numpy(), col.numpy())), (N, N))\n    return out\n\n\ndef from_scipy_sparse_matrix(A: Any) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Converts a scipy sparse matrix to edge indices and edge attributes.\n\n    Args:\n        A (scipy.sparse): A sparse matrix.\n\n    Examples:\n        >>> edge_index = torch.tensor([\n        ...     [0, 1, 1, 2, 2, 3],\n        ...     [1, 0, 2, 1, 3, 2],\n        ... ])\n        >>> adj = to_scipy_sparse_matrix(edge_index)\n        >>> # `edge_index` and `edge_weight` are both returned\n        >>> from_scipy_sparse_matrix(adj)\n        (tensor([[0, 1, 1, 2, 2, 3],\n                [1, 0, 2, 1, 3, 2]]),\n        tensor([1., 1., 1., 1., 1., 1.]))\n    \"\"\"\n    A = A.tocoo()\n    row = torch.from_numpy(A.row).to(torch.long)\n    col = torch.from_numpy(A.col).to(torch.long)\n    edge_index = torch.stack([row, col], dim=0)\n    edge_weight = torch.from_numpy(A.data)\n    return edge_index, edge_weight\n\n\ndef to_networkx(\n    data: Union[\n        'torch_geometric.data.Data',\n        'torch_geometric.data.HeteroData',\n    ],\n    node_attrs: Optional[Iterable[str]] = None,\n    edge_attrs: Optional[Iterable[str]] = None,\n    graph_attrs: Optional[Iterable[str]] = None,\n    to_undirected: Optional[Union[bool, str]] = False,\n    to_multi: bool = False,\n    remove_self_loops: bool = False,\n) -> Any:\n    r\"\"\"Converts a :class:`torch_geometric.data.Data` instance to a\n    :obj:`networkx.Graph` if :attr:`to_undirected` is set to :obj:`True`, or\n    a directed :obj:`networkx.DiGraph` otherwise.\n\n    Args:\n        data (torch_geometric.data.Data or torch_geometric.data.HeteroData): A\n            homogeneous or heterogeneous data object.\n        node_attrs (iterable of str, optional): The node attributes to be\n            copied. (default: :obj:`None`)\n        edge_attrs (iterable of str, optional): The edge attributes to be\n            copied. (default: :obj:`None`)\n        graph_attrs (iterable of str, optional): The graph attributes to be\n            copied. (default: :obj:`None`)\n        to_undirected (bool or str, optional): If set to :obj:`True`, will\n            return a :class:`networkx.Graph` instead of a\n            :class:`networkx.DiGraph`.\n            By default, will include all edges and make them undirected.\n            If set to :obj:`\"upper\"`, the undirected graph will only correspond\n            to the upper triangle of the input adjacency matrix.\n            If set to :obj:`\"lower\"`, the undirected graph will only correspond\n            to the lower triangle of the input adjacency matrix.\n            Only applicable in case the :obj:`data` object holds a homogeneous\n            graph. (default: :obj:`False`)\n        to_multi (bool, optional): if set to :obj:`True`, will return a\n            :class:`networkx.MultiGraph` or a :class:`networkx:MultiDiGraph`\n            (depending on the :obj:`to_undirected` option), which will not drop\n            duplicated edges that may exist in :obj:`data`.\n            (default: :obj:`False`)\n        remove_self_loops (bool, optional): If set to :obj:`True`, will not\n            include self-loops in the resulting graph. (default: :obj:`False`)\n\n    Examples:\n        >>> edge_index = torch.tensor([\n        ...     [0, 1, 1, 2, 2, 3],\n        ...     [1, 0, 2, 1, 3, 2],\n        ... ])\n        >>> data = Data(edge_index=edge_index, num_nodes=4)\n        >>> to_networkx(data)\n        <networkx.classes.digraph.DiGraph at 0x2713fdb40d0>\n\n    \"\"\"\n    import networkx as nx\n\n    from torch_geometric.data import HeteroData\n\n    to_undirected_upper: bool = to_undirected == 'upper'\n    to_undirected_lower: bool = to_undirected == 'lower'\n\n    to_undirected = to_undirected is True\n    to_undirected |= to_undirected_upper or to_undirected_lower\n    assert isinstance(to_undirected, bool)\n\n    if isinstance(data, HeteroData) and to_undirected:\n        raise ValueError(\"'to_undirected' is not supported in \"\n                         \"'to_networkx' for heterogeneous graphs\")\n\n    if to_undirected:\n        G = nx.MultiGraph() if to_multi else nx.Graph()\n    else:\n        G = nx.MultiDiGraph() if to_multi else nx.DiGraph()\n\n    def to_networkx_value(value: Any) -> Any:\n        return value.tolist() if isinstance(value, Tensor) else value\n\n    for key in graph_attrs or []:\n        G.graph[key] = to_networkx_value(data[key])\n\n    node_offsets = data.node_offsets\n    for node_store in data.node_stores:\n        start = node_offsets[node_store._key]\n        assert node_store.num_nodes is not None\n        for i in range(node_store.num_nodes):\n            node_kwargs: Dict[str, Any] = {}\n            if isinstance(data, HeteroData):\n                node_kwargs['type'] = node_store._key\n            for key in node_attrs or []:\n                node_kwargs[key] = to_networkx_value(node_store[key][i])\n\n            G.add_node(start + i, **node_kwargs)\n\n    for edge_store in data.edge_stores:\n        for i, (v, w) in enumerate(edge_store.edge_index.t().tolist()):\n            if to_undirected_upper and v > w:\n                continue\n            elif to_undirected_lower and v < w:\n                continue\n            elif remove_self_loops and v == w and not edge_store.is_bipartite(\n            ):\n                continue\n\n            edge_kwargs: Dict[str, Any] = {}\n            if isinstance(data, HeteroData):\n                v = v + node_offsets[edge_store._key[0]]\n                w = w + node_offsets[edge_store._key[-1]]\n                edge_kwargs['type'] = edge_store._key\n            for key in edge_attrs or []:\n                edge_kwargs[key] = to_networkx_value(edge_store[key][i])\n\n            G.add_edge(v, w, **edge_kwargs)\n\n    return G\n\n\ndef from_networkx(\n    G: Any,\n    group_node_attrs: Optional[Union[List[str], Literal['all']]] = None,\n    group_edge_attrs: Optional[Union[List[str], Literal['all']]] = None,\n) -> 'torch_geometric.data.Data':\n    r\"\"\"Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a\n    :class:`torch_geometric.data.Data` instance.\n\n    Args:\n        G (networkx.Graph or networkx.DiGraph): A networkx graph.\n        group_node_attrs (List[str] or \"all\", optional): The node attributes to\n            be concatenated and added to :obj:`data.x`. (default: :obj:`None`)\n        group_edge_attrs (List[str] or \"all\", optional): The edge attributes to\n            be concatenated and added to :obj:`data.edge_attr`.\n            (default: :obj:`None`)\n\n    .. note::\n\n        All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must\n        be numeric.\n\n    Examples:\n        >>> edge_index = torch.tensor([\n        ...     [0, 1, 1, 2, 2, 3],\n        ...     [1, 0, 2, 1, 3, 2],\n        ... ])\n        >>> data = Data(edge_index=edge_index, num_nodes=4)\n        >>> g = to_networkx(data)\n        >>> # A `Data` object is returned\n        >>> from_networkx(g)\n        Data(edge_index=[2, 6], num_nodes=4)\n    \"\"\"\n    import networkx as nx\n\n    from torch_geometric.data import Data\n\n    G = G.to_directed() if not nx.is_directed(G) else G\n\n    mapping = dict(zip(G.nodes(), range(G.number_of_nodes())))\n    edge_index = torch.empty((2, G.number_of_edges()), dtype=torch.long)\n    for i, (src, dst) in enumerate(G.edges()):\n        edge_index[0, i] = mapping[src]\n        edge_index[1, i] = mapping[dst]\n\n    data_dict: Dict[str, Any] = defaultdict(list)\n    data_dict['edge_index'] = edge_index\n\n    node_attrs: List[str] = []\n    if G.number_of_nodes() > 0:\n        node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys())\n\n    edge_attrs: List[str] = []\n    if G.number_of_edges() > 0:\n        edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys())\n\n    if group_node_attrs is not None and not isinstance(group_node_attrs, list):\n        group_node_attrs = node_attrs\n\n    if group_edge_attrs is not None and not isinstance(group_edge_attrs, list):\n        group_edge_attrs = edge_attrs\n\n    for _, feat_dict in G.nodes(data=True):\n        if set(feat_dict.keys()) != set(node_attrs):\n            raise ValueError('Not all nodes contain the same attributes')\n        for key, value in feat_dict.items():\n            data_dict[str(key)].append(value)\n\n    for _, _, feat_dict in G.edges(data=True):\n        if set(feat_dict.keys()) != set(edge_attrs):\n            raise ValueError('Not all edges contain the same attributes')\n        for key, value in feat_dict.items():\n            key = f'edge_{key}' if key in node_attrs else key\n            data_dict[str(key)].append(value)\n\n    for key, value in G.graph.items():\n        if key == 'node_default' or key == 'edge_default':\n            continue  # Do not load default attributes.\n        key = f'graph_{key}' if key in node_attrs else key\n        data_dict[str(key)] = value\n\n    for key, value in data_dict.items():\n        if isinstance(value, (tuple, list)) and isinstance(value[0], Tensor):\n            data_dict[key] = torch.stack(value, dim=0)\n        else:\n            try:\n                data_dict[key] = torch.as_tensor(value)\n            except Exception:\n                pass\n\n    data = Data.from_dict(data_dict)\n\n    if group_node_attrs is not None:\n        xs = []\n        for key in group_node_attrs:\n            x = data[key]\n            x = x.view(-1, 1) if x.dim() <= 1 else x\n            xs.append(x)\n            del data[key]\n        data.x = torch.cat(xs, dim=-1)\n\n    if group_edge_attrs is not None:\n        xs = []\n        for key in group_edge_attrs:\n            key = f'edge_{key}' if key in node_attrs else key\n            x = data[key]\n            x = x.view(-1, 1) if x.dim() <= 1 else x\n            xs.append(x)\n            del data[key]\n        data.edge_attr = torch.cat(xs, dim=-1)\n\n    if data.x is None and data.pos is None:\n        data.num_nodes = G.number_of_nodes()\n\n    return data\n\n\ndef to_networkit(\n    edge_index: Tensor,\n    edge_weight: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n    directed: bool = True,\n) -> Any:\n    r\"\"\"Converts a :obj:`(edge_index, edge_weight)` tuple to a\n    :class:`networkit.Graph`.\n\n    Args:\n        edge_index (torch.Tensor): The edge indices of the graph.\n        edge_weight (torch.Tensor, optional): The edge weights of the graph.\n            (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes in the graph.\n            (default: :obj:`None`)\n        directed (bool, optional): If set to :obj:`False`, the graph will be\n            undirected. (default: :obj:`True`)\n    \"\"\"\n    import networkit as nk\n\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    g = nk.graph.Graph(\n        num_nodes,\n        weighted=edge_weight is not None,\n        directed=directed,\n    )\n\n    if edge_weight is None:\n        edge_weight = torch.ones(edge_index.size(1))\n\n    if not directed:\n        mask = edge_index[0] <= edge_index[1]\n        edge_index = edge_index[:, mask]\n        edge_weight = edge_weight[mask]\n\n    for (u, v), w in zip(edge_index.t().tolist(), edge_weight.tolist()):\n        g.addEdge(u, v, w)\n\n    return g\n\n\ndef from_networkit(g: Any) -> Tuple[Tensor, Optional[Tensor]]:\n    r\"\"\"Converts a :class:`networkit.Graph` to a\n    :obj:`(edge_index, edge_weight)` tuple.\n    If the :class:`networkit.Graph` is not weighted, the returned\n    :obj:`edge_weight` will be :obj:`None`.\n\n    Args:\n        g (networkkit.graph.Graph): A :obj:`networkit` graph object.\n    \"\"\"\n    is_directed = g.isDirected()\n    is_weighted = g.isWeighted()\n\n    edge_indices, edge_weights = [], []\n    for u, v, w in g.iterEdgesWeights():\n        edge_indices.append([u, v])\n        edge_weights.append(w)\n        if not is_directed:\n            edge_indices.append([v, u])\n            edge_weights.append(w)\n\n    edge_index = torch.tensor(edge_indices).t().contiguous()\n    edge_weight = torch.tensor(edge_weights) if is_weighted else None\n\n    return edge_index, edge_weight\n\n\ndef to_trimesh(data: 'torch_geometric.data.Data') -> Any:\n    r\"\"\"Converts a :class:`torch_geometric.data.Data` instance to a\n    :obj:`trimesh.Trimesh`.\n\n    Args:\n        data (torch_geometric.data.Data): The data object.\n\n    Example:\n        >>> pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]],\n        ...                    dtype=torch.float)\n        >>> face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()\n\n        >>> data = Data(pos=pos, face=face)\n        >>> to_trimesh(data)\n        <trimesh.Trimesh(vertices.shape=(4, 3), faces.shape=(2, 3))>\n    \"\"\"\n    import trimesh\n\n    assert data.pos is not None\n    assert data.face is not None\n\n    return trimesh.Trimesh(\n        vertices=data.pos.detach().cpu().numpy(),\n        faces=data.face.detach().t().cpu().numpy(),\n        process=False,\n    )\n\n\ndef from_trimesh(mesh: Any) -> 'torch_geometric.data.Data':\n    r\"\"\"Converts a :obj:`trimesh.Trimesh` to a\n    :class:`torch_geometric.data.Data` instance.\n\n    Args:\n        mesh (trimesh.Trimesh): A :obj:`trimesh` mesh.\n\n    Example:\n        >>> pos = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]],\n        ...                    dtype=torch.float)\n        >>> face = torch.tensor([[0, 1, 2], [1, 2, 3]]).t()\n\n        >>> data = Data(pos=pos, face=face)\n        >>> mesh = to_trimesh(data)\n        >>> from_trimesh(mesh)\n        Data(pos=[4, 3], face=[3, 2])\n    \"\"\"\n    from torch_geometric.data import Data\n\n    pos = torch.from_numpy(mesh.vertices).to(torch.float)\n    face = torch.from_numpy(mesh.faces).t().contiguous()\n\n    return Data(pos=pos, face=face)\n\n\ndef to_cugraph(\n    edge_index: Tensor,\n    edge_weight: Optional[Tensor] = None,\n    relabel_nodes: bool = True,\n    directed: bool = True,\n) -> Any:\n    r\"\"\"Converts a graph given by :obj:`edge_index` and optional\n    :obj:`edge_weight` into a :obj:`cugraph` graph object.\n\n    Args:\n        edge_index (torch.Tensor): The edge indices of the graph.\n        edge_weight (torch.Tensor, optional): The edge weights of the graph.\n            (default: :obj:`None`)\n        relabel_nodes (bool, optional): If set to :obj:`True`,\n            :obj:`cugraph` will remove any isolated nodes, leading to a\n            relabeling of nodes. (default: :obj:`True`)\n        directed (bool, optional): If set to :obj:`False`, the graph will be\n            undirected. (default: :obj:`True`)\n    \"\"\"\n    import cudf\n    import cugraph\n\n    g = cugraph.Graph(directed=directed)\n    df = cudf.from_dlpack(to_dlpack(edge_index.t()))\n\n    df = cudf.DataFrame({\n        'source':\n        cudf.from_dlpack(to_dlpack(edge_index[0])),\n        'destination':\n        cudf.from_dlpack(to_dlpack(edge_index[1])),\n    })\n\n    if edge_weight is not None:\n        assert edge_weight.dim() == 1\n        df['weight'] = cudf.from_dlpack(to_dlpack(edge_weight))\n\n    g.from_cudf_edgelist(\n        df,\n        source='source',\n        destination='destination',\n        edge_attr='weight' if edge_weight is not None else None,\n        renumber=relabel_nodes,\n    )\n\n    return g\n\n\ndef from_cugraph(g: Any) -> Tuple[Tensor, Optional[Tensor]]:\n    r\"\"\"Converts a :obj:`cugraph` graph object into :obj:`edge_index` and\n    optional :obj:`edge_weight` tensors.\n\n    Args:\n        g (cugraph.Graph): A :obj:`cugraph` graph object.\n    \"\"\"\n    df = g.view_edge_list()\n\n    src = from_dlpack(df[g.source_columns].to_dlpack()).long()\n    dst = from_dlpack(df[g.destination_columns].to_dlpack()).long()\n    edge_index = torch.stack([src, dst], dim=0)\n\n    edge_weight = None\n    if g.weight_column is not None:\n        edge_weight = from_dlpack(df[g.weight_column].to_dlpack())\n\n    return edge_index, edge_weight\n\n\ndef to_dgl(\n    data: Union['torch_geometric.data.Data', 'torch_geometric.data.HeteroData']\n) -> Any:\n    r\"\"\"Converts a :class:`torch_geometric.data.Data` or\n    :class:`torch_geometric.data.HeteroData` instance to a :obj:`dgl` graph\n    object.\n\n    Args:\n        data (torch_geometric.data.Data or torch_geometric.data.HeteroData):\n            The data object.\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 3, 0], [1, 0, 2, 1, 4, 4]])\n        >>> x = torch.randn(5, 3)\n        >>> edge_attr = torch.randn(6, 2)\n        >>> data = Data(x=x, edge_index=edge_index, edge_attr=y)\n        >>> g = to_dgl(data)\n        >>> g\n        Graph(num_nodes=5, num_edges=6,\n            ndata_schemes={'x': Scheme(shape=(3,))}\n            edata_schemes={'edge_attr': Scheme(shape=(2, ))})\n\n        >>> data = HeteroData()\n        >>> data['paper'].x = torch.randn(5, 3)\n        >>> data['author'].x = torch.ones(5, 3)\n        >>> edge_index = torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]])\n        >>> data['author', 'cites', 'paper'].edge_index = edge_index\n        >>> g = to_dgl(data)\n        >>> g\n        Graph(num_nodes={'author': 5, 'paper': 5},\n            num_edges={('author', 'cites', 'paper'): 5},\n            metagraph=[('author', 'paper', 'cites')])\n    \"\"\"\n    import dgl\n\n    from torch_geometric.data import Data, HeteroData\n\n    if isinstance(data, Data):\n        if data.edge_index is not None:\n            row, col = data.edge_index\n        elif 'adj' in data:\n            row, col, _ = data.adj.coo()\n        elif 'adj_t' in data:\n            row, col, _ = data.adj_t.t().coo()\n        else:\n            row, col = [], []\n\n        g = dgl.graph((row, col), num_nodes=data.num_nodes)\n\n        for attr in data.node_attrs():\n            g.ndata[attr] = data[attr]\n        for attr in data.edge_attrs():\n            if attr in ['edge_index', 'adj_t']:\n                continue\n            g.edata[attr] = data[attr]\n\n        return g\n\n    if isinstance(data, HeteroData):\n        data_dict = {}\n        for edge_type, edge_store in data.edge_items():\n            if edge_store.get('edge_index') is not None:\n                row, col = edge_store.edge_index\n            else:\n                row, col, _ = edge_store['adj_t'].t().coo()\n\n            data_dict[edge_type] = (row, col)\n\n        g = dgl.heterograph(data_dict)\n\n        for node_type, node_store in data.node_items():\n            for attr, value in node_store.items():\n                g.nodes[node_type].data[attr] = value\n\n        for edge_type, edge_store in data.edge_items():\n            for attr, value in edge_store.items():\n                if attr in ['edge_index', 'adj_t']:\n                    continue\n                g.edges[edge_type].data[attr] = value\n\n        return g\n\n    raise ValueError(f\"Invalid data type (got '{type(data)}')\")\n\n\ndef from_dgl(\n    g: Any,\n) -> Union['torch_geometric.data.Data', 'torch_geometric.data.HeteroData']:\n    r\"\"\"Converts a :obj:`dgl` graph object to a\n    :class:`torch_geometric.data.Data` or\n    :class:`torch_geometric.data.HeteroData` instance.\n\n    Args:\n        g (dgl.DGLGraph): The :obj:`dgl` graph object.\n\n    Example:\n        >>> g = dgl.graph(([0, 0, 1, 5], [1, 2, 2, 0]))\n        >>> g.ndata['x'] = torch.randn(g.num_nodes(), 3)\n        >>> g.edata['edge_attr'] = torch.randn(g.num_edges(), 2)\n        >>> data = from_dgl(g)\n        >>> data\n        Data(x=[6, 3], edge_attr=[4, 2], edge_index=[2, 4])\n\n        >>> g = dgl.heterograph({\n        >>> g = dgl.heterograph({\n        ...     ('author', 'writes', 'paper'): ([0, 1, 1, 2, 3, 3, 4],\n        ...                                     [0, 0, 1, 1, 1, 2, 2])})\n        >>> g.nodes['author'].data['x'] = torch.randn(5, 3)\n        >>> g.nodes['paper'].data['x'] = torch.randn(5, 3)\n        >>> data = from_dgl(g)\n        >>> data\n        HeteroData(\n        author={ x=[5, 3] },\n        paper={ x=[3, 3] },\n        (author, writes, paper)={ edge_index=[2, 7] }\n        )\n    \"\"\"\n    import dgl\n\n    from torch_geometric.data import Data, HeteroData\n\n    if not isinstance(g, dgl.DGLGraph):\n        raise ValueError(f\"Invalid data type (got '{type(g)}')\")\n\n    data: Union[Data, HeteroData]\n\n    if g.is_homogeneous:\n        data = Data()\n        data.edge_index = torch.stack(g.edges(), dim=0)\n\n        for attr, value in g.ndata.items():\n            data[attr] = value\n        for attr, value in g.edata.items():\n            data[attr] = value\n\n        return data\n\n    data = HeteroData()\n\n    for node_type in g.ntypes:\n        for attr, value in g.nodes[node_type].data.items():\n            data[node_type][attr] = value\n\n    for edge_type in g.canonical_etypes:\n        row, col = g.edges(form=\"uv\", etype=edge_type)\n        data[edge_type].edge_index = torch.stack([row, col], dim=0)\n        for attr, value in g.edge_attr_schemes(edge_type).items():\n            data[edge_type][attr] = value\n\n    return data\n"
  },
  {
    "path": "torch_geometric/utils/cross_entropy.py",
    "content": "from typing import Any, Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import scatter\n\n\nclass SparseCrossEntropy(torch.autograd.Function):\n    # We implement our own custom autograd function for this to avoid the\n    # double gradient computation to `inputs`.\n    @staticmethod\n    def forward(\n        ctx: Any,\n        inputs: Tensor,\n        edge_label_index: Tensor,\n        edge_label_weight: Optional[Tensor],\n    ) -> Tensor:\n        assert inputs.dim() == 2\n\n        # Support for both positive and negative weights:\n        # Positive weights scale the logits *after* softmax.\n        # Negative weights scale the denominator *before* softmax:\n        pos_y = edge_label_index\n        neg_y = pos_weight = neg_weight = None\n\n        if edge_label_weight is not None:\n            pos_mask = edge_label_weight >= 0\n            pos_y = edge_label_index[:, pos_mask]\n            pos_weight = edge_label_weight[pos_mask]\n\n            if pos_y.size(1) < edge_label_index.size(1):\n                neg_mask = ~pos_mask\n                neg_y = edge_label_index[:, neg_mask]\n                neg_weight = edge_label_weight[neg_mask]\n\n            if neg_y is not None and neg_weight is not None:\n                inputs = inputs.clone()\n                inputs[\n                    neg_y[0],\n                    neg_y[1],\n                ] += neg_weight.abs().log().clamp(min=1e-12)\n\n        logsumexp = inputs.logsumexp(dim=-1)\n        ctx.save_for_backward(inputs, pos_y, pos_weight, logsumexp)\n\n        out = inputs[pos_y[0], pos_y[1]]\n        out.neg_().add_(logsumexp[pos_y[0]])\n        if pos_weight is not None:\n            out *= pos_weight\n\n        return out.sum() / inputs.size(0)\n\n    @staticmethod\n    @torch.autograd.function.once_differentiable\n    def backward(ctx: Any, grad_out: Tensor) -> Tuple[Tensor, None, None]:\n        inputs, pos_y, pos_weight, logsumexp = ctx.saved_tensors\n\n        grad_out = grad_out / inputs.size(0)\n        grad_out = grad_out.expand(pos_y.size(1))\n\n        if pos_weight is not None:\n            grad_out = grad_out * pos_weight\n\n        grad_logsumexp = scatter(grad_out, pos_y[0], dim=0,\n                                 dim_size=inputs.size(0), reduce='sum')\n\n        # Gradient computation of `logsumexp`: `grad * (self - result).exp()`\n        grad_input = (inputs - logsumexp.view(-1, 1))\n        grad_input.exp_()\n        grad_input.mul_(grad_logsumexp.view(-1, 1))\n\n        grad_input[pos_y[0], pos_y[1]] -= grad_out\n\n        return grad_input, None, None\n\n\ndef sparse_cross_entropy(\n    inputs: Tensor,\n    edge_label_index: Tensor,\n    edge_label_weight: Optional[Tensor] = None,\n) -> Tensor:\n    r\"\"\"A sparse-label variant of :func:`torch.nn.functional.cross_entropy`.\n    In particular, the binary target matrix is solely given by sparse indices\n    :obj:`edge_label_index`.\n\n    Args:\n        inputs (torch.Tensor): The predicted unnormalized logits of shape\n            :obj:`[batch_size, num_classes]`.\n        edge_label_index (torch.Tensor): The sparse ground-truth indices with\n            shape :obj:`[2, num_labels]`.\n        edge_label_weight (torch.Tensor, optional): The weight of ground-truth\n            indices with shape :obj:`[num_labels]`. (default: :obj:`None`)\n\n    :rtype: :class:`torch.Tensor`\n\n    Example:\n        >>> inputs = torch.randn(2, 3)\n        >>> edge_label_index = torch.tensor([\n        ...     [0, 0, 1],\n        ...     [0, 1, 2],\n        ... ])\n        >>> loss = sparse_cross_entropy(inputs, edge_label_index)\n        tensor(1.2919)\n    \"\"\"\n    if edge_label_weight is not None:\n        assert not edge_label_weight.requires_grad\n\n    return SparseCrossEntropy.apply(\n        inputs,\n        edge_label_index,\n        edge_label_weight,\n    )\n"
  },
  {
    "path": "torch_geometric/utils/dropout.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric import is_compiling\nfrom torch_geometric.deprecation import deprecated\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import cumsum, degree, sort_edge_index, subgraph\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef filter_adj(row: Tensor, col: Tensor, edge_attr: OptTensor,\n               mask: Tensor) -> Tuple[Tensor, Tensor, OptTensor]:\n    return row[mask], col[mask], None if edge_attr is None else edge_attr[mask]\n\n\n@deprecated(\"use 'dropout_edge' instead\")\ndef dropout_adj(\n    edge_index: Tensor,\n    edge_attr: OptTensor = None,\n    p: float = 0.5,\n    force_undirected: bool = False,\n    num_nodes: Optional[int] = None,\n    training: bool = True,\n) -> Tuple[Tensor, OptTensor]:\n    r\"\"\"Randomly drops edges from the adjacency matrix\n    :obj:`(edge_index, edge_attr)` with probability :obj:`p` using samples from\n    a Bernoulli distribution.\n\n    .. warning::\n\n        :class:`~torch_geometric.utils.dropout_adj` is deprecated and will\n        be removed in a future release.\n        Use :class:`torch_geometric.utils.dropout_edge` instead.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional\n            edge features. (default: :obj:`None`)\n        p (float, optional): Dropout probability. (default: :obj:`0.5`)\n        force_undirected (bool, optional): If set to :obj:`True`, will either\n            drop or keep both edges of an undirected edge.\n            (default: :obj:`False`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n        training (bool, optional): If set to :obj:`False`, this operation is a\n            no-op. (default: :obj:`True`)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n        ...                            [1, 0, 2, 1, 3, 2]])\n        >>> edge_attr = torch.tensor([1, 2, 3, 4, 5, 6])\n        >>> dropout_adj(edge_index, edge_attr)\n        (tensor([[0, 1, 2, 3],\n                [1, 2, 3, 2]]),\n        tensor([1, 3, 5, 6]))\n\n        >>> # The returned graph is kept undirected\n        >>> dropout_adj(edge_index, edge_attr, force_undirected=True)\n        (tensor([[0, 1, 2, 1, 2, 3],\n                [1, 2, 3, 0, 1, 2]]),\n        tensor([1, 3, 5, 1, 3, 5]))\n    \"\"\"\n    if p < 0. or p > 1.:\n        raise ValueError(f'Dropout probability has to be between 0 and 1 '\n                         f'(got {p}')\n\n    if not training or p == 0.0:\n        return edge_index, edge_attr\n\n    row, col = edge_index\n\n    mask = torch.rand(row.size(0), device=edge_index.device) >= p\n\n    if force_undirected:\n        mask[row > col] = False\n\n    row, col, edge_attr = filter_adj(row, col, edge_attr, mask)\n\n    if force_undirected:\n        edge_index = torch.stack(\n            [torch.cat([row, col], dim=0),\n             torch.cat([col, row], dim=0)], dim=0)\n        if edge_attr is not None:\n            edge_attr = torch.cat([edge_attr, edge_attr], dim=0)\n    else:\n        edge_index = torch.stack([row, col], dim=0)\n\n    return edge_index, edge_attr\n\n\ndef dropout_node(\n    edge_index: Tensor,\n    p: float = 0.5,\n    num_nodes: Optional[int] = None,\n    training: bool = True,\n    relabel_nodes: bool = False,\n) -> Tuple[Tensor, Tensor, Tensor]:\n    r\"\"\"Randomly drops nodes from the adjacency matrix\n    :obj:`edge_index` with probability :obj:`p` using samples from\n    a Bernoulli distribution.\n\n    The method returns (1) the retained :obj:`edge_index`, (2) the edge mask\n    indicating which edges were retained. (3) the node mask indicating\n    which nodes were retained.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        p (float, optional): Dropout probability. (default: :obj:`0.5`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n        training (bool, optional): If set to :obj:`False`, this operation is a\n            no-op. (default: :obj:`True`)\n        relabel_nodes (bool, optional): If set to `True`, the resulting\n            `edge_index` will be relabeled to hold consecutive indices\n            starting from zero.\n\n    :rtype: (:class:`LongTensor`, :class:`BoolTensor`, :class:`BoolTensor`)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n        ...                            [1, 0, 2, 1, 3, 2]])\n        >>> edge_index, edge_mask, node_mask = dropout_node(edge_index)\n        >>> edge_index\n        tensor([[0, 1],\n                [1, 0]])\n        >>> edge_mask\n        tensor([ True,  True, False, False, False, False])\n        >>> node_mask\n        tensor([ True,  True, False, False])\n    \"\"\"\n    if p < 0. or p > 1.:\n        raise ValueError(f'Dropout probability has to be between 0 and 1 '\n                         f'(got {p}')\n\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    if not training or p == 0.0:\n        node_mask = edge_index.new_ones(num_nodes, dtype=torch.bool)\n        edge_mask = edge_index.new_ones(edge_index.size(1), dtype=torch.bool)\n        return edge_index, edge_mask, node_mask\n\n    prob = torch.rand(num_nodes, device=edge_index.device)\n    node_mask = prob > p\n    edge_index, _, edge_mask = subgraph(\n        node_mask,\n        edge_index,\n        relabel_nodes=relabel_nodes,\n        num_nodes=num_nodes,\n        return_edge_mask=True,\n    )\n    return edge_index, edge_mask, node_mask\n\n\ndef dropout_edge(edge_index: Tensor, p: float = 0.5,\n                 force_undirected: bool = False,\n                 training: bool = True) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Randomly drops edges from the adjacency matrix\n    :obj:`edge_index` with probability :obj:`p` using samples from\n    a Bernoulli distribution.\n\n    The method returns (1) the retained :obj:`edge_index`, (2) the edge mask\n    or index indicating which edges were retained, depending on the argument\n    :obj:`force_undirected`.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        p (float, optional): Dropout probability. (default: :obj:`0.5`)\n        force_undirected (bool, optional): If set to :obj:`True`, will either\n            drop or keep both edges of an undirected edge.\n            (default: :obj:`False`)\n        training (bool, optional): If set to :obj:`False`, this operation is a\n            no-op. (default: :obj:`True`)\n\n    :rtype: (:class:`LongTensor`, :class:`BoolTensor` or :class:`LongTensor`)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n        ...                            [1, 0, 2, 1, 3, 2]])\n        >>> edge_index, edge_mask = dropout_edge(edge_index)\n        >>> edge_index\n        tensor([[0, 1, 2, 2],\n                [1, 2, 1, 3]])\n        >>> edge_mask # masks indicating which edges are retained\n        tensor([ True, False,  True,  True,  True, False])\n\n        >>> edge_index, edge_id = dropout_edge(edge_index,\n        ...                                    force_undirected=True)\n        >>> edge_index\n        tensor([[0, 1, 2, 1, 2, 3],\n                [1, 2, 3, 0, 1, 2]])\n        >>> edge_id # indices indicating which edges are retained\n        tensor([0, 2, 4, 0, 2, 4])\n    \"\"\"\n    if p < 0. or p > 1.:\n        raise ValueError(f'Dropout probability has to be between 0 and 1 '\n                         f'(got {p}')\n\n    if not training or p == 0.0:\n        edge_mask = edge_index.new_ones(edge_index.size(1), dtype=torch.bool)\n        return edge_index, edge_mask\n\n    row, col = edge_index\n\n    edge_mask = torch.rand(row.size(0), device=edge_index.device) >= p\n\n    if force_undirected:\n        edge_mask[row > col] = False\n\n    edge_index = edge_index[:, edge_mask]\n\n    if force_undirected:\n        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)\n        edge_mask = edge_mask.nonzero().repeat((2, 1)).squeeze()\n\n    return edge_index, edge_mask\n\n\ndef dropout_path(edge_index: Tensor, p: float = 0.2, walks_per_node: int = 1,\n                 walk_length: int = 3, num_nodes: Optional[int] = None,\n                 is_sorted: bool = False,\n                 training: bool = True) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Drops edges from the adjacency matrix :obj:`edge_index`\n    based on random walks. The source nodes to start random walks from are\n    sampled from :obj:`edge_index` with probability :obj:`p`, following\n    a Bernoulli distribution.\n\n    The method returns (1) the retained :obj:`edge_index`, (2) the edge mask\n    indicating which edges were retained.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        p (float, optional): Sample probability. (default: :obj:`0.2`)\n        walks_per_node (int, optional): The number of walks per node, same as\n            :class:`~torch_geometric.nn.models.Node2Vec`. (default: :obj:`1`)\n        walk_length (int, optional): The walk length, same as\n            :class:`~torch_geometric.nn.models.Node2Vec`. (default: :obj:`3`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n        is_sorted (bool, optional): If set to :obj:`True`, will expect\n            :obj:`edge_index` to be already sorted row-wise.\n            (default: :obj:`False`)\n        training (bool, optional): If set to :obj:`False`, this operation is a\n            no-op. (default: :obj:`True`)\n\n    :rtype: (:class:`LongTensor`, :class:`BoolTensor`)\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n        ...                            [1, 0, 2, 1, 3, 2]])\n        >>> edge_index, edge_mask = dropout_path(edge_index)\n        >>> edge_index\n        tensor([[1, 2],\n                [2, 3]])\n        >>> edge_mask # masks indicating which edges are retained\n        tensor([False, False,  True, False,  True, False])\n    \"\"\"\n    if p < 0. or p > 1.:\n        raise ValueError(f'Sample probability has to be between 0 and 1 '\n                         f'(got {p}')\n\n    num_edges = edge_index.size(1)\n    edge_mask = edge_index.new_ones(num_edges, dtype=torch.bool)\n    if not training or p == 0.0:\n        return edge_index, edge_mask\n\n    if not torch_geometric.typing.WITH_TORCH_CLUSTER or is_compiling():\n        raise ImportError('`dropout_path` requires `torch-cluster`.')\n\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n    edge_orders = None\n    ori_edge_index = edge_index\n    if not is_sorted:\n        edge_orders = torch.arange(num_edges, device=edge_index.device)\n        edge_index, edge_orders = sort_edge_index(edge_index, edge_orders,\n                                                  num_nodes=num_nodes)\n\n    row, col = edge_index\n    sample_mask = torch.rand(row.size(0), device=edge_index.device) <= p\n    start = row[sample_mask].repeat(walks_per_node)\n\n    rowptr = cumsum(degree(row, num_nodes=num_nodes, dtype=torch.long))\n    n_id, e_id = torch.ops.torch_cluster.random_walk(rowptr, col, start,\n                                                     walk_length, 1.0, 1.0)\n    e_id = e_id[e_id != -1].view(-1)  # filter illegal edges\n\n    if edge_orders is not None:  # Permute edge indices:\n        e_id = edge_orders[e_id]\n    edge_mask[e_id] = False\n    edge_index = ori_edge_index[:, edge_mask]\n\n    return edge_index, edge_mask\n"
  },
  {
    "path": "torch_geometric/utils/embedding.py",
    "content": "import warnings\nfrom typing import Any, Dict, List, Optional, Type\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import NodeType\n\n\ndef get_embeddings(\n    model: torch.nn.Module,\n    *args: Any,\n    **kwargs: Any,\n) -> List[Tensor]:\n    \"\"\"Returns the output embeddings of all\n    :class:`~torch_geometric.nn.conv.MessagePassing` layers in\n    :obj:`model`.\n\n    Internally, this method registers forward hooks on all\n    :class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`,\n    and runs the forward pass of the :obj:`model` by calling\n    :obj:`model(*args, **kwargs)`.\n\n    Args:\n        model (torch.nn.Module): The message passing model.\n        *args: Arguments passed to the model.\n        **kwargs (optional): Additional keyword arguments passed to the model.\n    \"\"\"\n    from torch_geometric.nn import MessagePassing\n\n    embeddings: List[Tensor] = []\n\n    def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:\n        # Clone output in case it will be later modified in-place:\n        outputs = outputs[0] if isinstance(outputs, tuple) else outputs\n        assert isinstance(outputs, Tensor)\n        embeddings.append(outputs.clone())\n\n    hook_handles = []\n    for module in model.modules():  # Register forward hooks:\n        if isinstance(module, MessagePassing):\n            hook_handles.append(module.register_forward_hook(hook))\n\n    if len(hook_handles) == 0:\n        warnings.warn(\"The 'model' does not have any 'MessagePassing' layers\",\n                      stacklevel=2)\n\n    training = model.training\n    model.eval()\n    with torch.no_grad():\n        model(*args, **kwargs)\n    model.train(training)\n\n    for handle in hook_handles:  # Remove hooks:\n        handle.remove()\n\n    return embeddings\n\n\ndef get_embeddings_hetero(\n    model: torch.nn.Module,\n    supported_models: Optional[List[Type[torch.nn.Module]]] = None,\n    *args: Any,\n    **kwargs: Any,\n) -> Dict[NodeType, List[Tensor]]:\n    \"\"\"Returns the output embeddings of all\n    :class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous\n    :obj:`model`, organized by edge type.\n\n    Internally, this method registers forward hooks on all modules that process\n    heterogeneous graphs in the model and runs the forward pass of the model.\n    For heterogeneous models, the output is a dictionary where each key is a\n    node type and each value is a list of embeddings from different layers.\n\n    Args:\n        model (torch.nn.Module): The heterogeneous GNN model.\n        supported_models (List[Type[torch.nn.Module]], optional): A list of\n            supported model classes. If not provided, defaults to\n            [HGTConv, HANConv, HeteroConv].\n        *args: Arguments passed to the model.\n        **kwargs (optional): Additional keyword arguments passed to the model.\n\n    Returns:\n        Dict[NodeType, List[Tensor]]: A dictionary mapping each node type to\n        a list of embeddings from different layers.\n    \"\"\"\n    from torch_geometric.nn import HANConv, HeteroConv, HGTConv\n    if not supported_models:\n        supported_models = [HGTConv, HANConv, HeteroConv]\n\n    # Dictionary to store node embeddings by type\n    node_embeddings_dict: Dict[NodeType, List[Tensor]] = {}\n\n    # Hook function to capture node embeddings\n    def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:\n        # Check if the outputs is a dictionary mapping node types to embeddings\n        if isinstance(outputs, dict) and outputs:\n            # Store embeddings for each node type\n            for node_type, embedding in outputs.items():\n                # Made sure that the outputs are a dictionary mapping node\n                # types to embeddings and remove the false positives.\n                if node_type not in node_embeddings_dict:\n                    node_embeddings_dict[node_type] = []\n                node_embeddings_dict[node_type].append(embedding.clone())\n\n    # List to store hook handles\n    hook_handles = []\n\n    # Find ModuleDict objects in the model\n    for _, module in model.named_modules():\n        # Handle the native heterogenous models, e.g. HGTConv, HANConv\n        # and HeteroConv, etc.\n        if isinstance(module, tuple(supported_models)):\n            hook_handles.append(module.register_forward_hook(hook))\n        else:\n            # Handle the heterogenous models that are generated by calling\n            # to_hetero() on the homogeneous models.\n            submodules = list(module.children())\n            submodules_contains_module_dict = any([\n                isinstance(submodule, torch.nn.ModuleDict)\n                for submodule in submodules\n            ])\n            if submodules_contains_module_dict:\n                hook_handles.append(module.register_forward_hook(hook))\n\n    if len(hook_handles) == 0:\n        warnings.warn(\n            \"The 'model' does not have any heterogenous \"\n            \"'MessagePassing' layers\", stacklevel=2)\n\n    # Run the model forward pass\n    training = model.training\n    model.eval()\n\n    with torch.no_grad():\n        model(*args, **kwargs)\n    model.train(training)\n\n    # Clean up hooks\n    for handle in hook_handles:\n        handle.remove()\n\n    return node_embeddings_dict\n"
  },
  {
    "path": "torch_geometric/utils/functions.py",
    "content": "import torch\nfrom torch import Tensor\n\n\ndef cumsum(x: Tensor, dim: int = 0) -> Tensor:\n    r\"\"\"Returns the cumulative sum of elements of :obj:`x`.\n    In contrast to :meth:`torch.cumsum`, prepends the output with zero.\n\n    Args:\n        x (torch.Tensor): The input tensor.\n        dim (int, optional): The dimension to do the operation over.\n            (default: :obj:`0`)\n\n    Example:\n        >>> x = torch.tensor([2, 4, 1])\n        >>> cumsum(x)\n        tensor([0, 2, 6, 7])\n\n    \"\"\"\n    size = x.size()[:dim] + (x.size(dim) + 1, ) + x.size()[dim + 1:]\n    out = x.new_empty(size)\n\n    out.narrow(dim, 0, 1).zero_()\n    torch.cumsum(x, dim=dim, out=out.narrow(dim, 1, x.size(dim)))\n\n    return out\n"
  },
  {
    "path": "torch_geometric/utils/geodesic.py",
    "content": "import multiprocessing as mp\nimport warnings\nfrom typing import Optional\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\n\ndef geodesic_distance(  # noqa: D417\n    pos: Tensor,\n    face: Tensor,\n    src: Optional[Tensor] = None,\n    dst: Optional[Tensor] = None,\n    norm: bool = True,\n    max_distance: Optional[float] = None,\n    num_workers: int = 0,\n    # Backward compatibility for `dest`:\n    **kwargs: Optional[Tensor],\n) -> Tensor:\n    r\"\"\"Computes (normalized) geodesic distances of a mesh given by :obj:`pos`\n    and :obj:`face`. If :obj:`src` and :obj:`dst` are given, this method only\n    computes the geodesic distances for the respective source and target\n    node-pairs.\n\n    .. note::\n\n        This function requires the :obj:`gdist` package.\n        To install, run :obj:`pip install cython && pip install gdist`.\n\n    Args:\n        pos (torch.Tensor): The node positions.\n        face (torch.Tensor): The face indices.\n        src (torch.Tensor, optional): If given, only compute geodesic distances\n            for the specified source indices. (default: :obj:`None`)\n        dst (torch.Tensor, optional): If given, only compute geodesic distances\n            for the specified target indices. (default: :obj:`None`)\n        norm (bool, optional): Normalizes geodesic distances by\n            :math:`\\sqrt{\\textrm{area}(\\mathcal{M})}`. (default: :obj:`True`)\n        max_distance (float, optional): If given, only yields results for\n            geodesic distances less than :obj:`max_distance`. This will speed\n            up runtime dramatically. (default: :obj:`None`)\n        num_workers (int, optional): How many subprocesses to use for\n            calculating geodesic distances.\n            :obj:`0` means that computation takes place in the main process.\n            :obj:`-1` means that the available amount of CPU cores is used.\n            (default: :obj:`0`)\n\n    :rtype: :class:`Tensor`\n\n    Example:\n        >>> pos = torch.tensor([[0.0, 0.0, 0.0],\n        ...                     [2.0, 0.0, 0.0],\n        ...                     [0.0, 2.0, 0.0],\n        ...                     [2.0, 2.0, 0.0]])\n        >>> face = torch.tensor([[0, 0],\n        ...                      [1, 2],\n        ...                      [3, 3]])\n        >>> geodesic_distance(pos, face)\n        [[0, 1, 1, 1.4142135623730951],\n        [1, 0, 1.4142135623730951, 1],\n        [1, 1.4142135623730951, 0, 1],\n        [1.4142135623730951, 1, 1, 0]]\n    \"\"\"\n    import gdist\n\n    if 'dest' in kwargs:\n        dst = kwargs['dest']\n        warnings.warn(\n            \"'dest' attribute in 'geodesic_distance' is deprecated \"\n            \"and will be removed in a future release. Use the 'dst' \"\n            \"argument instead.\", stacklevel=2)\n\n    max_distance = float('inf') if max_distance is None else max_distance\n\n    if norm:\n        area = (pos[face[1]] - pos[face[0]]).cross(\n            pos[face[2]] - pos[face[0]],\n            dim=1,\n        )\n        scale = float((area.norm(p=2, dim=1) / 2).sum().sqrt())\n    else:\n        scale = 1.0\n\n    dtype = pos.dtype\n\n    pos_np = pos.detach().cpu().to(torch.double).numpy()\n    face_np = face.detach().t().cpu().to(torch.int).numpy()\n\n    if src is None and dst is None:\n        out = gdist.local_gdist_matrix(\n            pos_np,\n            face_np,\n            max_distance * scale,\n        ).toarray() / scale\n        return torch.from_numpy(out).to(dtype)\n\n    if src is None:\n        src_np = torch.arange(pos.size(0), dtype=torch.int).numpy()\n    else:\n        src_np = src.detach().cpu().to(torch.int).numpy()\n\n    dst_np = None if dst is None else dst.detach().cpu().to(torch.int).numpy()\n\n    def _parallel_loop(\n        pos_np: np.ndarray,\n        face_np: np.ndarray,\n        src_np: np.ndarray,\n        dst_np: Optional[np.ndarray],\n        max_distance: float,\n        scale: float,\n        i: int,\n        dtype: torch.dtype,\n    ) -> Tensor:\n        s = src_np[i:i + 1]\n        d = None if dst_np is None else dst_np[i:i + 1]\n        out = gdist.compute_gdist(pos_np, face_np, s, d, max_distance * scale)\n        out = out / scale\n        return torch.from_numpy(out).to(dtype)\n\n    num_workers = mp.cpu_count() if num_workers <= -1 else num_workers\n    if num_workers > 0:\n        with mp.Pool(num_workers) as pool:\n            data = [(pos_np, face_np, src_np, dst_np, max_distance, scale, i,\n                     dtype) for i in range(len(src_np))]\n            outs = pool.starmap(_parallel_loop, data)\n    else:\n        outs = [\n            _parallel_loop(pos_np, face_np, src_np, dst_np, max_distance,\n                           scale, i, dtype) for i in range(len(src_np))\n        ]\n\n    out = torch.cat(outs, dim=0)\n\n    if dst is None:\n        out = out.view(-1, pos.size(0))\n\n    return out\n"
  },
  {
    "path": "torch_geometric/utils/hetero.py",
    "content": "from typing import Dict, List, Optional, Set, Tuple, Union\n\nimport torch\nfrom torch import Tensor\nfrom torch.nn import ParameterDict\n\nfrom torch_geometric.typing import Adj, EdgeType, NodeType, SparseTensor\nfrom torch_geometric.utils import is_sparse, to_edge_index\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes_dict\n\n\ndef group_hetero_graph(\n    edge_index_dict: Dict[EdgeType, Tensor],\n    num_nodes_dict: Optional[Dict[NodeType, int]] = None,\n) -> Tuple[\n        Tensor,\n        Tensor,\n        Tensor,\n        Tensor,\n        Dict[Union[str, int], Tensor],\n        Dict[Union[NodeType, EdgeType], int],\n]:\n    num_nodes_dict = maybe_num_nodes_dict(edge_index_dict, num_nodes_dict)\n\n    tmp = list(edge_index_dict.values())[0]\n\n    key2int: Dict[Union[NodeType, EdgeType], int] = {}\n\n    cumsum, offset = 0, {}  # Helper data.\n    node_types, local_node_indices = [], []\n    local2global: Dict[Union[str, int], Tensor] = {}\n    for i, (key, N) in enumerate(num_nodes_dict.items()):\n        key2int[key] = i\n        node_types.append(tmp.new_full((N, ), i))\n        local_node_indices.append(torch.arange(N, device=tmp.device))\n        offset[key] = cumsum\n        local2global[key] = local_node_indices[-1] + cumsum\n        local2global[i] = local2global[key]\n        cumsum += N\n\n    node_type = torch.cat(node_types, dim=0)\n    local_node_idx = torch.cat(local_node_indices, dim=0)\n\n    edge_indices, edge_types = [], []\n    for i, (keys, edge_index) in enumerate(edge_index_dict.items()):\n        key2int[keys] = i\n        inc = torch.tensor([offset[keys[0]], offset[keys[-1]]]).view(2, 1)\n        edge_indices.append(edge_index + inc.to(tmp.device))\n        edge_types.append(tmp.new_full((edge_index.size(1), ), i))\n\n    edge_index = torch.cat(edge_indices, dim=-1)\n    edge_type = torch.cat(edge_types, dim=0)\n\n    return (\n        edge_index,\n        edge_type,\n        node_type,\n        local_node_idx,\n        local2global,\n        key2int,\n    )\n\n\ndef get_unused_node_types(node_types: List[NodeType],\n                          edge_types: List[EdgeType]) -> Set[NodeType]:\n    dst_node_types = {edge_type[-1] for edge_type in edge_types}\n    return set(node_types) - set(dst_node_types)\n\n\ndef check_add_self_loops(\n    module: torch.nn.Module,\n    edge_types: List[EdgeType],\n) -> None:\n    is_bipartite = any([key[0] != key[-1] for key in edge_types])\n    if is_bipartite and getattr(module, 'add_self_loops', False):\n        raise ValueError(\n            f\"'add_self_loops' attribute set to 'True' on module '{module}' \"\n            f\"for use with edge type(s) '{edge_types}'. This will lead to \"\n            f\"incorrect message passing results.\")\n\n\ndef construct_bipartite_edge_index(\n    edge_index_dict: Dict[EdgeType, Adj],\n    src_offset_dict: Dict[EdgeType, int],\n    dst_offset_dict: Dict[NodeType, int],\n    edge_attr_dict: Optional[Dict[EdgeType, Tensor]] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Adj, Optional[Tensor]]:\n    \"\"\"Constructs a tensor of edge indices by concatenating edge indices\n    for each edge type. The edge indices are increased by the offset of the\n    source and destination nodes.\n\n    Args:\n        edge_index_dict (Dict[Tuple[str, str, str], torch.Tensor]): A\n            dictionary holding graph connectivity information for each\n            individual edge type, either as a :class:`torch.Tensor` of\n            shape :obj:`[2, num_edges]` or a\n            :class:`torch_sparse.SparseTensor`.\n        src_offset_dict (Dict[Tuple[str, str, str], int]): A dictionary of\n            offsets to apply to the source node type for each edge type.\n        dst_offset_dict (Dict[str, int]): A dictionary of offsets to apply for\n            destination node types.\n        edge_attr_dict (Dict[Tuple[str, str, str], torch.Tensor]): A\n            dictionary holding edge features for each individual edge type.\n            (default: :obj:`None`)\n        num_nodes (int, optional): The final number of nodes in the bipartite\n            adjacency matrix. (default: :obj:`None`)\n    \"\"\"\n    is_sparse_tensor = False\n    edge_indices: List[Tensor] = []\n    edge_attrs: List[Tensor] = []\n    for edge_type, src_offset in src_offset_dict.items():\n        edge_index = edge_index_dict[edge_type]\n        dst_offset = dst_offset_dict[edge_type[-1]]\n\n        # TODO Add support for SparseTensor w/o converting.\n        is_sparse_tensor = isinstance(edge_index, SparseTensor)\n        if is_sparse(edge_index):\n            edge_index, _ = to_edge_index(edge_index)\n            edge_index = edge_index.flip([0])\n        else:\n            edge_index = edge_index.clone()\n\n        edge_index[0] += src_offset\n        edge_index[1] += dst_offset\n        edge_indices.append(edge_index)\n\n        if edge_attr_dict is not None:\n            if isinstance(edge_attr_dict, ParameterDict):\n                value = edge_attr_dict['__'.join(edge_type)]\n            else:\n                value = edge_attr_dict[edge_type]\n            if value.size(0) != edge_index.size(1):\n                value = value.expand(edge_index.size(1), -1)\n            edge_attrs.append(value)\n\n    edge_index = torch.cat(edge_indices, dim=1)\n\n    edge_attr: Optional[Tensor] = None\n    if edge_attr_dict is not None:\n        edge_attr = torch.cat(edge_attrs, dim=0)\n\n    if is_sparse_tensor:\n        edge_index = SparseTensor(\n            row=edge_index[1],\n            col=edge_index[0],\n            value=edge_attr,\n            sparse_sizes=(num_nodes, num_nodes),\n        )\n\n    return edge_index, edge_attr\n"
  },
  {
    "path": "torch_geometric/utils/influence.py",
    "content": "from typing import List, Tuple, Union, cast\n\nimport torch\nfrom torch import Tensor\nfrom torch.autograd.functional import jacobian\nfrom tqdm.auto import tqdm\n\nfrom torch_geometric.data import Data\nfrom torch_geometric.utils import k_hop_subgraph\n\n\ndef k_hop_subsets_rough(\n    node_idx: int,\n    num_hops: int,\n    edge_index: Tensor,\n    num_nodes: int,\n) -> List[Tensor]:\n    r\"\"\"Return *rough* (possibly overlapping) *k*-hop node subsets.\n\n    This is a thin wrapper around\n    :pyfunc:`torch_geometric.utils.k_hop_subgraph` that *additionally* returns\n    **all** intermediate hop subsets rather than the full union only.\n\n    Parameters\n    ----------\n    node_idx: int\n        Index or indices of the central node(s).\n    num_hops: int\n        Number of hops *k*.\n    edge_index: Tensor\n        Edge index in COO format with shape :math:`[2, \\text{num_edges}]`.\n    num_nodes: int\n        Total number of nodes in the graph. Required to allocate the masks.\n\n    Returns:\n    -------\n    List[Tensor]\n        A list ``[H₀, H₁, …, H_k]`` where ``H₀`` contains the seed node(s) and\n        ``H_i`` (for *i*>0) contains **all** nodes that are exactly *i* hops\n        away in the *expanded* neighbourhood (i.e. overlaps are *not*\n        removed).\n    \"\"\"\n    col, row = edge_index\n\n    node_mask = row.new_empty(num_nodes, dtype=torch.bool)\n    edge_mask = row.new_empty(row.size(0), dtype=torch.bool)\n\n    node_idx_ = torch.tensor([node_idx], device=row.device)\n\n    subsets = [node_idx_]\n    for _ in range(num_hops):\n        node_mask.zero_()\n        node_mask[subsets[-1]] = True\n        torch.index_select(node_mask, 0, row, out=edge_mask)\n        subsets.append(col[edge_mask])\n\n    return subsets\n\n\ndef k_hop_subsets_exact(\n    node_idx: int,\n    num_hops: int,\n    edge_index: Tensor,\n    num_nodes: int,\n    device: Union[torch.device, str],\n) -> List[Tensor]:\n    \"\"\"Return **disjoint** *k*-hop subsets.\n\n    This function refines :pyfunc:`k_hop_subsets_rough` by removing nodes that\n    have already appeared in previous hops, ensuring that each subset contains\n    nodes *exactly* *i* hops away from the seed.\n    \"\"\"\n    rough_subsets = k_hop_subsets_rough(node_idx, num_hops, edge_index,\n                                        num_nodes)\n\n    exact_subsets: List[List[int]] = [rough_subsets[0].tolist()]\n    visited: set[int] = set(exact_subsets[0])\n\n    for hop_subset in rough_subsets[1:]:\n        fresh = set(hop_subset.tolist()) - visited\n        visited |= fresh\n        exact_subsets.append(list(fresh))\n\n    return [\n        torch.tensor(s, device=device, dtype=edge_index.dtype)\n        for s in exact_subsets\n    ]\n\n\ndef jacobian_l1(\n    model: torch.nn.Module,\n    data: Data,\n    max_hops: int,\n    node_idx: int,\n    device: Union[torch.device, str],\n    *,\n    vectorize: bool = True,\n) -> Tensor:\n    \"\"\"Compute the **L1 norm** of the Jacobian for a given node.\n\n    The Jacobian is evaluated w.r.t. the node features of the *k*-hop induced\n    sub‑graph centred at ``node_idx``. The result is *folded back* onto the\n    **original** node index space so that the returned tensor has length\n    ``data.num_nodes``, where the influence score will be zero for nodes\n    outside the *k*-hop subgraph.\n\n    Notes:\n    -----\n    *   The function assumes that the model *and* ``data.x`` share the same\n        floating‑point precision (e.g. both ``float32`` or both ``float16``).\n\n    \"\"\"\n    # Build the induced *k*-hop sub‑graph (with node re‑labelling).\n    edge_index = cast(Tensor, data.edge_index)\n    x = cast(Tensor, data.x)\n    k_hop_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(\n        node_idx, max_hops, edge_index, relabel_nodes=True)\n    # get the location of the *center* node inside the sub‑graph\n    root_pos = cast(int, mapping[0])\n\n    # Move tensors & model to the correct device\n    device = torch.device(device)\n    sub_x = x[k_hop_nodes].to(device)\n    sub_edge_index = sub_edge_index.to(device)\n    model = model.to(device)\n\n    # Jacobian evaluation\n    def _forward(x: Tensor) -> Tensor:\n        return model(x, sub_edge_index)[root_pos]\n\n    jac = jacobian(_forward, sub_x, vectorize=vectorize)\n    influence_sub = jac.abs().sum(dim=(0, 2))  # Sum of L1 norm\n    num_nodes = cast(int, data.num_nodes)\n    # Scatter the influence scores back to the *global* node space\n    influence_full = torch.zeros(num_nodes, dtype=influence_sub.dtype,\n                                 device=device)\n    influence_full[k_hop_nodes] = influence_sub\n\n    return influence_full\n\n\ndef jacobian_l1_agg_per_hop(\n    model: torch.nn.Module,\n    data: Data,\n    max_hops: int,\n    node_idx: int,\n    device: Union[torch.device, str],\n    vectorize: bool = True,\n) -> Tensor:\n    \"\"\"Aggregate Jacobian L1 norms **per hop** for node_idx.\n\n    Returns a vector ``[I_0, I_1, …, I_k]`` where ``I_i`` is the *total*\n    influence exerted by nodes that are exactly *i* hops away from\n    ``node_idx``.\n    \"\"\"\n    num_nodes = cast(int, data.num_nodes)\n    edge_index = cast(Tensor, data.edge_index)\n    influence = jacobian_l1(model, data, max_hops, node_idx, device,\n                            vectorize=vectorize)\n    hop_subsets = k_hop_subsets_exact(node_idx, max_hops, edge_index,\n                                      num_nodes, influence.device)\n    single_node_influence_per_hop = [influence[s].sum() for s in hop_subsets]\n    return torch.tensor(single_node_influence_per_hop, device=influence.device)\n\n\ndef avg_total_influence(\n    influence_all_nodes: Tensor,\n    normalize: bool = True,\n) -> Tensor:\n    \"\"\"Compute the *influence‑weighted receptive field* ``R``.\"\"\"\n    avg_total_influences = torch.mean(influence_all_nodes, dim=0)\n    if normalize:  # normalize by hop_0 (jacobian of the center node feature)\n        avg_total_influences = avg_total_influences / avg_total_influences[0]\n    return avg_total_influences\n\n\ndef influence_weighted_receptive_field(T: Tensor) -> float:\n    \"\"\"Compute the *influence‑weighted receptive field* ``R``.\n\n    Given an influence matrix ``T`` of shape ``[N, k+1]`` (i‑th row contains\n    the per‑hop influences of node *i*), the receptive field breadth *R* is\n    defined as the expected hop distance when weighting by influence.\n\n    A larger *R* indicates that, on average, influence comes from **farther**\n    hops.\n    \"\"\"\n    normalised = T / torch.sum(T, dim=1, keepdim=True)\n    hops = torch.arange(T.shape[1]).float()  # 0 … k\n    breadth = normalised @ hops  # shape (N,)\n    return breadth.mean().item()\n\n\ndef total_influence(\n    model: torch.nn.Module,\n    data: Data,\n    max_hops: int,\n    num_samples: Union[int, None] = None,\n    normalize: bool = True,\n    average: bool = True,\n    device: Union[torch.device, str] = \"cpu\",\n    vectorize: bool = True,\n) -> Tuple[Tensor, float]:\n    r\"\"\"Compute Jacobian‑based influence aggregates for *multiple* seed nodes,\n    as introduced in the\n    `\"Towards Quantifying Long-Range Interactions in Graph Machine Learning:\n    a Large Graph Dataset and a Measurement\"\n    <https://arxiv.org/abs/2503.09008>`_ paper.\n    This measurement quantifies how a GNN model's output at a node is\n    influenced by features of other nodes at increasing hop distances.\n\n    Specifically, for every sampled node :math:`v`, this method\n\n    1. evaluates the **L1‑norm** of the Jacobian of the model output at\n       :math:`v` w.r.t. the node features of its *k*-hop induced sub‑graph;\n    2. sums these scores **per hop** to obtain the influence vector\n       :math:`(I_{0}, I_{1}, \\dots, I_{k})`;\n    3. optionally averages those vectors over all sampled nodes and\n       optionally normalises them by :math:`I_{0}`.\n\n    Please refer to Section 4 of the paper for a more detailed definition.\n\n    Args:\n        model (torch.nn.Module): A PyTorch Geometric‑compatible model with\n            forward signature ``model(x, edge_index) -> Tensor``.\n        data (torch_geometric.data.Data): Graph data object providing at least\n            :obj:`x` (node features) and :obj:`edge_index` (connectivity).\n        max_hops (int): Maximum hop distance :math:`k`.\n        num_samples (int, optional): Number of random seed nodes to evaluate.\n            If :obj:`None`, all nodes are used. (default: :obj:`None`)\n        normalize (bool, optional): If :obj:`True`, normalize each hop‑wise\n            influence by the influence of hop 0. (default: :obj:`True`)\n        average (bool, optional): If :obj:`True`, return the hop‑wise **mean**\n            over all seed nodes (shape ``[k+1]``).\n            If :obj:`False`, return the full influence matrix of shape\n            ``[N, k+1]``. (default: :obj:`True`)\n        device (torch.device or str, optional): Device on which to perform the\n            computation. (default: :obj:`\"cpu\"`)\n        vectorize (bool, optional): Forwarded to\n            :func:`torch.autograd.functional.jacobian`.  Keeping this\n            :obj:`True` is often faster but increases memory usage.\n            (default: :obj:`True`)\n\n    Returns:\n        Tuple[Tensor, float]:\n            * **avg_influence** (*Tensor*):\n              shape ``[k+1]`` if :obj:`average=True`;\n              shape ``[N, k+1]`` otherwise.\n            * **R** (*float*): Influence‑weighted receptive‑field breadth\n              returned by :func:`influence_weighted_receptive_field`.\n\n    Example::\n        >>> avg_I, R = total_influence(model, data, max_hops=3,\n        ...                            num_samples=1000)\n        >>> avg_I\n        tensor([1.0000, 0.1273, 0.0142, 0.0019])\n        >>> R\n        0.216\n    \"\"\"\n    num_samples = data.num_nodes if num_samples is None else num_samples\n    num_nodes = cast(int, data.num_nodes)\n    nodes = torch.randperm(num_nodes)[:num_samples].tolist()\n\n    influence_all_nodes: List[Tensor] = [\n        jacobian_l1_agg_per_hop(model, data, max_hops, n, device,\n                                vectorize=vectorize)\n        for n in tqdm(nodes, desc=\"Influence\")\n    ]\n    allnodes = torch.vstack(influence_all_nodes).detach().cpu()\n\n    # Average total influence at each hop\n    if average:\n        avg_influence = avg_total_influence(allnodes, normalize=normalize)\n    else:\n        avg_influence = allnodes\n\n    # Influence‑weighted receptive field\n    R = influence_weighted_receptive_field(allnodes)\n\n    return avg_influence, R\n"
  },
  {
    "path": "torch_geometric/utils/isolated.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import remove_self_loops, segregate_self_loops\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef contains_isolated_nodes(\n    edge_index: Tensor,\n    num_nodes: Optional[int] = None,\n) -> bool:\n    r\"\"\"Returns :obj:`True` if the graph given by :attr:`edge_index` contains\n    isolated nodes.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n    :rtype: bool\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 0],\n        ...                            [1, 0, 0]])\n        >>> contains_isolated_nodes(edge_index)\n        False\n\n        >>> contains_isolated_nodes(edge_index, num_nodes=3)\n        True\n    \"\"\"\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n    edge_index, _ = remove_self_loops(edge_index)\n    return torch.unique(edge_index.view(-1)).numel() < num_nodes\n\n\ndef remove_isolated_nodes(\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor], Tensor]:\n    r\"\"\"Removes the isolated nodes from the graph given by :attr:`edge_index`\n    with optional edge attributes :attr:`edge_attr`.\n    In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter\n    out isolated node features later on.\n    Self-loops are preserved for non-isolated nodes.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional\n            edge features. (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n    :rtype: (LongTensor, Tensor, BoolTensor)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 0],\n        ...                            [1, 0, 0]])\n        >>> edge_index, edge_attr, mask = remove_isolated_nodes(edge_index)\n        >>> mask # node mask (2 nodes)\n        tensor([True, True])\n\n        >>> edge_index, edge_attr, mask = remove_isolated_nodes(edge_index,\n        ...                                                     num_nodes=3)\n        >>> mask # node mask (3 nodes)\n        tensor([True, True, False])\n    \"\"\"\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    out = segregate_self_loops(edge_index, edge_attr)\n    edge_index, edge_attr, loop_edge_index, loop_edge_attr = out\n\n    mask = torch.zeros(num_nodes, dtype=torch.bool, device=edge_index.device)\n    mask[edge_index.view(-1)] = 1\n\n    assoc = torch.full((num_nodes, ), -1, dtype=torch.long, device=mask.device)\n    assoc[mask] = torch.arange(mask.sum(), device=assoc.device)  # type: ignore\n    edge_index = assoc[edge_index]\n\n    loop_mask = torch.zeros_like(mask)\n    loop_mask[loop_edge_index[0]] = 1\n    loop_mask = loop_mask & mask\n    loop_assoc = torch.full_like(assoc, -1)\n    loop_assoc[loop_edge_index[0]] = torch.arange(loop_edge_index.size(1),\n                                                  device=loop_assoc.device)\n    loop_idx = loop_assoc[loop_mask]\n    loop_edge_index = assoc[loop_edge_index[:, loop_idx]]\n\n    edge_index = torch.cat([edge_index, loop_edge_index], dim=1)\n\n    if edge_attr is not None:\n        assert loop_edge_attr is not None\n        loop_edge_attr = loop_edge_attr[loop_idx]\n        edge_attr = torch.cat([edge_attr, loop_edge_attr], dim=0)\n\n    return edge_index, edge_attr, mask\n"
  },
  {
    "path": "torch_geometric/utils/laplacian.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import add_self_loops, remove_self_loops, scatter\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\n\ndef get_laplacian(\n    edge_index: Tensor,\n    edge_weight: OptTensor = None,\n    normalization: Optional[str] = None,\n    dtype: Optional[torch.dtype] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Computes the graph Laplacian of the graph given by :obj:`edge_index`\n    and optional :obj:`edge_weight`.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_weight (Tensor, optional): One-dimensional edge weights.\n            (default: :obj:`None`)\n        normalization (str, optional): The normalization scheme for the graph\n            Laplacian (default: :obj:`None`):\n\n            1. :obj:`None`: No normalization\n            :math:`\\mathbf{L} = \\mathbf{D} - \\mathbf{A}`\n\n            2. :obj:`\"sym\"`: Symmetric normalization\n            :math:`\\mathbf{L} = \\mathbf{I} - \\mathbf{D}^{-1/2} \\mathbf{A}\n            \\mathbf{D}^{-1/2}`\n\n            3. :obj:`\"rw\"`: Random-walk normalization\n            :math:`\\mathbf{L} = \\mathbf{I} - \\mathbf{D}^{-1} \\mathbf{A}`\n        dtype (torch.dtype, optional): The desired data type of returned tensor\n            in case :obj:`edge_weight=None`. (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2],\n        ...                            [1, 0, 2, 1]])\n        >>> edge_weight = torch.tensor([1., 2., 2., 4.])\n\n        >>> # No normalization\n        >>> lap = get_laplacian(edge_index, edge_weight)\n\n        >>> # Symmetric normalization\n        >>> lap_sym = get_laplacian(edge_index, edge_weight,\n                                    normalization='sym')\n\n        >>> # Random-walk normalization\n        >>> lap_rw = get_laplacian(edge_index, edge_weight, normalization='rw')\n    \"\"\"\n    if normalization is not None:\n        assert normalization in ['sym', 'rw']  # 'Invalid normalization'\n\n    edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)\n\n    if edge_weight is None:\n        edge_weight = torch.ones(edge_index.size(1), dtype=dtype,\n                                 device=edge_index.device)\n\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    row, col = edge_index[0], edge_index[1]\n    deg = scatter(edge_weight, row, 0, dim_size=num_nodes, reduce='sum')\n\n    if normalization is None:\n        # L = D - A.\n        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)\n        edge_weight = torch.cat([-edge_weight, deg], dim=0)\n    elif normalization == 'sym':\n        # Compute A_norm = -D^{-1/2} A D^{-1/2}.\n        deg_inv_sqrt = deg.pow_(-0.5)\n        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)\n        edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]\n\n        # L = I - A_norm.\n        assert isinstance(edge_weight, Tensor)\n        edge_index, edge_weight = add_self_loops(  #\n            edge_index, -edge_weight, fill_value=1., num_nodes=num_nodes)\n    else:\n        # Compute A_norm = -D^{-1} A.\n        deg_inv = 1.0 / deg\n        deg_inv.masked_fill_(deg_inv == float('inf'), 0)\n        edge_weight = deg_inv[row] * edge_weight\n\n        # L = I - A_norm.\n        assert isinstance(edge_weight, Tensor)\n        edge_index, edge_weight = add_self_loops(  #\n            edge_index, -edge_weight, fill_value=1., num_nodes=num_nodes)\n\n    return edge_index, edge_weight\n"
  },
  {
    "path": "torch_geometric/utils/loop.py",
    "content": "import typing\nfrom typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.utils import scatter\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\nfrom torch_geometric.utils.sparse import (\n    is_torch_sparse_tensor,\n    to_edge_index,\n    to_torch_coo_tensor,\n    to_torch_csr_tensor,\n)\n\nif typing.TYPE_CHECKING:\n    from typing import overload\nelse:\n    from torch.jit import _overload as overload\n\n\ndef contains_self_loops(edge_index: Tensor) -> bool:\n    r\"\"\"Returns :obj:`True` if the graph given by :attr:`edge_index` contains\n    self-loops.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n\n    :rtype: bool\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 0],\n        ...                            [1, 0, 0]])\n        >>> contains_self_loops(edge_index)\n        True\n\n        >>> edge_index = torch.tensor([[0, 1, 1],\n        ...                            [1, 0, 2]])\n        >>> contains_self_loops(edge_index)\n        False\n    \"\"\"\n    mask = edge_index[0] == edge_index[1]\n    return mask.sum().item() > 0\n\n\n@overload\ndef remove_self_loops(\n    edge_index: Tensor,\n    edge_attr: None = None,\n) -> Tuple[Tensor, None]:\n    pass\n\n\n@overload\ndef remove_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef remove_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\ndef remove_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    r\"\"\"Removes every self-loop in the graph given by :attr:`edge_index`, so\n    that :math:`(i,i) \\not\\in \\mathcal{E}` for every :math:`i \\in \\mathcal{V}`.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional\n            edge features. (default: :obj:`None`)\n\n    :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 1, 0],\n        ...                            [1, 0, 0]])\n        >>> edge_attr = [[1, 2], [3, 4], [5, 6]]\n        >>> edge_attr = torch.tensor(edge_attr)\n        >>> remove_self_loops(edge_index, edge_attr)\n        (tensor([[0, 1],\n                [1, 0]]),\n        tensor([[1, 2],\n                [3, 4]]))\n    \"\"\"\n    size: Optional[Tuple[int, int]] = None\n    if not typing.TYPE_CHECKING and torch.jit.is_scripting():\n        layout: Optional[int] = None\n    else:\n        layout: Optional[torch.layout] = None\n\n    value: Optional[Tensor] = None\n    if is_torch_sparse_tensor(edge_index):\n        layout = edge_index.layout\n        size = (edge_index.size(0), edge_index.size(1))\n        edge_index, value = to_edge_index(edge_index)\n\n    is_undirected = False\n    if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        is_undirected = edge_index.is_undirected\n\n    mask = edge_index[0] != edge_index[1]\n    edge_index = edge_index[:, mask]\n\n    if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        edge_index._is_undirected = is_undirected\n\n    if layout is not None:\n        assert edge_attr is None\n        assert value is not None\n        value = value[mask]\n        if str(layout) == 'torch.sparse_coo':  # str(...) for TorchScript :(\n            return to_torch_coo_tensor(edge_index, value, size, True), None\n        elif str(layout) == 'torch.sparse_csr':\n            return to_torch_csr_tensor(edge_index, value, size, True), None\n        raise ValueError(f\"Unexpected sparse tensor layout (got '{layout}')\")\n\n    if edge_attr is None:\n        return edge_index, None\n    else:\n        return edge_index, edge_attr[mask]\n\n\n@overload\ndef segregate_self_loops(\n    edge_index: Tensor,\n    edge_attr: None = None,\n) -> Tuple[Tensor, None, Tensor, None]:\n    pass\n\n\n@overload\ndef segregate_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n) -> Tuple[Tensor, Tensor, Tensor, Tensor]:\n    pass\n\n\n@overload\ndef segregate_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n) -> Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]:\n    pass\n\n\ndef segregate_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n) -> Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]:\n    r\"\"\"Segregates self-loops from the graph.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional\n            edge features. (default: :obj:`None`)\n\n    :rtype: (:class:`LongTensor`, :class:`Tensor`, :class:`LongTensor`,\n        :class:`Tensor`)\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 0, 1],\n        ...                            [0, 1, 0]])\n        >>> (edge_index, edge_attr,\n        ...  loop_edge_index,\n        ...  loop_edge_attr) = segregate_self_loops(edge_index)\n        >>>  loop_edge_index\n        tensor([[0],\n                [0]])\n    \"\"\"\n    mask = edge_index[0] != edge_index[1]\n    inv_mask = ~mask\n\n    is_undirected = False\n    if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        is_undirected = edge_index.is_undirected\n\n    loop_edge_index = edge_index[:, inv_mask]\n    loop_edge_attr = None if edge_attr is None else edge_attr[inv_mask]\n    edge_index = edge_index[:, mask]\n    edge_attr = None if edge_attr is None else edge_attr[mask]\n\n    if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        assert isinstance(loop_edge_index, EdgeIndex)\n        edge_index._is_undirected = is_undirected\n        loop_edge_index._is_undirected = is_undirected\n\n    return edge_index, edge_attr, loop_edge_index, loop_edge_attr\n\n\n@overload\ndef add_self_loops(\n    edge_index: Tensor,\n    edge_attr: None = None,\n    fill_value: Optional[float] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: None = None,\n    fill_value: Optional[float] = None,\n    num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, None]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: None = None,\n    fill_value: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: None = None,\n    fill_value: Optional[Tensor] = None,\n    num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, None]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: None = None,\n    fill_value: Optional[str] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: None = None,\n    fill_value: Optional[str] = None,\n    num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, None]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    fill_value: Optional[float] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    fill_value: Optional[float] = None,\n    num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    fill_value: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    fill_value: Optional[Tensor] = None,\n    num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    fill_value: Optional[str] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    fill_value: Optional[str] = None,\n    num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    fill_value: Optional[float] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    fill_value: Optional[float] = None,\n    num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    fill_value: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    fill_value: Optional[Tensor] = None,\n    num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    fill_value: Optional[str] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\n@overload\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    fill_value: Optional[str] = None,\n    num_nodes: Optional[Tuple[int, int]] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\ndef add_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n    fill_value: Optional[Union[float, Tensor, str]] = None,\n    num_nodes: Optional[Union[int, Tuple[int, int]]] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    r\"\"\"Adds a self-loop :math:`(i,i) \\in \\mathcal{E}` to every node\n    :math:`i \\in \\mathcal{V}` in the graph given by :attr:`edge_index`.\n    In case the graph is weighted or has multi-dimensional edge features\n    (:obj:`edge_attr != None`), edge features of self-loops will be added\n    according to :obj:`fill_value`.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional edge\n            features. (default: :obj:`None`)\n        fill_value (float or Tensor or str, optional): The way to generate\n            edge features of self-loops (in case :obj:`edge_attr != None`).\n            If given as :obj:`float` or :class:`torch.Tensor`, edge features of\n            self-loops will be directly given by :obj:`fill_value`.\n            If given as :obj:`str`, edge features of self-loops are computed by\n            aggregating all features of edges that point to the specific node,\n            according to a reduce operation. (:obj:`\"add\"`, :obj:`\"mean\"`,\n            :obj:`\"min\"`, :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`1.`)\n        num_nodes (int or Tuple[int, int], optional): The number of nodes,\n            *i.e.* :obj:`max_val + 1` of :attr:`edge_index`.\n            If given as a tuple, then :obj:`edge_index` is interpreted as a\n            bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`.\n            (default: :obj:`None`)\n\n    :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 0],\n        ...                            [1, 0, 0]])\n        >>> edge_weight = torch.tensor([0.5, 0.5, 0.5])\n        >>> add_self_loops(edge_index)\n        (tensor([[0, 1, 0, 0, 1],\n                [1, 0, 0, 0, 1]]),\n        None)\n\n        >>> add_self_loops(edge_index, edge_weight)\n        (tensor([[0, 1, 0, 0, 1],\n                [1, 0, 0, 0, 1]]),\n        tensor([0.5000, 0.5000, 0.5000, 1.0000, 1.0000]))\n\n        >>> # edge features of self-loops are filled by constant `2.0`\n        >>> add_self_loops(edge_index, edge_weight,\n        ...                fill_value=2.)\n        (tensor([[0, 1, 0, 0, 1],\n                [1, 0, 0, 0, 1]]),\n        tensor([0.5000, 0.5000, 0.5000, 2.0000, 2.0000]))\n\n        >>> # Use 'add' operation to merge edge features for self-loops\n        >>> add_self_loops(edge_index, edge_weight,\n        ...                fill_value='add')\n        (tensor([[0, 1, 0, 0, 1],\n                [1, 0, 0, 0, 1]]),\n        tensor([0.5000, 0.5000, 0.5000, 1.0000, 0.5000]))\n    \"\"\"\n    if not typing.TYPE_CHECKING and torch.jit.is_scripting():\n        layout: Optional[int] = None\n    else:\n        layout: Optional[torch.layout] = None\n    is_sparse = is_torch_sparse_tensor(edge_index)\n\n    value: Optional[Tensor] = None\n    if is_sparse:\n        assert edge_attr is None\n        layout = edge_index.layout\n        size = (edge_index.size(0), edge_index.size(1))\n        N = min(size)\n        edge_index, value = to_edge_index(edge_index)\n    elif isinstance(num_nodes, (tuple, list)):\n        size = (num_nodes[0], num_nodes[1])\n        N = min(size)\n    else:\n        N = maybe_num_nodes(edge_index, num_nodes)\n        size = (N, N)\n\n    device = edge_index.device\n    if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        loop_index: Tensor = EdgeIndex(\n            torch.arange(0, N, device=device).view(1, -1).repeat(2, 1),\n            sparse_size=(N, N),\n            is_undirected=True,\n        )\n    else:\n        loop_index = torch.arange(0, N, device=device).view(1, -1).repeat(2, 1)\n\n    full_edge_index = torch.cat([edge_index, loop_index], dim=1)\n\n    if is_sparse:\n        assert edge_attr is None\n        assert value is not None\n        loop_attr = compute_loop_attr(  #\n            edge_index, value, N, is_sparse, fill_value)\n        value = torch.cat([value, loop_attr], dim=0)\n\n        if str(layout) == 'torch.sparse_coo':  # str(...) for TorchScript :(\n            return to_torch_coo_tensor(full_edge_index, value, size), None\n        elif str(layout) == 'torch.sparse_csr':\n            return to_torch_csr_tensor(full_edge_index, value, size), None\n        raise ValueError(f\"Unexpected sparse tensor layout (got '{layout}')\")\n\n    if edge_attr is not None:\n        loop_attr = compute_loop_attr(  #\n            edge_index, edge_attr, N, is_sparse, fill_value)\n        edge_attr = torch.cat([edge_attr, loop_attr], dim=0)\n\n    return full_edge_index, edge_attr\n\n\n@overload\ndef add_remaining_self_loops(\n    edge_index: Tensor,\n    edge_attr: None = None,\n    fill_value: Optional[float] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n    pass\n\n\n@overload\ndef add_remaining_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: None = None,\n    fill_value: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n    pass\n\n\n@overload\ndef add_remaining_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: None = None,\n    fill_value: Optional[str] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, None]:\n    pass\n\n\n@overload\ndef add_remaining_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    fill_value: Optional[float] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef add_remaining_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    fill_value: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef add_remaining_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    fill_value: Optional[str] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef add_remaining_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    fill_value: Optional[float] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\n@overload\ndef add_remaining_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    fill_value: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\n@overload\ndef add_remaining_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    fill_value: Optional[str] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\ndef add_remaining_self_loops(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n    fill_value: Optional[Union[float, Tensor, str]] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    r\"\"\"Adds remaining self-loop :math:`(i,i) \\in \\mathcal{E}` to every node\n    :math:`i \\in \\mathcal{V}` in the graph given by :attr:`edge_index`.\n    In case the graph is weighted or has multi-dimensional edge features\n    (:obj:`edge_attr != None`), edge features of non-existing self-loops will\n    be added according to :obj:`fill_value`.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional edge\n            features. (default: :obj:`None`)\n        fill_value (float or Tensor or str, optional): The way to generate\n            edge features of self-loops (in case :obj:`edge_attr != None`).\n            If given as :obj:`float` or :class:`torch.Tensor`, edge features of\n            self-loops will be directly given by :obj:`fill_value`.\n            If given as :obj:`str`, edge features of self-loops are computed by\n            aggregating all features of edges that point to the specific node,\n            according to a reduce operation. (:obj:`\"add\"`, :obj:`\"mean\"`,\n            :obj:`\"min\"`, :obj:`\"max\"`, :obj:`\"mul\"`). (default: :obj:`1.`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n    :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 1],\n        ...                            [1, 0]])\n        >>> edge_weight = torch.tensor([0.5, 0.5])\n        >>> add_remaining_self_loops(edge_index, edge_weight)\n        (tensor([[0, 1, 0, 1],\n                [1, 0, 0, 1]]),\n        tensor([0.5000, 0.5000, 1.0000, 1.0000]))\n    \"\"\"\n    N = maybe_num_nodes(edge_index, num_nodes)\n    mask = edge_index[0] != edge_index[1]\n\n    device = edge_index.device\n    if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        loop_index: Tensor = EdgeIndex(\n            torch.arange(0, N, device=device).view(1, -1).repeat(2, 1),\n            sparse_size=(N, N),\n            is_undirected=True,\n        )\n    else:\n        loop_index = torch.arange(0, N, device=device).view(1, -1).repeat(2, 1)\n\n    if edge_attr is not None:\n\n        loop_attr = compute_loop_attr(  #\n            edge_index, edge_attr, N, False, fill_value)\n\n        inv_mask = ~mask\n        loop_attr[edge_index[0][inv_mask]] = edge_attr[inv_mask]\n\n        edge_attr = torch.cat([edge_attr[mask], loop_attr], dim=0)\n\n    is_undirected = False\n    if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        is_undirected = edge_index.is_undirected\n\n    edge_index = edge_index[:, mask]\n\n    if not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        edge_index._is_undirected = is_undirected\n\n    edge_index = torch.cat([edge_index, loop_index], dim=1)\n\n    return edge_index, edge_attr\n\n\ndef get_self_loop_attr(\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Tensor:\n    r\"\"\"Returns the edge features or weights of self-loops\n    :math:`(i, i)` of every node :math:`i \\in \\mathcal{V}` in the\n    graph given by :attr:`edge_index`. Edge features of missing self-loops not\n    present in :attr:`edge_index` will be filled with zeros. If\n    :attr:`edge_attr` is not given, it will be the vector of ones.\n\n    .. note::\n        This operation is analogous to getting the diagonal elements of the\n        dense adjacency matrix.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): Edge weights or multi-dimensional edge\n            features. (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`)\n\n    :rtype: :class:`Tensor`\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 0],\n        ...                            [1, 0, 0]])\n        >>> edge_weight = torch.tensor([0.2, 0.3, 0.5])\n        >>> get_self_loop_attr(edge_index, edge_weight)\n        tensor([0.5000, 0.0000])\n\n        >>> get_self_loop_attr(edge_index, edge_weight, num_nodes=4)\n        tensor([0.5000, 0.0000, 0.0000, 0.0000])\n    \"\"\"\n    loop_mask = edge_index[0] == edge_index[1]\n    loop_index = edge_index[0][loop_mask]\n\n    if edge_attr is not None:\n        loop_attr = edge_attr[loop_mask]\n    else:  # A vector of ones:\n        loop_attr = torch.ones(loop_index.numel(), device=edge_index.device)\n\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n    full_loop_attr = loop_attr.new_zeros((num_nodes, ) + loop_attr.size()[1:])\n    full_loop_attr[loop_index] = loop_attr\n\n    return full_loop_attr\n\n\n@overload\ndef compute_loop_attr(\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    num_nodes: int,\n    is_sparse: bool,\n    fill_value: Optional[float] = None,\n) -> Tensor:\n    pass\n\n\n@overload\ndef compute_loop_attr(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    num_nodes: int,\n    is_sparse: bool,\n    fill_value: Optional[Tensor] = None,\n) -> Tensor:\n    pass\n\n\n@overload\ndef compute_loop_attr(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    num_nodes: int,\n    is_sparse: bool,\n    fill_value: Optional[str] = None,\n) -> Tensor:\n    pass\n\n\ndef compute_loop_attr(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    num_nodes: int,\n    is_sparse: bool,\n    fill_value: Optional[Union[float, Tensor, str]] = None,\n) -> Tensor:\n\n    if fill_value is None:\n        size = (num_nodes, ) + edge_attr.size()[1:]\n        return edge_attr.new_ones(size)\n\n    elif isinstance(fill_value, (int, float)):\n        size = (num_nodes, ) + edge_attr.size()[1:]\n        return edge_attr.new_full(size, fill_value)\n\n    elif isinstance(fill_value, Tensor):\n        size = (num_nodes, ) + edge_attr.size()[1:]\n        loop_attr = fill_value.to(edge_attr.device, edge_attr.dtype)\n        if edge_attr.dim() != loop_attr.dim():\n            loop_attr = loop_attr.unsqueeze(0)\n        return loop_attr.expand(size).contiguous()\n\n    elif isinstance(fill_value, str):\n        col = edge_index[0] if is_sparse else edge_index[1]\n        return scatter(edge_attr, col, 0, num_nodes, fill_value)\n\n    raise AttributeError(\"No valid 'fill_value' provided\")\n"
  },
  {
    "path": "torch_geometric/utils/map.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\nfrom torch.utils.dlpack import from_dlpack\n\nfrom torch_geometric.warnings import WarningCache\n\n_warning_cache = WarningCache()\n\n\ndef map_index(\n    src: Tensor,\n    index: Tensor,\n    max_index: Optional[Union[int, Tensor]] = None,\n    inclusive: bool = False,\n) -> Tuple[Tensor, Optional[Tensor]]:\n    r\"\"\"Maps indices in :obj:`src` to the positional value of their\n    corresponding occurrence in :obj:`index`.\n    Indices must be strictly positive.\n\n    Args:\n        src (torch.Tensor): The source tensor to map.\n        index (torch.Tensor): The index tensor that denotes the new mapping.\n        max_index (int, optional): The maximum index value.\n            (default :obj:`None`)\n        inclusive (bool, optional): If set to :obj:`True`, it is assumed that\n            every entry in :obj:`src` has a valid entry in :obj:`index`.\n            Can speed-up computation. (default: :obj:`False`)\n\n    :rtype: (:class:`torch.Tensor`, :class:`torch.BoolTensor`)\n\n    Examples:\n        >>> src = torch.tensor([2, 0, 1, 0, 3])\n        >>> index = torch.tensor([3, 2, 0, 1])\n\n        >>> map_index(src, index)\n        (tensor([1, 2, 3, 2, 0]), tensor([True, True, True, True, True]))\n\n        >>> src = torch.tensor([2, 0, 1, 0, 3])\n        >>> index = torch.tensor([3, 2, 0])\n\n        >>> map_index(src, index)\n        (tensor([1, 2, 2, 0]), tensor([True, True, False, True, True]))\n\n    .. note::\n\n        If inputs are on GPU and :obj:`cudf` is available, consider using RMM\n        for significant speed boosts.\n        Proceed with caution as RMM may conflict with other allocators or\n        fragments.\n\n        .. code-block:: python\n\n            import rmm\n            rmm.reinitialize(pool_allocator=True)\n            torch.cuda.memory.change_current_allocator(rmm.rmm_torch_allocator)\n    \"\"\"\n    if src.is_floating_point():\n        raise ValueError(f\"Expected 'src' to be an index (got '{src.dtype}')\")\n    if index.is_floating_point():\n        raise ValueError(f\"Expected 'index' to be an index (got \"\n                         f\"'{index.dtype}')\")\n    if src.device != index.device:\n        raise ValueError(f\"Both 'src' and 'index' must be on the same device \"\n                         f\"(got '{src.device}' and '{index.device}')\")\n\n    if max_index is None:\n        max_index = torch.maximum(src.max(), index.max())\n\n    # If the `max_index` is in a reasonable range, we can accelerate this\n    # operation by creating a helper vector to perform the mapping.\n    # NOTE This will potentially consumes a large chunk of memory\n    # (max_index=10 million => ~75MB), so we cap it at a reasonable size:\n    THRESHOLD = 40_000_000 if src.is_cuda else 10_000_000\n    if max_index <= THRESHOLD:\n        if inclusive:\n            assoc = src.new_empty(max_index + 1)  # type: ignore\n        else:\n            assoc = src.new_full((max_index + 1, ), -1)  # type: ignore\n        assoc[index] = torch.arange(index.numel(), dtype=src.dtype,\n                                    device=src.device)\n        out = assoc[src]\n\n        if inclusive:\n            return out, None\n        else:\n            mask = out != -1\n            return out[mask], mask\n\n    WITH_CUDF = False\n    if src.is_cuda:\n        try:\n            import cudf\n            WITH_CUDF = True\n        except ImportError:\n            import pandas as pd\n            _warning_cache.warn(\"Using CPU-based processing within \"\n                                \"'map_index' which may cause slowdowns and \"\n                                \"device synchronization. Consider installing \"\n                                \"'cudf' to accelerate computation\")\n    else:\n        import pandas as pd\n\n    if not WITH_CUDF:\n        left_ser = pd.Series(src.cpu().numpy(), name='left_ser')\n        right_ser = pd.Series(\n            index=index.cpu().numpy(),\n            data=pd.RangeIndex(0, index.size(0)),\n            name='right_ser',\n        )\n\n        result = pd.merge(left_ser, right_ser, how='left', left_on='left_ser',\n                          right_index=True)\n\n        out_numpy = result['right_ser'].values\n        if (index.device.type == 'mps'  # MPS does not support `float64`\n                and issubclass(out_numpy.dtype.type, np.floating)):\n            out_numpy = out_numpy.astype(np.float32)\n\n        out = torch.from_numpy(out_numpy).to(index.device)\n\n        if out.is_floating_point() and inclusive:\n            raise ValueError(\"Found invalid entries in 'src' that do not have \"\n                             \"a corresponding entry in 'index'. Set \"\n                             \"`inclusive=False` to ignore these entries.\")\n\n        if out.is_floating_point():\n            mask = torch.isnan(out).logical_not_()\n            out = out[mask].to(index.dtype)\n            return out, mask\n\n        if inclusive:\n            return out, None\n        else:\n            mask = out != -1\n            return out[mask], mask\n\n    else:\n        left_ser = cudf.Series(src, name='left_ser')\n        right_ser = cudf.Series(\n            index=index,\n            data=cudf.RangeIndex(0, index.size(0)),\n            name='right_ser',\n        )\n\n        result = cudf.merge(left_ser, right_ser, how='left',\n                            left_on='left_ser', right_index=True, sort=True)\n\n        if inclusive:\n            try:\n                out = from_dlpack(result['right_ser'].to_dlpack())\n            except ValueError as e:\n                raise ValueError(\n                    \"Found invalid entries in 'src' that do not \"\n                    \"have a corresponding entry in 'index'. Set \"\n                    \"`inclusive=False` to ignore these entries.\") from e\n        else:\n            out = from_dlpack(result['right_ser'].fillna(-1).to_dlpack())\n\n        out = out[src.argsort().argsort()]  # Restore original order.\n\n        if inclusive:\n            return out, None\n        else:\n            mask = out != -1\n            return out[mask], mask\n"
  },
  {
    "path": "torch_geometric/utils/mask.py",
    "content": "from typing import Optional\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import TensorFrame\n\n\ndef mask_select(src: Tensor, dim: int, mask: Tensor) -> Tensor:\n    r\"\"\"Returns a new tensor which masks the :obj:`src` tensor along the\n    dimension :obj:`dim` according to the boolean mask :obj:`mask`.\n\n    Args:\n        src (torch.Tensor): The input tensor.\n        dim (int): The dimension in which to mask.\n        mask (torch.BoolTensor): The 1-D tensor containing the binary mask to\n            index with.\n    \"\"\"\n    assert mask.dim() == 1\n\n    if not torch.jit.is_scripting():\n        if isinstance(src, TensorFrame):\n            assert dim == 0 and src.num_rows == mask.numel()\n            return src[mask]\n\n    assert src.size(dim) == mask.numel()\n    dim = dim + src.dim() if dim < 0 else dim\n    assert dim >= 0 and dim < src.dim()\n\n    # Applying a 1-dimensional mask in the first dimension is significantly\n    # faster than broadcasting the mask and utilizing `masked_select`.\n    # As such, we transpose in the first dimension, perform the masking, and\n    # then transpose back to the original shape.\n    src = src.transpose(0, dim) if dim != 0 else src\n    out = src[mask]\n    out = out.transpose(0, dim) if dim != 0 else out\n\n    return out\n\n\ndef index_to_mask(index: Tensor, size: Optional[int] = None) -> Tensor:\n    r\"\"\"Converts indices to a mask representation.\n\n    Args:\n        index (Tensor): The indices.\n        size (int, optional): The size of the mask. If set to :obj:`None`, a\n            minimal sized output mask is returned.\n\n    Example:\n        >>> index = torch.tensor([1, 3, 5])\n        >>> index_to_mask(index)\n        tensor([False,  True, False,  True, False,  True])\n\n        >>> index_to_mask(index, size=7)\n        tensor([False,  True, False,  True, False,  True, False])\n    \"\"\"\n    index = index.view(-1)\n    size = int(index.max()) + 1 if size is None else size\n    mask = index.new_zeros(size, dtype=torch.bool)\n    mask[index] = True\n    return mask\n\n\ndef mask_to_index(mask: Tensor) -> Tensor:\n    r\"\"\"Converts a mask to an index representation.\n\n    Args:\n        mask (Tensor): The mask.\n\n    Example:\n        >>> mask = torch.tensor([False, True, False])\n        >>> mask_to_index(mask)\n        tensor([1])\n    \"\"\"\n    return mask.nonzero(as_tuple=False).view(-1)\n"
  },
  {
    "path": "torch_geometric/utils/mesh_laplacian.py",
    "content": "from typing import Optional, Tuple\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import add_self_loops, scatter, to_undirected\n\n\ndef get_mesh_laplacian(\n    pos: Tensor,\n    face: Tensor,\n    normalization: Optional[str] = None,\n) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Computes the mesh Laplacian of a mesh given by :obj:`pos` and\n    :obj:`face`.\n\n    Computation is based on the cotangent matrix defined as\n\n    .. math::\n        \\mathbf{C}_{ij} = \\begin{cases}\n            \\frac{\\cot \\angle_{ikj}~+\\cot \\angle_{ilj}}{2} &\n            \\text{if } i, j \\text{ is an edge} \\\\\n            -\\sum_{j \\in N(i)}{C_{ij}} &\n            \\text{if } i \\text{ is in the diagonal} \\\\\n            0 & \\text{otherwise}\n      \\end{cases}\n\n    Normalization depends on the mass matrix defined as\n\n    .. math::\n        \\mathbf{M}_{ij} = \\begin{cases}\n            a(i) & \\text{if } i \\text{ is in the diagonal} \\\\\n            0 & \\text{otherwise}\n      \\end{cases}\n\n    where :math:`a(i)` is obtained by joining the barycenters of the\n    triangles around vertex :math:`i`.\n\n    Args:\n        pos (Tensor): The node positions.\n        face (LongTensor): The face indices.\n        normalization (str, optional): The normalization scheme for the mesh\n            Laplacian (default: :obj:`None`):\n\n            1. :obj:`None`: No normalization\n            :math:`\\mathbf{L} = \\mathbf{C}`\n\n            2. :obj:`\"sym\"`: Symmetric normalization\n            :math:`\\mathbf{L} = \\mathbf{M}^{-1/2} \\mathbf{C}\\mathbf{M}^{-1/2}`\n\n            3. :obj:`\"rw\"`: Row-wise normalization\n            :math:`\\mathbf{L} = \\mathbf{M}^{-1} \\mathbf{C}`\n    \"\"\"\n    assert pos.size(1) == 3 and face.size(0) == 3\n\n    num_nodes = pos.shape[0]\n\n    def get_cots(left: Tensor, centre: Tensor, right: Tensor) -> Tensor:\n        left_pos, central_pos, right_pos = pos[left], pos[centre], pos[right]\n        left_vec = left_pos - central_pos\n        right_vec = right_pos - central_pos\n        dot = torch.einsum('ij, ij -> i', left_vec, right_vec)\n        cross = torch.norm(torch.cross(left_vec, right_vec, dim=1), dim=1)\n        cot = dot / cross  # cot = cos / sin\n        return cot / 2.0  # by definition\n\n    # For each triangle face, get all three cotangents:\n    cot_021 = get_cots(face[0], face[2], face[1])\n    cot_102 = get_cots(face[1], face[0], face[2])\n    cot_012 = get_cots(face[0], face[1], face[2])\n    cot_weight = torch.cat([cot_021, cot_102, cot_012])\n\n    # Face to edge:\n    cot_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)\n    cot_index, cot_weight = to_undirected(cot_index, cot_weight)\n\n    # Compute the diagonal part:\n    cot_deg = scatter(cot_weight, cot_index[0], 0, num_nodes, reduce='sum')\n    edge_index, _ = add_self_loops(cot_index, num_nodes=num_nodes)\n    edge_weight = torch.cat([cot_weight, -cot_deg], dim=0)\n\n    if normalization is not None:\n\n        def get_areas(left: Tensor, centre: Tensor, right: Tensor) -> Tensor:\n            central_pos = pos[centre]\n            left_vec = pos[left] - central_pos\n            right_vec = pos[right] - central_pos\n            cross = torch.norm(torch.cross(left_vec, right_vec, dim=1), dim=1)\n            area = cross / 6.0  # one-third of a triangle's area is cross / 6.0\n            return area / 2.0  # since each corresponding area is counted twice\n\n        # Like before, but here we only need the diagonal (the mass matrix):\n        area_021 = get_areas(face[0], face[2], face[1])\n        area_102 = get_areas(face[1], face[0], face[2])\n        area_012 = get_areas(face[0], face[1], face[2])\n        area_weight = torch.cat([area_021, area_102, area_012])\n        area_index = torch.cat([face[:2], face[1:], face[::2]], dim=1)\n        area_index, area_weight = to_undirected(area_index, area_weight)\n        area_deg = scatter(area_weight, area_index[0], 0, num_nodes, 'sum')\n\n        if normalization == 'sym':\n            area_deg_inv_sqrt = area_deg.pow_(-0.5)\n            area_deg_inv_sqrt[area_deg_inv_sqrt == float('inf')] = 0.0\n            edge_weight = (area_deg_inv_sqrt[edge_index[0]] * edge_weight *\n                           area_deg_inv_sqrt[edge_index[1]])\n        elif normalization == 'rw':\n            area_deg_inv = 1.0 / area_deg\n            area_deg_inv[area_deg_inv == float('inf')] = 0.0\n            edge_weight = area_deg_inv[edge_index[0]] * edge_weight\n\n    return edge_index, edge_weight\n"
  },
  {
    "path": "torch_geometric/utils/mixin.py",
    "content": "from typing import Any, Iterator, TypeVar\n\nT = TypeVar('T')\n\n\nclass CastMixin:\n    @classmethod\n    def cast(cls: T, *args: Any, **kwargs: Any) -> T:\n        if len(args) == 1 and len(kwargs) == 0:\n            elem = args[0]\n            if elem is None:\n                return None  # type: ignore\n            if isinstance(elem, CastMixin):\n                return elem  # type: ignore\n            if isinstance(elem, tuple):\n                return cls(*elem)  # type: ignore\n            if isinstance(elem, dict):\n                return cls(**elem)  # type: ignore\n        return cls(*args, **kwargs)  # type: ignore\n\n    def __iter__(self) -> Iterator:\n        return iter(self.__dict__.values())\n"
  },
  {
    "path": "torch_geometric/utils/nested.py",
    "content": "from typing import Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.utils import scatter\n\n\ndef to_nested_tensor(\n    x: Tensor,\n    batch: Optional[Tensor] = None,\n    ptr: Optional[Tensor] = None,\n    batch_size: Optional[int] = None,\n) -> Tensor:\n    r\"\"\"Given a contiguous batch of tensors\n    :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times *}`\n    (with :math:`N_i` indicating the number of elements in example :math:`i`),\n    creates a `nested PyTorch tensor\n    <https://pytorch.org/docs/stable/nested.html>`__.\n    Reverse operation of :meth:`from_nested_tensor`.\n\n    Args:\n        x (torch.Tensor): The input tensor\n            :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times *}`.\n        batch (torch.Tensor, optional): The batch vector\n            :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`, which assigns each\n            element to a specific example. Must be ordered.\n            (default: :obj:`None`)\n        ptr (torch.Tensor, optional): Alternative representation of\n            :obj:`batch` in compressed format. (default: :obj:`None`)\n        batch_size (int, optional): The batch size :math:`B`.\n            (default: :obj:`None`)\n    \"\"\"\n    if ptr is not None:\n        offsets = ptr[1:] - ptr[:-1]\n        sizes = offsets.tolist()\n        xs = list(torch.split(x, sizes, dim=0))\n    elif batch is not None:\n        offsets = scatter(torch.ones_like(batch), batch, dim_size=batch_size)\n        sizes = offsets.tolist()\n        xs = list(torch.split(x, sizes, dim=0))\n    else:\n        xs = [x]\n\n    # This currently copies the data, although `x` is already contiguous.\n    # Sadly, there does not exist any (public) API to prevent this :(\n    return torch.nested.as_nested_tensor(xs)\n\n\ndef from_nested_tensor(\n    x: Tensor,\n    return_batch: bool = False,\n) -> Union[Tensor, Tuple[Tensor, Tensor]]:\n    r\"\"\"Given a `nested PyTorch tensor\n    <https://pytorch.org/docs/stable/nested.html>`__, creates a contiguous\n    batch of tensors\n    :math:`\\mathbf{X} \\in \\mathbb{R}^{(N_1 + \\ldots + N_B) \\times *}`, and\n    optionally a batch vector which assigns each element to a specific example.\n    Reverse operation of :meth:`to_nested_tensor`.\n\n    Args:\n        x (torch.Tensor): The nested input tensor. The size of nested tensors\n            need to match except for the first dimension.\n        return_batch (bool, optional): If set to :obj:`True`, will also return\n            the batch vector :math:`\\mathbf{b} \\in {\\{ 0, \\ldots, B-1\\}}^N`.\n            (default: :obj:`False`)\n    \"\"\"\n    if not x.is_nested:\n        raise ValueError(\"Input tensor in 'from_nested_tensor' is not nested\")\n\n    sizes = x._nested_tensor_size()\n\n    for dim, (a, b) in enumerate(zip(sizes[0, 1:], sizes.t()[1:])):\n        if not torch.equal(a.expand_as(b), b):\n            raise ValueError(f\"Not all nested tensors have the same size \"\n                             f\"in dimension {dim + 1} \"\n                             f\"(expected size {a.item()} for all tensors)\")\n\n    out = x.contiguous().values()\n    out = out.view(-1, *sizes[0, 1:].tolist())\n\n    if not return_batch:\n        return out\n\n    batch = torch.arange(x.size(0), device=x.device)\n    batch = batch.repeat_interleave(sizes[:, 0].to(batch.device))\n\n    return out, batch\n"
  },
  {
    "path": "torch_geometric/utils/noise_scheduler.py",
    "content": "import math\nfrom typing import Literal, Optional\n\nimport torch\nfrom torch import Tensor\n\n\ndef get_smld_sigma_schedule(\n    sigma_min: float,\n    sigma_max: float,\n    num_scales: int,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n) -> Tensor:\n    r\"\"\"Generates a set of noise values on a logarithmic scale for \"Score\n    Matching with Langevin Dynamics\" from the `\"Generative Modeling by\n    Estimating Gradients of the Data Distribution\"\n    <https://arxiv.org/abs/1907.05600>`_ paper.\n\n    This function returns a vector of sigma values that define the schedule of\n    noise levels used during Score Matching with Langevin Dynamics.\n    The sigma values are determined on a logarithmic scale from\n    :obj:`sigma_max` to :obj:`sigma_min`, inclusive.\n\n    Args:\n        sigma_min (float): The minimum value of sigma, corresponding to the\n            lowest noise level.\n        sigma_max (float): The maximum value of sigma, corresponding to the\n            highest noise level.\n        num_scales (int): The number of sigma values to generate, defining the\n            granularity of the noise schedule.\n        dtype (torch.dtype, optional): The output data type.\n            (default: :obj:`None`)\n        device (torch.device, optional): The output device.\n            (default: :obj:`None`)\n    \"\"\"\n    return torch.linspace(\n        math.log(sigma_max),\n        math.log(sigma_min),\n        num_scales,\n        dtype=dtype,\n        device=device,\n    ).exp()\n\n\ndef get_diffusion_beta_schedule(\n    schedule_type: Literal['linear', 'quadratic', 'constant', 'sigmoid'],\n    beta_start: float,\n    beta_end: float,\n    num_diffusion_timesteps: int,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n) -> Tensor:\n    r\"\"\"Generates a schedule of beta values according to the specified strategy\n    for the diffusion process from the `\"Denoising Diffusion Probabilistic\n    Models\" <https://arxiv.org/abs/2006.11239>`_ paper.\n\n    Beta values are used to scale the noise added during the diffusion process\n    in generative models. This function creates an array of beta values\n    according to a pre-defined schedule, which can be either :obj:`\"linear\"`,\n    :obj:`\"quadratic\"`, :obj:`\"constant\"`, or :obj:`\"sigmoid\"`.\n\n    Args:\n        schedule_type (str): The type of schedule to use for beta values.\n        beta_start (float): The starting value of beta.\n        beta_end (float): The ending value of beta.\n        num_diffusion_timesteps (int): The number of timesteps for the\n            diffusion process.\n        dtype (torch.dtype, optional): The output data type.\n            (default: :obj:`None`)\n        device (torch.device, optional): The output device.\n            (default: :obj:`None`)\n    \"\"\"\n    if schedule_type == 'linear':\n        return torch.linspace(\n            beta_start,\n            beta_end,\n            num_diffusion_timesteps,\n            dtype=dtype,\n            device=device,\n        )\n\n    if schedule_type == 'quadratic':\n        return torch.linspace(\n            beta_start**0.5,\n            beta_end**0.5,\n            num_diffusion_timesteps,\n            dtype=dtype,\n            device=device,\n        )**2\n\n    if schedule_type == 'constant':\n        return torch.full(\n            (num_diffusion_timesteps, ),\n            fill_value=beta_end,\n            dtype=dtype,\n            device=device,\n        )\n\n    if schedule_type == 'sigmoid':\n        return torch.linspace(\n            -6,\n            6,\n            num_diffusion_timesteps,\n            dtype=dtype,\n            device=device,\n        ).sigmoid() * (beta_end - beta_start) + beta_start\n\n    raise ValueError(f\"Found invalid 'schedule_type' (got '{schedule_type}')\")\n"
  },
  {
    "path": "torch_geometric/utils/num_nodes.py",
    "content": "from copy import copy\nfrom typing import Dict, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.typing import EdgeType, NodeType, SparseTensor\n\n\ndef maybe_num_nodes(\n    edge_index: Union[Tensor, Tuple[Tensor, Tensor], SparseTensor],\n    num_nodes: Optional[int] = None,\n) -> int:\n    if num_nodes is not None:\n        return num_nodes\n    elif not torch.jit.is_scripting() and isinstance(edge_index, EdgeIndex):\n        return max(edge_index.get_sparse_size())\n    elif isinstance(edge_index, Tensor):\n        if torch_geometric.utils.is_torch_sparse_tensor(edge_index):\n            return max(edge_index.size(0), edge_index.size(1))\n\n        if torch.jit.is_tracing():\n            # Avoid non-traceable if-check for empty `edge_index` tensor:\n            tmp = torch.concat([\n                edge_index.view(-1),\n                edge_index.new_full((1, ), fill_value=-1)\n            ])\n            return tmp.max() + 1  # type: ignore\n\n        return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0\n    elif isinstance(edge_index, tuple):\n        return max(\n            int(edge_index[0].max()) + 1 if edge_index[0].numel() > 0 else 0,\n            int(edge_index[1].max()) + 1 if edge_index[1].numel() > 0 else 0,\n        )\n    elif isinstance(edge_index, SparseTensor):\n        return max(edge_index.size(0), edge_index.size(1))\n    raise NotImplementedError\n\n\ndef maybe_num_nodes_dict(\n    edge_index_dict: Dict[EdgeType, Tensor],\n    num_nodes_dict: Optional[Dict[NodeType, int]] = None,\n) -> Dict[NodeType, int]:\n    num_nodes_dict = {} if num_nodes_dict is None else copy(num_nodes_dict)\n\n    found_types = list(num_nodes_dict.keys())\n\n    for keys, edge_index in edge_index_dict.items():\n\n        key = keys[0]\n        if key not in found_types:\n            N = int(edge_index[0].max() + 1)\n            num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))\n\n        key = keys[-1]\n        if key not in found_types:\n            N = int(edge_index[1].max() + 1)\n            num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N))\n\n    return num_nodes_dict\n"
  },
  {
    "path": "torch_geometric/utils/ppr.py",
    "content": "from itertools import chain\nfrom typing import Callable, List, Optional, Tuple\n\nimport numpy as np\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric import EdgeIndex\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\ntry:\n    import numba\n    WITH_NUMBA = True\nexcept Exception:  # pragma: no cover\n    WITH_NUMBA = False\n\n\ndef _get_ppr(  # pragma: no cover\n    rowptr: np.ndarray,\n    col: np.ndarray,\n    alpha: float,\n    eps: float,\n    target: Optional[np.ndarray] = None,\n) -> Tuple[List[List[int]], List[List[float]]]:\n\n    num_nodes = len(rowptr) - 1 if target is None else len(target)\n    alpha_eps = alpha * eps\n    js = [[0]] * num_nodes\n    vals = [[0.]] * num_nodes\n\n    for inode_uint in numba.prange(num_nodes):\n        if target is None:\n            inode = numba.int64(inode_uint)\n        else:\n            inode = target[inode_uint]\n\n        p = {inode: 0.0}\n        r = {}\n        r[inode] = alpha\n        q = [inode]\n\n        while len(q) > 0:\n            unode = q.pop()\n\n            res = r[unode] if unode in r else 0\n            if unode in p:\n                p[unode] += res\n            else:\n                p[unode] = res\n\n            r[unode] = 0\n            start, end = rowptr[unode], rowptr[unode + 1]\n            ucount = end - start\n\n            for vnode in col[start:end]:\n                _val = (1 - alpha) * res / ucount\n                if vnode in r:\n                    r[vnode] += _val\n                else:\n                    r[vnode] = _val\n\n                res_vnode = r[vnode] if vnode in r else 0\n                vcount = rowptr[vnode + 1] - rowptr[vnode]\n                if res_vnode >= alpha_eps * vcount:\n                    if vnode not in q:\n                        q.append(vnode)\n\n        js[inode_uint] = list(p.keys())\n        vals[inode_uint] = list(p.values())\n\n    return js, vals\n\n\n_get_ppr_numba: Optional[Callable] = None\n\n\ndef get_ppr(\n    edge_index: Tensor,\n    alpha: float = 0.2,\n    eps: float = 1e-5,\n    target: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Calculates the personalized PageRank (PPR) vector for all or a subset\n    of nodes using a variant of the `Andersen algorithm\n    <https://mathweb.ucsd.edu/~fan/wp/localpartition.pdf>`_.\n\n    Args:\n        edge_index (torch.Tensor): The indices of the graph.\n        alpha (float, optional): The alpha value of the PageRank algorithm.\n            (default: :obj:`0.2`)\n        eps (float, optional): The threshold for stopping the PPR calculation\n            (:obj:`edge_weight >= eps * out_degree`). (default: :obj:`1e-5`)\n        target (torch.Tensor, optional): The target nodes to compute PPR for.\n            If not given, calculates PPR vectors for all nodes.\n            (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes. (default: :obj:`None`)\n\n    :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`)\n    \"\"\"\n    if not WITH_NUMBA:  # pragma: no cover\n        raise ImportError(\"'get_ppr' requires the 'numba' package\")\n\n    global _get_ppr_numba\n    if _get_ppr_numba is None:\n        _get_ppr_numba = numba.jit(nopython=True, parallel=True)(_get_ppr)\n\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n    edge_index = EdgeIndex(edge_index, sparse_size=(num_nodes, num_nodes))\n    edge_index = edge_index.sort_by('row')[0]\n    (rowptr, col), _ = edge_index.get_csr()\n\n    cols, weights = _get_ppr_numba(\n        rowptr.cpu().numpy(),\n        col.cpu().numpy(),\n        alpha,\n        eps,\n        None if target is None else target.cpu().numpy(),\n    )\n\n    device = edge_index.device\n    col = torch.tensor(list(chain.from_iterable(cols)), device=device)\n    weight = torch.tensor(list(chain.from_iterable(weights)), device=device)\n    deg = torch.tensor([len(value) for value in cols], device=device)\n\n    row = torch.arange(num_nodes) if target is None else target\n    row = row.repeat_interleave(deg, output_size=col.numel())\n\n    edge_index = torch.stack([row, col], dim=0)\n\n    return edge_index, weight\n"
  },
  {
    "path": "torch_geometric/utils/random.py",
    "content": "import warnings\nfrom typing import List, Union\n\nimport numpy as np\nimport torch\n\nfrom torch_geometric.utils import remove_self_loops, to_undirected\n\n\ndef erdos_renyi_graph(\n    num_nodes: int,\n    edge_prob: float,\n    directed: bool = False,\n) -> torch.Tensor:\n    r\"\"\"Returns the :obj:`edge_index` of a random Erdos-Renyi graph.\n\n    Args:\n        num_nodes (int): The number of nodes.\n        edge_prob (float): Probability of an edge.\n        directed (bool, optional): If set to :obj:`True`, will return a\n            directed graph. (default: :obj:`False`)\n\n    Examples:\n        >>> erdos_renyi_graph(5, 0.2, directed=False)\n        tensor([[0, 1, 1, 4],\n                [1, 0, 4, 1]])\n\n        >>> erdos_renyi_graph(5, 0.2, directed=True)\n        tensor([[0, 1, 3, 3, 4, 4],\n                [4, 3, 1, 2, 1, 3]])\n    \"\"\"\n    if directed:\n        idx = torch.arange((num_nodes - 1) * num_nodes)\n        idx = idx.view(num_nodes - 1, num_nodes)\n        idx = idx + torch.arange(1, num_nodes).view(-1, 1)\n        idx = idx.view(-1)\n    else:\n        warnings.filterwarnings('ignore', '.*pass the indexing argument.*')\n        idx = torch.combinations(torch.arange(num_nodes), r=2)\n\n    # Filter edges.\n    mask = torch.rand(idx.size(0)) < edge_prob\n    idx = idx[mask]\n\n    if directed:\n        row = idx.div(num_nodes, rounding_mode='floor')\n        col = idx % num_nodes\n        edge_index = torch.stack([row, col], dim=0)\n    else:\n        edge_index = to_undirected(idx.t(), num_nodes=num_nodes)\n\n    return edge_index\n\n\ndef stochastic_blockmodel_graph(\n    block_sizes: Union[List[int], torch.Tensor],\n    edge_probs: Union[List[List[float]], torch.Tensor],\n    directed: bool = False,\n) -> torch.Tensor:\n    r\"\"\"Returns the :obj:`edge_index` of a stochastic blockmodel graph.\n\n    Args:\n        block_sizes ([int] or LongTensor): The sizes of blocks.\n        edge_probs ([[float]] or FloatTensor): The density of edges going\n            from each block to each other block. Must be symmetric if the\n            graph is undirected.\n        directed (bool, optional): If set to :obj:`True`, will return a\n            directed graph. (default: :obj:`False`)\n\n    Examples:\n        >>> block_sizes = [2, 2, 4]\n        >>> edge_probs = [[0.25, 0.05, 0.02],\n        ...               [0.05, 0.35, 0.07],\n        ...               [0.02, 0.07, 0.40]]\n        >>> stochastic_blockmodel_graph(block_sizes, edge_probs,\n        ...                             directed=False)\n        tensor([[2, 4, 4, 5, 5, 6, 7, 7],\n                [5, 6, 7, 2, 7, 4, 4, 5]])\n\n        >>> stochastic_blockmodel_graph(block_sizes, edge_probs,\n        ...                             directed=True)\n        tensor([[0, 2, 3, 4, 4, 5, 5],\n                [3, 4, 1, 5, 6, 6, 7]])\n    \"\"\"\n    size, prob = block_sizes, edge_probs\n\n    if not isinstance(size, torch.Tensor):\n        size = torch.tensor(size, dtype=torch.long)\n    if not isinstance(prob, torch.Tensor):\n        prob = torch.tensor(prob, dtype=torch.float)\n\n    assert size.dim() == 1\n    assert prob.dim() == 2 and prob.size(0) == prob.size(1)\n    assert size.size(0) == prob.size(0)\n    if not directed:\n        assert torch.allclose(prob, prob.t())\n\n    node_idx = torch.cat([size.new_full((b, ), i) for i, b in enumerate(size)])\n    num_nodes = node_idx.size(0)\n\n    if directed:\n        idx = torch.arange((num_nodes - 1) * num_nodes)\n        idx = idx.view(num_nodes - 1, num_nodes)\n        idx = idx + torch.arange(1, num_nodes).view(-1, 1)\n        idx = idx.view(-1)\n        row = idx.div(num_nodes, rounding_mode='floor')\n        col = idx % num_nodes\n    else:\n        row, col = torch.combinations(torch.arange(num_nodes), r=2).t()\n\n    mask = torch.bernoulli(prob[node_idx[row], node_idx[col]]).to(torch.bool)\n    edge_index = torch.stack([row[mask], col[mask]], dim=0)\n\n    if not directed:\n        edge_index = to_undirected(edge_index, num_nodes=num_nodes)\n\n    return edge_index\n\n\ndef barabasi_albert_graph(num_nodes: int, num_edges: int) -> torch.Tensor:\n    r\"\"\"Returns the :obj:`edge_index` of a Barabasi-Albert preferential\n    attachment model, where a graph of :obj:`num_nodes` nodes grows by\n    attaching new nodes with :obj:`num_edges` edges that are preferentially\n    attached to existing nodes with high degree.\n\n    Args:\n        num_nodes (int): The number of nodes.\n        num_edges (int): The number of edges from a new node to existing nodes.\n\n    Example:\n        >>> barabasi_albert_graph(num_nodes=4, num_edges=3)\n        tensor([[0, 0, 0, 1, 1, 2, 2, 3],\n                [1, 2, 3, 0, 2, 0, 1, 0]])\n    \"\"\"\n    assert num_edges > 0 and num_edges < num_nodes\n\n    row, col = torch.arange(num_edges), torch.randperm(num_edges)\n\n    for i in range(num_edges, num_nodes):\n        row = torch.cat([row, torch.full((num_edges, ), i, dtype=torch.long)])\n        choice = np.random.choice(torch.cat([row, col]).numpy(), num_edges)\n        col = torch.cat([col, torch.from_numpy(choice)])\n\n    edge_index = torch.stack([row, col], dim=0)\n    edge_index, _ = remove_self_loops(edge_index)\n    edge_index = to_undirected(edge_index, num_nodes=num_nodes)\n\n    return edge_index\n"
  },
  {
    "path": "torch_geometric/utils/repeat.py",
    "content": "import itertools\nimport numbers\nfrom typing import Any\n\nimport torch\nfrom torch import Tensor\n\n\ndef repeat(src: Any, length: int) -> Any:\n    if src is None:\n        return None\n\n    if isinstance(src, Tensor):\n        if src.numel() == 1:\n            return src.repeat(length)\n\n        if src.numel() > length:\n            return src[:length]\n\n        if src.numel() < length:\n            last_elem = src[-1].unsqueeze(0)\n            padding = last_elem.repeat(length - src.numel())\n            return torch.cat([src, padding])\n\n        return src\n\n    if isinstance(src, numbers.Number):\n        return list(itertools.repeat(src, length))\n\n    if (len(src) > length):\n        return src[:length]\n\n    if (len(src) < length):\n        return src + list(itertools.repeat(src[-1], length - len(src)))\n\n    return src\n"
  },
  {
    "path": "torch_geometric/utils/smiles.py",
    "content": "from typing import Any, Dict, List\n\nimport torch\n\nimport torch_geometric\n\nx_map: Dict[str, List[Any]] = {\n    'atomic_num':\n    list(range(0, 119)),\n    'chirality': [\n        'CHI_UNSPECIFIED',\n        'CHI_TETRAHEDRAL_CW',\n        'CHI_TETRAHEDRAL_CCW',\n        'CHI_OTHER',\n        'CHI_TETRAHEDRAL',\n        'CHI_ALLENE',\n        'CHI_SQUAREPLANAR',\n        'CHI_TRIGONALBIPYRAMIDAL',\n        'CHI_OCTAHEDRAL',\n    ],\n    'degree':\n    list(range(0, 11)),\n    'formal_charge':\n    list(range(-5, 7)),\n    'num_hs':\n    list(range(0, 9)),\n    'num_radical_electrons':\n    list(range(0, 5)),\n    'hybridization': [\n        'UNSPECIFIED',\n        'S',\n        'SP',\n        'SP2',\n        'SP3',\n        'SP3D',\n        'SP3D2',\n        'OTHER',\n    ],\n    'is_aromatic': [False, True],\n    'is_in_ring': [False, True],\n}\n\ne_map: Dict[str, List[Any]] = {\n    'bond_type': [\n        'UNSPECIFIED',\n        'SINGLE',\n        'DOUBLE',\n        'TRIPLE',\n        'QUADRUPLE',\n        'QUINTUPLE',\n        'HEXTUPLE',\n        'ONEANDAHALF',\n        'TWOANDAHALF',\n        'THREEANDAHALF',\n        'FOURANDAHALF',\n        'FIVEANDAHALF',\n        'AROMATIC',\n        'IONIC',\n        'HYDROGEN',\n        'THREECENTER',\n        'DATIVEONE',\n        'DATIVE',\n        'DATIVEL',\n        'DATIVER',\n        'OTHER',\n        'ZERO',\n    ],\n    'stereo': [\n        'STEREONONE',\n        'STEREOANY',\n        'STEREOZ',\n        'STEREOE',\n        'STEREOCIS',\n        'STEREOTRANS',\n    ],\n    'is_conjugated': [False, True],\n}\n\n\ndef from_rdmol(mol: Any) -> 'torch_geometric.data.Data':\n    r\"\"\"Converts a :class:`rdkit.Chem.Mol` instance to a\n    :class:`torch_geometric.data.Data` instance.\n\n    Args:\n        mol (rdkit.Chem.Mol): The :class:`rdkit` molecule.\n    \"\"\"\n    from rdkit import Chem\n\n    from torch_geometric.data import Data\n\n    assert isinstance(mol, Chem.Mol)\n\n    xs: List[List[int]] = []\n    for atom in mol.GetAtoms():\n        row: List[int] = []\n        row.append(x_map['atomic_num'].index(atom.GetAtomicNum()))\n        row.append(x_map['chirality'].index(str(atom.GetChiralTag())))\n        row.append(x_map['degree'].index(atom.GetTotalDegree()))\n        row.append(x_map['formal_charge'].index(atom.GetFormalCharge()))\n        row.append(x_map['num_hs'].index(atom.GetTotalNumHs()))\n        row.append(x_map['num_radical_electrons'].index(\n            atom.GetNumRadicalElectrons()))\n        row.append(x_map['hybridization'].index(str(atom.GetHybridization())))\n        row.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))\n        row.append(x_map['is_in_ring'].index(atom.IsInRing()))\n        xs.append(row)\n\n    x = torch.tensor(xs, dtype=torch.long).view(-1, 9)\n\n    edge_indices, edge_attrs = [], []\n    for bond in mol.GetBonds():\n        i = bond.GetBeginAtomIdx()\n        j = bond.GetEndAtomIdx()\n\n        e = []\n        e.append(e_map['bond_type'].index(str(bond.GetBondType())))\n        e.append(e_map['stereo'].index(str(bond.GetStereo())))\n        e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))\n\n        edge_indices += [[i, j], [j, i]]\n        edge_attrs += [e, e]\n\n    edge_index = torch.tensor(edge_indices)\n    edge_index = edge_index.t().to(torch.long).view(2, -1)\n    edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)\n\n    if edge_index.numel() > 0:  # Sort indices.\n        perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()\n        edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]\n\n    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)\n\n\ndef from_smiles(\n    smiles: str,\n    with_hydrogen: bool = False,\n    kekulize: bool = False,\n) -> 'torch_geometric.data.Data':\n    r\"\"\"Converts a SMILES string to a :class:`torch_geometric.data.Data`\n    instance.\n\n    Args:\n        smiles (str): The SMILES string.\n        with_hydrogen (bool, optional): If set to :obj:`True`, will store\n            hydrogens in the molecule graph. (default: :obj:`False`)\n        kekulize (bool, optional): If set to :obj:`True`, converts aromatic\n            bonds to single/double bonds. (default: :obj:`False`)\n    \"\"\"\n    from rdkit import Chem, RDLogger\n\n    RDLogger.DisableLog('rdApp.*')  # type: ignore[attr-defined]\n\n    mol = Chem.MolFromSmiles(smiles)\n\n    if mol is None:\n        mol = Chem.MolFromSmiles('')\n    if with_hydrogen:\n        mol = Chem.AddHs(mol)\n    if kekulize:\n        Chem.Kekulize(mol)\n\n    data = from_rdmol(mol)\n    data.smiles = smiles\n    return data\n\n\ndef to_rdmol(\n    data: 'torch_geometric.data.Data',\n    kekulize: bool = False,\n) -> Any:\n    \"\"\"Converts a :class:`torch_geometric.data.Data` instance to a\n    :class:`rdkit.Chem.Mol` instance.\n\n    Args:\n        data (torch_geometric.data.Data): The molecular graph data.\n        kekulize (bool, optional): If set to :obj:`True`, converts aromatic\n            bonds to single/double bonds. (default: :obj:`False`)\n    \"\"\"\n    from rdkit import Chem\n\n    mol = Chem.RWMol()\n\n    assert data.x is not None\n    assert data.num_nodes is not None\n    assert data.edge_index is not None\n    assert data.edge_attr is not None\n    for i in range(data.num_nodes):\n        atom = Chem.Atom(int(data.x[i, 0]))\n        atom.SetChiralTag(Chem.rdchem.ChiralType.values[int(data.x[i, 1])])\n        atom.SetFormalCharge(x_map['formal_charge'][int(data.x[i, 3])])\n        atom.SetNumExplicitHs(x_map['num_hs'][int(data.x[i, 4])])\n        atom.SetNumRadicalElectrons(x_map['num_radical_electrons'][int(\n            data.x[i, 5])])\n        atom.SetHybridization(Chem.rdchem.HybridizationType.values[int(\n            data.x[i, 6])])\n        atom.SetIsAromatic(bool(data.x[i, 7]))\n        mol.AddAtom(atom)\n\n    edges = [tuple(i) for i in data.edge_index.t().tolist()]\n    visited = set()\n\n    for i in range(len(edges)):\n        src, dst = edges[i]\n        if tuple(sorted(edges[i])) in visited:\n            continue\n\n        bond_type = Chem.BondType.values[int(data.edge_attr[i, 0])]\n        mol.AddBond(src, dst, bond_type)\n\n        # Set stereochemistry:\n        stereo = Chem.rdchem.BondStereo.values[int(data.edge_attr[i, 1])]\n        if stereo != Chem.rdchem.BondStereo.STEREONONE:\n            db = mol.GetBondBetweenAtoms(src, dst)\n            db.SetStereoAtoms(dst, src)\n            db.SetStereo(stereo)\n\n        # Set conjugation:\n        is_conjugated = bool(data.edge_attr[i, 2])\n        mol.GetBondBetweenAtoms(src, dst).SetIsConjugated(is_conjugated)\n\n        visited.add(tuple(sorted(edges[i])))\n\n    mol = mol.GetMol()\n\n    if kekulize:\n        Chem.Kekulize(mol)\n\n    Chem.SanitizeMol(mol)\n    Chem.AssignStereochemistry(mol)\n\n    return mol\n\n\ndef to_smiles(\n    data: 'torch_geometric.data.Data',\n    kekulize: bool = False,\n) -> str:\n    \"\"\"Converts a :class:`torch_geometric.data.Data` instance to a SMILES\n    string.\n\n    Args:\n        data (torch_geometric.data.Data): The molecular graph.\n        kekulize (bool, optional): If set to :obj:`True`, converts aromatic\n            bonds to single/double bonds. (default: :obj:`False`)\n    \"\"\"\n    from rdkit import Chem\n    mol = to_rdmol(data, kekulize=kekulize)\n    return Chem.MolToSmiles(mol, isomericSmiles=True)\n"
  },
  {
    "path": "torch_geometric/utils/sparse.py",
    "content": "import warnings\nfrom typing import Any, List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nimport torch_geometric.typing\nfrom torch_geometric.index import index2ptr, ptr2index\nfrom torch_geometric.typing import SparseTensor\nfrom torch_geometric.utils import coalesce, cumsum\n\n\ndef dense_to_sparse(\n    adj: Tensor,\n    mask: Optional[Tensor] = None,\n) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Converts a dense adjacency matrix to a sparse adjacency matrix defined\n    by edge indices and edge attributes.\n\n    Args:\n        adj (torch.Tensor): The dense adjacency matrix of shape\n            :obj:`[num_nodes, num_nodes]` or\n            :obj:`[batch_size, num_nodes, num_nodes]`.\n        mask (torch.Tensor, optional): A boolean tensor of shape\n            :obj:`[batch_size, num_nodes]` holding information about which\n            nodes are in each example are valid. (default: :obj:`None`)\n\n    :rtype: (:class:`LongTensor`, :class:`Tensor`)\n\n    Examples:\n        >>> # For a single adjacency matrix:\n        >>> adj = torch.tensor([[3, 1],\n        ...                     [2, 0]])\n        >>> dense_to_sparse(adj)\n        (tensor([[0, 0, 1],\n                [0, 1, 0]]),\n        tensor([3, 1, 2]))\n\n        >>> # For two adjacency matrixes:\n        >>> adj = torch.tensor([[[3, 1],\n        ...                      [2, 0]],\n        ...                     [[0, 1],\n        ...                      [0, 2]]])\n        >>> dense_to_sparse(adj)\n        (tensor([[0, 0, 1, 2, 3],\n                [0, 1, 0, 3, 3]]),\n        tensor([3, 1, 2, 1, 2]))\n\n        >>> # First graph with two nodes, second with three:\n        >>> adj = torch.tensor([[\n        ...         [3, 1, 0],\n        ...         [2, 0, 0],\n        ...         [0, 0, 0]\n        ...     ], [\n        ...         [0, 1, 0],\n        ...         [0, 2, 3],\n        ...         [0, 5, 0]\n        ...     ]])\n        >>> mask = torch.tensor([\n        ...         [True, True, False],\n        ...         [True, True, True]\n        ...     ])\n        >>> dense_to_sparse(adj, mask)\n        (tensor([[0, 0, 1, 2, 3, 3, 4],\n                [0, 1, 0, 3, 3, 4, 3]]),\n        tensor([3, 1, 2, 1, 2, 3, 5]))\n    \"\"\"\n    if adj.dim() < 2 or adj.dim() > 3:\n        raise ValueError(f\"Dense adjacency matrix 'adj' must be two- or \"\n                         f\"three-dimensional (got {adj.dim()} dimensions)\")\n\n    if mask is not None and adj.dim() == 2:\n        warnings.warn(\n            \"Mask should not be provided in case the dense \"\n            \"adjacency matrix is two-dimensional\", stacklevel=2)\n        mask = None\n\n    if mask is not None and mask.dim() != 2:\n        raise ValueError(f\"Mask must be two-dimensional \"\n                         f\"(got {mask.dim()} dimensions)\")\n\n    if mask is not None and adj.size(-2) != adj.size(-1):\n        raise ValueError(f\"Mask is only supported on quadratic adjacency \"\n                         f\"matrices (got [*, {adj.size(-2)}, {adj.size(-1)}])\")\n\n    if adj.dim() == 2:\n        edge_index = adj.nonzero().t()\n        edge_attr = adj[edge_index[0], edge_index[1]]\n        return edge_index, edge_attr\n    else:\n        flatten_adj = adj.view(-1, adj.size(-1))\n        if mask is not None:\n            flatten_adj = flatten_adj[mask.view(-1)]\n        edge_index = flatten_adj.nonzero().t()\n        edge_attr = flatten_adj[edge_index[0], edge_index[1]]\n\n        if mask is None:\n            offset = torch.arange(\n                start=0,\n                end=adj.size(0) * adj.size(2),\n                step=adj.size(2),\n                device=adj.device,\n            )\n            offset = offset.repeat_interleave(adj.size(1))\n        else:\n            count = mask.sum(dim=-1)\n            offset = cumsum(count)[:-1]\n            offset = offset.repeat_interleave(count)\n\n        edge_index[1] += offset[edge_index[0]]\n\n        return edge_index, edge_attr\n\n\ndef is_torch_sparse_tensor(src: Any) -> bool:\n    r\"\"\"Returns :obj:`True` if the input :obj:`src` is a\n    :class:`torch.sparse.Tensor` (in any sparse layout).\n\n    Args:\n        src (Any): The input object to be checked.\n    \"\"\"\n    if isinstance(src, Tensor):\n        if src.layout == torch.sparse_coo:\n            return True\n        if src.layout == torch.sparse_csr:\n            return True\n        if src.layout == torch.sparse_csc:\n            return True\n    return False\n\n\ndef is_sparse(src: Any) -> bool:\n    r\"\"\"Returns :obj:`True` if the input :obj:`src` is of type\n    :class:`torch.sparse.Tensor` (in any sparse layout) or of type\n    :class:`torch_sparse.SparseTensor`.\n\n    Args:\n        src (Any): The input object to be checked.\n    \"\"\"\n    return is_torch_sparse_tensor(src) or isinstance(src, SparseTensor)\n\n\ndef to_torch_coo_tensor(\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n    size: Optional[Union[int, Tuple[Optional[int], Optional[int]]]] = None,\n    is_coalesced: bool = False,\n) -> Tensor:\n    r\"\"\"Converts a sparse adjacency matrix defined by edge indices and edge\n    attributes to a :class:`torch.sparse.Tensor` with layout\n    `torch.sparse_coo`.\n    See :meth:`~torch_geometric.utils.to_edge_index` for the reverse operation.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): The edge attributes.\n            (default: :obj:`None`)\n        size (int or (int, int), optional): The size of the sparse matrix.\n            If given as an integer, will create a quadratic sparse matrix.\n            If set to :obj:`None`, will infer a quadratic sparse matrix based\n            on :obj:`edge_index.max() + 1`. (default: :obj:`None`)\n        is_coalesced (bool): If set to :obj:`True`, will assume that\n            :obj:`edge_index` is already coalesced and thus avoids expensive\n            computation. (default: :obj:`False`)\n\n    :rtype: :class:`torch.sparse.Tensor`\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n        ...                            [1, 0, 2, 1, 3, 2]])\n        >>> to_torch_coo_tensor(edge_index)\n        tensor(indices=tensor([[0, 1, 1, 2, 2, 3],\n                               [1, 0, 2, 1, 3, 2]]),\n               values=tensor([1., 1., 1., 1., 1., 1.]),\n               size=(4, 4), nnz=6, layout=torch.sparse_coo)\n\n    \"\"\"\n    if size is None:\n        size = int(edge_index.max()) + 1\n\n    if isinstance(size, (tuple, list)):\n        num_src_nodes, num_dst_nodes = size\n        if num_src_nodes is None:\n            num_src_nodes = int(edge_index[0].max()) + 1\n        if num_dst_nodes is None:\n            num_dst_nodes = int(edge_index[1].max()) + 1\n        size = (num_src_nodes, num_dst_nodes)\n    else:\n        size = (size, size)\n\n    if not is_coalesced:\n        edge_index, edge_attr = coalesce(edge_index, edge_attr, max(size))\n\n    if edge_attr is None:\n        # Expanded tensors are not yet supported in all PyTorch code paths :(\n        # edge_attr = torch.ones(1, device=edge_index.device)\n        # edge_attr = edge_attr.expand(edge_index.size(1))\n        edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)\n\n    if not torch_geometric.typing.WITH_PT21:\n        adj = torch.sparse_coo_tensor(\n            indices=edge_index,\n            values=edge_attr,\n            size=tuple(size) + edge_attr.size()[1:],\n            device=edge_index.device,\n        )\n        adj = adj._coalesced_(True)\n        return adj\n\n    return torch.sparse_coo_tensor(\n        indices=edge_index,\n        values=edge_attr,\n        size=tuple(size) + edge_attr.size()[1:],\n        device=edge_index.device,\n        is_coalesced=True,\n    )\n\n\ndef to_torch_csr_tensor(\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n    size: Optional[Union[int, Tuple[Optional[int], Optional[int]]]] = None,\n    is_coalesced: bool = False,\n) -> Tensor:\n    r\"\"\"Converts a sparse adjacency matrix defined by edge indices and edge\n    attributes to a :class:`torch.sparse.Tensor` with layout\n    `torch.sparse_csr`.\n    See :meth:`~torch_geometric.utils.to_edge_index` for the reverse operation.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): The edge attributes.\n            (default: :obj:`None`)\n        size (int or (int, int), optional): The size of the sparse matrix.\n            If given as an integer, will create a quadratic sparse matrix.\n            If set to :obj:`None`, will infer a quadratic sparse matrix based\n            on :obj:`edge_index.max() + 1`. (default: :obj:`None`)\n        is_coalesced (bool): If set to :obj:`True`, will assume that\n            :obj:`edge_index` is already coalesced and thus avoids expensive\n            computation. (default: :obj:`False`)\n\n    :rtype: :class:`torch.sparse.Tensor`\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n        ...                            [1, 0, 2, 1, 3, 2]])\n        >>> to_torch_csr_tensor(edge_index)\n        tensor(crow_indices=tensor([0, 1, 3, 5, 6]),\n               col_indices=tensor([1, 0, 2, 1, 3, 2]),\n               values=tensor([1., 1., 1., 1., 1., 1.]),\n               size=(4, 4), nnz=6, layout=torch.sparse_csr)\n\n    \"\"\"\n    if size is None:\n        size = int(edge_index.max()) + 1\n\n    if isinstance(size, (tuple, list)):\n        num_src_nodes, num_dst_nodes = size\n        if num_src_nodes is None:\n            num_src_nodes = int(edge_index[0].max()) + 1\n        if num_dst_nodes is None:\n            num_dst_nodes = int(edge_index[1].max()) + 1\n        size = (num_src_nodes, num_dst_nodes)\n    else:\n        size = (size, size)\n\n    if not is_coalesced:\n        edge_index, edge_attr = coalesce(edge_index, edge_attr, max(size))\n\n    if edge_attr is None:\n        # Expanded tensors are not yet supported in all PyTorch code paths :(\n        # edge_attr = torch.ones(1, device=edge_index.device)\n        # edge_attr = edge_attr.expand(edge_index.size(1))\n        edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)\n\n    adj = torch.sparse_csr_tensor(\n        crow_indices=index2ptr(edge_index[0], size[0]),\n        col_indices=edge_index[1],\n        values=edge_attr,\n        size=tuple(size) + edge_attr.size()[1:],\n        device=edge_index.device,\n    )\n\n    return adj\n\n\ndef to_torch_csc_tensor(\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n    size: Optional[Union[int, Tuple[Optional[int], Optional[int]]]] = None,\n    is_coalesced: bool = False,\n) -> Tensor:\n    r\"\"\"Converts a sparse adjacency matrix defined by edge indices and edge\n    attributes to a :class:`torch.sparse.Tensor` with layout\n    `torch.sparse_csc`.\n    See :meth:`~torch_geometric.utils.to_edge_index` for the reverse operation.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): The edge attributes.\n            (default: :obj:`None`)\n        size (int or (int, int), optional): The size of the sparse matrix.\n            If given as an integer, will create a quadratic sparse matrix.\n            If set to :obj:`None`, will infer a quadratic sparse matrix based\n            on :obj:`edge_index.max() + 1`. (default: :obj:`None`)\n        is_coalesced (bool): If set to :obj:`True`, will assume that\n            :obj:`edge_index` is already coalesced and thus avoids expensive\n            computation. (default: :obj:`False`)\n\n    :rtype: :class:`torch.sparse.Tensor`\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n        ...                            [1, 0, 2, 1, 3, 2]])\n        >>> to_torch_csc_tensor(edge_index)\n        tensor(ccol_indices=tensor([0, 1, 3, 5, 6]),\n               row_indices=tensor([1, 0, 2, 1, 3, 2]),\n               values=tensor([1., 1., 1., 1., 1., 1.]),\n               size=(4, 4), nnz=6, layout=torch.sparse_csc)\n\n    \"\"\"\n    if size is None:\n        size = int(edge_index.max()) + 1\n\n    if isinstance(size, (tuple, list)):\n        num_src_nodes, num_dst_nodes = size\n        if num_src_nodes is None:\n            num_src_nodes = int(edge_index[0].max()) + 1\n        if num_dst_nodes is None:\n            num_dst_nodes = int(edge_index[1].max()) + 1\n        size = (num_src_nodes, num_dst_nodes)\n    else:\n        size = (size, size)\n\n    if not is_coalesced:\n        edge_index, edge_attr = coalesce(edge_index, edge_attr, max(size),\n                                         sort_by_row=False)\n\n    if edge_attr is None:\n        # Expanded tensors are not yet supported in all PyTorch code paths :(\n        # edge_attr = torch.ones(1, device=edge_index.device)\n        # edge_attr = edge_attr.expand(edge_index.size(1))\n        edge_attr = torch.ones(edge_index.size(1), device=edge_index.device)\n\n    adj = torch.sparse_csc_tensor(\n        ccol_indices=index2ptr(edge_index[1], size[1]),\n        row_indices=edge_index[0],\n        values=edge_attr,\n        size=tuple(size) + edge_attr.size()[1:],\n        device=edge_index.device,\n    )\n\n    return adj\n\n\ndef to_torch_sparse_tensor(\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n    size: Optional[Union[int, Tuple[Optional[int], Optional[int]]]] = None,\n    is_coalesced: bool = False,\n    layout: torch.layout = torch.sparse_coo,\n) -> Tensor:\n    r\"\"\"Converts a sparse adjacency matrix defined by edge indices and edge\n    attributes to a :class:`torch.sparse.Tensor` with custom :obj:`layout`.\n    See :meth:`~torch_geometric.utils.to_edge_index` for the reverse operation.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor, optional): The edge attributes.\n            (default: :obj:`None`)\n        size (int or (int, int), optional): The size of the sparse matrix.\n            If given as an integer, will create a quadratic sparse matrix.\n            If set to :obj:`None`, will infer a quadratic sparse matrix based\n            on :obj:`edge_index.max() + 1`. (default: :obj:`None`)\n        is_coalesced (bool): If set to :obj:`True`, will assume that\n            :obj:`edge_index` is already coalesced and thus avoids expensive\n            computation. (default: :obj:`False`)\n        layout (torch.layout, optional): The layout of the output sparse tensor\n            (:obj:`torch.sparse_coo`, :obj:`torch.sparse_csr`,\n            :obj:`torch.sparse_csc`). (default: :obj:`torch.sparse_coo`)\n\n    :rtype: :class:`torch.sparse.Tensor`\n    \"\"\"\n    if layout == torch.sparse_coo:\n        return to_torch_coo_tensor(edge_index, edge_attr, size, is_coalesced)\n    if layout == torch.sparse_csr:\n        return to_torch_csr_tensor(edge_index, edge_attr, size, is_coalesced)\n    if layout == torch.sparse_csc:\n        return to_torch_csc_tensor(edge_index, edge_attr, size, is_coalesced)\n\n    raise ValueError(f\"Unexpected sparse tensor layout (got '{layout}')\")\n\n\ndef to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]:\n    r\"\"\"Converts a :class:`torch.sparse.Tensor` or a\n    :class:`torch_sparse.SparseTensor` to edge indices and edge attributes.\n\n    Args:\n        adj (torch.sparse.Tensor or SparseTensor): The adjacency matrix.\n\n    :rtype: (:class:`torch.Tensor`, :class:`torch.Tensor`)\n\n    Example:\n        >>> edge_index = torch.tensor([[0, 1, 1, 2, 2, 3],\n        ...                            [1, 0, 2, 1, 3, 2]])\n        >>> adj = to_torch_coo_tensor(edge_index)\n        >>> to_edge_index(adj)\n        (tensor([[0, 1, 1, 2, 2, 3],\n                [1, 0, 2, 1, 3, 2]]),\n        tensor([1., 1., 1., 1., 1., 1.]))\n    \"\"\"\n    if isinstance(adj, SparseTensor):\n        row, col, value = adj.coo()\n        if value is None:\n            value = torch.ones(row.size(0), device=row.device)\n        return torch.stack([row, col], dim=0).long(), value\n\n    if adj.layout == torch.sparse_coo:\n        adj = adj._coalesced_(True)\n        return adj.indices().detach().long(), adj.values()\n\n    if adj.layout == torch.sparse_csr:\n        row = ptr2index(adj.crow_indices().detach())\n        col = adj.col_indices().detach()\n        return torch.stack([row, col], dim=0).long(), adj.values()\n\n    if adj.layout == torch.sparse_csc:\n        col = ptr2index(adj.ccol_indices().detach())\n        row = adj.row_indices().detach()\n        return torch.stack([row, col], dim=0).long(), adj.values()\n\n    raise ValueError(f\"Unexpected sparse tensor layout (got '{adj.layout}')\")\n\n\n# Helper functions ############################################################\n\n\ndef get_sparse_diag(\n    size: int,\n    fill_value: float = 1.0,\n    layout: Optional[int] = None,\n    dtype: Optional[torch.dtype] = None,\n    device: Optional[torch.device] = None,\n) -> Tensor:\n    return torch.sparse.spdiags(\n        torch.full((1, size), fill_value, dtype=dtype, device=device),\n        offsets=torch.zeros(1, dtype=torch.long, device=device),\n        shape=(size, size),\n        layout=layout,\n    )\n\n\ndef set_sparse_value(adj: Tensor, value: Tensor) -> Tensor:\n    if value.dim() > 1:\n        size = adj.size() + value.size()[1:]\n    else:\n        size = adj.size()\n\n    if adj.layout == torch.sparse_coo:\n        return torch.sparse_coo_tensor(\n            indices=adj.indices(),\n            values=value,\n            size=size,\n            device=value.device,\n        ).coalesce()\n\n    if adj.layout == torch.sparse_csr:\n        return torch.sparse_csr_tensor(\n            crow_indices=adj.crow_indices(),\n            col_indices=adj.col_indices(),\n            values=value,\n            size=size,\n            device=value.device,\n        )\n\n    if adj.layout == torch.sparse_csc:\n        return torch.sparse_csc_tensor(\n            ccol_indices=adj.ccol_indices(),\n            row_indices=adj.row_indices(),\n            values=value,\n            size=size,\n            device=value.device,\n        )\n\n    raise ValueError(f\"Unexpected sparse tensor layout (got '{adj.layout}')\")\n\n\ndef cat_coo(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:\n    assert dim in {0, 1, (0, 1)}\n    assert tensors[0].layout == torch.sparse_coo\n\n    indices, values = [], []\n    num_rows = num_cols = 0\n    is_coalesced = True\n\n    if dim == 0:\n        for i, tensor in enumerate(tensors):\n            if i == 0:\n                indices.append(tensor._indices())\n            else:\n                offset = torch.tensor([[num_rows], [0]], device=tensor.device)\n                indices.append(tensor._indices() + offset)\n            values.append(tensor._values())\n            num_rows += tensor.size(0)\n            num_cols = max(num_cols, tensor.size(1))\n            if not tensor.is_coalesced():\n                is_coalesced = False\n\n    elif dim == 1:\n        for i, tensor in enumerate(tensors):\n            if i == 0:\n                indices.append(tensor._indices())\n            else:\n                offset = torch.tensor([[0], [num_cols]], device=tensor.device)\n                indices.append(tensor.indices() + offset)\n            values.append(tensor._values())\n            num_rows = max(num_rows, tensor.size(0))\n            num_cols += tensor.size(1)\n            is_coalesced = False\n\n    else:\n        for i, tensor in enumerate(tensors):\n            if i == 0:\n                indices.append(tensor._indices())\n            else:\n                offset = torch.tensor([[num_rows], [num_cols]],\n                                      device=tensor.device)\n                indices.append(tensor._indices() + offset)\n            values.append(tensor._values())\n            num_rows += tensor.size(0)\n            num_cols += tensor.size(1)\n            if not tensor.is_coalesced():\n                is_coalesced = False\n\n    if not torch_geometric.typing.WITH_PT21:\n        out = torch.sparse_coo_tensor(\n            indices=torch.cat(indices, dim=-1),\n            values=torch.cat(values),\n            size=(num_rows, num_cols) + values[-1].size()[1:],\n            device=tensor.device,\n        )\n        if is_coalesced:\n            out = out._coalesced_(True)\n        return out\n\n    return torch.sparse_coo_tensor(\n        indices=torch.cat(indices, dim=-1),\n        values=torch.cat(values),\n        size=(num_rows, num_cols) + values[-1].size()[1:],\n        device=tensor.device,\n        is_coalesced=True if is_coalesced else None,\n    )\n\n\ndef cat_csr(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:\n    assert dim in {0, 1, (0, 1)}\n    assert tensors[0].layout == torch.sparse_csr\n\n    rows, cols, values = [], [], []\n    num_rows = num_cols = nnz = 0\n\n    if dim == 0:\n        for i, tensor in enumerate(tensors):\n            if i == 0:\n                rows.append(tensor.crow_indices())\n            else:\n                rows.append(tensor.crow_indices()[1:] + nnz)\n            cols.append(tensor.col_indices())\n            values.append(tensor.values())\n            num_rows += tensor.size(0)\n            num_cols = max(num_cols, tensor.size(1))\n            nnz += cols[-1].numel()\n\n        return torch.sparse_csr_tensor(\n            crow_indices=torch.cat(rows),\n            col_indices=torch.cat(cols),\n            values=torch.cat(values),\n            size=(num_rows, num_cols) + values[-1].size()[1:],\n            device=tensor.device,\n        )\n\n    elif dim == 1:\n        for i, tensor in enumerate(tensors):\n            rows.append(ptr2index(tensor.crow_indices()))\n            if i == 0:\n                cols.append(tensor.col_indices())\n            else:\n                cols.append(tensor.col_indices() + num_cols)\n            values.append(tensor.values())\n            num_rows = max(num_rows, tensor.size(0))\n            num_cols += tensor.size(1)\n\n        return torch.sparse_coo_tensor(\n            indices=torch.stack((torch.cat(rows), torch.cat(cols)), 0),\n            values=torch.cat(values),\n            size=(num_rows, num_cols) + values[-1].size()[1:],\n            device=tensor.device,\n        )\n\n    else:\n        for i, tensor in enumerate(tensors):\n            if i == 0:\n                rows.append(tensor.crow_indices())\n                cols.append(tensor.col_indices())\n            else:\n                rows.append(tensor.crow_indices()[1:] + nnz)\n                cols.append(tensor.col_indices() + num_cols)\n            values.append(tensor.values())\n            num_rows += tensor.size(0)\n            num_cols += tensor.size(1)\n            nnz += cols[-1].numel()\n\n        return torch.sparse_csr_tensor(\n            crow_indices=torch.cat(rows),\n            col_indices=torch.cat(cols),\n            values=torch.cat(values),\n            size=(num_rows, num_cols) + values[-1].size()[1:],\n            device=tensor.device,\n        )\n\n\ndef cat_csc(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:\n    assert dim in {0, 1, (0, 1)}\n    assert tensors[0].layout == torch.sparse_csc\n\n    rows, cols, values = [], [], []\n    num_rows = num_cols = nnz = 0\n\n    if dim == 0:\n        for i, tensor in enumerate(tensors):\n            cols.append(ptr2index(tensor.ccol_indices()))\n            if i == 0:\n                rows.append(tensor.row_indices())\n            else:\n                rows.append(tensor.row_indices() + num_rows)\n            values.append(tensor.values())\n            num_rows += tensor.size(0)\n            num_cols = max(num_cols, tensor.size(1))\n\n        return torch.sparse_coo_tensor(\n            indices=torch.stack((torch.cat(rows), torch.cat(cols)), 0),\n            values=torch.cat(values),\n            size=(num_rows, num_cols) + values[-1].size()[1:],\n            device=tensor.device,\n        )\n\n    elif dim == 1:\n        for i, tensor in enumerate(tensors):\n            if i == 0:\n                cols.append(tensor.ccol_indices())\n            else:\n                cols.append(tensor.ccol_indices()[1:] + nnz)\n            rows.append(tensor.row_indices())\n            values.append(tensor.values())\n            num_rows = max(num_rows, tensor.size(0))\n            num_cols += tensor.size(1)\n            nnz += rows[-1].numel()\n\n        return torch.sparse_csc_tensor(\n            row_indices=torch.cat(rows),\n            ccol_indices=torch.cat(cols),\n            values=torch.cat(values),\n            size=(num_rows, num_cols) + values[-1].size()[1:],\n            device=tensor.device,\n        )\n\n    else:\n        for i, tensor in enumerate(tensors):\n            if i == 0:\n                rows.append(tensor.row_indices())\n                cols.append(tensor.ccol_indices())\n            else:\n                rows.append(tensor.row_indices() + num_rows)\n                cols.append(tensor.ccol_indices()[1:] + nnz)\n            values.append(tensor.values())\n            num_rows += tensor.size(0)\n            num_cols += tensor.size(1)\n            nnz += rows[-1].numel()\n\n        return torch.sparse_csc_tensor(\n            row_indices=torch.cat(rows),\n            ccol_indices=torch.cat(cols),\n            values=torch.cat(values),\n            size=(num_rows, num_cols) + values[-1].size()[1:],\n            device=tensor.device,\n        )\n\n\ndef cat(tensors: List[Tensor], dim: Union[int, Tuple[int, int]]) -> Tensor:\n    assert is_torch_sparse_tensor(tensors[0])\n\n    if tensors[0].layout == torch.sparse_coo:\n        return cat_coo(tensors, dim)\n    elif tensors[0].layout == torch.sparse_csr:\n        return cat_csr(tensors, dim)\n    else:\n        return cat_csc(tensors, dim)\n"
  },
  {
    "path": "torch_geometric/utils/undirected.py",
    "content": "import typing\nfrom typing import List, Optional, Tuple, Union\n\nimport torch\nfrom torch import Tensor\n\nfrom torch_geometric.typing import OptTensor\nfrom torch_geometric.utils import coalesce, sort_edge_index\nfrom torch_geometric.utils.num_nodes import maybe_num_nodes\n\nif typing.TYPE_CHECKING:\n    from typing import overload\nelse:\n    from torch.jit import _overload as overload\n\nMISSING = '???'\n\n\n@overload\ndef is_undirected(\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor] = None,\n    num_nodes: Optional[int] = None,\n) -> bool:\n    pass\n\n\n@overload\ndef is_undirected(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: List[Tensor],\n    num_nodes: Optional[int] = None,\n) -> bool:\n    pass\n\n\ndef is_undirected(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Union[Optional[Tensor], List[Tensor]] = None,\n    num_nodes: Optional[int] = None,\n) -> bool:\n    r\"\"\"Returns :obj:`True` if the graph given by :attr:`edge_index` is\n    undirected.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor or List[Tensor], optional): Edge weights or multi-\n            dimensional edge features.\n            If given as a list, will check for equivalence in all its entries.\n            (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max(edge_index) + 1`. (default: :obj:`None`)\n\n    :rtype: bool\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 0],\n        ...                         [1, 0, 0]])\n        >>> weight = torch.tensor([0, 0, 1])\n        >>> is_undirected(edge_index, weight)\n        True\n\n        >>> weight = torch.tensor([0, 1, 1])\n        >>> is_undirected(edge_index, weight)\n        False\n\n    \"\"\"\n    num_nodes = maybe_num_nodes(edge_index, num_nodes)\n\n    edge_attrs: List[Tensor] = []\n    if isinstance(edge_attr, Tensor):\n        edge_attrs.append(edge_attr)\n    elif isinstance(edge_attr, (list, tuple)):\n        edge_attrs = edge_attr\n\n    edge_index1, edge_attrs1 = sort_edge_index(\n        edge_index,\n        edge_attrs,\n        num_nodes=num_nodes,\n        sort_by_row=True,\n    )\n    edge_index2, edge_attrs2 = sort_edge_index(\n        edge_index,\n        edge_attrs,\n        num_nodes=num_nodes,\n        sort_by_row=False,\n    )\n\n    if not torch.equal(edge_index1[0], edge_index2[1]):\n        return False\n\n    if not torch.equal(edge_index1[1], edge_index2[0]):\n        return False\n\n    assert isinstance(edge_attrs1, list) and isinstance(edge_attrs2, list)\n    for edge_attr1, edge_attr2 in zip(edge_attrs1, edge_attrs2):\n        if not torch.equal(edge_attr1, edge_attr2):\n            return False\n\n    return True\n\n\n@overload\ndef to_undirected(\n    edge_index: Tensor,\n    edge_attr: str = MISSING,\n    num_nodes: Optional[int] = None,\n    reduce: str = 'add',\n) -> Tensor:\n    pass\n\n\n@overload\ndef to_undirected(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Tensor,\n    num_nodes: Optional[int] = None,\n    reduce: str = 'add',\n) -> Tuple[Tensor, Tensor]:\n    pass\n\n\n@overload\ndef to_undirected(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Optional[Tensor],\n    num_nodes: Optional[int] = None,\n    reduce: str = 'add',\n) -> Tuple[Tensor, Optional[Tensor]]:\n    pass\n\n\n@overload\ndef to_undirected(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: List[Tensor],\n    num_nodes: Optional[int] = None,\n    reduce: str = 'add',\n) -> Tuple[Tensor, List[Tensor]]:\n    pass\n\n\ndef to_undirected(  # noqa: F811\n    edge_index: Tensor,\n    edge_attr: Union[Optional[Tensor], List[Tensor], str] = MISSING,\n    num_nodes: Optional[int] = None,\n    reduce: str = 'add',\n) -> Union[Tensor, Tuple[Tensor, OptTensor], Tuple[Tensor, List[Tensor]]]:\n    r\"\"\"Converts the graph given by :attr:`edge_index` to an undirected graph\n    such that :math:`(j,i) \\in \\mathcal{E}` for every edge :math:`(i,j) \\in\n    \\mathcal{E}`.\n\n    Args:\n        edge_index (LongTensor): The edge indices.\n        edge_attr (Tensor or List[Tensor], optional): Edge weights or multi-\n            dimensional edge features.\n            If given as a list, will remove duplicates for all its entries.\n            (default: :obj:`None`)\n        num_nodes (int, optional): The number of nodes, *i.e.*\n            :obj:`max(edge_index) + 1`. (default: :obj:`None`)\n        reduce (str, optional): The reduce operation to use for merging edge\n            features (:obj:`\"add\"`, :obj:`\"mean\"`, :obj:`\"min\"`, :obj:`\"max\"`,\n            :obj:`\"mul\"`). (default: :obj:`\"add\"`)\n\n    :rtype: :class:`LongTensor` if :attr:`edge_attr` is not passed, else\n        (:class:`LongTensor`, :obj:`Optional[Tensor]` or :obj:`List[Tensor]]`)\n\n    .. warning::\n\n        From :pyg:`PyG >= 2.3.0` onwards, this function will always return a\n        tuple whenever :obj:`edge_attr` is passed as an argument (even in case\n        it is set to :obj:`None`).\n\n    Examples:\n        >>> edge_index = torch.tensor([[0, 1, 1],\n        ...                            [1, 0, 2]])\n        >>> to_undirected(edge_index)\n        tensor([[0, 1, 1, 2],\n                [1, 0, 2, 1]])\n\n        >>> edge_index = torch.tensor([[0, 1, 1],\n        ...                            [1, 0, 2]])\n        >>> edge_weight = torch.tensor([1., 1., 1.])\n        >>> to_undirected(edge_index, edge_weight)\n        (tensor([[0, 1, 1, 2],\n                [1, 0, 2, 1]]),\n        tensor([2., 2., 1., 1.]))\n\n        >>> # Use 'mean' operation to merge edge features\n        >>>  to_undirected(edge_index, edge_weight, reduce='mean')\n        (tensor([[0, 1, 1, 2],\n                [1, 0, 2, 1]]),\n        tensor([1., 1., 1., 1.]))\n    \"\"\"\n    # Maintain backward compatibility to `to_undirected(edge_index, num_nodes)`\n    if isinstance(edge_attr, int):\n        num_nodes = edge_attr\n        edge_attr = MISSING\n\n    row, col = edge_index[0], edge_index[1]\n    row, col = torch.cat([row, col], dim=0), torch.cat([col, row], dim=0)\n    edge_index = torch.stack([row, col], dim=0)\n\n    if isinstance(edge_attr, Tensor):\n        edge_attr = torch.cat([edge_attr, edge_attr], dim=0)\n    elif isinstance(edge_attr, (list, tuple)):\n        edge_attr = [torch.cat([e, e], dim=0) for e in edge_attr]\n\n    return coalesce(edge_index, edge_attr, num_nodes, reduce)\n"
  },
  {
    "path": "torch_geometric/visualization/__init__.py",
    "content": "r\"\"\"Visualization package.\"\"\"\n\nfrom .graph import visualize_graph, visualize_hetero_graph\nfrom .influence import influence\n\n__all__ = [\n    'visualize_graph',\n    'visualize_hetero_graph',\n    'influence',\n]\n"
  },
  {
    "path": "torch_geometric/visualization/graph.py",
    "content": "from math import sqrt\nfrom typing import Any, Dict, List, Optional, Set, Tuple\n\nimport torch\nfrom torch import Tensor\n\nBACKENDS = {'graphviz', 'networkx'}\n\n\ndef has_graphviz() -> bool:\n    try:\n        import graphviz\n    except ImportError:\n        return False\n\n    try:\n        graphviz.Digraph().pipe()\n    except graphviz.backend.ExecutableNotFound:\n        return False\n\n    return True\n\n\ndef visualize_graph(\n    edge_index: Tensor,\n    edge_weight: Optional[Tensor] = None,\n    path: Optional[str] = None,\n    backend: Optional[str] = None,\n    node_labels: Optional[List[str]] = None,\n) -> Any:\n    r\"\"\"Visualizes the graph given via :obj:`edge_index` and (optional)\n    :obj:`edge_weight`.\n\n    Args:\n        edge_index (torch.Tensor): The edge indices.\n        edge_weight (torch.Tensor, optional): The edge weights.\n        path (str, optional): The path to where the plot is saved.\n            If set to :obj:`None`, will visualize the plot on-the-fly.\n            (default: :obj:`None`)\n        backend (str, optional): The graph drawing backend to use for\n            visualization (:obj:`\"graphviz\"`, :obj:`\"networkx\"`).\n            If set to :obj:`None`, will use the most appropriate\n            visualization backend based on available system packages.\n            (default: :obj:`None`)\n        node_labels (List[str], optional): The labels/IDs of nodes.\n            (default: :obj:`None`)\n    \"\"\"\n    if edge_weight is not None:  # Normalize edge weights.\n        edge_weight = edge_weight - edge_weight.min()\n        edge_weight = edge_weight / edge_weight.max()\n\n    if edge_weight is not None:  # Discard any edges with zero edge weight:\n        mask = edge_weight > 1e-7\n        edge_index = edge_index[:, mask]\n        edge_weight = edge_weight[mask]\n\n    if edge_weight is None:\n        edge_weight = torch.ones(edge_index.size(1))\n\n    if backend is None:\n        backend = 'graphviz' if has_graphviz() else 'networkx'\n\n    if backend.lower() == 'networkx':\n        return _visualize_graph_via_networkx(edge_index, edge_weight, path,\n                                             node_labels)\n    elif backend.lower() == 'graphviz':\n        return _visualize_graph_via_graphviz(edge_index, edge_weight, path,\n                                             node_labels)\n\n    raise ValueError(f\"Expected graph drawing backend to be in \"\n                     f\"{BACKENDS} (got '{backend}')\")\n\n\ndef _visualize_graph_via_graphviz(\n    edge_index: Tensor,\n    edge_weight: Tensor,\n    path: Optional[str] = None,\n    node_labels: Optional[List[str]] = None,\n) -> Any:\n    import graphviz\n\n    suffix = path.split('.')[-1] if path is not None else None\n    g = graphviz.Digraph('graph', format=suffix)\n    g.attr('node', shape='circle', fontsize='11pt')\n\n    for node in edge_index.view(-1).unique().tolist():\n        g.node(str(node) if node_labels is None else node_labels[node])\n\n    for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()):\n        hex_color = hex(255 - round(255 * w))[2:]\n        hex_color = f'{hex_color}0' if len(hex_color) == 1 else hex_color\n        if node_labels is not None:\n            src = node_labels[src]\n            dst = node_labels[dst]\n        g.edge(str(src), str(dst), color=f'#{hex_color}{hex_color}{hex_color}')\n\n    if path is not None:\n        path = '.'.join(path.split('.')[:-1])\n        g.render(path, cleanup=True)\n    else:\n        g.view()\n\n    return g\n\n\ndef _visualize_graph_via_networkx(\n    edge_index: Tensor,\n    edge_weight: Tensor,\n    path: Optional[str] = None,\n    node_labels: Optional[List[str]] = None,\n) -> Any:\n    import matplotlib.pyplot as plt\n    import networkx as nx\n\n    g = nx.DiGraph()\n    node_size = 800\n\n    for node in edge_index.view(-1).unique().tolist():\n        g.add_node(node if node_labels is None else node_labels[node])\n\n    for (src, dst), w in zip(edge_index.t().tolist(), edge_weight.tolist()):\n        if node_labels is not None:\n            src = node_labels[src]\n            dst = node_labels[dst]\n        g.add_edge(src, dst, alpha=w)\n\n    ax = plt.gca()\n    pos = nx.spring_layout(g)\n    for src, dst, data in g.edges(data=True):\n        ax.annotate(\n            '',\n            xy=pos[src],\n            xytext=pos[dst],\n            arrowprops=dict(\n                arrowstyle=\"<-\",\n                alpha=data['alpha'],\n                shrinkA=sqrt(node_size) / 2.0,\n                shrinkB=sqrt(node_size) / 2.0,\n                connectionstyle=\"arc3,rad=0.1\",\n            ),\n        )\n\n    nx.draw_networkx_nodes(g, pos, node_size=node_size, node_color='white',\n                           margins=0.1, edgecolors='black')\n    nx.draw_networkx_labels(g, pos, font_size=10)\n\n    if path is not None:\n        plt.savefig(path)\n    else:\n        plt.show()\n\n    plt.close()\n\n\ndef visualize_hetero_graph(\n        edge_index_dict: Dict[Tuple[str, str, str], Tensor],\n        edge_weight_dict: Dict[Tuple[str, str, str], Tensor],\n        path: Optional[str] = None,\n        backend: Optional[str] = None,\n        node_labels_dict: Optional[Dict[str, List[str]]] = None,\n        node_weight_dict: Optional[Dict[str, Tensor]] = None,\n        node_size_range: Tuple[float, float] = (50, 500),\n        node_opacity_range: Tuple[float, float] = (1.0, 1.0),\n        edge_width_range: Tuple[float, float] = (0.1, 2.0),\n        edge_opacity_range: Tuple[float, float] = (1.0, 1.0),\n) -> Any:\n    \"\"\"Visualizes a heterogeneous graph using networkx.\"\"\"\n    if backend is not None and backend != \"networkx\":\n        raise ValueError(\"Only 'networkx' backend is supported\")\n\n    # Filter out edges with 0 weight\n    filtered_edge_index_dict = {}\n    filtered_edge_weight_dict = {}\n    for edge_type in edge_index_dict.keys():\n        mask = edge_weight_dict[edge_type] > 0\n        if mask.sum() > 0:\n            filtered_edge_index_dict[edge_type] = edge_index_dict[\n                edge_type][:, mask]\n            filtered_edge_weight_dict[edge_type] = edge_weight_dict[edge_type][\n                mask]\n\n    # Get all unique nodes that are still in the filtered edges\n    remaining_nodes: Dict[str, Set[int]] = {}\n    for edge_type, edge_index in filtered_edge_index_dict.items():\n        src_type, _, dst_type = edge_type\n        if src_type not in remaining_nodes:\n            remaining_nodes[src_type] = set()\n        if dst_type not in remaining_nodes:\n            remaining_nodes[dst_type] = set()\n        remaining_nodes[src_type].update(edge_index[0].tolist())\n        remaining_nodes[dst_type].update(edge_index[1].tolist())\n\n    # Filter node weights to only include remaining nodes\n    if node_weight_dict is not None:\n        filtered_node_weight_dict = {}\n        for node_type, weights in node_weight_dict.items():\n            if node_type in remaining_nodes:\n                mask = torch.zeros(len(weights), dtype=torch.bool)\n                mask[list(remaining_nodes[node_type])] = True\n                filtered_node_weight_dict[node_type] = weights[mask]\n        node_weight_dict = filtered_node_weight_dict\n\n    # Filter node labels to only include remaining nodes\n    if node_labels_dict is not None:\n        filtered_node_labels_dict = {}\n        for node_type, labels in node_labels_dict.items():\n            if node_type in remaining_nodes:\n                filtered_node_labels_dict[node_type] = [\n                    label for i, label in enumerate(labels)\n                    if i in remaining_nodes[node_type]\n                ]\n        node_labels_dict = filtered_node_labels_dict\n\n    return _visualize_hetero_graph_via_networkx(\n        filtered_edge_index_dict,\n        filtered_edge_weight_dict,\n        path,\n        node_labels_dict,\n        node_weight_dict,\n        node_size_range,\n        node_opacity_range,\n        edge_width_range,\n        edge_opacity_range,\n    )\n\n\ndef _visualize_hetero_graph_via_networkx(\n        edge_index_dict: Dict[Tuple[str, str, str], Tensor],\n        edge_weight_dict: Dict[Tuple[str, str, str], Tensor],\n        path: Optional[str] = None,\n        node_labels_dict: Optional[Dict[str, List[str]]] = None,\n        node_weight_dict: Optional[Dict[str, Tensor]] = None,\n        node_size_range: Tuple[float, float] = (50, 500),\n        node_opacity_range: Tuple[float, float] = (1.0, 1.0),\n        edge_width_range: Tuple[float, float] = (0.1, 2.0),\n        edge_opacity_range: Tuple[float, float] = (1.0, 1.0),\n) -> Any:\n    import matplotlib.pyplot as plt\n    import networkx as nx\n\n    g = nx.DiGraph()\n    node_offsets: Dict[str, int] = {}\n    current_offset = 0\n\n    # First, collect all unique node types and their counts\n    node_types = set()\n    node_counts: Dict[str, int] = {}\n    remaining_nodes: Dict[str, Set[int]] = {\n    }  # Track which nodes are actually present in edges\n\n    # Get all unique nodes that are in the edges\n    for edge_type in edge_index_dict.keys():\n        src_type, _, dst_type = edge_type\n        node_types.add(src_type)\n        node_types.add(dst_type)\n\n        if src_type not in remaining_nodes:\n            remaining_nodes[src_type] = set()\n        if dst_type not in remaining_nodes:\n            remaining_nodes[dst_type] = set()\n\n        remaining_nodes[src_type].update(\n            edge_index_dict[edge_type][0].tolist())\n        remaining_nodes[dst_type].update(\n            edge_index_dict[edge_type][1].tolist())\n\n    # Set node counts based on remaining nodes\n    for node_type in node_types:\n        node_counts[node_type] = len(remaining_nodes[node_type])\n\n    # Add nodes for each node type\n    for node_type in node_types:\n        num_nodes = node_counts[node_type]\n        node_offsets[node_type] = current_offset\n\n        # Get node weights if provided\n        weights = None\n        if node_weight_dict is not None and node_type in node_weight_dict:\n            weights = node_weight_dict[node_type]\n            if len(weights) != num_nodes:\n                raise ValueError(f\"Number of weights for node type \"\n                                 f\"{node_type} ({len(weights)}) does not \"\n                                 f\"match number of nodes ({num_nodes})\")\n\n        for i in range(num_nodes):\n            node_id = current_offset + i\n            label = (node_labels_dict[node_type][i]\n                     if node_labels_dict is not None\n                     and node_type in node_labels_dict else \"\")\n\n            # Calculate node size and opacity if weights provided\n            size = node_size_range[1]\n            opacity = node_opacity_range[1]\n            if weights is not None:\n                w = weights[i].item()\n                size = node_size_range[0] + w * \\\n                    (node_size_range[1] - node_size_range[0])\n                opacity = node_opacity_range[0] + w * \\\n                    (node_opacity_range[1] - node_opacity_range[0])\n\n            g.add_node(node_id, label=label, type=node_type, size=size,\n                       alpha=opacity)\n\n        current_offset += num_nodes\n\n    # Add edges with remapped node indices\n    for edge_type, edge_index in edge_index_dict.items():\n        src_type, _, dst_type = edge_type\n        edge_weight = edge_weight_dict[edge_type]\n        src_offset = node_offsets[src_type]\n        dst_offset = node_offsets[dst_type]\n\n        # Create mappings for source and target nodes\n        src_mapping = {\n            old_idx: new_idx\n            for new_idx, old_idx in enumerate(sorted(\n                remaining_nodes[src_type]))\n        }\n        dst_mapping = {\n            old_idx: new_idx\n            for new_idx, old_idx in enumerate(sorted(\n                remaining_nodes[dst_type]))\n        }\n\n        for (src, dst), w in zip(edge_index.t().tolist(),\n                                 edge_weight.tolist()):\n            # Remap node indices\n            new_src = src_mapping[src] + src_offset\n            new_dst = dst_mapping[dst] + dst_offset\n\n            # Calculate edge width and opacity based on weight\n            width = edge_width_range[0] + w * \\\n                (edge_width_range[1] - edge_width_range[0])\n            opacity = edge_opacity_range[0] + w * \\\n                (edge_opacity_range[1] - edge_opacity_range[0])\n            g.add_edge(new_src, new_dst, width=width, alpha=opacity)\n\n    # Draw the graph\n    ax = plt.gca()\n    pos = nx.arf_layout(g)\n\n    # Draw edges with arrows\n    for src, dst, data in g.edges(data=True):\n        ax.annotate(\n            '',\n            xy=pos[src],\n            xytext=pos[dst],\n            arrowprops=dict(\n                arrowstyle=\"<-\",\n                alpha=data['alpha'],\n                linewidth=data['width'],\n                shrinkA=sqrt(g.nodes[src]['size']) / 2.0,\n                shrinkB=sqrt(g.nodes[dst]['size']) / 2.0,\n                connectionstyle=\"arc3,rad=0.1\",\n            ),\n        )\n\n    # Draw nodes colored by type\n    node_colors = []\n    node_sizes = []\n    node_alphas = []\n\n    # Use matplotlib tab20 colormap for consistent coloring\n    tab10_cmap = plt.cm.tab10  # type: ignore[attr-defined]\n    node_type_colors: Dict[str, Any] = {}  # Store color for each node type\n    for node in g.nodes():\n        node_type = g.nodes[node]['type']\n        # Assign a consistent color for each node type\n        if node_type not in node_type_colors:\n            color_idx = len(node_type_colors) % 10  # Cycle through colors\n            node_type_colors[node_type] = tab10_cmap(color_idx)\n        node_colors.append(node_type_colors[node_type])\n        node_sizes.append(g.nodes[node]['size'])\n        node_alphas.append(g.nodes[node]['alpha'])\n\n    nx.draw_networkx_nodes(g, pos, node_size=node_sizes,\n                           node_color=node_colors, margins=0.1,\n                           alpha=node_alphas)\n\n    # Draw labels\n    labels = nx.get_node_attributes(g, 'label')\n    nx.draw_networkx_labels(g, pos, labels, font_size=10)\n\n    # Add legend\n    legend_elements = []\n    for node_type, color in node_type_colors.items():\n        legend_elements.append(\n            plt.Line2D([0], [0], marker='o', color='w', label=node_type,\n                       markerfacecolor=color, markersize=10))\n    ax.legend(handles=legend_elements, loc='upper right',\n              bbox_to_anchor=(0.9, 1))\n\n    if path is not None:\n        plt.savefig(path, bbox_inches='tight')\n    else:\n        plt.show()\n\n    plt.close()\n"
  },
  {
    "path": "torch_geometric/visualization/influence.py",
    "content": "from typing import Any\n\nimport torch\nfrom torch import Tensor\nfrom torch.autograd import grad\n\n\ndef influence(model: torch.nn.Module, src: Tensor, *args: Any) -> Tensor:\n    x = src.clone().requires_grad_()\n    out = model(x, *args).sum(dim=-1)\n\n    influences = []\n    for j in range(src.size(0)):\n        influence = grad([out[j]], [x], retain_graph=True)[0].abs().sum(dim=-1)\n        influences.append(influence / influence.sum())\n\n    return torch.stack(influences, dim=0)\n"
  },
  {
    "path": "torch_geometric/warnings.py",
    "content": "import warnings\nfrom typing import Literal\n\nimport torch_geometric\n\n\ndef warn(message: str, stacklevel: int = 5) -> None:\n    if torch_geometric.is_compiling():\n        return\n\n    warnings.warn(message, stacklevel=stacklevel)\n\n\ndef filterwarnings(\n    action: Literal['default', 'error', 'ignore', 'always', 'module', 'once'],\n    message: str,\n) -> None:\n    if torch_geometric.is_compiling():\n        return\n\n    warnings.filterwarnings(action, message)\n\n\nclass WarningCache(set):\n    \"\"\"Cache for warnings.\"\"\"\n    def warn(self, message: str, stacklevel: int = 5) -> None:\n        \"\"\"Trigger warning message.\"\"\"\n        if message not in self:\n            self.add(message)\n            warn(message, stacklevel=stacklevel)\n"
  }
]